You've already forked DataMate
fix: 修复评估时模型输出json格式不对导致读取错误的问题 (#133)
* feature: add cot data evaluation function * fix: added verification to evaluation results * fix: fix the prompt for evaluating * fix: 修复当评估结果为空导致读取失败的问题
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
// TypeScript
|
||||||
import React, { useState, useEffect } from 'react';
|
import React, { useState, useEffect } from 'react';
|
||||||
import { Button, Form, Input, Select, message, Modal, Row, Col, Table, Space } from 'antd';
|
import { Button, Form, Input, Select, message, Modal, Row, Col, Table, Space } from 'antd';
|
||||||
import { EyeOutlined } from '@ant-design/icons';
|
import { EyeOutlined } from '@ant-design/icons';
|
||||||
@@ -36,6 +37,7 @@ interface CreateTaskModalProps {
|
|||||||
|
|
||||||
const TASK_TYPES = [
|
const TASK_TYPES = [
|
||||||
{ label: 'QA评估', value: 'QA' },
|
{ label: 'QA评估', value: 'QA' },
|
||||||
|
{ label: 'COT评估', value: 'COT' },
|
||||||
];
|
];
|
||||||
|
|
||||||
const EVAL_METHODS = [
|
const EVAL_METHODS = [
|
||||||
@@ -55,7 +57,7 @@ const CreateTaskModal: React.FC<CreateTaskModalProps> = ({ visible, onCancel, on
|
|||||||
dimension: '',
|
dimension: '',
|
||||||
description: ''
|
description: ''
|
||||||
});
|
});
|
||||||
const [taskType, setTaskType] = useState<string>("QA");
|
const [taskType, setTaskType] = useState<string>(DEFAULT_TASK_TYPE);
|
||||||
const [promptTemplates, setPromptTemplates] = useState<PromptTemplate[]>([]);
|
const [promptTemplates, setPromptTemplates] = useState<PromptTemplate[]>([]);
|
||||||
const [previewVisible, setPreviewVisible] = useState(false);
|
const [previewVisible, setPreviewVisible] = useState(false);
|
||||||
const [evaluationPrompt, setEvaluationPrompt] = useState('');
|
const [evaluationPrompt, setEvaluationPrompt] = useState('');
|
||||||
@@ -82,9 +84,24 @@ const CreateTaskModal: React.FC<CreateTaskModalProps> = ({ visible, onCancel, on
|
|||||||
fetchDatasets().then();
|
fetchDatasets().then();
|
||||||
fetchModels().then();
|
fetchModels().then();
|
||||||
fetchPromptTemplates().then();
|
fetchPromptTemplates().then();
|
||||||
|
// sync form with local taskType default
|
||||||
|
form.setFieldsValue({ taskType: DEFAULT_TASK_TYPE });
|
||||||
}
|
}
|
||||||
}, [visible]);
|
}, [visible]);
|
||||||
|
|
||||||
|
// when promptTemplates or taskType change, switch dimensions to template defaults (COT/QA)
|
||||||
|
useEffect(() => {
|
||||||
|
if (!promptTemplates || promptTemplates.length === 0) return;
|
||||||
|
const template = promptTemplates.find(t => t.evalType === taskType);
|
||||||
|
if (template && template.defaultDimensions) {
|
||||||
|
setDimensions(template.defaultDimensions.map((dim: any, index: number) => ({
|
||||||
|
key: `dim-${index}`,
|
||||||
|
dimension: dim.dimension,
|
||||||
|
description: dim.description
|
||||||
|
})));
|
||||||
|
}
|
||||||
|
}, [taskType, promptTemplates]);
|
||||||
|
|
||||||
const fetchDatasets = async () => {
|
const fetchDatasets = async () => {
|
||||||
try {
|
try {
|
||||||
const { data } = await queryDatasetsUsingGet({ page: 1, size: 1000 });
|
const { data } = await queryDatasetsUsingGet({ page: 1, size: 1000 });
|
||||||
@@ -106,23 +123,39 @@ const CreateTaskModal: React.FC<CreateTaskModalProps> = ({ visible, onCancel, on
|
|||||||
};
|
};
|
||||||
|
|
||||||
const formatDimensionsForPrompt = (dimensions: Dimension[]) => {
|
const formatDimensionsForPrompt = (dimensions: Dimension[]) => {
|
||||||
let result = "\n";
|
let result = "";
|
||||||
dimensions.forEach((dim, index) => {
|
dimensions.forEach((dim, index) => {
|
||||||
result += `### ${index + 1}. ${dim.dimension}\n**评估标准:**\n${dim.description}\n\n`;
|
if (index > 0) {
|
||||||
|
result += "\n";
|
||||||
|
}
|
||||||
|
result += `### ${index + 1}. ${dim.dimension}\n**评估标准:**\n${dim.description}`;
|
||||||
|
if (index < dimensions.length - 1) {
|
||||||
|
result += "\n";
|
||||||
|
}
|
||||||
});
|
});
|
||||||
return result;
|
return result;
|
||||||
};
|
};
|
||||||
|
|
||||||
const formatResultExample = (dimensions: Dimension[]) => {
|
const formatResultExample = (dimensions: Dimension[]) => {
|
||||||
return dimensions.map(dim => `\n "${dim.dimension}": "Y",`).join('');
|
let result = "";
|
||||||
|
dimensions.forEach((dim, index) => {
|
||||||
|
if (index > 0) {
|
||||||
|
result += "\n ";
|
||||||
|
}
|
||||||
|
result += `"${dim.dimension}": "Y"`;
|
||||||
|
if (index < dimensions.length - 1) {
|
||||||
|
result += ",";
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return result;
|
||||||
};
|
};
|
||||||
|
|
||||||
const fetchPromptTemplates = async () => {
|
const fetchPromptTemplates = async () => {
|
||||||
try {
|
try {
|
||||||
const response = await queryPromptTemplatesUsingGet();
|
const response = await queryPromptTemplatesUsingGet();
|
||||||
const templates: PromptTemplate[] = response.data?.templates
|
const templates: PromptTemplate[] = response.data?.templates || [];
|
||||||
setPromptTemplates(templates)
|
setPromptTemplates(templates);
|
||||||
if (taskType) {
|
// if a template exists for current taskType, initialize dimensions (handled also by useEffect)
|
||||||
const template = templates.find(t => t.evalType === taskType);
|
const template = templates.find(t => t.evalType === taskType);
|
||||||
if (template) {
|
if (template) {
|
||||||
setDimensions(template.defaultDimensions.map((dim: any, index: number) => ({
|
setDimensions(template.defaultDimensions.map((dim: any, index: number) => ({
|
||||||
@@ -131,7 +164,6 @@ const CreateTaskModal: React.FC<CreateTaskModalProps> = ({ visible, onCancel, on
|
|||||||
description: dim.description
|
description: dim.description
|
||||||
})));
|
})));
|
||||||
}
|
}
|
||||||
}
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Error fetching prompt templates:', error);
|
console.error('Error fetching prompt templates:', error);
|
||||||
message.error('获取评估维度失败');
|
message.error('获取评估维度失败');
|
||||||
@@ -144,8 +176,11 @@ const CreateTaskModal: React.FC<CreateTaskModalProps> = ({ visible, onCancel, on
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const template = promptTemplates.find(t => t.evalType === taskType);
|
const template = promptTemplates.find(t => t.evalType === taskType);
|
||||||
setEvaluationPrompt(template?.prompt.replace("{dimensions}", formatDimensionsForPrompt(dimensions))
|
const basePrompt = template?.prompt || '';
|
||||||
.replace('{result_example}', formatResultExample(dimensions)));
|
const filled = basePrompt
|
||||||
|
.replace('{dimensions}', formatDimensionsForPrompt(dimensions))
|
||||||
|
.replace('{result_example}', formatResultExample(dimensions));
|
||||||
|
setEvaluationPrompt(filled);
|
||||||
setPreviewVisible(true);
|
setPreviewVisible(true);
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -243,6 +278,13 @@ const CreateTaskModal: React.FC<CreateTaskModalProps> = ({ visible, onCancel, on
|
|||||||
evalMethod: DEFAULT_EVAL_METHOD,
|
evalMethod: DEFAULT_EVAL_METHOD,
|
||||||
taskType: DEFAULT_TASK_TYPE,
|
taskType: DEFAULT_TASK_TYPE,
|
||||||
}}
|
}}
|
||||||
|
onValuesChange={(changed) => {
|
||||||
|
if (changed.taskType) {
|
||||||
|
setTaskType(changed.taskType);
|
||||||
|
setEvaluationPrompt('');
|
||||||
|
setPreviewVisible(false);
|
||||||
|
}
|
||||||
|
}}
|
||||||
>
|
>
|
||||||
<Row gutter={16}>
|
<Row gutter={16}>
|
||||||
<Col span={12}>
|
<Col span={12}>
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
import { useEffect, useState } from 'react';
|
import { useEffect, useState } from 'react';
|
||||||
import { Table, Typography, Button, Space, Spin, Empty, message, Tooltip } from 'antd';
|
import { Table, Typography, Button, Space, Empty, Tooltip } from 'antd';
|
||||||
import { FolderOpen, FileText, ArrowLeft } from 'lucide-react';
|
import { FolderOpen, FileText, ArrowLeft } from 'lucide-react';
|
||||||
import { queryEvaluationFilesUsingGet, queryEvaluationItemsUsingGet } from '../../evaluation.api';
|
import { queryEvaluationFilesUsingGet, queryEvaluationItemsUsingGet } from '../../evaluation.api';
|
||||||
|
import useFetchData from '@/hooks/useFetchData';
|
||||||
|
|
||||||
const { Text } = Typography;
|
const { Text } = Typography;
|
||||||
|
|
||||||
@@ -39,63 +40,52 @@ type EvalItem = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
export default function EvaluationItems({ task }: { task: any }) {
|
export default function EvaluationItems({ task }: { task: any }) {
|
||||||
const [loadingFiles, setLoadingFiles] = useState<boolean>(false);
|
|
||||||
const [files, setFiles] = useState<EvalFile[]>([]);
|
|
||||||
const [filePagination, setFilePagination] = useState({ current: 1, pageSize: 10, total: 0 });
|
|
||||||
|
|
||||||
const [selectedFile, setSelectedFile] = useState<{ fileId: string; fileName: string } | null>(null);
|
const [selectedFile, setSelectedFile] = useState<{ fileId: string; fileName: string } | null>(null);
|
||||||
const [loadingItems, setLoadingItems] = useState<boolean>(false);
|
|
||||||
const [items, setItems] = useState<EvalItem[]>([]);
|
|
||||||
const [itemPagination, setItemPagination] = useState({ current: 1, pageSize: 10, total: 0 });
|
|
||||||
|
|
||||||
// Fetch files list
|
// 文件列表数据(使用 useFetchData),pageOffset=0 表示后端分页为 1 基
|
||||||
useEffect(() => {
|
const {
|
||||||
if (!task?.id || selectedFile) return;
|
loading: loadingFiles,
|
||||||
const fetchFiles = async () => {
|
tableData: files,
|
||||||
setLoadingFiles(true);
|
pagination: filePagination,
|
||||||
try {
|
setSearchParams: setFileSearchParams,
|
||||||
const res = await queryEvaluationFilesUsingGet({ taskId: task.id, page: filePagination.current, size: filePagination.pageSize });
|
} = useFetchData<EvalFile>(
|
||||||
const data = res?.data;
|
(params) => queryEvaluationFilesUsingGet({ taskId: task?.id, ...params }),
|
||||||
const list: EvalFile[] = data?.content || [];
|
(d) => d as unknown as EvalFile,
|
||||||
setFiles(list);
|
30000,
|
||||||
setFilePagination((p) => ({ ...p, total: data?.totalElements || 0 }));
|
false,
|
||||||
} catch (e) {
|
[],
|
||||||
message.error('加载评估文件失败');
|
0
|
||||||
console.error(e);
|
);
|
||||||
} finally {
|
|
||||||
setLoadingFiles(false);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
fetchFiles();
|
|
||||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
|
||||||
}, [task?.id, filePagination.current, filePagination.pageSize, selectedFile]);
|
|
||||||
|
|
||||||
// Fetch items of selected file
|
// 评估条目数据(使用 useFetchData),依赖选中文件
|
||||||
useEffect(() => {
|
const {
|
||||||
if (!task?.id || !selectedFile) return;
|
loading: loadingItems,
|
||||||
const fetchItems = async () => {
|
tableData: items,
|
||||||
setLoadingItems(true);
|
pagination: itemPagination,
|
||||||
try {
|
setSearchParams: setItemSearchParams,
|
||||||
const res = await queryEvaluationItemsUsingGet({
|
fetchData: fetchItems,
|
||||||
taskId: task.id,
|
} = useFetchData<EvalItem>(
|
||||||
page: itemPagination.current,
|
(params) => {
|
||||||
size: itemPagination.pageSize,
|
if (!task?.id || !selectedFile?.fileId) {
|
||||||
file_id: selectedFile.fileId,
|
return Promise.resolve({ data: { content: [], totalElements: 0 } });
|
||||||
});
|
|
||||||
const data = res?.data;
|
|
||||||
const list: EvalItem[] = data?.content || [];
|
|
||||||
setItems(list);
|
|
||||||
setItemPagination((p) => ({ ...p, total: data?.totalElements || 0 }));
|
|
||||||
} catch (e) {
|
|
||||||
message.error('加载评估条目失败');
|
|
||||||
console.error(e);
|
|
||||||
} finally {
|
|
||||||
setLoadingItems(false);
|
|
||||||
}
|
}
|
||||||
};
|
return queryEvaluationItemsUsingGet({ taskId: task.id, file_id: selectedFile.fileId, ...params });
|
||||||
|
},
|
||||||
|
(d) => d as unknown as EvalItem,
|
||||||
|
30000,
|
||||||
|
false,
|
||||||
|
[],
|
||||||
|
0
|
||||||
|
);
|
||||||
|
|
||||||
|
// 当选择文件变化时,主动触发一次条目查询,避免仅依赖 searchParams 变更导致未触发
|
||||||
|
useEffect(() => {
|
||||||
|
if (task?.id && selectedFile?.fileId) {
|
||||||
|
setItemSearchParams((prev: any) => ({ ...prev, current: 1 }));
|
||||||
|
// 立即拉取一次,保证点击后立刻出现数据
|
||||||
fetchItems();
|
fetchItems();
|
||||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
}
|
||||||
}, [task?.id, selectedFile?.fileId, itemPagination.current, itemPagination.pageSize]);
|
}, [task?.id, selectedFile?.fileId]);
|
||||||
|
|
||||||
const fileColumns = [
|
const fileColumns = [
|
||||||
{
|
{
|
||||||
@@ -228,19 +218,20 @@ export default function EvaluationItems({ task }: { task: any }) {
|
|||||||
dataSource={files}
|
dataSource={files}
|
||||||
loading={loadingFiles}
|
loading={loadingFiles}
|
||||||
size="middle"
|
size="middle"
|
||||||
onRow={(record) => ({ onClick: () => setSelectedFile({ fileId: record.fileId, fileName: record.fileName }) })}
|
onRow={(record) => ({
|
||||||
pagination={{
|
onClick: () => {
|
||||||
current: filePagination.current,
|
setSelectedFile({ fileId: record.fileId, fileName: record.fileName });
|
||||||
pageSize: filePagination.pageSize,
|
// 切换文件时,重置条目表到第一页
|
||||||
total: filePagination.total,
|
setItemSearchParams((prev: any) => ({ ...prev, current: 1 }));
|
||||||
onChange: (current, pageSize) => setFilePagination({ current, pageSize, total: filePagination.total }),
|
},
|
||||||
}}
|
})}
|
||||||
|
pagination={filePagination}
|
||||||
/>
|
/>
|
||||||
) : (
|
) : (
|
||||||
<div className="flex flex-col gap-3">
|
<div className="flex flex-col gap-3">
|
||||||
<div className="sticky top-0 z-10 bg-white py-2" style={{ borderBottom: '1px solid #f0f0f0' }}>
|
<div className="sticky top-0 z-10 bg-white py-2" style={{ borderBottom: '1px solid #f0f0f0' }}>
|
||||||
<Space wrap>
|
<Space wrap>
|
||||||
<Button icon={<ArrowLeft size={16} />} onClick={() => { setSelectedFile(null); setItems([]); }}>
|
<Button icon={<ArrowLeft size={16} />} onClick={() => { setSelectedFile(null); }}>
|
||||||
返回文件列表
|
返回文件列表
|
||||||
</Button>
|
</Button>
|
||||||
<Space>
|
<Space>
|
||||||
@@ -257,12 +248,7 @@ export default function EvaluationItems({ task }: { task: any }) {
|
|||||||
dataSource={items}
|
dataSource={items}
|
||||||
loading={loadingItems}
|
loading={loadingItems}
|
||||||
size="middle"
|
size="middle"
|
||||||
pagination={{
|
pagination={itemPagination}
|
||||||
current: itemPagination.current,
|
|
||||||
pageSize: itemPagination.pageSize,
|
|
||||||
total: itemPagination.total,
|
|
||||||
onChange: (current, pageSize) => setItemPagination({ current, pageSize, total: itemPagination.total }),
|
|
||||||
}}
|
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|||||||
@@ -82,6 +82,7 @@ export default function DataEvaluationPage() {
|
|||||||
label: '任务类型',
|
label: '任务类型',
|
||||||
options: [
|
options: [
|
||||||
{ value: 'QA', label: 'QA评估' },
|
{ value: 'QA', label: 'QA评估' },
|
||||||
|
{ value: 'COT', label: 'COPT评估' },
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -89,7 +90,6 @@ export default function DataEvaluationPage() {
|
|||||||
label: '评估方式',
|
label: '评估方式',
|
||||||
options: [
|
options: [
|
||||||
{ value: 'AUTO', label: '自动评估' },
|
{ value: 'AUTO', label: '自动评估' },
|
||||||
{ value: 'MANUAL', label: '人工评估' },
|
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ class EvaluationTask(Base):
|
|||||||
source_id = Column(String(36), nullable=True, comment="待评估对象ID")
|
source_id = Column(String(36), nullable=True, comment="待评估对象ID")
|
||||||
source_name = Column(String(255), nullable=True, comment="待评估对象名称")
|
source_name = Column(String(255), nullable=True, comment="待评估对象名称")
|
||||||
status = Column(String(50), server_default="PENDING", nullable=False, comment="状态:PENDING/RUNNING/COMPLETED/STOPPED/FAILED")
|
status = Column(String(50), server_default="PENDING", nullable=False, comment="状态:PENDING/RUNNING/COMPLETED/STOPPED/FAILED")
|
||||||
|
eval_method = Column(String(50), server_default="AUTO", nullable=False, comment="评估方式:AUTO/MANUAL")
|
||||||
eval_process = Column(Float, nullable=False, server_default="0", comment="评估进度")
|
eval_process = Column(Float, nullable=False, server_default="0", comment="评估进度")
|
||||||
eval_prompt = Column(Text, nullable=True, comment="评估提示词")
|
eval_prompt = Column(Text, nullable=True, comment="评估提示词")
|
||||||
eval_config = Column(Text, nullable=True, comment="评估配置")
|
eval_config = Column(Text, nullable=True, comment="评估配置")
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import math
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
@@ -116,7 +117,7 @@ class Service:
|
|||||||
for f in files
|
for f in files
|
||||||
]
|
]
|
||||||
|
|
||||||
total_pages = (total + size - 1) // size if size > 0 else 0
|
total_pages = math.ceil(total / size) if total > 0 else 0
|
||||||
|
|
||||||
return PagedDatasetFileResponse(
|
return PagedDatasetFileResponse(
|
||||||
content=content,
|
content=content,
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import uuid
|
import uuid
|
||||||
|
import math
|
||||||
import json
|
import json
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks
|
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks
|
||||||
@@ -171,7 +172,7 @@ async def list_evaluation_tasks(
|
|||||||
|
|
||||||
# 转换为响应模型
|
# 转换为响应模型
|
||||||
items = [_map_to_task_detail_response(task) for task in tasks]
|
items = [_map_to_task_detail_response(task) for task in tasks]
|
||||||
total_pages = (total + size - 1) // size if size > 0 else 0
|
total_pages = math.ceil(total / size) if total > 0 else 0
|
||||||
|
|
||||||
return StandardResponse(
|
return StandardResponse(
|
||||||
code=200,
|
code=200,
|
||||||
@@ -217,7 +218,7 @@ async def list_evaluation_items(
|
|||||||
count_query = select(func.count()).select_from(query.subquery())
|
count_query = select(func.count()).select_from(query.subquery())
|
||||||
total = (await db.execute(count_query)).scalar_one()
|
total = (await db.execute(count_query)).scalar_one()
|
||||||
files = (await db.execute(query.offset(offset).limit(size))).scalars().all()
|
files = (await db.execute(query.offset(offset).limit(size))).scalars().all()
|
||||||
total_pages = (total + size - 1) // size if size > 0 else 0
|
total_pages = math.ceil(total / size) if total > 0 else 0
|
||||||
file_responses = [
|
file_responses = [
|
||||||
EvaluationFileResponse(
|
EvaluationFileResponse(
|
||||||
taskId=file.task_id,
|
taskId=file.task_id,
|
||||||
@@ -298,7 +299,7 @@ async def list_evaluation_items(
|
|||||||
taskId=item.task_id,
|
taskId=item.task_id,
|
||||||
itemId=item.item_id,
|
itemId=item.item_id,
|
||||||
fileId=item.file_id,
|
fileId=item.file_id,
|
||||||
evalContent=json.loads(item.eval_content),
|
evalContent=json.loads(item.eval_content) if item.eval_content else None,
|
||||||
evalScore=float(item.eval_score) if item.eval_score else None,
|
evalScore=float(item.eval_score) if item.eval_score else None,
|
||||||
evalResult=json.loads(item.eval_result),
|
evalResult=json.loads(item.eval_result),
|
||||||
status=item.status
|
status=item.status
|
||||||
@@ -306,7 +307,7 @@ async def list_evaluation_items(
|
|||||||
for item in items
|
for item in items
|
||||||
]
|
]
|
||||||
|
|
||||||
total_pages = (total + size - 1) // size if size > 0 else 0
|
total_pages = math.ceil(total / size) if total > 0 else 0
|
||||||
|
|
||||||
return StandardResponse(
|
return StandardResponse(
|
||||||
code=200,
|
code=200,
|
||||||
@@ -387,6 +388,12 @@ async def delete_eval_tasks(
|
|||||||
.where(EvaluationItem.task_id == task_id)
|
.where(EvaluationItem.task_id == task_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 删除评估文件
|
||||||
|
await db.execute(
|
||||||
|
EvaluationFile.__table__.delete()
|
||||||
|
.where(EvaluationFile.task_id == task_id)
|
||||||
|
)
|
||||||
|
|
||||||
# 删除任务
|
# 删除任务
|
||||||
await db.delete(task)
|
await db.delete(task)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
@@ -419,6 +426,7 @@ def _map_to_task_detail_response(
|
|||||||
sourceId=task.source_id,
|
sourceId=task.source_id,
|
||||||
sourceName=task.source_name,
|
sourceName=task.source_name,
|
||||||
status=task.status,
|
status=task.status,
|
||||||
|
evalMethod=task.eval_method,
|
||||||
evalProcess=task.eval_process,
|
evalProcess=task.eval_process,
|
||||||
evalPrompt=task.eval_prompt,
|
evalPrompt=task.eval_prompt,
|
||||||
evalConfig=json.loads(task.eval_config),
|
evalConfig=json.loads(task.eval_config),
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ class EvaluationTaskItem(BaseModel):
|
|||||||
source_id: Optional[str] = Field(..., alias="sourceId", description="数据源ID")
|
source_id: Optional[str] = Field(..., alias="sourceId", description="数据源ID")
|
||||||
source_name: Optional[str] = Field(None, alias="sourceName", description="数据源名称")
|
source_name: Optional[str] = Field(None, alias="sourceName", description="数据源名称")
|
||||||
status: TaskStatus = Field(..., description="任务状态")
|
status: TaskStatus = Field(..., description="任务状态")
|
||||||
|
eval_method: Optional[str] = Field(None, alias="evalMethod", description="评估方式")
|
||||||
eval_process: Optional[float] = Field(0, alias="evalProcess", description="评估进度")
|
eval_process: Optional[float] = Field(0, alias="evalProcess", description="评估进度")
|
||||||
created_at: Optional[str] = Field(None, alias="createdAt", description="创建时间")
|
created_at: Optional[str] = Field(None, alias="createdAt", description="创建时间")
|
||||||
updated_at: Optional[str] = Field(None, alias="updatedAt", description="更新时间")
|
updated_at: Optional[str] = Field(None, alias="updatedAt", description="更新时间")
|
||||||
|
|||||||
@@ -1,3 +1,7 @@
|
|||||||
|
from app.core.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
EVALUATION_PROMPT_TEMPLATE = [
|
EVALUATION_PROMPT_TEMPLATE = [
|
||||||
{
|
{
|
||||||
"evalType": "QA",
|
"evalType": "QA",
|
||||||
@@ -51,26 +55,90 @@ EVALUATION_PROMPT_TEMPLATE = [
|
|||||||
请按照以下JSON格式输出评估结果,评估结果为Y/N,符合标注输出Y,不符合标准输出N:
|
请按照以下JSON格式输出评估结果,评估结果为Y/N,符合标注输出Y,不符合标准输出N:
|
||||||
|
|
||||||
{
|
{
|
||||||
"result": {{result_example}
|
"result": {
|
||||||
|
{result_example}
|
||||||
},
|
},
|
||||||
"evaluation": "这是一个高质量的问答数据集。问题表述清晰具体,答案准确完整且逻辑性强,与原始文本高度相关。建议:可以进一步丰富答案的细节描述。"
|
"evaluation": "这是一个高质量的问答数据集。问题表述清晰具体,答案准确完整且逻辑性强,与原始文本高度相关。建议:可以进一步丰富答案的细节描述。"
|
||||||
}
|
}
|
||||||
|
"""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"evalType": "COT",
|
||||||
|
"defaultDimensions": [
|
||||||
|
{
|
||||||
|
"dimension": "思维链逻辑是否连贯",
|
||||||
|
"description": "分析思维链中推理链条的连续性:步骤间有明确的逻辑连接词;每一步都是基于前置在步骤的结果;没有逻辑跳跃或断层;推理方向一致,不偏离目标。"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dimension": "推理步骤是否合理必要",
|
||||||
|
"description": "分析思维链中对于步骤分解的合理性和必要性:复杂问题被适当分解; 每个步骤都是解决整体问题的必要部分;步骤粒度适中(既不过细也不过粗);符合人类认知习惯。"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dimension": "内容是否准确",
|
||||||
|
"description": "分析整个COT数据内容是否准确:所有陈述的事实必须准确;展示每一步的计算结果(如何涉及数学计算,必须保证数学计算无错误);逻辑推导有效且合理,最终答案与推理过程一致。"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"prompt": """
|
||||||
|
# Role: COT数据质量评估专家
|
||||||
|
## Profile:
|
||||||
|
- Description: 你是一名专业的Chain-of-Thought(CoT)推理数据质量评估专家,擅长从多个维度对COT数据进行质量评估,挑选出有助于模型学习如何分解问题、展示推理链条,提高模型对于复杂问题解决能力的COT数据。具备深度学习、自然语言处理和数据科学的专业背景。
|
||||||
|
|
||||||
|
## Skills:
|
||||||
|
1. 能够从多个维度对COT数据进行综合评估,保证客观、专业、细致
|
||||||
|
2. 擅长识别COT数据中的潜在问题,如推包含事实性错误(关键信息错误),存在严重逻辑矛(无法自洽),包含有害、偏见或不当内容,完全偏离主题,抄袭或高度重复内容等
|
||||||
|
3. 能够给出具体的改进建议和质量评分,并提供可操作的优化方案
|
||||||
|
|
||||||
|
## 评估维度:
|
||||||
|
{dimensions}
|
||||||
|
|
||||||
|
## 问题或指令:
|
||||||
|
{question}
|
||||||
|
|
||||||
|
## 思维链:
|
||||||
|
{chain_of_thought}
|
||||||
|
|
||||||
|
## 结论:
|
||||||
|
{conclusion}
|
||||||
|
|
||||||
|
## 注意事项:
|
||||||
|
- 评估结论要具体指出优点和不足,提供可操作的改进建议
|
||||||
|
- 评估结论控制在150字以内,简洁明了但要涵盖关键信息
|
||||||
|
|
||||||
|
## 输出要求:
|
||||||
|
请按照以下JSON格式输出评估结果,评估结果为Y/N,符合标注输出Y,不符合标准输出N;将评估结论写到evaluation中:
|
||||||
|
|
||||||
|
{
|
||||||
|
"result": {
|
||||||
|
{result_example}
|
||||||
|
},
|
||||||
|
"evaluation": "这是一个高质量的COT数据。思维链逻辑连贯,推理步骤合理,信息完整。建议:部分表达可以进一步优化,以及个别步骤的过渡可以更加平滑。"
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_dimensions_for_qa(dimensions: list[dict]) -> str:
|
def get_dimensions_for_qa(dimensions: list[dict]) -> str:
|
||||||
dimensions_str = "\n"
|
dimensions_str = ""
|
||||||
index = 1
|
index = 1
|
||||||
for dimension in dimensions:
|
for dimension in dimensions:
|
||||||
dimensions_str += f"### {index}. {dimension.get("dimension")}\n**评估标准:**\n{dimension.get("description")}\n\n"
|
if index > 1:
|
||||||
|
dimensions_str += "\n"
|
||||||
|
dimensions_str += f"### {index}. {dimension.get("dimension")}\n**评估标准:**\n{dimension.get("description")}"
|
||||||
|
if index < len(dimensions):
|
||||||
|
dimensions_str += "\n"
|
||||||
index += 1
|
index += 1
|
||||||
return dimensions_str
|
return dimensions_str
|
||||||
|
|
||||||
def get_result_example_for_qa(dimensions: list[dict]) -> str:
|
def get_result_example_for_qa(dimensions: list[dict]) -> str:
|
||||||
result_example = ""
|
result_example = ""
|
||||||
|
index = 1
|
||||||
for dimension in dimensions:
|
for dimension in dimensions:
|
||||||
result_example += f'\n "{dimension.get("dimension")}": "Y",'
|
if index > 1:
|
||||||
|
result_example += "\n "
|
||||||
|
result_example += f'"{dimension.get("dimension")}": "Y"'
|
||||||
|
if index < len(dimensions):
|
||||||
|
result_example += ","
|
||||||
|
index += 1
|
||||||
return result_example
|
return result_example
|
||||||
|
|
||||||
def get_prompt(task_type: str, dimensions: list[dict]) -> str:
|
def get_prompt(task_type: str, dimensions: list[dict]) -> str:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Schema for evaluation prompt templates.
|
Schema for evaluation prompt templates.
|
||||||
"""
|
"""
|
||||||
from typing import List, Dict, Any
|
from typing import List
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import json
|
|||||||
import uuid
|
import uuid
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select, func
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.core.exception import BusinessErrorCodeEnum, BusinessException
|
from app.core.exception import BusinessErrorCodeEnum, BusinessException
|
||||||
@@ -13,7 +13,7 @@ from app.db.models.data_synthesis import DataSynthesisFileInstance, SynthesisDat
|
|||||||
from app.db.session import AsyncSessionLocal
|
from app.db.session import AsyncSessionLocal
|
||||||
from app.module.evaluation.schema.evaluation import SourceType
|
from app.module.evaluation.schema.evaluation import SourceType
|
||||||
from app.module.shared.schema import TaskStatus
|
from app.module.shared.schema import TaskStatus
|
||||||
from app.module.shared.util.model_chat import call_openai_style_model
|
from app.module.shared.util.model_chat import call_openai_style_model, _extract_json_substring
|
||||||
from app.module.evaluation.schema.prompt import get_prompt
|
from app.module.evaluation.schema.prompt import get_prompt
|
||||||
from app.module.shared.util.structured_file import StructuredFileHandlerFactory
|
from app.module.shared.util.structured_file import StructuredFileHandlerFactory
|
||||||
from app.module.system.service.common_service import get_model_by_id
|
from app.module.system.service.common_service import get_model_by_id
|
||||||
@@ -35,6 +35,10 @@ class EvaluationExecutor:
|
|||||||
prompt_text = ((prompt_text.replace("{content}", eval_content.get("input"))
|
prompt_text = ((prompt_text.replace("{content}", eval_content.get("input"))
|
||||||
.replace("{question}", eval_content.get("instruction")))
|
.replace("{question}", eval_content.get("instruction")))
|
||||||
.replace("{answer}", eval_content.get("output")))
|
.replace("{answer}", eval_content.get("output")))
|
||||||
|
if self.task.task_type == "COT":
|
||||||
|
prompt_text = ((prompt_text.replace("{question}", eval_content.get("question"))
|
||||||
|
.replace("{conclusion}", eval_content.get("conclusion")))
|
||||||
|
.replace("{chain_of_thought}", eval_content.get("chain_of_thought")))
|
||||||
return prompt_text
|
return prompt_text
|
||||||
|
|
||||||
async def execute(self):
|
async def execute(self):
|
||||||
@@ -44,29 +48,44 @@ class EvaluationExecutor:
|
|||||||
files = (await self.db.execute(
|
files = (await self.db.execute(
|
||||||
select(EvaluationFile).where(EvaluationFile.task_id == self.task.id)
|
select(EvaluationFile).where(EvaluationFile.task_id == self.task.id)
|
||||||
)).scalars().all()
|
)).scalars().all()
|
||||||
|
query = select(EvaluationItem).where(EvaluationItem.task_id == self.task.id)
|
||||||
|
count_query = select(func.count()).select_from(query.subquery())
|
||||||
|
total = (await self.db.execute(count_query)).scalar_one()
|
||||||
|
evaluated_count = 0
|
||||||
for file in files:
|
for file in files:
|
||||||
items = (await self.db.execute(
|
items = (await self.db.execute(query.where(EvaluationItem.file_id == file.file_id))).scalars().all()
|
||||||
select(EvaluationItem).where(EvaluationItem.task_id == self.task.id)
|
|
||||||
.where(EvaluationItem.file_id == file.file_id)
|
|
||||||
)).scalars().all()
|
|
||||||
tasks = [
|
tasks = [
|
||||||
self.evaluate_item(model_config, item, semaphore)
|
self.evaluate_item(model_config, item, semaphore)
|
||||||
for item in items
|
for item in items
|
||||||
]
|
]
|
||||||
await asyncio.gather(*tasks, return_exceptions=True)
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
file.evaluated_count = len(items)
|
file.evaluated_count = len(items)
|
||||||
|
evaluated_count += file.evaluated_count
|
||||||
|
self.task.eval_process = evaluated_count / total
|
||||||
await self.db.commit()
|
await self.db.commit()
|
||||||
|
|
||||||
async def evaluate_item(self, model_config, item: EvaluationItem, semaphore: asyncio.Semaphore):
|
async def evaluate_item(self, model_config, item: EvaluationItem, semaphore: asyncio.Semaphore):
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
|
max_try = 3
|
||||||
|
while max_try > 0:
|
||||||
prompt_text = self.get_eval_prompt(item)
|
prompt_text = self.get_eval_prompt(item)
|
||||||
resp_text = await asyncio.to_thread(
|
resp_text = await asyncio.to_thread(
|
||||||
call_openai_style_model, model_config.base_url, model_config.api_key, model_config.model_name,
|
call_openai_style_model, model_config.base_url, model_config.api_key, model_config.model_name,
|
||||||
prompt_text,
|
prompt_text,
|
||||||
)
|
)
|
||||||
|
resp_text = _extract_json_substring(resp_text)
|
||||||
|
try:
|
||||||
|
json.loads(resp_text)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to parse LLM answer as JSON for task={self.task.id}, file={item.file_id}: {e}. Raw answer: {resp_text!r}"
|
||||||
|
)
|
||||||
|
max_try -= 1
|
||||||
|
continue
|
||||||
item.eval_result = resp_text
|
item.eval_result = resp_text
|
||||||
item.status = TaskStatus.COMPLETED.value
|
item.status = TaskStatus.COMPLETED.value
|
||||||
await self.db.commit()
|
await self.db.commit()
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
def get_source_type(self) -> SourceType:
|
def get_source_type(self) -> SourceType:
|
||||||
@@ -119,7 +138,7 @@ class SynthesisEvaluationExecutor(EvaluationExecutor):
|
|||||||
|
|
||||||
async def save_eval_items(self):
|
async def save_eval_items(self):
|
||||||
synthesis_files = ((await self.db.execute(select(DataSynthesisFileInstance)
|
synthesis_files = ((await self.db.execute(select(DataSynthesisFileInstance)
|
||||||
.where(DataSynthesisFileInstance.task_id == self.task.source_id)))
|
.where(DataSynthesisFileInstance.synthesis_instance_id == self.task.source_id)))
|
||||||
.scalars().all())
|
.scalars().all())
|
||||||
for synthesis_file in synthesis_files:
|
for synthesis_file in synthesis_files:
|
||||||
synthesis_datas = ((await self.db.execute(select(SynthesisData)
|
synthesis_datas = ((await self.db.execute(select(SynthesisData)
|
||||||
@@ -132,7 +151,7 @@ class SynthesisEvaluationExecutor(EvaluationExecutor):
|
|||||||
task_id=self.task.id,
|
task_id=self.task.id,
|
||||||
file_id=synthesis_file.id,
|
file_id=synthesis_file.id,
|
||||||
item_id=synthesis_data.id,
|
item_id=synthesis_data.id,
|
||||||
eval_content=synthesis_data.data,
|
eval_content=json.dumps(synthesis_data.data),
|
||||||
status=TaskStatus.PENDING.value,
|
status=TaskStatus.PENDING.value,
|
||||||
created_by=self.task.created_by,
|
created_by=self.task.created_by,
|
||||||
updated_by=self.task.updated_by,
|
updated_by=self.task.updated_by,
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from app.db.models.data_synthesis import (
|
|||||||
from app.db.models.dataset_management import DatasetFiles
|
from app.db.models.dataset_management import DatasetFiles
|
||||||
from app.db.models.model_config import get_model_by_id
|
from app.db.models.model_config import get_model_by_id
|
||||||
from app.db.session import logger
|
from app.db.session import logger
|
||||||
|
from app.module.shared.util.model_chat import _extract_json_substring
|
||||||
from app.module.system.service.common_service import get_chat_client, chat
|
from app.module.system.service.common_service import get_chat_client, chat
|
||||||
|
|
||||||
|
|
||||||
@@ -365,7 +366,7 @@ class GenerationService:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# 1. 预处理原始回答:尝试从中截取出最可能的 JSON 片段
|
# 1. 预处理原始回答:尝试从中截取出最可能的 JSON 片段
|
||||||
cleaned = self._extract_json_substring(raw_answer)
|
cleaned = _extract_json_substring(raw_answer)
|
||||||
|
|
||||||
# 2. 解析 JSON,统一成列表结构
|
# 2. 解析 JSON,统一成列表结构
|
||||||
try:
|
try:
|
||||||
@@ -426,45 +427,6 @@ class GenerationService:
|
|||||||
await self.db.commit()
|
await self.db.commit()
|
||||||
await self.db.refresh(file_instance)
|
await self.db.refresh(file_instance)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _extract_json_substring(raw: str) -> str:
|
|
||||||
"""从 LLM 的原始回答中提取最可能的 JSON 字符串片段。
|
|
||||||
|
|
||||||
处理思路:
|
|
||||||
- 原始回答可能是:说明文字 + JSON + 说明文字,甚至带有 Markdown 代码块。
|
|
||||||
- 优先在文本中查找第一个 '{' 或 '[' 作为 JSON 起始;
|
|
||||||
- 再从后向前找最后一个 '}' 或 ']' 作为结束;
|
|
||||||
- 如果找不到合适的边界,就退回原始字符串。
|
|
||||||
该方法不会保证截取的一定是合法 JSON,但能显著提高 json.loads 的成功率。
|
|
||||||
"""
|
|
||||||
if not raw:
|
|
||||||
return raw
|
|
||||||
|
|
||||||
start = None
|
|
||||||
end = None
|
|
||||||
|
|
||||||
# 查找第一个 JSON 起始符号
|
|
||||||
for i, ch in enumerate(raw):
|
|
||||||
if ch in "[{":
|
|
||||||
start = i
|
|
||||||
break
|
|
||||||
|
|
||||||
# 查找最后一个 JSON 结束符号
|
|
||||||
for i in range(len(raw) - 1, -1, -1):
|
|
||||||
if raw[i] in "]}":
|
|
||||||
end = i + 1 # 切片是左闭右开
|
|
||||||
break
|
|
||||||
|
|
||||||
if start is not None and end is not None and start < end:
|
|
||||||
return raw[start:end].strip()
|
|
||||||
|
|
||||||
# 兜底:去掉常见 Markdown 包裹(```json ... ```)
|
|
||||||
stripped = raw.strip()
|
|
||||||
if stripped.startswith("```"):
|
|
||||||
# 去掉首尾 ``` 标记
|
|
||||||
stripped = stripped.strip("`")
|
|
||||||
return stripped
|
|
||||||
|
|
||||||
async def _get_or_create_file_instance(
|
async def _get_or_create_file_instance(
|
||||||
self,
|
self,
|
||||||
synthesis_task_id: str,
|
synthesis_task_id: str,
|
||||||
|
|||||||
@@ -13,3 +13,41 @@ def call_openai_style_model(base_url, api_key, model_name, prompt, **kwargs):
|
|||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
return response.choices[0].message.content
|
return response.choices[0].message.content
|
||||||
|
|
||||||
|
def _extract_json_substring(raw: str) -> str:
|
||||||
|
"""从 LLM 的原始回答中提取最可能的 JSON 字符串片段。
|
||||||
|
|
||||||
|
处理思路:
|
||||||
|
- 原始回答可能是:说明文字 + JSON + 说明文字,甚至带有 Markdown 代码块。
|
||||||
|
- 优先在文本中查找第一个 '{' 或 '[' 作为 JSON 起始;
|
||||||
|
- 再从后向前找最后一个 '}' 或 ']' 作为结束;
|
||||||
|
- 如果找不到合适的边界,就退回原始字符串。
|
||||||
|
该方法不会保证截取的一定是合法 JSON,但能显著提高 json.loads 的成功率。
|
||||||
|
"""
|
||||||
|
if not raw:
|
||||||
|
return raw
|
||||||
|
|
||||||
|
start = None
|
||||||
|
end = None
|
||||||
|
|
||||||
|
# 查找第一个 JSON 起始符号
|
||||||
|
for i, ch in enumerate(raw):
|
||||||
|
if ch in "[{":
|
||||||
|
start = i
|
||||||
|
break
|
||||||
|
|
||||||
|
# 查找最后一个 JSON 结束符号
|
||||||
|
for i in range(len(raw) - 1, -1, -1):
|
||||||
|
if raw[i] in "]}":
|
||||||
|
end = i + 1 # 切片是左闭右开
|
||||||
|
break
|
||||||
|
|
||||||
|
if start is not None and end is not None and start < end:
|
||||||
|
return raw[start:end].strip()
|
||||||
|
|
||||||
|
# 兜底:去掉常见 Markdown 包裹(```json ... ```)
|
||||||
|
stripped = raw.strip()
|
||||||
|
if stripped.startswith("```"):
|
||||||
|
# 去掉首尾 ``` 标记
|
||||||
|
stripped = stripped.strip("`")
|
||||||
|
return stripped
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from jsonschema import validate
|
|||||||
|
|
||||||
class ItemTypes(Enum):
|
class ItemTypes(Enum):
|
||||||
QA = "QA"
|
QA = "QA"
|
||||||
|
COT = "COT"
|
||||||
|
|
||||||
|
|
||||||
class StructuredFileItemHandler:
|
class StructuredFileItemHandler:
|
||||||
@@ -14,11 +15,26 @@ class StructuredFileItemHandler:
|
|||||||
def get_item_type(self) -> ItemTypes:
|
def get_item_type(self) -> ItemTypes:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_items_from_file(self, file_path: str) -> list[dict]:
|
def validate_json(self, data):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def check_file(self) -> bool:
|
def get_items_from_file(self, file_path: str) -> list[dict]:
|
||||||
pass
|
file_type = file_path.split(".")[-1].upper()
|
||||||
|
items = []
|
||||||
|
if file_type == "JSON":
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
if not self.validate_json(data):
|
||||||
|
return items
|
||||||
|
items = data
|
||||||
|
elif file_type == "JSONL":
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
data = json.loads(line)
|
||||||
|
if not self.validate_json(data):
|
||||||
|
continue
|
||||||
|
items.append(data)
|
||||||
|
return items
|
||||||
|
|
||||||
class QAItemHandler(StructuredFileItemHandler):
|
class QAItemHandler(StructuredFileItemHandler):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -51,32 +67,44 @@ class QAItemHandler(StructuredFileItemHandler):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_items_from_file(self, file_path: str) -> list[dict]:
|
|
||||||
file_type = file_path.split(".")[-1].upper()
|
|
||||||
items = []
|
|
||||||
if file_type == "JSON":
|
|
||||||
with open(file_path, "r", encoding="utf-8") as f:
|
|
||||||
data = json.load(f)
|
|
||||||
if not self.validate_json(data):
|
|
||||||
return items
|
|
||||||
items = data
|
|
||||||
elif file_type == "JSONL":
|
|
||||||
with open(file_path, "r", encoding="utf-8") as f:
|
|
||||||
for line in f:
|
|
||||||
data = json.loads(line)
|
|
||||||
if not self.validate_json(data):
|
|
||||||
continue
|
|
||||||
items.append(data)
|
|
||||||
return items
|
|
||||||
|
|
||||||
def check_file(self) -> bool:
|
class COTItemHandler(StructuredFileItemHandler):
|
||||||
pass
|
def __init__(self):
|
||||||
|
self.schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"question": {"type": "string"},
|
||||||
|
"conclusion": {"type": "string"},
|
||||||
|
"chain_of_thought": {"type": "string"}
|
||||||
|
},
|
||||||
|
"required": ["question", "conclusion", "chain_of_thought"],
|
||||||
|
}
|
||||||
|
self.schema_list = {
|
||||||
|
"type": "array",
|
||||||
|
"items": self.schema,
|
||||||
|
}
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def get_item_type(self):
|
||||||
|
return ItemTypes.COT
|
||||||
|
|
||||||
|
def validate_json(self, data):
|
||||||
|
try:
|
||||||
|
validate(instance=data, schema=self.schema)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
try:
|
||||||
|
validate(instance=data, schema=self.schema_list)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class StructuredFileHandlerFactory:
|
class StructuredFileHandlerFactory:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.handlers: list[StructuredFileItemHandler] = []
|
self.handlers: list[StructuredFileItemHandler] = []
|
||||||
self.handlers.append(QAItemHandler())
|
self.handlers.append(QAItemHandler())
|
||||||
|
self.handlers.append(COTItemHandler())
|
||||||
|
|
||||||
def get_handler(self, item_type: str) -> StructuredFileItemHandler:
|
def get_handler(self, item_type: str) -> StructuredFileItemHandler:
|
||||||
for handler in self.handlers:
|
for handler in self.handlers:
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ CREATE TABLE IF NOT EXISTS t_de_eval_task (
|
|||||||
source_id VARCHAR(36) COMMENT '待评估对象ID',
|
source_id VARCHAR(36) COMMENT '待评估对象ID',
|
||||||
source_name VARCHAR(255) COMMENT '待评估对象名称',
|
source_name VARCHAR(255) COMMENT '待评估对象名称',
|
||||||
status VARCHAR(50) DEFAULT 'PENDING' COMMENT '状态:PENDING/RUNNING/COMPLETED/STOPPED/FAILED',
|
status VARCHAR(50) DEFAULT 'PENDING' COMMENT '状态:PENDING/RUNNING/COMPLETED/STOPPED/FAILED',
|
||||||
eval_method VARCHAR(50) DEFAULT 'AUTO' COMMENT '状态:AUTO/MANUAL',
|
eval_method VARCHAR(50) DEFAULT 'AUTO' COMMENT '评估方式:AUTO/MANUAL',
|
||||||
eval_process DOUBLE PRECISION NOT NULL DEFAULT 0 COMMENT '评估进度',
|
eval_process DOUBLE PRECISION NOT NULL DEFAULT 0 COMMENT '评估进度',
|
||||||
eval_prompt TEXT COMMENT '评估提示词',
|
eval_prompt TEXT COMMENT '评估提示词',
|
||||||
eval_config TEXT COMMENT '评估配置',
|
eval_config TEXT COMMENT '评估配置',
|
||||||
|
|||||||
Reference in New Issue
Block a user