fix: 修复评估时模型输出json格式不对导致读取错误的问题 (#133)

* feature: add cot data evaluation function

* fix: added verification to evaluation results

* fix: fix the prompt for evaluating

* fix: 修复当评估结果为空导致读取失败的问题
This commit is contained in:
hefanli
2025-12-04 18:49:50 +08:00
committed by GitHub
parent 31c4966608
commit 744d15ba24
14 changed files with 373 additions and 219 deletions

View File

@@ -1,3 +1,4 @@
// TypeScript
import React, { useState, useEffect } from 'react'; import React, { useState, useEffect } from 'react';
import { Button, Form, Input, Select, message, Modal, Row, Col, Table, Space } from 'antd'; import { Button, Form, Input, Select, message, Modal, Row, Col, Table, Space } from 'antd';
import { EyeOutlined } from '@ant-design/icons'; import { EyeOutlined } from '@ant-design/icons';
@@ -36,6 +37,7 @@ interface CreateTaskModalProps {
const TASK_TYPES = [ const TASK_TYPES = [
{ label: 'QA评估', value: 'QA' }, { label: 'QA评估', value: 'QA' },
{ label: 'COT评估', value: 'COT' },
]; ];
const EVAL_METHODS = [ const EVAL_METHODS = [
@@ -55,7 +57,7 @@ const CreateTaskModal: React.FC<CreateTaskModalProps> = ({ visible, onCancel, on
dimension: '', dimension: '',
description: '' description: ''
}); });
const [taskType, setTaskType] = useState<string>("QA"); const [taskType, setTaskType] = useState<string>(DEFAULT_TASK_TYPE);
const [promptTemplates, setPromptTemplates] = useState<PromptTemplate[]>([]); const [promptTemplates, setPromptTemplates] = useState<PromptTemplate[]>([]);
const [previewVisible, setPreviewVisible] = useState(false); const [previewVisible, setPreviewVisible] = useState(false);
const [evaluationPrompt, setEvaluationPrompt] = useState(''); const [evaluationPrompt, setEvaluationPrompt] = useState('');
@@ -82,9 +84,24 @@ const CreateTaskModal: React.FC<CreateTaskModalProps> = ({ visible, onCancel, on
fetchDatasets().then(); fetchDatasets().then();
fetchModels().then(); fetchModels().then();
fetchPromptTemplates().then(); fetchPromptTemplates().then();
// sync form with local taskType default
form.setFieldsValue({ taskType: DEFAULT_TASK_TYPE });
} }
}, [visible]); }, [visible]);
// when promptTemplates or taskType change, switch dimensions to template defaults (COT/QA)
useEffect(() => {
if (!promptTemplates || promptTemplates.length === 0) return;
const template = promptTemplates.find(t => t.evalType === taskType);
if (template && template.defaultDimensions) {
setDimensions(template.defaultDimensions.map((dim: any, index: number) => ({
key: `dim-${index}`,
dimension: dim.dimension,
description: dim.description
})));
}
}, [taskType, promptTemplates]);
const fetchDatasets = async () => { const fetchDatasets = async () => {
try { try {
const { data } = await queryDatasetsUsingGet({ page: 1, size: 1000 }); const { data } = await queryDatasetsUsingGet({ page: 1, size: 1000 });
@@ -106,31 +123,46 @@ const CreateTaskModal: React.FC<CreateTaskModalProps> = ({ visible, onCancel, on
}; };
const formatDimensionsForPrompt = (dimensions: Dimension[]) => { const formatDimensionsForPrompt = (dimensions: Dimension[]) => {
let result = "\n"; let result = "";
dimensions.forEach((dim, index) => { dimensions.forEach((dim, index) => {
result += `### ${index + 1}. ${dim.dimension}\n**评估标准:**\n${dim.description}\n\n`; if (index > 0) {
result += "\n";
}
result += `### ${index + 1}. ${dim.dimension}\n**评估标准:**\n${dim.description}`;
if (index < dimensions.length - 1) {
result += "\n";
}
}); });
return result; return result;
}; };
const formatResultExample = (dimensions: Dimension[]) => { const formatResultExample = (dimensions: Dimension[]) => {
return dimensions.map(dim => `\n "${dim.dimension}": "Y",`).join(''); let result = "";
dimensions.forEach((dim, index) => {
if (index > 0) {
result += "\n ";
}
result += `"${dim.dimension}": "Y"`;
if (index < dimensions.length - 1) {
result += ",";
}
});
return result;
}; };
const fetchPromptTemplates = async () => { const fetchPromptTemplates = async () => {
try { try {
const response = await queryPromptTemplatesUsingGet(); const response = await queryPromptTemplatesUsingGet();
const templates: PromptTemplate[] = response.data?.templates const templates: PromptTemplate[] = response.data?.templates || [];
setPromptTemplates(templates) setPromptTemplates(templates);
if (taskType) { // if a template exists for current taskType, initialize dimensions (handled also by useEffect)
const template = templates.find(t => t.evalType === taskType); const template = templates.find(t => t.evalType === taskType);
if (template) { if (template) {
setDimensions(template.defaultDimensions.map((dim: any, index: number) => ({ setDimensions(template.defaultDimensions.map((dim: any, index: number) => ({
key: `dim-${index}`, key: `dim-${index}`,
dimension: dim.dimension, dimension: dim.dimension,
description: dim.description description: dim.description
}))); })));
}
} }
} catch (error) { } catch (error) {
console.error('Error fetching prompt templates:', error); console.error('Error fetching prompt templates:', error);
@@ -144,8 +176,11 @@ const CreateTaskModal: React.FC<CreateTaskModalProps> = ({ visible, onCancel, on
return; return;
} }
const template = promptTemplates.find(t => t.evalType === taskType); const template = promptTemplates.find(t => t.evalType === taskType);
setEvaluationPrompt(template?.prompt.replace("{dimensions}", formatDimensionsForPrompt(dimensions)) const basePrompt = template?.prompt || '';
.replace('{result_example}', formatResultExample(dimensions))); const filled = basePrompt
.replace('{dimensions}', formatDimensionsForPrompt(dimensions))
.replace('{result_example}', formatResultExample(dimensions));
setEvaluationPrompt(filled);
setPreviewVisible(true); setPreviewVisible(true);
}; };
@@ -243,6 +278,13 @@ const CreateTaskModal: React.FC<CreateTaskModalProps> = ({ visible, onCancel, on
evalMethod: DEFAULT_EVAL_METHOD, evalMethod: DEFAULT_EVAL_METHOD,
taskType: DEFAULT_TASK_TYPE, taskType: DEFAULT_TASK_TYPE,
}} }}
onValuesChange={(changed) => {
if (changed.taskType) {
setTaskType(changed.taskType);
setEvaluationPrompt('');
setPreviewVisible(false);
}
}}
> >
<Row gutter={16}> <Row gutter={16}>
<Col span={12}> <Col span={12}>

View File

@@ -1,7 +1,8 @@
import { useEffect, useState } from 'react'; import { useEffect, useState } from 'react';
import { Table, Typography, Button, Space, Spin, Empty, message, Tooltip } from 'antd'; import { Table, Typography, Button, Space, Empty, Tooltip } from 'antd';
import { FolderOpen, FileText, ArrowLeft } from 'lucide-react'; import { FolderOpen, FileText, ArrowLeft } from 'lucide-react';
import { queryEvaluationFilesUsingGet, queryEvaluationItemsUsingGet } from '../../evaluation.api'; import { queryEvaluationFilesUsingGet, queryEvaluationItemsUsingGet } from '../../evaluation.api';
import useFetchData from '@/hooks/useFetchData';
const { Text } = Typography; const { Text } = Typography;
@@ -39,63 +40,52 @@ type EvalItem = {
}; };
export default function EvaluationItems({ task }: { task: any }) { export default function EvaluationItems({ task }: { task: any }) {
const [loadingFiles, setLoadingFiles] = useState<boolean>(false);
const [files, setFiles] = useState<EvalFile[]>([]);
const [filePagination, setFilePagination] = useState({ current: 1, pageSize: 10, total: 0 });
const [selectedFile, setSelectedFile] = useState<{ fileId: string; fileName: string } | null>(null); const [selectedFile, setSelectedFile] = useState<{ fileId: string; fileName: string } | null>(null);
const [loadingItems, setLoadingItems] = useState<boolean>(false);
const [items, setItems] = useState<EvalItem[]>([]);
const [itemPagination, setItemPagination] = useState({ current: 1, pageSize: 10, total: 0 });
// Fetch files list // 文件列表数据(使用 useFetchData),pageOffset=0 表示后端分页为 1 基
useEffect(() => { const {
if (!task?.id || selectedFile) return; loading: loadingFiles,
const fetchFiles = async () => { tableData: files,
setLoadingFiles(true); pagination: filePagination,
try { setSearchParams: setFileSearchParams,
const res = await queryEvaluationFilesUsingGet({ taskId: task.id, page: filePagination.current, size: filePagination.pageSize }); } = useFetchData<EvalFile>(
const data = res?.data; (params) => queryEvaluationFilesUsingGet({ taskId: task?.id, ...params }),
const list: EvalFile[] = data?.content || []; (d) => d as unknown as EvalFile,
setFiles(list); 30000,
setFilePagination((p) => ({ ...p, total: data?.totalElements || 0 })); false,
} catch (e) { [],
message.error('加载评估文件失败'); 0
console.error(e); );
} finally {
setLoadingFiles(false);
}
};
fetchFiles();
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [task?.id, filePagination.current, filePagination.pageSize, selectedFile]);
// Fetch items of selected file // 评估条目数据(使用 useFetchData),依赖选中文件
useEffect(() => { const {
if (!task?.id || !selectedFile) return; loading: loadingItems,
const fetchItems = async () => { tableData: items,
setLoadingItems(true); pagination: itemPagination,
try { setSearchParams: setItemSearchParams,
const res = await queryEvaluationItemsUsingGet({ fetchData: fetchItems,
taskId: task.id, } = useFetchData<EvalItem>(
page: itemPagination.current, (params) => {
size: itemPagination.pageSize, if (!task?.id || !selectedFile?.fileId) {
file_id: selectedFile.fileId, return Promise.resolve({ data: { content: [], totalElements: 0 } });
});
const data = res?.data;
const list: EvalItem[] = data?.content || [];
setItems(list);
setItemPagination((p) => ({ ...p, total: data?.totalElements || 0 }));
} catch (e) {
message.error('加载评估条目失败');
console.error(e);
} finally {
setLoadingItems(false);
} }
}; return queryEvaluationItemsUsingGet({ taskId: task.id, file_id: selectedFile.fileId, ...params });
fetchItems(); },
// eslint-disable-next-line react-hooks/exhaustive-deps (d) => d as unknown as EvalItem,
}, [task?.id, selectedFile?.fileId, itemPagination.current, itemPagination.pageSize]); 30000,
false,
[],
0
);
// 当选择文件变化时,主动触发一次条目查询,避免仅依赖 searchParams 变更导致未触发
useEffect(() => {
if (task?.id && selectedFile?.fileId) {
setItemSearchParams((prev: any) => ({ ...prev, current: 1 }));
// 立即拉取一次,保证点击后立刻出现数据
fetchItems();
}
}, [task?.id, selectedFile?.fileId]);
const fileColumns = [ const fileColumns = [
{ {
@@ -228,19 +218,20 @@ export default function EvaluationItems({ task }: { task: any }) {
dataSource={files} dataSource={files}
loading={loadingFiles} loading={loadingFiles}
size="middle" size="middle"
onRow={(record) => ({ onClick: () => setSelectedFile({ fileId: record.fileId, fileName: record.fileName }) })} onRow={(record) => ({
pagination={{ onClick: () => {
current: filePagination.current, setSelectedFile({ fileId: record.fileId, fileName: record.fileName });
pageSize: filePagination.pageSize, // 切换文件时,重置条目表到第一页
total: filePagination.total, setItemSearchParams((prev: any) => ({ ...prev, current: 1 }));
onChange: (current, pageSize) => setFilePagination({ current, pageSize, total: filePagination.total }), },
}} })}
pagination={filePagination}
/> />
) : ( ) : (
<div className="flex flex-col gap-3"> <div className="flex flex-col gap-3">
<div className="sticky top-0 z-10 bg-white py-2" style={{ borderBottom: '1px solid #f0f0f0' }}> <div className="sticky top-0 z-10 bg-white py-2" style={{ borderBottom: '1px solid #f0f0f0' }}>
<Space wrap> <Space wrap>
<Button icon={<ArrowLeft size={16} />} onClick={() => { setSelectedFile(null); setItems([]); }}> <Button icon={<ArrowLeft size={16} />} onClick={() => { setSelectedFile(null); }}>
</Button> </Button>
<Space> <Space>
@@ -257,12 +248,7 @@ export default function EvaluationItems({ task }: { task: any }) {
dataSource={items} dataSource={items}
loading={loadingItems} loading={loadingItems}
size="middle" size="middle"
pagination={{ pagination={itemPagination}
current: itemPagination.current,
pageSize: itemPagination.pageSize,
total: itemPagination.total,
onChange: (current, pageSize) => setItemPagination({ current, pageSize, total: itemPagination.total }),
}}
/> />
</div> </div>
)} )}

View File

@@ -82,6 +82,7 @@ export default function DataEvaluationPage() {
label: '任务类型', label: '任务类型',
options: [ options: [
{ value: 'QA', label: 'QA评估' }, { value: 'QA', label: 'QA评估' },
{ value: 'COT', label: 'COPT评估' },
], ],
}, },
{ {
@@ -89,7 +90,6 @@ export default function DataEvaluationPage() {
label: '评估方式', label: '评估方式',
options: [ options: [
{ value: 'AUTO', label: '自动评估' }, { value: 'AUTO', label: '自动评估' },
{ value: 'MANUAL', label: '人工评估' },
], ],
}, },
]; ];

View File

@@ -32,6 +32,7 @@ class EvaluationTask(Base):
source_id = Column(String(36), nullable=True, comment="待评估对象ID") source_id = Column(String(36), nullable=True, comment="待评估对象ID")
source_name = Column(String(255), nullable=True, comment="待评估对象名称") source_name = Column(String(255), nullable=True, comment="待评估对象名称")
status = Column(String(50), server_default="PENDING", nullable=False, comment="状态:PENDING/RUNNING/COMPLETED/STOPPED/FAILED") status = Column(String(50), server_default="PENDING", nullable=False, comment="状态:PENDING/RUNNING/COMPLETED/STOPPED/FAILED")
eval_method = Column(String(50), server_default="AUTO", nullable=False, comment="评估方式:AUTO/MANUAL")
eval_process = Column(Float, nullable=False, server_default="0", comment="评估进度") eval_process = Column(Float, nullable=False, server_default="0", comment="评估进度")
eval_prompt = Column(Text, nullable=True, comment="评估提示词") eval_prompt = Column(Text, nullable=True, comment="评估提示词")
eval_config = Column(Text, nullable=True, comment="评估配置") eval_config = Column(Text, nullable=True, comment="评估配置")

View File

@@ -1,3 +1,4 @@
import math
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from sqlalchemy import func from sqlalchemy import func
@@ -116,7 +117,7 @@ class Service:
for f in files for f in files
] ]
total_pages = (total + size - 1) // size if size > 0 else 0 total_pages = math.ceil(total / size) if total > 0 else 0
return PagedDatasetFileResponse( return PagedDatasetFileResponse(
content=content, content=content,

View File

@@ -1,5 +1,6 @@
import asyncio import asyncio
import uuid import uuid
import math
import json import json
from typing import Optional from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks
@@ -171,7 +172,7 @@ async def list_evaluation_tasks(
# 转换为响应模型 # 转换为响应模型
items = [_map_to_task_detail_response(task) for task in tasks] items = [_map_to_task_detail_response(task) for task in tasks]
total_pages = (total + size - 1) // size if size > 0 else 0 total_pages = math.ceil(total / size) if total > 0 else 0
return StandardResponse( return StandardResponse(
code=200, code=200,
@@ -217,7 +218,7 @@ async def list_evaluation_items(
count_query = select(func.count()).select_from(query.subquery()) count_query = select(func.count()).select_from(query.subquery())
total = (await db.execute(count_query)).scalar_one() total = (await db.execute(count_query)).scalar_one()
files = (await db.execute(query.offset(offset).limit(size))).scalars().all() files = (await db.execute(query.offset(offset).limit(size))).scalars().all()
total_pages = (total + size - 1) // size if size > 0 else 0 total_pages = math.ceil(total / size) if total > 0 else 0
file_responses = [ file_responses = [
EvaluationFileResponse( EvaluationFileResponse(
taskId=file.task_id, taskId=file.task_id,
@@ -298,7 +299,7 @@ async def list_evaluation_items(
taskId=item.task_id, taskId=item.task_id,
itemId=item.item_id, itemId=item.item_id,
fileId=item.file_id, fileId=item.file_id,
evalContent=json.loads(item.eval_content), evalContent=json.loads(item.eval_content) if item.eval_content else None,
evalScore=float(item.eval_score) if item.eval_score else None, evalScore=float(item.eval_score) if item.eval_score else None,
evalResult=json.loads(item.eval_result), evalResult=json.loads(item.eval_result),
status=item.status status=item.status
@@ -306,7 +307,7 @@ async def list_evaluation_items(
for item in items for item in items
] ]
total_pages = (total + size - 1) // size if size > 0 else 0 total_pages = math.ceil(total / size) if total > 0 else 0
return StandardResponse( return StandardResponse(
code=200, code=200,
@@ -387,6 +388,12 @@ async def delete_eval_tasks(
.where(EvaluationItem.task_id == task_id) .where(EvaluationItem.task_id == task_id)
) )
# 删除评估文件
await db.execute(
EvaluationFile.__table__.delete()
.where(EvaluationFile.task_id == task_id)
)
# 删除任务 # 删除任务
await db.delete(task) await db.delete(task)
await db.commit() await db.commit()
@@ -419,6 +426,7 @@ def _map_to_task_detail_response(
sourceId=task.source_id, sourceId=task.source_id,
sourceName=task.source_name, sourceName=task.source_name,
status=task.status, status=task.status,
evalMethod=task.eval_method,
evalProcess=task.eval_process, evalProcess=task.eval_process,
evalPrompt=task.eval_prompt, evalPrompt=task.eval_prompt,
evalConfig=json.loads(task.eval_config), evalConfig=json.loads(task.eval_config),

View File

@@ -36,6 +36,7 @@ class EvaluationTaskItem(BaseModel):
source_id: Optional[str] = Field(..., alias="sourceId", description="数据源ID") source_id: Optional[str] = Field(..., alias="sourceId", description="数据源ID")
source_name: Optional[str] = Field(None, alias="sourceName", description="数据源名称") source_name: Optional[str] = Field(None, alias="sourceName", description="数据源名称")
status: TaskStatus = Field(..., description="任务状态") status: TaskStatus = Field(..., description="任务状态")
eval_method: Optional[str] = Field(None, alias="evalMethod", description="评估方式")
eval_process: Optional[float] = Field(0, alias="evalProcess", description="评估进度") eval_process: Optional[float] = Field(0, alias="evalProcess", description="评估进度")
created_at: Optional[str] = Field(None, alias="createdAt", description="创建时间") created_at: Optional[str] = Field(None, alias="createdAt", description="创建时间")
updated_at: Optional[str] = Field(None, alias="updatedAt", description="更新时间") updated_at: Optional[str] = Field(None, alias="updatedAt", description="更新时间")

View File

@@ -1,3 +1,7 @@
from app.core.logging import get_logger
logger = get_logger(__name__)
EVALUATION_PROMPT_TEMPLATE = [ EVALUATION_PROMPT_TEMPLATE = [
{ {
"evalType": "QA", "evalType": "QA",
@@ -51,26 +55,90 @@ EVALUATION_PROMPT_TEMPLATE = [
请按照以下JSON格式输出评估结果,评估结果为Y/N,符合标注输出Y,不符合标准输出N: 请按照以下JSON格式输出评估结果,评估结果为Y/N,符合标注输出Y,不符合标准输出N:
{ {
"result": {{result_example} "result": {
{result_example}
}, },
"evaluation": "这是一个高质量的问答数据集。问题表述清晰具体,答案准确完整且逻辑性强,与原始文本高度相关。建议:可以进一步丰富答案的细节描述。" "evaluation": "这是一个高质量的问答数据集。问题表述清晰具体,答案准确完整且逻辑性强,与原始文本高度相关。建议:可以进一步丰富答案的细节描述。"
} }
"""
},
{
"evalType": "COT",
"defaultDimensions": [
{
"dimension": "思维链逻辑是否连贯",
"description": "分析思维链中推理链条的连续性:步骤间有明确的逻辑连接词;每一步都是基于前置在步骤的结果;没有逻辑跳跃或断层;推理方向一致,不偏离目标。"
},
{
"dimension": "推理步骤是否合理必要",
"description": "分析思维链中对于步骤分解的合理性和必要性:复杂问题被适当分解; 每个步骤都是解决整体问题的必要部分;步骤粒度适中(既不过细也不过粗);符合人类认知习惯。"
},
{
"dimension": "内容是否准确",
"description": "分析整个COT数据内容是否准确:所有陈述的事实必须准确;展示每一步的计算结果(如何涉及数学计算,必须保证数学计算无错误);逻辑推导有效且合理,最终答案与推理过程一致。"
}
],
"prompt": """
# Role: COT数据质量评估专家
## Profile:
- Description: 你是一名专业的Chain-of-Thought(CoT)推理数据质量评估专家,擅长从多个维度对COT数据进行质量评估,挑选出有助于模型学习如何分解问题、展示推理链条,提高模型对于复杂问题解决能力的COT数据。具备深度学习、自然语言处理和数据科学的专业背景。
## Skills:
1. 能够从多个维度对COT数据进行综合评估,保证客观、专业、细致
2. 擅长识别COT数据中的潜在问题,如推包含事实性错误(关键信息错误),存在严重逻辑矛(无法自洽),包含有害、偏见或不当内容,完全偏离主题,抄袭或高度重复内容等
3. 能够给出具体的改进建议和质量评分,并提供可操作的优化方案
## 评估维度:
{dimensions}
## 问题或指令:
{question}
## 思维链:
{chain_of_thought}
## 结论:
{conclusion}
## 注意事项:
- 评估结论要具体指出优点和不足,提供可操作的改进建议
- 评估结论控制在150字以内,简洁明了但要涵盖关键信息
## 输出要求:
请按照以下JSON格式输出评估结果,评估结果为Y/N,符合标注输出Y,不符合标准输出N;将评估结论写到evaluation中:
{
"result": {
{result_example}
},
"evaluation": "这是一个高质量的COT数据。思维链逻辑连贯,推理步骤合理,信息完整。建议:部分表达可以进一步优化,以及个别步骤的过渡可以更加平滑。"
}
""" """
} }
] ]
def get_dimensions_for_qa(dimensions: list[dict]) -> str: def get_dimensions_for_qa(dimensions: list[dict]) -> str:
dimensions_str = "\n" dimensions_str = ""
index = 1 index = 1
for dimension in dimensions: for dimension in dimensions:
dimensions_str += f"### {index}. {dimension.get("dimension")}\n**评估标准:**\n{dimension.get("description")}\n\n" if index > 1:
dimensions_str += "\n"
dimensions_str += f"### {index}. {dimension.get("dimension")}\n**评估标准:**\n{dimension.get("description")}"
if index < len(dimensions):
dimensions_str += "\n"
index += 1 index += 1
return dimensions_str return dimensions_str
def get_result_example_for_qa(dimensions: list[dict]) -> str: def get_result_example_for_qa(dimensions: list[dict]) -> str:
result_example = "" result_example = ""
index = 1
for dimension in dimensions: for dimension in dimensions:
result_example += f'\n "{dimension.get("dimension")}": "Y",' if index > 1:
result_example += "\n "
result_example += f'"{dimension.get("dimension")}": "Y"'
if index < len(dimensions):
result_example += ","
index += 1
return result_example return result_example
def get_prompt(task_type: str, dimensions: list[dict]) -> str: def get_prompt(task_type: str, dimensions: list[dict]) -> str:

View File

@@ -1,7 +1,7 @@
""" """
Schema for evaluation prompt templates. Schema for evaluation prompt templates.
""" """
from typing import List, Dict, Any from typing import List
from pydantic import BaseModel, Field from pydantic import BaseModel, Field

View File

@@ -2,7 +2,7 @@ import json
import uuid import uuid
import asyncio import asyncio
from sqlalchemy import select from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exception import BusinessErrorCodeEnum, BusinessException from app.core.exception import BusinessErrorCodeEnum, BusinessException
@@ -13,7 +13,7 @@ from app.db.models.data_synthesis import DataSynthesisFileInstance, SynthesisDat
from app.db.session import AsyncSessionLocal from app.db.session import AsyncSessionLocal
from app.module.evaluation.schema.evaluation import SourceType from app.module.evaluation.schema.evaluation import SourceType
from app.module.shared.schema import TaskStatus from app.module.shared.schema import TaskStatus
from app.module.shared.util.model_chat import call_openai_style_model from app.module.shared.util.model_chat import call_openai_style_model, _extract_json_substring
from app.module.evaluation.schema.prompt import get_prompt from app.module.evaluation.schema.prompt import get_prompt
from app.module.shared.util.structured_file import StructuredFileHandlerFactory from app.module.shared.util.structured_file import StructuredFileHandlerFactory
from app.module.system.service.common_service import get_model_by_id from app.module.system.service.common_service import get_model_by_id
@@ -35,6 +35,10 @@ class EvaluationExecutor:
prompt_text = ((prompt_text.replace("{content}", eval_content.get("input")) prompt_text = ((prompt_text.replace("{content}", eval_content.get("input"))
.replace("{question}", eval_content.get("instruction"))) .replace("{question}", eval_content.get("instruction")))
.replace("{answer}", eval_content.get("output"))) .replace("{answer}", eval_content.get("output")))
if self.task.task_type == "COT":
prompt_text = ((prompt_text.replace("{question}", eval_content.get("question"))
.replace("{conclusion}", eval_content.get("conclusion")))
.replace("{chain_of_thought}", eval_content.get("chain_of_thought")))
return prompt_text return prompt_text
async def execute(self): async def execute(self):
@@ -44,29 +48,44 @@ class EvaluationExecutor:
files = (await self.db.execute( files = (await self.db.execute(
select(EvaluationFile).where(EvaluationFile.task_id == self.task.id) select(EvaluationFile).where(EvaluationFile.task_id == self.task.id)
)).scalars().all() )).scalars().all()
query = select(EvaluationItem).where(EvaluationItem.task_id == self.task.id)
count_query = select(func.count()).select_from(query.subquery())
total = (await self.db.execute(count_query)).scalar_one()
evaluated_count = 0
for file in files: for file in files:
items = (await self.db.execute( items = (await self.db.execute(query.where(EvaluationItem.file_id == file.file_id))).scalars().all()
select(EvaluationItem).where(EvaluationItem.task_id == self.task.id)
.where(EvaluationItem.file_id == file.file_id)
)).scalars().all()
tasks = [ tasks = [
self.evaluate_item(model_config, item, semaphore) self.evaluate_item(model_config, item, semaphore)
for item in items for item in items
] ]
await asyncio.gather(*tasks, return_exceptions=True) await asyncio.gather(*tasks, return_exceptions=True)
file.evaluated_count = len(items) file.evaluated_count = len(items)
evaluated_count += file.evaluated_count
self.task.eval_process = evaluated_count / total
await self.db.commit() await self.db.commit()
async def evaluate_item(self, model_config, item: EvaluationItem, semaphore: asyncio.Semaphore): async def evaluate_item(self, model_config, item: EvaluationItem, semaphore: asyncio.Semaphore):
async with semaphore: async with semaphore:
prompt_text = self.get_eval_prompt(item) max_try = 3
resp_text = await asyncio.to_thread( while max_try > 0:
call_openai_style_model, model_config.base_url, model_config.api_key, model_config.model_name, prompt_text = self.get_eval_prompt(item)
prompt_text, resp_text = await asyncio.to_thread(
) call_openai_style_model, model_config.base_url, model_config.api_key, model_config.model_name,
item.eval_result = resp_text prompt_text,
item.status = TaskStatus.COMPLETED.value )
await self.db.commit() resp_text = _extract_json_substring(resp_text)
try:
json.loads(resp_text)
except Exception as e:
logger.error(
f"Failed to parse LLM answer as JSON for task={self.task.id}, file={item.file_id}: {e}. Raw answer: {resp_text!r}"
)
max_try -= 1
continue
item.eval_result = resp_text
item.status = TaskStatus.COMPLETED.value
await self.db.commit()
return
def get_source_type(self) -> SourceType: def get_source_type(self) -> SourceType:
@@ -119,7 +138,7 @@ class SynthesisEvaluationExecutor(EvaluationExecutor):
async def save_eval_items(self): async def save_eval_items(self):
synthesis_files = ((await self.db.execute(select(DataSynthesisFileInstance) synthesis_files = ((await self.db.execute(select(DataSynthesisFileInstance)
.where(DataSynthesisFileInstance.task_id == self.task.source_id))) .where(DataSynthesisFileInstance.synthesis_instance_id == self.task.source_id)))
.scalars().all()) .scalars().all())
for synthesis_file in synthesis_files: for synthesis_file in synthesis_files:
synthesis_datas = ((await self.db.execute(select(SynthesisData) synthesis_datas = ((await self.db.execute(select(SynthesisData)
@@ -132,7 +151,7 @@ class SynthesisEvaluationExecutor(EvaluationExecutor):
task_id=self.task.id, task_id=self.task.id,
file_id=synthesis_file.id, file_id=synthesis_file.id,
item_id=synthesis_data.id, item_id=synthesis_data.id,
eval_content=synthesis_data.data, eval_content=json.dumps(synthesis_data.data),
status=TaskStatus.PENDING.value, status=TaskStatus.PENDING.value,
created_by=self.task.created_by, created_by=self.task.created_by,
updated_by=self.task.updated_by, updated_by=self.task.updated_by,

View File

@@ -28,6 +28,7 @@ from app.db.models.data_synthesis import (
from app.db.models.dataset_management import DatasetFiles from app.db.models.dataset_management import DatasetFiles
from app.db.models.model_config import get_model_by_id from app.db.models.model_config import get_model_by_id
from app.db.session import logger from app.db.session import logger
from app.module.shared.util.model_chat import _extract_json_substring
from app.module.system.service.common_service import get_chat_client, chat from app.module.system.service.common_service import get_chat_client, chat
@@ -365,7 +366,7 @@ class GenerationService:
return return
# 1. 预处理原始回答:尝试从中截取出最可能的 JSON 片段 # 1. 预处理原始回答:尝试从中截取出最可能的 JSON 片段
cleaned = self._extract_json_substring(raw_answer) cleaned = _extract_json_substring(raw_answer)
# 2. 解析 JSON,统一成列表结构 # 2. 解析 JSON,统一成列表结构
try: try:
@@ -426,45 +427,6 @@ class GenerationService:
await self.db.commit() await self.db.commit()
await self.db.refresh(file_instance) await self.db.refresh(file_instance)
@staticmethod
def _extract_json_substring(raw: str) -> str:
"""从 LLM 的原始回答中提取最可能的 JSON 字符串片段。
处理思路:
- 原始回答可能是:说明文字 + JSON + 说明文字,甚至带有 Markdown 代码块。
- 优先在文本中查找第一个 '{''[' 作为 JSON 起始;
- 再从后向前找最后一个 '}'']' 作为结束;
- 如果找不到合适的边界,就退回原始字符串。
该方法不会保证截取的一定是合法 JSON,但能显著提高 json.loads 的成功率。
"""
if not raw:
return raw
start = None
end = None
# 查找第一个 JSON 起始符号
for i, ch in enumerate(raw):
if ch in "[{":
start = i
break
# 查找最后一个 JSON 结束符号
for i in range(len(raw) - 1, -1, -1):
if raw[i] in "]}":
end = i + 1 # 切片是左闭右开
break
if start is not None and end is not None and start < end:
return raw[start:end].strip()
# 兜底:去掉常见 Markdown 包裹(```json ... ```)
stripped = raw.strip()
if stripped.startswith("```"):
# 去掉首尾 ``` 标记
stripped = stripped.strip("`")
return stripped
async def _get_or_create_file_instance( async def _get_or_create_file_instance(
self, self,
synthesis_task_id: str, synthesis_task_id: str,

View File

@@ -13,3 +13,41 @@ def call_openai_style_model(base_url, api_key, model_name, prompt, **kwargs):
**kwargs **kwargs
) )
return response.choices[0].message.content return response.choices[0].message.content
def _extract_json_substring(raw: str) -> str:
"""从 LLM 的原始回答中提取最可能的 JSON 字符串片段。
处理思路:
- 原始回答可能是:说明文字 + JSON + 说明文字,甚至带有 Markdown 代码块。
- 优先在文本中查找第一个 '{''[' 作为 JSON 起始;
- 再从后向前找最后一个 '}'']' 作为结束;
- 如果找不到合适的边界,就退回原始字符串。
该方法不会保证截取的一定是合法 JSON,但能显著提高 json.loads 的成功率。
"""
if not raw:
return raw
start = None
end = None
# 查找第一个 JSON 起始符号
for i, ch in enumerate(raw):
if ch in "[{":
start = i
break
# 查找最后一个 JSON 结束符号
for i in range(len(raw) - 1, -1, -1):
if raw[i] in "]}":
end = i + 1 # 切片是左闭右开
break
if start is not None and end is not None and start < end:
return raw[start:end].strip()
# 兜底:去掉常见 Markdown 包裹(```json ... ```)
stripped = raw.strip()
if stripped.startswith("```"):
# 去掉首尾 ``` 标记
stripped = stripped.strip("`")
return stripped

View File

@@ -5,6 +5,7 @@ from jsonschema import validate
class ItemTypes(Enum): class ItemTypes(Enum):
QA = "QA" QA = "QA"
COT = "COT"
class StructuredFileItemHandler: class StructuredFileItemHandler:
@@ -14,11 +15,26 @@ class StructuredFileItemHandler:
def get_item_type(self) -> ItemTypes: def get_item_type(self) -> ItemTypes:
pass pass
def get_items_from_file(self, file_path: str) -> list[dict]: def validate_json(self, data):
pass pass
def check_file(self) -> bool: def get_items_from_file(self, file_path: str) -> list[dict]:
pass file_type = file_path.split(".")[-1].upper()
items = []
if file_type == "JSON":
with open(file_path, "r", encoding="utf-8") as f:
data = json.load(f)
if not self.validate_json(data):
return items
items = data
elif file_type == "JSONL":
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
data = json.loads(line)
if not self.validate_json(data):
continue
items.append(data)
return items
class QAItemHandler(StructuredFileItemHandler): class QAItemHandler(StructuredFileItemHandler):
def __init__(self): def __init__(self):
@@ -51,32 +67,44 @@ class QAItemHandler(StructuredFileItemHandler):
except Exception as e: except Exception as e:
return False return False
def get_items_from_file(self, file_path: str) -> list[dict]:
file_type = file_path.split(".")[-1].upper()
items = []
if file_type == "JSON":
with open(file_path, "r", encoding="utf-8") as f:
data = json.load(f)
if not self.validate_json(data):
return items
items = data
elif file_type == "JSONL":
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
data = json.loads(line)
if not self.validate_json(data):
continue
items.append(data)
return items
def check_file(self) -> bool: class COTItemHandler(StructuredFileItemHandler):
pass def __init__(self):
self.schema = {
"type": "object",
"properties": {
"question": {"type": "string"},
"conclusion": {"type": "string"},
"chain_of_thought": {"type": "string"}
},
"required": ["question", "conclusion", "chain_of_thought"],
}
self.schema_list = {
"type": "array",
"items": self.schema,
}
super().__init__()
def get_item_type(self):
return ItemTypes.COT
def validate_json(self, data):
try:
validate(instance=data, schema=self.schema)
return True
except Exception as e:
try:
validate(instance=data, schema=self.schema_list)
return True
except Exception as e:
return False
class StructuredFileHandlerFactory: class StructuredFileHandlerFactory:
def __init__(self): def __init__(self):
self.handlers: list[StructuredFileItemHandler] = [] self.handlers: list[StructuredFileItemHandler] = []
self.handlers.append(QAItemHandler()) self.handlers.append(QAItemHandler())
self.handlers.append(COTItemHandler())
def get_handler(self, item_type: str) -> StructuredFileItemHandler: def get_handler(self, item_type: str) -> StructuredFileItemHandler:
for handler in self.handlers: for handler in self.handlers:

View File

@@ -12,7 +12,7 @@ CREATE TABLE IF NOT EXISTS t_de_eval_task (
source_id VARCHAR(36) COMMENT '待评估对象ID', source_id VARCHAR(36) COMMENT '待评估对象ID',
source_name VARCHAR(255) COMMENT '待评估对象名称', source_name VARCHAR(255) COMMENT '待评估对象名称',
status VARCHAR(50) DEFAULT 'PENDING' COMMENT '状态:PENDING/RUNNING/COMPLETED/STOPPED/FAILED', status VARCHAR(50) DEFAULT 'PENDING' COMMENT '状态:PENDING/RUNNING/COMPLETED/STOPPED/FAILED',
eval_method VARCHAR(50) DEFAULT 'AUTO' COMMENT '状态:AUTO/MANUAL', eval_method VARCHAR(50) DEFAULT 'AUTO' COMMENT '评估方式:AUTO/MANUAL',
eval_process DOUBLE PRECISION NOT NULL DEFAULT 0 COMMENT '评估进度', eval_process DOUBLE PRECISION NOT NULL DEFAULT 0 COMMENT '评估进度',
eval_prompt TEXT COMMENT '评估提示词', eval_prompt TEXT COMMENT '评估提示词',
eval_config TEXT COMMENT '评估配置', eval_config TEXT COMMENT '评估配置',