From 744d15ba2467d73a907095f69eff51b5c82c16a2 Mon Sep 17 00:00:00 2001 From: hefanli <76611805+hefanli@users.noreply.github.com> Date: Thu, 4 Dec 2025 18:49:50 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E8=AF=84=E4=BC=B0?= =?UTF-8?q?=E6=97=B6=E6=A8=A1=E5=9E=8B=E8=BE=93=E5=87=BAjson=E6=A0=BC?= =?UTF-8?q?=E5=BC=8F=E4=B8=8D=E5=AF=B9=E5=AF=BC=E8=87=B4=E8=AF=BB=E5=8F=96?= =?UTF-8?q?=E9=94=99=E8=AF=AF=E7=9A=84=E9=97=AE=E9=A2=98=20(#133)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feature: add cot data evaluation function * fix: added verification to evaluation results * fix: fix the prompt for evaluating * fix: 修复当评估结果为空导致读取失败的问题 --- .../DataEvaluation/Create/CreateTask.tsx | 76 ++++++++--- .../Detail/components/EvaluationItems.tsx | 122 ++++++++---------- .../DataEvaluation/Home/DataEvaluation.tsx | 2 +- .../app/db/models/data_evaluation.py | 1 + .../app/module/dataset/service/service.py | 91 ++++++------- .../module/evaluation/interface/evaluation.py | 16 ++- .../module/evaluation/schema/evaluation.py | 1 + .../app/module/evaluation/schema/prompt.py | 76 ++++++++++- .../evaluation/schema/prompt_template.py | 2 +- .../module/evaluation/service/evaluation.py | 51 +++++--- .../generation/service/generation_service.py | 42 +----- .../app/module/shared/util/model_chat.py | 38 ++++++ .../app/module/shared/util/structured_file.py | 72 +++++++---- scripts/db/data-evaluation-init.sql | 2 +- 14 files changed, 373 insertions(+), 219 deletions(-) diff --git a/frontend/src/pages/DataEvaluation/Create/CreateTask.tsx b/frontend/src/pages/DataEvaluation/Create/CreateTask.tsx index 4ad3823..4a52fb1 100644 --- a/frontend/src/pages/DataEvaluation/Create/CreateTask.tsx +++ b/frontend/src/pages/DataEvaluation/Create/CreateTask.tsx @@ -1,3 +1,4 @@ +// TypeScript import React, { useState, useEffect } from 'react'; import { Button, Form, Input, Select, message, Modal, Row, Col, Table, Space } from 'antd'; import { EyeOutlined } from '@ant-design/icons'; @@ -36,6 +37,7 @@ interface CreateTaskModalProps { const TASK_TYPES = [ { label: 'QA评估', value: 'QA' }, + { label: 'COT评估', value: 'COT' }, ]; const EVAL_METHODS = [ @@ -55,7 +57,7 @@ const CreateTaskModal: React.FC = ({ visible, onCancel, on dimension: '', description: '' }); - const [taskType, setTaskType] = useState("QA"); + const [taskType, setTaskType] = useState(DEFAULT_TASK_TYPE); const [promptTemplates, setPromptTemplates] = useState([]); const [previewVisible, setPreviewVisible] = useState(false); const [evaluationPrompt, setEvaluationPrompt] = useState(''); @@ -82,9 +84,24 @@ const CreateTaskModal: React.FC = ({ visible, onCancel, on fetchDatasets().then(); fetchModels().then(); fetchPromptTemplates().then(); + // sync form with local taskType default + form.setFieldsValue({ taskType: DEFAULT_TASK_TYPE }); } }, [visible]); + // when promptTemplates or taskType change, switch dimensions to template defaults (COT/QA) + useEffect(() => { + if (!promptTemplates || promptTemplates.length === 0) return; + const template = promptTemplates.find(t => t.evalType === taskType); + if (template && template.defaultDimensions) { + setDimensions(template.defaultDimensions.map((dim: any, index: number) => ({ + key: `dim-${index}`, + dimension: dim.dimension, + description: dim.description + }))); + } + }, [taskType, promptTemplates]); + const fetchDatasets = async () => { try { const { data } = await queryDatasetsUsingGet({ page: 1, size: 1000 }); @@ -106,31 +123,46 @@ const CreateTaskModal: React.FC = ({ visible, onCancel, on }; const formatDimensionsForPrompt = (dimensions: Dimension[]) => { - let result = "\n"; + let result = ""; dimensions.forEach((dim, index) => { - result += `### ${index + 1}. ${dim.dimension}\n**评估标准:**\n${dim.description}\n\n`; + if (index > 0) { + result += "\n"; + } + result += `### ${index + 1}. ${dim.dimension}\n**评估标准:**\n${dim.description}`; + if (index < dimensions.length - 1) { + result += "\n"; + } }); return result; }; const formatResultExample = (dimensions: Dimension[]) => { - return dimensions.map(dim => `\n "${dim.dimension}": "Y",`).join(''); + let result = ""; + dimensions.forEach((dim, index) => { + if (index > 0) { + result += "\n "; + } + result += `"${dim.dimension}": "Y"`; + if (index < dimensions.length - 1) { + result += ","; + } + }); + return result; }; const fetchPromptTemplates = async () => { try { const response = await queryPromptTemplatesUsingGet(); - const templates: PromptTemplate[] = response.data?.templates - setPromptTemplates(templates) - if (taskType) { - const template = templates.find(t => t.evalType === taskType); - if (template) { - setDimensions(template.defaultDimensions.map((dim: any, index: number) => ({ - key: `dim-${index}`, - dimension: dim.dimension, - description: dim.description - }))); - } + const templates: PromptTemplate[] = response.data?.templates || []; + setPromptTemplates(templates); + // if a template exists for current taskType, initialize dimensions (handled also by useEffect) + const template = templates.find(t => t.evalType === taskType); + if (template) { + setDimensions(template.defaultDimensions.map((dim: any, index: number) => ({ + key: `dim-${index}`, + dimension: dim.dimension, + description: dim.description + }))); } } catch (error) { console.error('Error fetching prompt templates:', error); @@ -144,8 +176,11 @@ const CreateTaskModal: React.FC = ({ visible, onCancel, on return; } const template = promptTemplates.find(t => t.evalType === taskType); - setEvaluationPrompt(template?.prompt.replace("{dimensions}", formatDimensionsForPrompt(dimensions)) - .replace('{result_example}', formatResultExample(dimensions))); + const basePrompt = template?.prompt || ''; + const filled = basePrompt + .replace('{dimensions}', formatDimensionsForPrompt(dimensions)) + .replace('{result_example}', formatResultExample(dimensions)); + setEvaluationPrompt(filled); setPreviewVisible(true); }; @@ -243,6 +278,13 @@ const CreateTaskModal: React.FC = ({ visible, onCancel, on evalMethod: DEFAULT_EVAL_METHOD, taskType: DEFAULT_TASK_TYPE, }} + onValuesChange={(changed) => { + if (changed.taskType) { + setTaskType(changed.taskType); + setEvaluationPrompt(''); + setPreviewVisible(false); + } + }} > diff --git a/frontend/src/pages/DataEvaluation/Detail/components/EvaluationItems.tsx b/frontend/src/pages/DataEvaluation/Detail/components/EvaluationItems.tsx index dd32a91..ad00af0 100644 --- a/frontend/src/pages/DataEvaluation/Detail/components/EvaluationItems.tsx +++ b/frontend/src/pages/DataEvaluation/Detail/components/EvaluationItems.tsx @@ -1,7 +1,8 @@ import { useEffect, useState } from 'react'; -import { Table, Typography, Button, Space, Spin, Empty, message, Tooltip } from 'antd'; +import { Table, Typography, Button, Space, Empty, Tooltip } from 'antd'; import { FolderOpen, FileText, ArrowLeft } from 'lucide-react'; import { queryEvaluationFilesUsingGet, queryEvaluationItemsUsingGet } from '../../evaluation.api'; +import useFetchData from '@/hooks/useFetchData'; const { Text } = Typography; @@ -39,63 +40,52 @@ type EvalItem = { }; export default function EvaluationItems({ task }: { task: any }) { - const [loadingFiles, setLoadingFiles] = useState(false); - const [files, setFiles] = useState([]); - const [filePagination, setFilePagination] = useState({ current: 1, pageSize: 10, total: 0 }); - const [selectedFile, setSelectedFile] = useState<{ fileId: string; fileName: string } | null>(null); - const [loadingItems, setLoadingItems] = useState(false); - const [items, setItems] = useState([]); - const [itemPagination, setItemPagination] = useState({ current: 1, pageSize: 10, total: 0 }); - // Fetch files list - useEffect(() => { - if (!task?.id || selectedFile) return; - const fetchFiles = async () => { - setLoadingFiles(true); - try { - const res = await queryEvaluationFilesUsingGet({ taskId: task.id, page: filePagination.current, size: filePagination.pageSize }); - const data = res?.data; - const list: EvalFile[] = data?.content || []; - setFiles(list); - setFilePagination((p) => ({ ...p, total: data?.totalElements || 0 })); - } catch (e) { - message.error('加载评估文件失败'); - console.error(e); - } finally { - setLoadingFiles(false); - } - }; - fetchFiles(); - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [task?.id, filePagination.current, filePagination.pageSize, selectedFile]); + // 文件列表数据(使用 useFetchData),pageOffset=0 表示后端分页为 1 基 + const { + loading: loadingFiles, + tableData: files, + pagination: filePagination, + setSearchParams: setFileSearchParams, + } = useFetchData( + (params) => queryEvaluationFilesUsingGet({ taskId: task?.id, ...params }), + (d) => d as unknown as EvalFile, + 30000, + false, + [], + 0 + ); - // Fetch items of selected file - useEffect(() => { - if (!task?.id || !selectedFile) return; - const fetchItems = async () => { - setLoadingItems(true); - try { - const res = await queryEvaluationItemsUsingGet({ - taskId: task.id, - page: itemPagination.current, - size: itemPagination.pageSize, - file_id: selectedFile.fileId, - }); - const data = res?.data; - const list: EvalItem[] = data?.content || []; - setItems(list); - setItemPagination((p) => ({ ...p, total: data?.totalElements || 0 })); - } catch (e) { - message.error('加载评估条目失败'); - console.error(e); - } finally { - setLoadingItems(false); + // 评估条目数据(使用 useFetchData),依赖选中文件 + const { + loading: loadingItems, + tableData: items, + pagination: itemPagination, + setSearchParams: setItemSearchParams, + fetchData: fetchItems, + } = useFetchData( + (params) => { + if (!task?.id || !selectedFile?.fileId) { + return Promise.resolve({ data: { content: [], totalElements: 0 } }); } - }; - fetchItems(); - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [task?.id, selectedFile?.fileId, itemPagination.current, itemPagination.pageSize]); + return queryEvaluationItemsUsingGet({ taskId: task.id, file_id: selectedFile.fileId, ...params }); + }, + (d) => d as unknown as EvalItem, + 30000, + false, + [], + 0 + ); + + // 当选择文件变化时,主动触发一次条目查询,避免仅依赖 searchParams 变更导致未触发 + useEffect(() => { + if (task?.id && selectedFile?.fileId) { + setItemSearchParams((prev: any) => ({ ...prev, current: 1 })); + // 立即拉取一次,保证点击后立刻出现数据 + fetchItems(); + } + }, [task?.id, selectedFile?.fileId]); const fileColumns = [ { @@ -228,19 +218,20 @@ export default function EvaluationItems({ task }: { task: any }) { dataSource={files} loading={loadingFiles} size="middle" - onRow={(record) => ({ onClick: () => setSelectedFile({ fileId: record.fileId, fileName: record.fileName }) })} - pagination={{ - current: filePagination.current, - pageSize: filePagination.pageSize, - total: filePagination.total, - onChange: (current, pageSize) => setFilePagination({ current, pageSize, total: filePagination.total }), - }} + onRow={(record) => ({ + onClick: () => { + setSelectedFile({ fileId: record.fileId, fileName: record.fileName }); + // 切换文件时,重置条目表到第一页 + setItemSearchParams((prev: any) => ({ ...prev, current: 1 })); + }, + })} + pagination={filePagination} /> ) : (
- @@ -257,12 +248,7 @@ export default function EvaluationItems({ task }: { task: any }) { dataSource={items} loading={loadingItems} size="middle" - pagination={{ - current: itemPagination.current, - pageSize: itemPagination.pageSize, - total: itemPagination.total, - onChange: (current, pageSize) => setItemPagination({ current, pageSize, total: itemPagination.total }), - }} + pagination={itemPagination} />
)} diff --git a/frontend/src/pages/DataEvaluation/Home/DataEvaluation.tsx b/frontend/src/pages/DataEvaluation/Home/DataEvaluation.tsx index 6e99f4f..7327608 100644 --- a/frontend/src/pages/DataEvaluation/Home/DataEvaluation.tsx +++ b/frontend/src/pages/DataEvaluation/Home/DataEvaluation.tsx @@ -82,6 +82,7 @@ export default function DataEvaluationPage() { label: '任务类型', options: [ { value: 'QA', label: 'QA评估' }, + { value: 'COT', label: 'COPT评估' }, ], }, { @@ -89,7 +90,6 @@ export default function DataEvaluationPage() { label: '评估方式', options: [ { value: 'AUTO', label: '自动评估' }, - { value: 'MANUAL', label: '人工评估' }, ], }, ]; diff --git a/runtime/datamate-python/app/db/models/data_evaluation.py b/runtime/datamate-python/app/db/models/data_evaluation.py index 5234064..d765187 100644 --- a/runtime/datamate-python/app/db/models/data_evaluation.py +++ b/runtime/datamate-python/app/db/models/data_evaluation.py @@ -32,6 +32,7 @@ class EvaluationTask(Base): source_id = Column(String(36), nullable=True, comment="待评估对象ID") source_name = Column(String(255), nullable=True, comment="待评估对象名称") status = Column(String(50), server_default="PENDING", nullable=False, comment="状态:PENDING/RUNNING/COMPLETED/STOPPED/FAILED") + eval_method = Column(String(50), server_default="AUTO", nullable=False, comment="评估方式:AUTO/MANUAL") eval_process = Column(Float, nullable=False, server_default="0", comment="评估进度") eval_prompt = Column(Text, nullable=True, comment="评估提示词") eval_config = Column(Text, nullable=True, comment="评估配置") diff --git a/runtime/datamate-python/app/module/dataset/service/service.py b/runtime/datamate-python/app/module/dataset/service/service.py index aff2f10..41e2c31 100644 --- a/runtime/datamate-python/app/module/dataset/service/service.py +++ b/runtime/datamate-python/app/module/dataset/service/service.py @@ -1,3 +1,4 @@ +import math from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy import func @@ -14,11 +15,11 @@ logger = get_logger(__name__) class Service: """数据管理服务客户端 - 直接访问数据库""" - + def __init__(self, db: AsyncSession): """ 初始化 DM 客户端 - + Args: db: 数据库会话 """ @@ -29,16 +30,16 @@ class Service: """获取数据集详情""" try: logger.debug(f"Getting dataset detail: {dataset_id} ...") - + result = await self.db.execute( select(Dataset).where(Dataset.id == dataset_id) ) dataset = result.scalar_one_or_none() - + if not dataset: logger.error(f"Dataset not found: {dataset_id}") return None - + # 将数据库模型转换为响应模型 # type: ignore 用于忽略 SQLAlchemy 的类型检查问题 return DatasetResponse( @@ -56,11 +57,11 @@ class Service: except Exception as e: logger.error(f"Failed to get dataset {dataset_id}: {e}") return None - + async def get_dataset_files( - self, - dataset_id: str, - page: int = 0, + self, + dataset_id: str, + page: int = 0, size: int = 100, file_type: Optional[str] = None, status: Optional[str] = None @@ -68,16 +69,16 @@ class Service: """获取数据集文件列表""" try: logger.debug(f"Get dataset files: dataset={dataset_id}, page={page}, size={size}") - + # 构建查询 query = select(DatasetFiles).where(DatasetFiles.dataset_id == dataset_id) - + # 添加可选过滤条件 if file_type: query = query.where(DatasetFiles.file_type == file_type) if status: query = query.where(DatasetFiles.status == status) - + # 获取总数 count_query = select(func.count()).select_from(DatasetFiles).where( DatasetFiles.dataset_id == dataset_id @@ -86,15 +87,15 @@ class Service: count_query = count_query.where(DatasetFiles.file_type == file_type) if status: count_query = count_query.where(DatasetFiles.status == status) - + count_result = await self.db.execute(count_query) total = count_result.scalar_one() - + # 分页查询 query = query.offset(page * size).limit(size).order_by(DatasetFiles.created_at.desc()) result = await self.db.execute(query) files = result.scalars().all() - + # 转换为响应模型 # type: ignore 用于忽略 SQLAlchemy 的类型检查问题 content = [ @@ -115,9 +116,9 @@ class Service: ) for f in files ] - - total_pages = (total + size - 1) // size if size > 0 else 0 - + + total_pages = math.ceil(total / size) if total > 0 else 0 + return PagedDatasetFileResponse( content=content, totalElements=total, @@ -128,7 +129,7 @@ class Service: except Exception as e: logger.error(f"Failed to get dataset files for {dataset_id}: {e}") return None - + async def download_file(self, dataset_id: str, file_id: str) -> Optional[bytes]: """ 下载文件内容 @@ -136,7 +137,7 @@ class Service: """ logger.warning(f"download_file is deprecated when using database mode. Use get_file_download_url instead.") return None - + async def get_file_download_url(self, dataset_id: str, file_id: str) -> Optional[str]: """获取文件下载URL(或文件路径)""" try: @@ -147,60 +148,60 @@ class Service: ) ) file = result.scalar_one_or_none() - + if not file: logger.error(f"File not found: {file_id} in dataset {dataset_id}") return None - + # 返回文件路径(可以是本地路径或对象存储URL) return file.file_path # type: ignore except Exception as e: logger.error(f"Failed to get file path for {file_id}: {e}") return None - + async def close(self): """关闭客户端连接(数据库模式下无需操作)""" logger.info("DM service client closed (Database mode)") - + async def update_file_tags_partial( - self, - file_id: str, + self, + file_id: str, new_tags: List[Dict[str, Any]], template_id: Optional[str] = None ) -> tuple[bool, Optional[str], Optional[datetime]]: """ 部分更新文件标签,支持自动格式转换 - + 如果提供了 template_id,会自动将简化格式的标签转换为完整格式。 简化格式: {"from_name": "x", "to_name": "y", "values": [...]} 完整格式: {"id": "...", "from_name": "x", "to_name": "y", "type": "...", "value": {"type": [...]}} - + Args: file_id: 文件ID new_tags: 新的标签列表(部分更新),可以是简化格式或完整格式 template_id: 可选的模板ID,用于格式转换 - + Returns: (成功标志, 错误信息, 更新时间) """ try: logger.info(f"Partial updating tags for file: {file_id}") - + # 获取文件记录 result = await self.db.execute( select(DatasetFiles).where(DatasetFiles.id == file_id) ) file_record = result.scalar_one_or_none() - + if not file_record: logger.error(f"File not found: {file_id}") return False, f"File not found: {file_id}", None - + # 如果提供了 template_id,尝试进行格式转换 processed_tags = new_tags if template_id: logger.debug(f"Converting tags using template: {template_id}") - + try: # 获取模板配置 from app.db.models import AnnotationTemplate @@ -211,29 +212,29 @@ class Service: ) ) template = template_result.scalar_one_or_none() - + if not template: logger.warning(f"Template {template_id} not found, skipping conversion") else: # 使用 converter 转换标签格式 from app.module.annotation.utils import create_converter_from_template_config - + converter = create_converter_from_template_config(template.configuration) # type: ignore processed_tags = converter.convert_if_needed(new_tags) - + logger.info(f"Converted {len(new_tags)} tags to full format") - + except Exception as e: logger.error(f"Failed to convert tags using template: {e}") # 继续使用原始标签格式 logger.warning("Continuing with original tag format") - + # 获取现有标签 existing_tags: List[Dict[str, Any]] = file_record.tags or [] # type: ignore - + # 创建标签ID到索引的映射 tag_id_map = {tag.get('id'): idx for idx, tag in enumerate(existing_tags) if tag.get('id')} - + # 更新或追加标签 for new_tag in processed_tags: tag_id = new_tag.get('id') @@ -246,19 +247,19 @@ class Service: # 追加新标签 existing_tags.append(new_tag) logger.debug(f"Added new tag with id: {tag_id}") - + # 更新数据库 update_time = datetime.utcnow() file_record.tags = existing_tags # type: ignore file_record.tags_updated_at = update_time # type: ignore - + await self.db.commit() await self.db.refresh(file_record) - + logger.info(f"Successfully updated tags for file: {file_id}") return True, None, update_time - + except Exception as e: logger.error(f"Failed to update tags for file {file_id}: {e}") await self.db.rollback() - return False, str(e), None \ No newline at end of file + return False, str(e), None diff --git a/runtime/datamate-python/app/module/evaluation/interface/evaluation.py b/runtime/datamate-python/app/module/evaluation/interface/evaluation.py index 9b62bca..1de58c9 100644 --- a/runtime/datamate-python/app/module/evaluation/interface/evaluation.py +++ b/runtime/datamate-python/app/module/evaluation/interface/evaluation.py @@ -1,5 +1,6 @@ import asyncio import uuid +import math import json from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks @@ -171,7 +172,7 @@ async def list_evaluation_tasks( # 转换为响应模型 items = [_map_to_task_detail_response(task) for task in tasks] - total_pages = (total + size - 1) // size if size > 0 else 0 + total_pages = math.ceil(total / size) if total > 0 else 0 return StandardResponse( code=200, @@ -217,7 +218,7 @@ async def list_evaluation_items( count_query = select(func.count()).select_from(query.subquery()) total = (await db.execute(count_query)).scalar_one() files = (await db.execute(query.offset(offset).limit(size))).scalars().all() - total_pages = (total + size - 1) // size if size > 0 else 0 + total_pages = math.ceil(total / size) if total > 0 else 0 file_responses = [ EvaluationFileResponse( taskId=file.task_id, @@ -298,7 +299,7 @@ async def list_evaluation_items( taskId=item.task_id, itemId=item.item_id, fileId=item.file_id, - evalContent=json.loads(item.eval_content), + evalContent=json.loads(item.eval_content) if item.eval_content else None, evalScore=float(item.eval_score) if item.eval_score else None, evalResult=json.loads(item.eval_result), status=item.status @@ -306,7 +307,7 @@ async def list_evaluation_items( for item in items ] - total_pages = (total + size - 1) // size if size > 0 else 0 + total_pages = math.ceil(total / size) if total > 0 else 0 return StandardResponse( code=200, @@ -387,6 +388,12 @@ async def delete_eval_tasks( .where(EvaluationItem.task_id == task_id) ) + # 删除评估文件 + await db.execute( + EvaluationFile.__table__.delete() + .where(EvaluationFile.task_id == task_id) + ) + # 删除任务 await db.delete(task) await db.commit() @@ -419,6 +426,7 @@ def _map_to_task_detail_response( sourceId=task.source_id, sourceName=task.source_name, status=task.status, + evalMethod=task.eval_method, evalProcess=task.eval_process, evalPrompt=task.eval_prompt, evalConfig=json.loads(task.eval_config), diff --git a/runtime/datamate-python/app/module/evaluation/schema/evaluation.py b/runtime/datamate-python/app/module/evaluation/schema/evaluation.py index ffdd41b..e40ab4c 100644 --- a/runtime/datamate-python/app/module/evaluation/schema/evaluation.py +++ b/runtime/datamate-python/app/module/evaluation/schema/evaluation.py @@ -36,6 +36,7 @@ class EvaluationTaskItem(BaseModel): source_id: Optional[str] = Field(..., alias="sourceId", description="数据源ID") source_name: Optional[str] = Field(None, alias="sourceName", description="数据源名称") status: TaskStatus = Field(..., description="任务状态") + eval_method: Optional[str] = Field(None, alias="evalMethod", description="评估方式") eval_process: Optional[float] = Field(0, alias="evalProcess", description="评估进度") created_at: Optional[str] = Field(None, alias="createdAt", description="创建时间") updated_at: Optional[str] = Field(None, alias="updatedAt", description="更新时间") diff --git a/runtime/datamate-python/app/module/evaluation/schema/prompt.py b/runtime/datamate-python/app/module/evaluation/schema/prompt.py index b151d71..8557ac3 100644 --- a/runtime/datamate-python/app/module/evaluation/schema/prompt.py +++ b/runtime/datamate-python/app/module/evaluation/schema/prompt.py @@ -1,3 +1,7 @@ +from app.core.logging import get_logger + +logger = get_logger(__name__) + EVALUATION_PROMPT_TEMPLATE = [ { "evalType": "QA", @@ -51,26 +55,90 @@ EVALUATION_PROMPT_TEMPLATE = [ 请按照以下JSON格式输出评估结果,评估结果为Y/N,符合标注输出Y,不符合标准输出N: { - "result": {{result_example} + "result": { + {result_example} }, "evaluation": "这是一个高质量的问答数据集。问题表述清晰具体,答案准确完整且逻辑性强,与原始文本高度相关。建议:可以进一步丰富答案的细节描述。" } +""" + }, + { + "evalType": "COT", + "defaultDimensions": [ + { + "dimension": "思维链逻辑是否连贯", + "description": "分析思维链中推理链条的连续性:步骤间有明确的逻辑连接词;每一步都是基于前置在步骤的结果;没有逻辑跳跃或断层;推理方向一致,不偏离目标。" + }, + { + "dimension": "推理步骤是否合理必要", + "description": "分析思维链中对于步骤分解的合理性和必要性:复杂问题被适当分解; 每个步骤都是解决整体问题的必要部分;步骤粒度适中(既不过细也不过粗);符合人类认知习惯。" + }, + { + "dimension": "内容是否准确", + "description": "分析整个COT数据内容是否准确:所有陈述的事实必须准确;展示每一步的计算结果(如何涉及数学计算,必须保证数学计算无错误);逻辑推导有效且合理,最终答案与推理过程一致。" + } + ], + "prompt": """ +# Role: COT数据质量评估专家 +## Profile: +- Description: 你是一名专业的Chain-of-Thought(CoT)推理数据质量评估专家,擅长从多个维度对COT数据进行质量评估,挑选出有助于模型学习如何分解问题、展示推理链条,提高模型对于复杂问题解决能力的COT数据。具备深度学习、自然语言处理和数据科学的专业背景。 + +## Skills: +1. 能够从多个维度对COT数据进行综合评估,保证客观、专业、细致 +2. 擅长识别COT数据中的潜在问题,如推包含事实性错误(关键信息错误),存在严重逻辑矛(无法自洽),包含有害、偏见或不当内容,完全偏离主题,抄袭或高度重复内容等 +3. 能够给出具体的改进建议和质量评分,并提供可操作的优化方案 + +## 评估维度: +{dimensions} + +## 问题或指令: +{question} + +## 思维链: +{chain_of_thought} + +## 结论: +{conclusion} + +## 注意事项: +- 评估结论要具体指出优点和不足,提供可操作的改进建议 +- 评估结论控制在150字以内,简洁明了但要涵盖关键信息 + +## 输出要求: +请按照以下JSON格式输出评估结果,评估结果为Y/N,符合标注输出Y,不符合标准输出N;将评估结论写到evaluation中: + +{ + "result": { + {result_example} + }, + "evaluation": "这是一个高质量的COT数据。思维链逻辑连贯,推理步骤合理,信息完整。建议:部分表达可以进一步优化,以及个别步骤的过渡可以更加平滑。" +} """ } ] def get_dimensions_for_qa(dimensions: list[dict]) -> str: - dimensions_str = "\n" + dimensions_str = "" index = 1 for dimension in dimensions: - dimensions_str += f"### {index}. {dimension.get("dimension")}\n**评估标准:**\n{dimension.get("description")}\n\n" + if index > 1: + dimensions_str += "\n" + dimensions_str += f"### {index}. {dimension.get("dimension")}\n**评估标准:**\n{dimension.get("description")}" + if index < len(dimensions): + dimensions_str += "\n" index += 1 return dimensions_str def get_result_example_for_qa(dimensions: list[dict]) -> str: result_example = "" + index = 1 for dimension in dimensions: - result_example += f'\n "{dimension.get("dimension")}": "Y",' + if index > 1: + result_example += "\n " + result_example += f'"{dimension.get("dimension")}": "Y"' + if index < len(dimensions): + result_example += "," + index += 1 return result_example def get_prompt(task_type: str, dimensions: list[dict]) -> str: diff --git a/runtime/datamate-python/app/module/evaluation/schema/prompt_template.py b/runtime/datamate-python/app/module/evaluation/schema/prompt_template.py index e25166a..66c7a5b 100644 --- a/runtime/datamate-python/app/module/evaluation/schema/prompt_template.py +++ b/runtime/datamate-python/app/module/evaluation/schema/prompt_template.py @@ -1,7 +1,7 @@ """ Schema for evaluation prompt templates. """ -from typing import List, Dict, Any +from typing import List from pydantic import BaseModel, Field diff --git a/runtime/datamate-python/app/module/evaluation/service/evaluation.py b/runtime/datamate-python/app/module/evaluation/service/evaluation.py index d187ae7..7af279f 100644 --- a/runtime/datamate-python/app/module/evaluation/service/evaluation.py +++ b/runtime/datamate-python/app/module/evaluation/service/evaluation.py @@ -2,7 +2,7 @@ import json import uuid import asyncio -from sqlalchemy import select +from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession from app.core.exception import BusinessErrorCodeEnum, BusinessException @@ -13,7 +13,7 @@ from app.db.models.data_synthesis import DataSynthesisFileInstance, SynthesisDat from app.db.session import AsyncSessionLocal from app.module.evaluation.schema.evaluation import SourceType from app.module.shared.schema import TaskStatus -from app.module.shared.util.model_chat import call_openai_style_model +from app.module.shared.util.model_chat import call_openai_style_model, _extract_json_substring from app.module.evaluation.schema.prompt import get_prompt from app.module.shared.util.structured_file import StructuredFileHandlerFactory from app.module.system.service.common_service import get_model_by_id @@ -35,6 +35,10 @@ class EvaluationExecutor: prompt_text = ((prompt_text.replace("{content}", eval_content.get("input")) .replace("{question}", eval_content.get("instruction"))) .replace("{answer}", eval_content.get("output"))) + if self.task.task_type == "COT": + prompt_text = ((prompt_text.replace("{question}", eval_content.get("question")) + .replace("{conclusion}", eval_content.get("conclusion"))) + .replace("{chain_of_thought}", eval_content.get("chain_of_thought"))) return prompt_text async def execute(self): @@ -44,29 +48,44 @@ class EvaluationExecutor: files = (await self.db.execute( select(EvaluationFile).where(EvaluationFile.task_id == self.task.id) )).scalars().all() + query = select(EvaluationItem).where(EvaluationItem.task_id == self.task.id) + count_query = select(func.count()).select_from(query.subquery()) + total = (await self.db.execute(count_query)).scalar_one() + evaluated_count = 0 for file in files: - items = (await self.db.execute( - select(EvaluationItem).where(EvaluationItem.task_id == self.task.id) - .where(EvaluationItem.file_id == file.file_id) - )).scalars().all() + items = (await self.db.execute(query.where(EvaluationItem.file_id == file.file_id))).scalars().all() tasks = [ self.evaluate_item(model_config, item, semaphore) for item in items ] await asyncio.gather(*tasks, return_exceptions=True) file.evaluated_count = len(items) + evaluated_count += file.evaluated_count + self.task.eval_process = evaluated_count / total await self.db.commit() async def evaluate_item(self, model_config, item: EvaluationItem, semaphore: asyncio.Semaphore): async with semaphore: - prompt_text = self.get_eval_prompt(item) - resp_text = await asyncio.to_thread( - call_openai_style_model, model_config.base_url, model_config.api_key, model_config.model_name, - prompt_text, - ) - item.eval_result = resp_text - item.status = TaskStatus.COMPLETED.value - await self.db.commit() + max_try = 3 + while max_try > 0: + prompt_text = self.get_eval_prompt(item) + resp_text = await asyncio.to_thread( + call_openai_style_model, model_config.base_url, model_config.api_key, model_config.model_name, + prompt_text, + ) + resp_text = _extract_json_substring(resp_text) + try: + json.loads(resp_text) + except Exception as e: + logger.error( + f"Failed to parse LLM answer as JSON for task={self.task.id}, file={item.file_id}: {e}. Raw answer: {resp_text!r}" + ) + max_try -= 1 + continue + item.eval_result = resp_text + item.status = TaskStatus.COMPLETED.value + await self.db.commit() + return def get_source_type(self) -> SourceType: @@ -119,7 +138,7 @@ class SynthesisEvaluationExecutor(EvaluationExecutor): async def save_eval_items(self): synthesis_files = ((await self.db.execute(select(DataSynthesisFileInstance) - .where(DataSynthesisFileInstance.task_id == self.task.source_id))) + .where(DataSynthesisFileInstance.synthesis_instance_id == self.task.source_id))) .scalars().all()) for synthesis_file in synthesis_files: synthesis_datas = ((await self.db.execute(select(SynthesisData) @@ -132,7 +151,7 @@ class SynthesisEvaluationExecutor(EvaluationExecutor): task_id=self.task.id, file_id=synthesis_file.id, item_id=synthesis_data.id, - eval_content=synthesis_data.data, + eval_content=json.dumps(synthesis_data.data), status=TaskStatus.PENDING.value, created_by=self.task.created_by, updated_by=self.task.updated_by, diff --git a/runtime/datamate-python/app/module/generation/service/generation_service.py b/runtime/datamate-python/app/module/generation/service/generation_service.py index 123ec3f..1c9996b 100644 --- a/runtime/datamate-python/app/module/generation/service/generation_service.py +++ b/runtime/datamate-python/app/module/generation/service/generation_service.py @@ -28,6 +28,7 @@ from app.db.models.data_synthesis import ( from app.db.models.dataset_management import DatasetFiles from app.db.models.model_config import get_model_by_id from app.db.session import logger +from app.module.shared.util.model_chat import _extract_json_substring from app.module.system.service.common_service import get_chat_client, chat @@ -365,7 +366,7 @@ class GenerationService: return # 1. 预处理原始回答:尝试从中截取出最可能的 JSON 片段 - cleaned = self._extract_json_substring(raw_answer) + cleaned = _extract_json_substring(raw_answer) # 2. 解析 JSON,统一成列表结构 try: @@ -426,45 +427,6 @@ class GenerationService: await self.db.commit() await self.db.refresh(file_instance) - @staticmethod - def _extract_json_substring(raw: str) -> str: - """从 LLM 的原始回答中提取最可能的 JSON 字符串片段。 - - 处理思路: - - 原始回答可能是:说明文字 + JSON + 说明文字,甚至带有 Markdown 代码块。 - - 优先在文本中查找第一个 '{' 或 '[' 作为 JSON 起始; - - 再从后向前找最后一个 '}' 或 ']' 作为结束; - - 如果找不到合适的边界,就退回原始字符串。 - 该方法不会保证截取的一定是合法 JSON,但能显著提高 json.loads 的成功率。 - """ - if not raw: - return raw - - start = None - end = None - - # 查找第一个 JSON 起始符号 - for i, ch in enumerate(raw): - if ch in "[{": - start = i - break - - # 查找最后一个 JSON 结束符号 - for i in range(len(raw) - 1, -1, -1): - if raw[i] in "]}": - end = i + 1 # 切片是左闭右开 - break - - if start is not None and end is not None and start < end: - return raw[start:end].strip() - - # 兜底:去掉常见 Markdown 包裹(```json ... ```) - stripped = raw.strip() - if stripped.startswith("```"): - # 去掉首尾 ``` 标记 - stripped = stripped.strip("`") - return stripped - async def _get_or_create_file_instance( self, synthesis_task_id: str, diff --git a/runtime/datamate-python/app/module/shared/util/model_chat.py b/runtime/datamate-python/app/module/shared/util/model_chat.py index 7b8fcbc..29161de 100644 --- a/runtime/datamate-python/app/module/shared/util/model_chat.py +++ b/runtime/datamate-python/app/module/shared/util/model_chat.py @@ -13,3 +13,41 @@ def call_openai_style_model(base_url, api_key, model_name, prompt, **kwargs): **kwargs ) return response.choices[0].message.content + +def _extract_json_substring(raw: str) -> str: + """从 LLM 的原始回答中提取最可能的 JSON 字符串片段。 + + 处理思路: + - 原始回答可能是:说明文字 + JSON + 说明文字,甚至带有 Markdown 代码块。 + - 优先在文本中查找第一个 '{' 或 '[' 作为 JSON 起始; + - 再从后向前找最后一个 '}' 或 ']' 作为结束; + - 如果找不到合适的边界,就退回原始字符串。 + 该方法不会保证截取的一定是合法 JSON,但能显著提高 json.loads 的成功率。 + """ + if not raw: + return raw + + start = None + end = None + + # 查找第一个 JSON 起始符号 + for i, ch in enumerate(raw): + if ch in "[{": + start = i + break + + # 查找最后一个 JSON 结束符号 + for i in range(len(raw) - 1, -1, -1): + if raw[i] in "]}": + end = i + 1 # 切片是左闭右开 + break + + if start is not None and end is not None and start < end: + return raw[start:end].strip() + + # 兜底:去掉常见 Markdown 包裹(```json ... ```) + stripped = raw.strip() + if stripped.startswith("```"): + # 去掉首尾 ``` 标记 + stripped = stripped.strip("`") + return stripped diff --git a/runtime/datamate-python/app/module/shared/util/structured_file.py b/runtime/datamate-python/app/module/shared/util/structured_file.py index dc990b5..040c2ba 100644 --- a/runtime/datamate-python/app/module/shared/util/structured_file.py +++ b/runtime/datamate-python/app/module/shared/util/structured_file.py @@ -5,6 +5,7 @@ from jsonschema import validate class ItemTypes(Enum): QA = "QA" + COT = "COT" class StructuredFileItemHandler: @@ -14,11 +15,26 @@ class StructuredFileItemHandler: def get_item_type(self) -> ItemTypes: pass - def get_items_from_file(self, file_path: str) -> list[dict]: + def validate_json(self, data): pass - def check_file(self) -> bool: - pass + def get_items_from_file(self, file_path: str) -> list[dict]: + file_type = file_path.split(".")[-1].upper() + items = [] + if file_type == "JSON": + with open(file_path, "r", encoding="utf-8") as f: + data = json.load(f) + if not self.validate_json(data): + return items + items = data + elif file_type == "JSONL": + with open(file_path, "r", encoding="utf-8") as f: + for line in f: + data = json.loads(line) + if not self.validate_json(data): + continue + items.append(data) + return items class QAItemHandler(StructuredFileItemHandler): def __init__(self): @@ -51,32 +67,44 @@ class QAItemHandler(StructuredFileItemHandler): except Exception as e: return False - def get_items_from_file(self, file_path: str) -> list[dict]: - file_type = file_path.split(".")[-1].upper() - items = [] - if file_type == "JSON": - with open(file_path, "r", encoding="utf-8") as f: - data = json.load(f) - if not self.validate_json(data): - return items - items = data - elif file_type == "JSONL": - with open(file_path, "r", encoding="utf-8") as f: - for line in f: - data = json.loads(line) - if not self.validate_json(data): - continue - items.append(data) - return items - def check_file(self) -> bool: - pass +class COTItemHandler(StructuredFileItemHandler): + def __init__(self): + self.schema = { + "type": "object", + "properties": { + "question": {"type": "string"}, + "conclusion": {"type": "string"}, + "chain_of_thought": {"type": "string"} + }, + "required": ["question", "conclusion", "chain_of_thought"], + } + self.schema_list = { + "type": "array", + "items": self.schema, + } + super().__init__() + + def get_item_type(self): + return ItemTypes.COT + + def validate_json(self, data): + try: + validate(instance=data, schema=self.schema) + return True + except Exception as e: + try: + validate(instance=data, schema=self.schema_list) + return True + except Exception as e: + return False class StructuredFileHandlerFactory: def __init__(self): self.handlers: list[StructuredFileItemHandler] = [] self.handlers.append(QAItemHandler()) + self.handlers.append(COTItemHandler()) def get_handler(self, item_type: str) -> StructuredFileItemHandler: for handler in self.handlers: diff --git a/scripts/db/data-evaluation-init.sql b/scripts/db/data-evaluation-init.sql index a2a1e72..2118de9 100644 --- a/scripts/db/data-evaluation-init.sql +++ b/scripts/db/data-evaluation-init.sql @@ -12,7 +12,7 @@ CREATE TABLE IF NOT EXISTS t_de_eval_task ( source_id VARCHAR(36) COMMENT '待评估对象ID', source_name VARCHAR(255) COMMENT '待评估对象名称', status VARCHAR(50) DEFAULT 'PENDING' COMMENT '状态:PENDING/RUNNING/COMPLETED/STOPPED/FAILED', - eval_method VARCHAR(50) DEFAULT 'AUTO' COMMENT '状态:AUTO/MANUAL', + eval_method VARCHAR(50) DEFAULT 'AUTO' COMMENT '评估方式:AUTO/MANUAL', eval_process DOUBLE PRECISION NOT NULL DEFAULT 0 COMMENT '评估进度', eval_prompt TEXT COMMENT '评估提示词', eval_config TEXT COMMENT '评估配置',