You've already forked DataMate
fix: 修复评估时模型输出json格式不对导致读取错误的问题 (#133)
* feature: add cot data evaluation function * fix: added verification to evaluation results * fix: fix the prompt for evaluating * fix: 修复当评估结果为空导致读取失败的问题
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
// TypeScript
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { Button, Form, Input, Select, message, Modal, Row, Col, Table, Space } from 'antd';
|
||||
import { EyeOutlined } from '@ant-design/icons';
|
||||
@@ -36,6 +37,7 @@ interface CreateTaskModalProps {
|
||||
|
||||
const TASK_TYPES = [
|
||||
{ label: 'QA评估', value: 'QA' },
|
||||
{ label: 'COT评估', value: 'COT' },
|
||||
];
|
||||
|
||||
const EVAL_METHODS = [
|
||||
@@ -55,7 +57,7 @@ const CreateTaskModal: React.FC<CreateTaskModalProps> = ({ visible, onCancel, on
|
||||
dimension: '',
|
||||
description: ''
|
||||
});
|
||||
const [taskType, setTaskType] = useState<string>("QA");
|
||||
const [taskType, setTaskType] = useState<string>(DEFAULT_TASK_TYPE);
|
||||
const [promptTemplates, setPromptTemplates] = useState<PromptTemplate[]>([]);
|
||||
const [previewVisible, setPreviewVisible] = useState(false);
|
||||
const [evaluationPrompt, setEvaluationPrompt] = useState('');
|
||||
@@ -82,9 +84,24 @@ const CreateTaskModal: React.FC<CreateTaskModalProps> = ({ visible, onCancel, on
|
||||
fetchDatasets().then();
|
||||
fetchModels().then();
|
||||
fetchPromptTemplates().then();
|
||||
// sync form with local taskType default
|
||||
form.setFieldsValue({ taskType: DEFAULT_TASK_TYPE });
|
||||
}
|
||||
}, [visible]);
|
||||
|
||||
// when promptTemplates or taskType change, switch dimensions to template defaults (COT/QA)
|
||||
useEffect(() => {
|
||||
if (!promptTemplates || promptTemplates.length === 0) return;
|
||||
const template = promptTemplates.find(t => t.evalType === taskType);
|
||||
if (template && template.defaultDimensions) {
|
||||
setDimensions(template.defaultDimensions.map((dim: any, index: number) => ({
|
||||
key: `dim-${index}`,
|
||||
dimension: dim.dimension,
|
||||
description: dim.description
|
||||
})));
|
||||
}
|
||||
}, [taskType, promptTemplates]);
|
||||
|
||||
const fetchDatasets = async () => {
|
||||
try {
|
||||
const { data } = await queryDatasetsUsingGet({ page: 1, size: 1000 });
|
||||
@@ -106,31 +123,46 @@ const CreateTaskModal: React.FC<CreateTaskModalProps> = ({ visible, onCancel, on
|
||||
};
|
||||
|
||||
const formatDimensionsForPrompt = (dimensions: Dimension[]) => {
|
||||
let result = "\n";
|
||||
let result = "";
|
||||
dimensions.forEach((dim, index) => {
|
||||
result += `### ${index + 1}. ${dim.dimension}\n**评估标准:**\n${dim.description}\n\n`;
|
||||
if (index > 0) {
|
||||
result += "\n";
|
||||
}
|
||||
result += `### ${index + 1}. ${dim.dimension}\n**评估标准:**\n${dim.description}`;
|
||||
if (index < dimensions.length - 1) {
|
||||
result += "\n";
|
||||
}
|
||||
});
|
||||
return result;
|
||||
};
|
||||
|
||||
const formatResultExample = (dimensions: Dimension[]) => {
|
||||
return dimensions.map(dim => `\n "${dim.dimension}": "Y",`).join('');
|
||||
let result = "";
|
||||
dimensions.forEach((dim, index) => {
|
||||
if (index > 0) {
|
||||
result += "\n ";
|
||||
}
|
||||
result += `"${dim.dimension}": "Y"`;
|
||||
if (index < dimensions.length - 1) {
|
||||
result += ",";
|
||||
}
|
||||
});
|
||||
return result;
|
||||
};
|
||||
|
||||
const fetchPromptTemplates = async () => {
|
||||
try {
|
||||
const response = await queryPromptTemplatesUsingGet();
|
||||
const templates: PromptTemplate[] = response.data?.templates
|
||||
setPromptTemplates(templates)
|
||||
if (taskType) {
|
||||
const template = templates.find(t => t.evalType === taskType);
|
||||
if (template) {
|
||||
setDimensions(template.defaultDimensions.map((dim: any, index: number) => ({
|
||||
key: `dim-${index}`,
|
||||
dimension: dim.dimension,
|
||||
description: dim.description
|
||||
})));
|
||||
}
|
||||
const templates: PromptTemplate[] = response.data?.templates || [];
|
||||
setPromptTemplates(templates);
|
||||
// if a template exists for current taskType, initialize dimensions (handled also by useEffect)
|
||||
const template = templates.find(t => t.evalType === taskType);
|
||||
if (template) {
|
||||
setDimensions(template.defaultDimensions.map((dim: any, index: number) => ({
|
||||
key: `dim-${index}`,
|
||||
dimension: dim.dimension,
|
||||
description: dim.description
|
||||
})));
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error fetching prompt templates:', error);
|
||||
@@ -144,8 +176,11 @@ const CreateTaskModal: React.FC<CreateTaskModalProps> = ({ visible, onCancel, on
|
||||
return;
|
||||
}
|
||||
const template = promptTemplates.find(t => t.evalType === taskType);
|
||||
setEvaluationPrompt(template?.prompt.replace("{dimensions}", formatDimensionsForPrompt(dimensions))
|
||||
.replace('{result_example}', formatResultExample(dimensions)));
|
||||
const basePrompt = template?.prompt || '';
|
||||
const filled = basePrompt
|
||||
.replace('{dimensions}', formatDimensionsForPrompt(dimensions))
|
||||
.replace('{result_example}', formatResultExample(dimensions));
|
||||
setEvaluationPrompt(filled);
|
||||
setPreviewVisible(true);
|
||||
};
|
||||
|
||||
@@ -243,6 +278,13 @@ const CreateTaskModal: React.FC<CreateTaskModalProps> = ({ visible, onCancel, on
|
||||
evalMethod: DEFAULT_EVAL_METHOD,
|
||||
taskType: DEFAULT_TASK_TYPE,
|
||||
}}
|
||||
onValuesChange={(changed) => {
|
||||
if (changed.taskType) {
|
||||
setTaskType(changed.taskType);
|
||||
setEvaluationPrompt('');
|
||||
setPreviewVisible(false);
|
||||
}
|
||||
}}
|
||||
>
|
||||
<Row gutter={16}>
|
||||
<Col span={12}>
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import { useEffect, useState } from 'react';
|
||||
import { Table, Typography, Button, Space, Spin, Empty, message, Tooltip } from 'antd';
|
||||
import { Table, Typography, Button, Space, Empty, Tooltip } from 'antd';
|
||||
import { FolderOpen, FileText, ArrowLeft } from 'lucide-react';
|
||||
import { queryEvaluationFilesUsingGet, queryEvaluationItemsUsingGet } from '../../evaluation.api';
|
||||
import useFetchData from '@/hooks/useFetchData';
|
||||
|
||||
const { Text } = Typography;
|
||||
|
||||
@@ -39,63 +40,52 @@ type EvalItem = {
|
||||
};
|
||||
|
||||
export default function EvaluationItems({ task }: { task: any }) {
|
||||
const [loadingFiles, setLoadingFiles] = useState<boolean>(false);
|
||||
const [files, setFiles] = useState<EvalFile[]>([]);
|
||||
const [filePagination, setFilePagination] = useState({ current: 1, pageSize: 10, total: 0 });
|
||||
|
||||
const [selectedFile, setSelectedFile] = useState<{ fileId: string; fileName: string } | null>(null);
|
||||
const [loadingItems, setLoadingItems] = useState<boolean>(false);
|
||||
const [items, setItems] = useState<EvalItem[]>([]);
|
||||
const [itemPagination, setItemPagination] = useState({ current: 1, pageSize: 10, total: 0 });
|
||||
|
||||
// Fetch files list
|
||||
useEffect(() => {
|
||||
if (!task?.id || selectedFile) return;
|
||||
const fetchFiles = async () => {
|
||||
setLoadingFiles(true);
|
||||
try {
|
||||
const res = await queryEvaluationFilesUsingGet({ taskId: task.id, page: filePagination.current, size: filePagination.pageSize });
|
||||
const data = res?.data;
|
||||
const list: EvalFile[] = data?.content || [];
|
||||
setFiles(list);
|
||||
setFilePagination((p) => ({ ...p, total: data?.totalElements || 0 }));
|
||||
} catch (e) {
|
||||
message.error('加载评估文件失败');
|
||||
console.error(e);
|
||||
} finally {
|
||||
setLoadingFiles(false);
|
||||
}
|
||||
};
|
||||
fetchFiles();
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [task?.id, filePagination.current, filePagination.pageSize, selectedFile]);
|
||||
// 文件列表数据(使用 useFetchData),pageOffset=0 表示后端分页为 1 基
|
||||
const {
|
||||
loading: loadingFiles,
|
||||
tableData: files,
|
||||
pagination: filePagination,
|
||||
setSearchParams: setFileSearchParams,
|
||||
} = useFetchData<EvalFile>(
|
||||
(params) => queryEvaluationFilesUsingGet({ taskId: task?.id, ...params }),
|
||||
(d) => d as unknown as EvalFile,
|
||||
30000,
|
||||
false,
|
||||
[],
|
||||
0
|
||||
);
|
||||
|
||||
// Fetch items of selected file
|
||||
useEffect(() => {
|
||||
if (!task?.id || !selectedFile) return;
|
||||
const fetchItems = async () => {
|
||||
setLoadingItems(true);
|
||||
try {
|
||||
const res = await queryEvaluationItemsUsingGet({
|
||||
taskId: task.id,
|
||||
page: itemPagination.current,
|
||||
size: itemPagination.pageSize,
|
||||
file_id: selectedFile.fileId,
|
||||
});
|
||||
const data = res?.data;
|
||||
const list: EvalItem[] = data?.content || [];
|
||||
setItems(list);
|
||||
setItemPagination((p) => ({ ...p, total: data?.totalElements || 0 }));
|
||||
} catch (e) {
|
||||
message.error('加载评估条目失败');
|
||||
console.error(e);
|
||||
} finally {
|
||||
setLoadingItems(false);
|
||||
// 评估条目数据(使用 useFetchData),依赖选中文件
|
||||
const {
|
||||
loading: loadingItems,
|
||||
tableData: items,
|
||||
pagination: itemPagination,
|
||||
setSearchParams: setItemSearchParams,
|
||||
fetchData: fetchItems,
|
||||
} = useFetchData<EvalItem>(
|
||||
(params) => {
|
||||
if (!task?.id || !selectedFile?.fileId) {
|
||||
return Promise.resolve({ data: { content: [], totalElements: 0 } });
|
||||
}
|
||||
};
|
||||
fetchItems();
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [task?.id, selectedFile?.fileId, itemPagination.current, itemPagination.pageSize]);
|
||||
return queryEvaluationItemsUsingGet({ taskId: task.id, file_id: selectedFile.fileId, ...params });
|
||||
},
|
||||
(d) => d as unknown as EvalItem,
|
||||
30000,
|
||||
false,
|
||||
[],
|
||||
0
|
||||
);
|
||||
|
||||
// 当选择文件变化时,主动触发一次条目查询,避免仅依赖 searchParams 变更导致未触发
|
||||
useEffect(() => {
|
||||
if (task?.id && selectedFile?.fileId) {
|
||||
setItemSearchParams((prev: any) => ({ ...prev, current: 1 }));
|
||||
// 立即拉取一次,保证点击后立刻出现数据
|
||||
fetchItems();
|
||||
}
|
||||
}, [task?.id, selectedFile?.fileId]);
|
||||
|
||||
const fileColumns = [
|
||||
{
|
||||
@@ -228,19 +218,20 @@ export default function EvaluationItems({ task }: { task: any }) {
|
||||
dataSource={files}
|
||||
loading={loadingFiles}
|
||||
size="middle"
|
||||
onRow={(record) => ({ onClick: () => setSelectedFile({ fileId: record.fileId, fileName: record.fileName }) })}
|
||||
pagination={{
|
||||
current: filePagination.current,
|
||||
pageSize: filePagination.pageSize,
|
||||
total: filePagination.total,
|
||||
onChange: (current, pageSize) => setFilePagination({ current, pageSize, total: filePagination.total }),
|
||||
}}
|
||||
onRow={(record) => ({
|
||||
onClick: () => {
|
||||
setSelectedFile({ fileId: record.fileId, fileName: record.fileName });
|
||||
// 切换文件时,重置条目表到第一页
|
||||
setItemSearchParams((prev: any) => ({ ...prev, current: 1 }));
|
||||
},
|
||||
})}
|
||||
pagination={filePagination}
|
||||
/>
|
||||
) : (
|
||||
<div className="flex flex-col gap-3">
|
||||
<div className="sticky top-0 z-10 bg-white py-2" style={{ borderBottom: '1px solid #f0f0f0' }}>
|
||||
<Space wrap>
|
||||
<Button icon={<ArrowLeft size={16} />} onClick={() => { setSelectedFile(null); setItems([]); }}>
|
||||
<Button icon={<ArrowLeft size={16} />} onClick={() => { setSelectedFile(null); }}>
|
||||
返回文件列表
|
||||
</Button>
|
||||
<Space>
|
||||
@@ -257,12 +248,7 @@ export default function EvaluationItems({ task }: { task: any }) {
|
||||
dataSource={items}
|
||||
loading={loadingItems}
|
||||
size="middle"
|
||||
pagination={{
|
||||
current: itemPagination.current,
|
||||
pageSize: itemPagination.pageSize,
|
||||
total: itemPagination.total,
|
||||
onChange: (current, pageSize) => setItemPagination({ current, pageSize, total: itemPagination.total }),
|
||||
}}
|
||||
pagination={itemPagination}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -82,6 +82,7 @@ export default function DataEvaluationPage() {
|
||||
label: '任务类型',
|
||||
options: [
|
||||
{ value: 'QA', label: 'QA评估' },
|
||||
{ value: 'COT', label: 'COPT评估' },
|
||||
],
|
||||
},
|
||||
{
|
||||
@@ -89,7 +90,6 @@ export default function DataEvaluationPage() {
|
||||
label: '评估方式',
|
||||
options: [
|
||||
{ value: 'AUTO', label: '自动评估' },
|
||||
{ value: 'MANUAL', label: '人工评估' },
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
@@ -32,6 +32,7 @@ class EvaluationTask(Base):
|
||||
source_id = Column(String(36), nullable=True, comment="待评估对象ID")
|
||||
source_name = Column(String(255), nullable=True, comment="待评估对象名称")
|
||||
status = Column(String(50), server_default="PENDING", nullable=False, comment="状态:PENDING/RUNNING/COMPLETED/STOPPED/FAILED")
|
||||
eval_method = Column(String(50), server_default="AUTO", nullable=False, comment="评估方式:AUTO/MANUAL")
|
||||
eval_process = Column(Float, nullable=False, server_default="0", comment="评估进度")
|
||||
eval_prompt = Column(Text, nullable=True, comment="评估提示词")
|
||||
eval_config = Column(Text, nullable=True, comment="评估配置")
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import math
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy import func
|
||||
@@ -116,7 +117,7 @@ class Service:
|
||||
for f in files
|
||||
]
|
||||
|
||||
total_pages = (total + size - 1) // size if size > 0 else 0
|
||||
total_pages = math.ceil(total / size) if total > 0 else 0
|
||||
|
||||
return PagedDatasetFileResponse(
|
||||
content=content,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
import math
|
||||
import json
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks
|
||||
@@ -171,7 +172,7 @@ async def list_evaluation_tasks(
|
||||
|
||||
# 转换为响应模型
|
||||
items = [_map_to_task_detail_response(task) for task in tasks]
|
||||
total_pages = (total + size - 1) // size if size > 0 else 0
|
||||
total_pages = math.ceil(total / size) if total > 0 else 0
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
@@ -217,7 +218,7 @@ async def list_evaluation_items(
|
||||
count_query = select(func.count()).select_from(query.subquery())
|
||||
total = (await db.execute(count_query)).scalar_one()
|
||||
files = (await db.execute(query.offset(offset).limit(size))).scalars().all()
|
||||
total_pages = (total + size - 1) // size if size > 0 else 0
|
||||
total_pages = math.ceil(total / size) if total > 0 else 0
|
||||
file_responses = [
|
||||
EvaluationFileResponse(
|
||||
taskId=file.task_id,
|
||||
@@ -298,7 +299,7 @@ async def list_evaluation_items(
|
||||
taskId=item.task_id,
|
||||
itemId=item.item_id,
|
||||
fileId=item.file_id,
|
||||
evalContent=json.loads(item.eval_content),
|
||||
evalContent=json.loads(item.eval_content) if item.eval_content else None,
|
||||
evalScore=float(item.eval_score) if item.eval_score else None,
|
||||
evalResult=json.loads(item.eval_result),
|
||||
status=item.status
|
||||
@@ -306,7 +307,7 @@ async def list_evaluation_items(
|
||||
for item in items
|
||||
]
|
||||
|
||||
total_pages = (total + size - 1) // size if size > 0 else 0
|
||||
total_pages = math.ceil(total / size) if total > 0 else 0
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
@@ -387,6 +388,12 @@ async def delete_eval_tasks(
|
||||
.where(EvaluationItem.task_id == task_id)
|
||||
)
|
||||
|
||||
# 删除评估文件
|
||||
await db.execute(
|
||||
EvaluationFile.__table__.delete()
|
||||
.where(EvaluationFile.task_id == task_id)
|
||||
)
|
||||
|
||||
# 删除任务
|
||||
await db.delete(task)
|
||||
await db.commit()
|
||||
@@ -419,6 +426,7 @@ def _map_to_task_detail_response(
|
||||
sourceId=task.source_id,
|
||||
sourceName=task.source_name,
|
||||
status=task.status,
|
||||
evalMethod=task.eval_method,
|
||||
evalProcess=task.eval_process,
|
||||
evalPrompt=task.eval_prompt,
|
||||
evalConfig=json.loads(task.eval_config),
|
||||
|
||||
@@ -36,6 +36,7 @@ class EvaluationTaskItem(BaseModel):
|
||||
source_id: Optional[str] = Field(..., alias="sourceId", description="数据源ID")
|
||||
source_name: Optional[str] = Field(None, alias="sourceName", description="数据源名称")
|
||||
status: TaskStatus = Field(..., description="任务状态")
|
||||
eval_method: Optional[str] = Field(None, alias="evalMethod", description="评估方式")
|
||||
eval_process: Optional[float] = Field(0, alias="evalProcess", description="评估进度")
|
||||
created_at: Optional[str] = Field(None, alias="createdAt", description="创建时间")
|
||||
updated_at: Optional[str] = Field(None, alias="updatedAt", description="更新时间")
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
from app.core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
EVALUATION_PROMPT_TEMPLATE = [
|
||||
{
|
||||
"evalType": "QA",
|
||||
@@ -51,26 +55,90 @@ EVALUATION_PROMPT_TEMPLATE = [
|
||||
请按照以下JSON格式输出评估结果,评估结果为Y/N,符合标注输出Y,不符合标准输出N:
|
||||
|
||||
{
|
||||
"result": {{result_example}
|
||||
"result": {
|
||||
{result_example}
|
||||
},
|
||||
"evaluation": "这是一个高质量的问答数据集。问题表述清晰具体,答案准确完整且逻辑性强,与原始文本高度相关。建议:可以进一步丰富答案的细节描述。"
|
||||
}
|
||||
"""
|
||||
},
|
||||
{
|
||||
"evalType": "COT",
|
||||
"defaultDimensions": [
|
||||
{
|
||||
"dimension": "思维链逻辑是否连贯",
|
||||
"description": "分析思维链中推理链条的连续性:步骤间有明确的逻辑连接词;每一步都是基于前置在步骤的结果;没有逻辑跳跃或断层;推理方向一致,不偏离目标。"
|
||||
},
|
||||
{
|
||||
"dimension": "推理步骤是否合理必要",
|
||||
"description": "分析思维链中对于步骤分解的合理性和必要性:复杂问题被适当分解; 每个步骤都是解决整体问题的必要部分;步骤粒度适中(既不过细也不过粗);符合人类认知习惯。"
|
||||
},
|
||||
{
|
||||
"dimension": "内容是否准确",
|
||||
"description": "分析整个COT数据内容是否准确:所有陈述的事实必须准确;展示每一步的计算结果(如何涉及数学计算,必须保证数学计算无错误);逻辑推导有效且合理,最终答案与推理过程一致。"
|
||||
}
|
||||
],
|
||||
"prompt": """
|
||||
# Role: COT数据质量评估专家
|
||||
## Profile:
|
||||
- Description: 你是一名专业的Chain-of-Thought(CoT)推理数据质量评估专家,擅长从多个维度对COT数据进行质量评估,挑选出有助于模型学习如何分解问题、展示推理链条,提高模型对于复杂问题解决能力的COT数据。具备深度学习、自然语言处理和数据科学的专业背景。
|
||||
|
||||
## Skills:
|
||||
1. 能够从多个维度对COT数据进行综合评估,保证客观、专业、细致
|
||||
2. 擅长识别COT数据中的潜在问题,如推包含事实性错误(关键信息错误),存在严重逻辑矛(无法自洽),包含有害、偏见或不当内容,完全偏离主题,抄袭或高度重复内容等
|
||||
3. 能够给出具体的改进建议和质量评分,并提供可操作的优化方案
|
||||
|
||||
## 评估维度:
|
||||
{dimensions}
|
||||
|
||||
## 问题或指令:
|
||||
{question}
|
||||
|
||||
## 思维链:
|
||||
{chain_of_thought}
|
||||
|
||||
## 结论:
|
||||
{conclusion}
|
||||
|
||||
## 注意事项:
|
||||
- 评估结论要具体指出优点和不足,提供可操作的改进建议
|
||||
- 评估结论控制在150字以内,简洁明了但要涵盖关键信息
|
||||
|
||||
## 输出要求:
|
||||
请按照以下JSON格式输出评估结果,评估结果为Y/N,符合标注输出Y,不符合标准输出N;将评估结论写到evaluation中:
|
||||
|
||||
{
|
||||
"result": {
|
||||
{result_example}
|
||||
},
|
||||
"evaluation": "这是一个高质量的COT数据。思维链逻辑连贯,推理步骤合理,信息完整。建议:部分表达可以进一步优化,以及个别步骤的过渡可以更加平滑。"
|
||||
}
|
||||
"""
|
||||
}
|
||||
]
|
||||
|
||||
def get_dimensions_for_qa(dimensions: list[dict]) -> str:
|
||||
dimensions_str = "\n"
|
||||
dimensions_str = ""
|
||||
index = 1
|
||||
for dimension in dimensions:
|
||||
dimensions_str += f"### {index}. {dimension.get("dimension")}\n**评估标准:**\n{dimension.get("description")}\n\n"
|
||||
if index > 1:
|
||||
dimensions_str += "\n"
|
||||
dimensions_str += f"### {index}. {dimension.get("dimension")}\n**评估标准:**\n{dimension.get("description")}"
|
||||
if index < len(dimensions):
|
||||
dimensions_str += "\n"
|
||||
index += 1
|
||||
return dimensions_str
|
||||
|
||||
def get_result_example_for_qa(dimensions: list[dict]) -> str:
|
||||
result_example = ""
|
||||
index = 1
|
||||
for dimension in dimensions:
|
||||
result_example += f'\n "{dimension.get("dimension")}": "Y",'
|
||||
if index > 1:
|
||||
result_example += "\n "
|
||||
result_example += f'"{dimension.get("dimension")}": "Y"'
|
||||
if index < len(dimensions):
|
||||
result_example += ","
|
||||
index += 1
|
||||
return result_example
|
||||
|
||||
def get_prompt(task_type: str, dimensions: list[dict]) -> str:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
Schema for evaluation prompt templates.
|
||||
"""
|
||||
from typing import List, Dict, Any
|
||||
from typing import List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import uuid
|
||||
import asyncio
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.exception import BusinessErrorCodeEnum, BusinessException
|
||||
@@ -13,7 +13,7 @@ from app.db.models.data_synthesis import DataSynthesisFileInstance, SynthesisDat
|
||||
from app.db.session import AsyncSessionLocal
|
||||
from app.module.evaluation.schema.evaluation import SourceType
|
||||
from app.module.shared.schema import TaskStatus
|
||||
from app.module.shared.util.model_chat import call_openai_style_model
|
||||
from app.module.shared.util.model_chat import call_openai_style_model, _extract_json_substring
|
||||
from app.module.evaluation.schema.prompt import get_prompt
|
||||
from app.module.shared.util.structured_file import StructuredFileHandlerFactory
|
||||
from app.module.system.service.common_service import get_model_by_id
|
||||
@@ -35,6 +35,10 @@ class EvaluationExecutor:
|
||||
prompt_text = ((prompt_text.replace("{content}", eval_content.get("input"))
|
||||
.replace("{question}", eval_content.get("instruction")))
|
||||
.replace("{answer}", eval_content.get("output")))
|
||||
if self.task.task_type == "COT":
|
||||
prompt_text = ((prompt_text.replace("{question}", eval_content.get("question"))
|
||||
.replace("{conclusion}", eval_content.get("conclusion")))
|
||||
.replace("{chain_of_thought}", eval_content.get("chain_of_thought")))
|
||||
return prompt_text
|
||||
|
||||
async def execute(self):
|
||||
@@ -44,29 +48,44 @@ class EvaluationExecutor:
|
||||
files = (await self.db.execute(
|
||||
select(EvaluationFile).where(EvaluationFile.task_id == self.task.id)
|
||||
)).scalars().all()
|
||||
query = select(EvaluationItem).where(EvaluationItem.task_id == self.task.id)
|
||||
count_query = select(func.count()).select_from(query.subquery())
|
||||
total = (await self.db.execute(count_query)).scalar_one()
|
||||
evaluated_count = 0
|
||||
for file in files:
|
||||
items = (await self.db.execute(
|
||||
select(EvaluationItem).where(EvaluationItem.task_id == self.task.id)
|
||||
.where(EvaluationItem.file_id == file.file_id)
|
||||
)).scalars().all()
|
||||
items = (await self.db.execute(query.where(EvaluationItem.file_id == file.file_id))).scalars().all()
|
||||
tasks = [
|
||||
self.evaluate_item(model_config, item, semaphore)
|
||||
for item in items
|
||||
]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
file.evaluated_count = len(items)
|
||||
evaluated_count += file.evaluated_count
|
||||
self.task.eval_process = evaluated_count / total
|
||||
await self.db.commit()
|
||||
|
||||
async def evaluate_item(self, model_config, item: EvaluationItem, semaphore: asyncio.Semaphore):
|
||||
async with semaphore:
|
||||
prompt_text = self.get_eval_prompt(item)
|
||||
resp_text = await asyncio.to_thread(
|
||||
call_openai_style_model, model_config.base_url, model_config.api_key, model_config.model_name,
|
||||
prompt_text,
|
||||
)
|
||||
item.eval_result = resp_text
|
||||
item.status = TaskStatus.COMPLETED.value
|
||||
await self.db.commit()
|
||||
max_try = 3
|
||||
while max_try > 0:
|
||||
prompt_text = self.get_eval_prompt(item)
|
||||
resp_text = await asyncio.to_thread(
|
||||
call_openai_style_model, model_config.base_url, model_config.api_key, model_config.model_name,
|
||||
prompt_text,
|
||||
)
|
||||
resp_text = _extract_json_substring(resp_text)
|
||||
try:
|
||||
json.loads(resp_text)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to parse LLM answer as JSON for task={self.task.id}, file={item.file_id}: {e}. Raw answer: {resp_text!r}"
|
||||
)
|
||||
max_try -= 1
|
||||
continue
|
||||
item.eval_result = resp_text
|
||||
item.status = TaskStatus.COMPLETED.value
|
||||
await self.db.commit()
|
||||
return
|
||||
|
||||
|
||||
def get_source_type(self) -> SourceType:
|
||||
@@ -119,7 +138,7 @@ class SynthesisEvaluationExecutor(EvaluationExecutor):
|
||||
|
||||
async def save_eval_items(self):
|
||||
synthesis_files = ((await self.db.execute(select(DataSynthesisFileInstance)
|
||||
.where(DataSynthesisFileInstance.task_id == self.task.source_id)))
|
||||
.where(DataSynthesisFileInstance.synthesis_instance_id == self.task.source_id)))
|
||||
.scalars().all())
|
||||
for synthesis_file in synthesis_files:
|
||||
synthesis_datas = ((await self.db.execute(select(SynthesisData)
|
||||
@@ -132,7 +151,7 @@ class SynthesisEvaluationExecutor(EvaluationExecutor):
|
||||
task_id=self.task.id,
|
||||
file_id=synthesis_file.id,
|
||||
item_id=synthesis_data.id,
|
||||
eval_content=synthesis_data.data,
|
||||
eval_content=json.dumps(synthesis_data.data),
|
||||
status=TaskStatus.PENDING.value,
|
||||
created_by=self.task.created_by,
|
||||
updated_by=self.task.updated_by,
|
||||
|
||||
@@ -28,6 +28,7 @@ from app.db.models.data_synthesis import (
|
||||
from app.db.models.dataset_management import DatasetFiles
|
||||
from app.db.models.model_config import get_model_by_id
|
||||
from app.db.session import logger
|
||||
from app.module.shared.util.model_chat import _extract_json_substring
|
||||
from app.module.system.service.common_service import get_chat_client, chat
|
||||
|
||||
|
||||
@@ -365,7 +366,7 @@ class GenerationService:
|
||||
return
|
||||
|
||||
# 1. 预处理原始回答:尝试从中截取出最可能的 JSON 片段
|
||||
cleaned = self._extract_json_substring(raw_answer)
|
||||
cleaned = _extract_json_substring(raw_answer)
|
||||
|
||||
# 2. 解析 JSON,统一成列表结构
|
||||
try:
|
||||
@@ -426,45 +427,6 @@ class GenerationService:
|
||||
await self.db.commit()
|
||||
await self.db.refresh(file_instance)
|
||||
|
||||
@staticmethod
|
||||
def _extract_json_substring(raw: str) -> str:
|
||||
"""从 LLM 的原始回答中提取最可能的 JSON 字符串片段。
|
||||
|
||||
处理思路:
|
||||
- 原始回答可能是:说明文字 + JSON + 说明文字,甚至带有 Markdown 代码块。
|
||||
- 优先在文本中查找第一个 '{' 或 '[' 作为 JSON 起始;
|
||||
- 再从后向前找最后一个 '}' 或 ']' 作为结束;
|
||||
- 如果找不到合适的边界,就退回原始字符串。
|
||||
该方法不会保证截取的一定是合法 JSON,但能显著提高 json.loads 的成功率。
|
||||
"""
|
||||
if not raw:
|
||||
return raw
|
||||
|
||||
start = None
|
||||
end = None
|
||||
|
||||
# 查找第一个 JSON 起始符号
|
||||
for i, ch in enumerate(raw):
|
||||
if ch in "[{":
|
||||
start = i
|
||||
break
|
||||
|
||||
# 查找最后一个 JSON 结束符号
|
||||
for i in range(len(raw) - 1, -1, -1):
|
||||
if raw[i] in "]}":
|
||||
end = i + 1 # 切片是左闭右开
|
||||
break
|
||||
|
||||
if start is not None and end is not None and start < end:
|
||||
return raw[start:end].strip()
|
||||
|
||||
# 兜底:去掉常见 Markdown 包裹(```json ... ```)
|
||||
stripped = raw.strip()
|
||||
if stripped.startswith("```"):
|
||||
# 去掉首尾 ``` 标记
|
||||
stripped = stripped.strip("`")
|
||||
return stripped
|
||||
|
||||
async def _get_or_create_file_instance(
|
||||
self,
|
||||
synthesis_task_id: str,
|
||||
|
||||
@@ -13,3 +13,41 @@ def call_openai_style_model(base_url, api_key, model_name, prompt, **kwargs):
|
||||
**kwargs
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
def _extract_json_substring(raw: str) -> str:
|
||||
"""从 LLM 的原始回答中提取最可能的 JSON 字符串片段。
|
||||
|
||||
处理思路:
|
||||
- 原始回答可能是:说明文字 + JSON + 说明文字,甚至带有 Markdown 代码块。
|
||||
- 优先在文本中查找第一个 '{' 或 '[' 作为 JSON 起始;
|
||||
- 再从后向前找最后一个 '}' 或 ']' 作为结束;
|
||||
- 如果找不到合适的边界,就退回原始字符串。
|
||||
该方法不会保证截取的一定是合法 JSON,但能显著提高 json.loads 的成功率。
|
||||
"""
|
||||
if not raw:
|
||||
return raw
|
||||
|
||||
start = None
|
||||
end = None
|
||||
|
||||
# 查找第一个 JSON 起始符号
|
||||
for i, ch in enumerate(raw):
|
||||
if ch in "[{":
|
||||
start = i
|
||||
break
|
||||
|
||||
# 查找最后一个 JSON 结束符号
|
||||
for i in range(len(raw) - 1, -1, -1):
|
||||
if raw[i] in "]}":
|
||||
end = i + 1 # 切片是左闭右开
|
||||
break
|
||||
|
||||
if start is not None and end is not None and start < end:
|
||||
return raw[start:end].strip()
|
||||
|
||||
# 兜底:去掉常见 Markdown 包裹(```json ... ```)
|
||||
stripped = raw.strip()
|
||||
if stripped.startswith("```"):
|
||||
# 去掉首尾 ``` 标记
|
||||
stripped = stripped.strip("`")
|
||||
return stripped
|
||||
|
||||
@@ -5,6 +5,7 @@ from jsonschema import validate
|
||||
|
||||
class ItemTypes(Enum):
|
||||
QA = "QA"
|
||||
COT = "COT"
|
||||
|
||||
|
||||
class StructuredFileItemHandler:
|
||||
@@ -14,11 +15,26 @@ class StructuredFileItemHandler:
|
||||
def get_item_type(self) -> ItemTypes:
|
||||
pass
|
||||
|
||||
def get_items_from_file(self, file_path: str) -> list[dict]:
|
||||
def validate_json(self, data):
|
||||
pass
|
||||
|
||||
def check_file(self) -> bool:
|
||||
pass
|
||||
def get_items_from_file(self, file_path: str) -> list[dict]:
|
||||
file_type = file_path.split(".")[-1].upper()
|
||||
items = []
|
||||
if file_type == "JSON":
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if not self.validate_json(data):
|
||||
return items
|
||||
items = data
|
||||
elif file_type == "JSONL":
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
if not self.validate_json(data):
|
||||
continue
|
||||
items.append(data)
|
||||
return items
|
||||
|
||||
class QAItemHandler(StructuredFileItemHandler):
|
||||
def __init__(self):
|
||||
@@ -51,32 +67,44 @@ class QAItemHandler(StructuredFileItemHandler):
|
||||
except Exception as e:
|
||||
return False
|
||||
|
||||
def get_items_from_file(self, file_path: str) -> list[dict]:
|
||||
file_type = file_path.split(".")[-1].upper()
|
||||
items = []
|
||||
if file_type == "JSON":
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if not self.validate_json(data):
|
||||
return items
|
||||
items = data
|
||||
elif file_type == "JSONL":
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
if not self.validate_json(data):
|
||||
continue
|
||||
items.append(data)
|
||||
return items
|
||||
|
||||
def check_file(self) -> bool:
|
||||
pass
|
||||
class COTItemHandler(StructuredFileItemHandler):
|
||||
def __init__(self):
|
||||
self.schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"question": {"type": "string"},
|
||||
"conclusion": {"type": "string"},
|
||||
"chain_of_thought": {"type": "string"}
|
||||
},
|
||||
"required": ["question", "conclusion", "chain_of_thought"],
|
||||
}
|
||||
self.schema_list = {
|
||||
"type": "array",
|
||||
"items": self.schema,
|
||||
}
|
||||
super().__init__()
|
||||
|
||||
def get_item_type(self):
|
||||
return ItemTypes.COT
|
||||
|
||||
def validate_json(self, data):
|
||||
try:
|
||||
validate(instance=data, schema=self.schema)
|
||||
return True
|
||||
except Exception as e:
|
||||
try:
|
||||
validate(instance=data, schema=self.schema_list)
|
||||
return True
|
||||
except Exception as e:
|
||||
return False
|
||||
|
||||
|
||||
class StructuredFileHandlerFactory:
|
||||
def __init__(self):
|
||||
self.handlers: list[StructuredFileItemHandler] = []
|
||||
self.handlers.append(QAItemHandler())
|
||||
self.handlers.append(COTItemHandler())
|
||||
|
||||
def get_handler(self, item_type: str) -> StructuredFileItemHandler:
|
||||
for handler in self.handlers:
|
||||
|
||||
@@ -12,7 +12,7 @@ CREATE TABLE IF NOT EXISTS t_de_eval_task (
|
||||
source_id VARCHAR(36) COMMENT '待评估对象ID',
|
||||
source_name VARCHAR(255) COMMENT '待评估对象名称',
|
||||
status VARCHAR(50) DEFAULT 'PENDING' COMMENT '状态:PENDING/RUNNING/COMPLETED/STOPPED/FAILED',
|
||||
eval_method VARCHAR(50) DEFAULT 'AUTO' COMMENT '状态:AUTO/MANUAL',
|
||||
eval_method VARCHAR(50) DEFAULT 'AUTO' COMMENT '评估方式:AUTO/MANUAL',
|
||||
eval_process DOUBLE PRECISION NOT NULL DEFAULT 0 COMMENT '评估进度',
|
||||
eval_prompt TEXT COMMENT '评估提示词',
|
||||
eval_config TEXT COMMENT '评估配置',
|
||||
|
||||
Reference in New Issue
Block a user