fix: 修复评估时模型输出json格式不对导致读取错误的问题 (#133)

* feature: add cot data evaluation function

* fix: added verification to evaluation results

* fix: fix the prompt for evaluating

* fix: 修复当评估结果为空导致读取失败的问题
This commit is contained in:
hefanli
2025-12-04 18:49:50 +08:00
committed by GitHub
parent 31c4966608
commit 744d15ba24
14 changed files with 373 additions and 219 deletions

View File

@@ -1,3 +1,4 @@
import math
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import func
@@ -14,11 +15,11 @@ logger = get_logger(__name__)
class Service:
"""数据管理服务客户端 - 直接访问数据库"""
def __init__(self, db: AsyncSession):
"""
初始化 DM 客户端
Args:
db: 数据库会话
"""
@@ -29,16 +30,16 @@ class Service:
"""获取数据集详情"""
try:
logger.debug(f"Getting dataset detail: {dataset_id} ...")
result = await self.db.execute(
select(Dataset).where(Dataset.id == dataset_id)
)
dataset = result.scalar_one_or_none()
if not dataset:
logger.error(f"Dataset not found: {dataset_id}")
return None
# 将数据库模型转换为响应模型
# type: ignore 用于忽略 SQLAlchemy 的类型检查问题
return DatasetResponse(
@@ -56,11 +57,11 @@ class Service:
except Exception as e:
logger.error(f"Failed to get dataset {dataset_id}: {e}")
return None
async def get_dataset_files(
self,
dataset_id: str,
page: int = 0,
self,
dataset_id: str,
page: int = 0,
size: int = 100,
file_type: Optional[str] = None,
status: Optional[str] = None
@@ -68,16 +69,16 @@ class Service:
"""获取数据集文件列表"""
try:
logger.debug(f"Get dataset files: dataset={dataset_id}, page={page}, size={size}")
# 构建查询
query = select(DatasetFiles).where(DatasetFiles.dataset_id == dataset_id)
# 添加可选过滤条件
if file_type:
query = query.where(DatasetFiles.file_type == file_type)
if status:
query = query.where(DatasetFiles.status == status)
# 获取总数
count_query = select(func.count()).select_from(DatasetFiles).where(
DatasetFiles.dataset_id == dataset_id
@@ -86,15 +87,15 @@ class Service:
count_query = count_query.where(DatasetFiles.file_type == file_type)
if status:
count_query = count_query.where(DatasetFiles.status == status)
count_result = await self.db.execute(count_query)
total = count_result.scalar_one()
# 分页查询
query = query.offset(page * size).limit(size).order_by(DatasetFiles.created_at.desc())
result = await self.db.execute(query)
files = result.scalars().all()
# 转换为响应模型
# type: ignore 用于忽略 SQLAlchemy 的类型检查问题
content = [
@@ -115,9 +116,9 @@ class Service:
)
for f in files
]
total_pages = (total + size - 1) // size if size > 0 else 0
total_pages = math.ceil(total / size) if total > 0 else 0
return PagedDatasetFileResponse(
content=content,
totalElements=total,
@@ -128,7 +129,7 @@ class Service:
except Exception as e:
logger.error(f"Failed to get dataset files for {dataset_id}: {e}")
return None
async def download_file(self, dataset_id: str, file_id: str) -> Optional[bytes]:
"""
下载文件内容
@@ -136,7 +137,7 @@ class Service:
"""
logger.warning(f"download_file is deprecated when using database mode. Use get_file_download_url instead.")
return None
async def get_file_download_url(self, dataset_id: str, file_id: str) -> Optional[str]:
"""获取文件下载URL(或文件路径)"""
try:
@@ -147,60 +148,60 @@ class Service:
)
)
file = result.scalar_one_or_none()
if not file:
logger.error(f"File not found: {file_id} in dataset {dataset_id}")
return None
# 返回文件路径(可以是本地路径或对象存储URL)
return file.file_path # type: ignore
except Exception as e:
logger.error(f"Failed to get file path for {file_id}: {e}")
return None
async def close(self):
"""关闭客户端连接(数据库模式下无需操作)"""
logger.info("DM service client closed (Database mode)")
async def update_file_tags_partial(
self,
file_id: str,
self,
file_id: str,
new_tags: List[Dict[str, Any]],
template_id: Optional[str] = None
) -> tuple[bool, Optional[str], Optional[datetime]]:
"""
部分更新文件标签,支持自动格式转换
如果提供了 template_id,会自动将简化格式的标签转换为完整格式。
简化格式: {"from_name": "x", "to_name": "y", "values": [...]}
完整格式: {"id": "...", "from_name": "x", "to_name": "y", "type": "...", "value": {"type": [...]}}
Args:
file_id: 文件ID
new_tags: 新的标签列表(部分更新),可以是简化格式或完整格式
template_id: 可选的模板ID,用于格式转换
Returns:
(成功标志, 错误信息, 更新时间)
"""
try:
logger.info(f"Partial updating tags for file: {file_id}")
# 获取文件记录
result = await self.db.execute(
select(DatasetFiles).where(DatasetFiles.id == file_id)
)
file_record = result.scalar_one_or_none()
if not file_record:
logger.error(f"File not found: {file_id}")
return False, f"File not found: {file_id}", None
# 如果提供了 template_id,尝试进行格式转换
processed_tags = new_tags
if template_id:
logger.debug(f"Converting tags using template: {template_id}")
try:
# 获取模板配置
from app.db.models import AnnotationTemplate
@@ -211,29 +212,29 @@ class Service:
)
)
template = template_result.scalar_one_or_none()
if not template:
logger.warning(f"Template {template_id} not found, skipping conversion")
else:
# 使用 converter 转换标签格式
from app.module.annotation.utils import create_converter_from_template_config
converter = create_converter_from_template_config(template.configuration) # type: ignore
processed_tags = converter.convert_if_needed(new_tags)
logger.info(f"Converted {len(new_tags)} tags to full format")
except Exception as e:
logger.error(f"Failed to convert tags using template: {e}")
# 继续使用原始标签格式
logger.warning("Continuing with original tag format")
# 获取现有标签
existing_tags: List[Dict[str, Any]] = file_record.tags or [] # type: ignore
# 创建标签ID到索引的映射
tag_id_map = {tag.get('id'): idx for idx, tag in enumerate(existing_tags) if tag.get('id')}
# 更新或追加标签
for new_tag in processed_tags:
tag_id = new_tag.get('id')
@@ -246,19 +247,19 @@ class Service:
# 追加新标签
existing_tags.append(new_tag)
logger.debug(f"Added new tag with id: {tag_id}")
# 更新数据库
update_time = datetime.utcnow()
file_record.tags = existing_tags # type: ignore
file_record.tags_updated_at = update_time # type: ignore
await self.db.commit()
await self.db.refresh(file_record)
logger.info(f"Successfully updated tags for file: {file_id}")
return True, None, update_time
except Exception as e:
logger.error(f"Failed to update tags for file {file_id}: {e}")
await self.db.rollback()
return False, str(e), None
return False, str(e), None

View File

@@ -1,5 +1,6 @@
import asyncio
import uuid
import math
import json
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks
@@ -171,7 +172,7 @@ async def list_evaluation_tasks(
# 转换为响应模型
items = [_map_to_task_detail_response(task) for task in tasks]
total_pages = (total + size - 1) // size if size > 0 else 0
total_pages = math.ceil(total / size) if total > 0 else 0
return StandardResponse(
code=200,
@@ -217,7 +218,7 @@ async def list_evaluation_items(
count_query = select(func.count()).select_from(query.subquery())
total = (await db.execute(count_query)).scalar_one()
files = (await db.execute(query.offset(offset).limit(size))).scalars().all()
total_pages = (total + size - 1) // size if size > 0 else 0
total_pages = math.ceil(total / size) if total > 0 else 0
file_responses = [
EvaluationFileResponse(
taskId=file.task_id,
@@ -298,7 +299,7 @@ async def list_evaluation_items(
taskId=item.task_id,
itemId=item.item_id,
fileId=item.file_id,
evalContent=json.loads(item.eval_content),
evalContent=json.loads(item.eval_content) if item.eval_content else None,
evalScore=float(item.eval_score) if item.eval_score else None,
evalResult=json.loads(item.eval_result),
status=item.status
@@ -306,7 +307,7 @@ async def list_evaluation_items(
for item in items
]
total_pages = (total + size - 1) // size if size > 0 else 0
total_pages = math.ceil(total / size) if total > 0 else 0
return StandardResponse(
code=200,
@@ -387,6 +388,12 @@ async def delete_eval_tasks(
.where(EvaluationItem.task_id == task_id)
)
# 删除评估文件
await db.execute(
EvaluationFile.__table__.delete()
.where(EvaluationFile.task_id == task_id)
)
# 删除任务
await db.delete(task)
await db.commit()
@@ -419,6 +426,7 @@ def _map_to_task_detail_response(
sourceId=task.source_id,
sourceName=task.source_name,
status=task.status,
evalMethod=task.eval_method,
evalProcess=task.eval_process,
evalPrompt=task.eval_prompt,
evalConfig=json.loads(task.eval_config),

View File

@@ -36,6 +36,7 @@ class EvaluationTaskItem(BaseModel):
source_id: Optional[str] = Field(..., alias="sourceId", description="数据源ID")
source_name: Optional[str] = Field(None, alias="sourceName", description="数据源名称")
status: TaskStatus = Field(..., description="任务状态")
eval_method: Optional[str] = Field(None, alias="evalMethod", description="评估方式")
eval_process: Optional[float] = Field(0, alias="evalProcess", description="评估进度")
created_at: Optional[str] = Field(None, alias="createdAt", description="创建时间")
updated_at: Optional[str] = Field(None, alias="updatedAt", description="更新时间")

View File

@@ -1,3 +1,7 @@
from app.core.logging import get_logger
logger = get_logger(__name__)
EVALUATION_PROMPT_TEMPLATE = [
{
"evalType": "QA",
@@ -51,26 +55,90 @@ EVALUATION_PROMPT_TEMPLATE = [
请按照以下JSON格式输出评估结果,评估结果为Y/N,符合标注输出Y,不符合标准输出N:
{
"result": {{result_example}
"result": {
{result_example}
},
"evaluation": "这是一个高质量的问答数据集。问题表述清晰具体,答案准确完整且逻辑性强,与原始文本高度相关。建议:可以进一步丰富答案的细节描述。"
}
"""
},
{
"evalType": "COT",
"defaultDimensions": [
{
"dimension": "思维链逻辑是否连贯",
"description": "分析思维链中推理链条的连续性:步骤间有明确的逻辑连接词;每一步都是基于前置在步骤的结果;没有逻辑跳跃或断层;推理方向一致,不偏离目标。"
},
{
"dimension": "推理步骤是否合理必要",
"description": "分析思维链中对于步骤分解的合理性和必要性:复杂问题被适当分解; 每个步骤都是解决整体问题的必要部分;步骤粒度适中(既不过细也不过粗);符合人类认知习惯。"
},
{
"dimension": "内容是否准确",
"description": "分析整个COT数据内容是否准确:所有陈述的事实必须准确;展示每一步的计算结果(如何涉及数学计算,必须保证数学计算无错误);逻辑推导有效且合理,最终答案与推理过程一致。"
}
],
"prompt": """
# Role: COT数据质量评估专家
## Profile:
- Description: 你是一名专业的Chain-of-Thought(CoT)推理数据质量评估专家,擅长从多个维度对COT数据进行质量评估,挑选出有助于模型学习如何分解问题、展示推理链条,提高模型对于复杂问题解决能力的COT数据。具备深度学习、自然语言处理和数据科学的专业背景。
## Skills:
1. 能够从多个维度对COT数据进行综合评估,保证客观、专业、细致
2. 擅长识别COT数据中的潜在问题,如推包含事实性错误(关键信息错误),存在严重逻辑矛(无法自洽),包含有害、偏见或不当内容,完全偏离主题,抄袭或高度重复内容等
3. 能够给出具体的改进建议和质量评分,并提供可操作的优化方案
## 评估维度:
{dimensions}
## 问题或指令:
{question}
## 思维链:
{chain_of_thought}
## 结论:
{conclusion}
## 注意事项:
- 评估结论要具体指出优点和不足,提供可操作的改进建议
- 评估结论控制在150字以内,简洁明了但要涵盖关键信息
## 输出要求:
请按照以下JSON格式输出评估结果,评估结果为Y/N,符合标注输出Y,不符合标准输出N;将评估结论写到evaluation中:
{
"result": {
{result_example}
},
"evaluation": "这是一个高质量的COT数据。思维链逻辑连贯,推理步骤合理,信息完整。建议:部分表达可以进一步优化,以及个别步骤的过渡可以更加平滑。"
}
"""
}
]
def get_dimensions_for_qa(dimensions: list[dict]) -> str:
dimensions_str = "\n"
dimensions_str = ""
index = 1
for dimension in dimensions:
dimensions_str += f"### {index}. {dimension.get("dimension")}\n**评估标准:**\n{dimension.get("description")}\n\n"
if index > 1:
dimensions_str += "\n"
dimensions_str += f"### {index}. {dimension.get("dimension")}\n**评估标准:**\n{dimension.get("description")}"
if index < len(dimensions):
dimensions_str += "\n"
index += 1
return dimensions_str
def get_result_example_for_qa(dimensions: list[dict]) -> str:
result_example = ""
index = 1
for dimension in dimensions:
result_example += f'\n "{dimension.get("dimension")}": "Y",'
if index > 1:
result_example += "\n "
result_example += f'"{dimension.get("dimension")}": "Y"'
if index < len(dimensions):
result_example += ","
index += 1
return result_example
def get_prompt(task_type: str, dimensions: list[dict]) -> str:

View File

@@ -1,7 +1,7 @@
"""
Schema for evaluation prompt templates.
"""
from typing import List, Dict, Any
from typing import List
from pydantic import BaseModel, Field

View File

@@ -2,7 +2,7 @@ import json
import uuid
import asyncio
from sqlalchemy import select
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exception import BusinessErrorCodeEnum, BusinessException
@@ -13,7 +13,7 @@ from app.db.models.data_synthesis import DataSynthesisFileInstance, SynthesisDat
from app.db.session import AsyncSessionLocal
from app.module.evaluation.schema.evaluation import SourceType
from app.module.shared.schema import TaskStatus
from app.module.shared.util.model_chat import call_openai_style_model
from app.module.shared.util.model_chat import call_openai_style_model, _extract_json_substring
from app.module.evaluation.schema.prompt import get_prompt
from app.module.shared.util.structured_file import StructuredFileHandlerFactory
from app.module.system.service.common_service import get_model_by_id
@@ -35,6 +35,10 @@ class EvaluationExecutor:
prompt_text = ((prompt_text.replace("{content}", eval_content.get("input"))
.replace("{question}", eval_content.get("instruction")))
.replace("{answer}", eval_content.get("output")))
if self.task.task_type == "COT":
prompt_text = ((prompt_text.replace("{question}", eval_content.get("question"))
.replace("{conclusion}", eval_content.get("conclusion")))
.replace("{chain_of_thought}", eval_content.get("chain_of_thought")))
return prompt_text
async def execute(self):
@@ -44,29 +48,44 @@ class EvaluationExecutor:
files = (await self.db.execute(
select(EvaluationFile).where(EvaluationFile.task_id == self.task.id)
)).scalars().all()
query = select(EvaluationItem).where(EvaluationItem.task_id == self.task.id)
count_query = select(func.count()).select_from(query.subquery())
total = (await self.db.execute(count_query)).scalar_one()
evaluated_count = 0
for file in files:
items = (await self.db.execute(
select(EvaluationItem).where(EvaluationItem.task_id == self.task.id)
.where(EvaluationItem.file_id == file.file_id)
)).scalars().all()
items = (await self.db.execute(query.where(EvaluationItem.file_id == file.file_id))).scalars().all()
tasks = [
self.evaluate_item(model_config, item, semaphore)
for item in items
]
await asyncio.gather(*tasks, return_exceptions=True)
file.evaluated_count = len(items)
evaluated_count += file.evaluated_count
self.task.eval_process = evaluated_count / total
await self.db.commit()
async def evaluate_item(self, model_config, item: EvaluationItem, semaphore: asyncio.Semaphore):
async with semaphore:
prompt_text = self.get_eval_prompt(item)
resp_text = await asyncio.to_thread(
call_openai_style_model, model_config.base_url, model_config.api_key, model_config.model_name,
prompt_text,
)
item.eval_result = resp_text
item.status = TaskStatus.COMPLETED.value
await self.db.commit()
max_try = 3
while max_try > 0:
prompt_text = self.get_eval_prompt(item)
resp_text = await asyncio.to_thread(
call_openai_style_model, model_config.base_url, model_config.api_key, model_config.model_name,
prompt_text,
)
resp_text = _extract_json_substring(resp_text)
try:
json.loads(resp_text)
except Exception as e:
logger.error(
f"Failed to parse LLM answer as JSON for task={self.task.id}, file={item.file_id}: {e}. Raw answer: {resp_text!r}"
)
max_try -= 1
continue
item.eval_result = resp_text
item.status = TaskStatus.COMPLETED.value
await self.db.commit()
return
def get_source_type(self) -> SourceType:
@@ -119,7 +138,7 @@ class SynthesisEvaluationExecutor(EvaluationExecutor):
async def save_eval_items(self):
synthesis_files = ((await self.db.execute(select(DataSynthesisFileInstance)
.where(DataSynthesisFileInstance.task_id == self.task.source_id)))
.where(DataSynthesisFileInstance.synthesis_instance_id == self.task.source_id)))
.scalars().all())
for synthesis_file in synthesis_files:
synthesis_datas = ((await self.db.execute(select(SynthesisData)
@@ -132,7 +151,7 @@ class SynthesisEvaluationExecutor(EvaluationExecutor):
task_id=self.task.id,
file_id=synthesis_file.id,
item_id=synthesis_data.id,
eval_content=synthesis_data.data,
eval_content=json.dumps(synthesis_data.data),
status=TaskStatus.PENDING.value,
created_by=self.task.created_by,
updated_by=self.task.updated_by,

View File

@@ -28,6 +28,7 @@ from app.db.models.data_synthesis import (
from app.db.models.dataset_management import DatasetFiles
from app.db.models.model_config import get_model_by_id
from app.db.session import logger
from app.module.shared.util.model_chat import _extract_json_substring
from app.module.system.service.common_service import get_chat_client, chat
@@ -365,7 +366,7 @@ class GenerationService:
return
# 1. 预处理原始回答:尝试从中截取出最可能的 JSON 片段
cleaned = self._extract_json_substring(raw_answer)
cleaned = _extract_json_substring(raw_answer)
# 2. 解析 JSON,统一成列表结构
try:
@@ -426,45 +427,6 @@ class GenerationService:
await self.db.commit()
await self.db.refresh(file_instance)
@staticmethod
def _extract_json_substring(raw: str) -> str:
"""从 LLM 的原始回答中提取最可能的 JSON 字符串片段。
处理思路:
- 原始回答可能是:说明文字 + JSON + 说明文字,甚至带有 Markdown 代码块。
- 优先在文本中查找第一个 '{''[' 作为 JSON 起始;
- 再从后向前找最后一个 '}'']' 作为结束;
- 如果找不到合适的边界,就退回原始字符串。
该方法不会保证截取的一定是合法 JSON,但能显著提高 json.loads 的成功率。
"""
if not raw:
return raw
start = None
end = None
# 查找第一个 JSON 起始符号
for i, ch in enumerate(raw):
if ch in "[{":
start = i
break
# 查找最后一个 JSON 结束符号
for i in range(len(raw) - 1, -1, -1):
if raw[i] in "]}":
end = i + 1 # 切片是左闭右开
break
if start is not None and end is not None and start < end:
return raw[start:end].strip()
# 兜底:去掉常见 Markdown 包裹(```json ... ```)
stripped = raw.strip()
if stripped.startswith("```"):
# 去掉首尾 ``` 标记
stripped = stripped.strip("`")
return stripped
async def _get_or_create_file_instance(
self,
synthesis_task_id: str,

View File

@@ -13,3 +13,41 @@ def call_openai_style_model(base_url, api_key, model_name, prompt, **kwargs):
**kwargs
)
return response.choices[0].message.content
def _extract_json_substring(raw: str) -> str:
"""从 LLM 的原始回答中提取最可能的 JSON 字符串片段。
处理思路:
- 原始回答可能是:说明文字 + JSON + 说明文字,甚至带有 Markdown 代码块。
- 优先在文本中查找第一个 '{''[' 作为 JSON 起始;
- 再从后向前找最后一个 '}'']' 作为结束;
- 如果找不到合适的边界,就退回原始字符串。
该方法不会保证截取的一定是合法 JSON,但能显著提高 json.loads 的成功率。
"""
if not raw:
return raw
start = None
end = None
# 查找第一个 JSON 起始符号
for i, ch in enumerate(raw):
if ch in "[{":
start = i
break
# 查找最后一个 JSON 结束符号
for i in range(len(raw) - 1, -1, -1):
if raw[i] in "]}":
end = i + 1 # 切片是左闭右开
break
if start is not None and end is not None and start < end:
return raw[start:end].strip()
# 兜底:去掉常见 Markdown 包裹(```json ... ```)
stripped = raw.strip()
if stripped.startswith("```"):
# 去掉首尾 ``` 标记
stripped = stripped.strip("`")
return stripped

View File

@@ -5,6 +5,7 @@ from jsonschema import validate
class ItemTypes(Enum):
QA = "QA"
COT = "COT"
class StructuredFileItemHandler:
@@ -14,11 +15,26 @@ class StructuredFileItemHandler:
def get_item_type(self) -> ItemTypes:
pass
def get_items_from_file(self, file_path: str) -> list[dict]:
def validate_json(self, data):
pass
def check_file(self) -> bool:
pass
def get_items_from_file(self, file_path: str) -> list[dict]:
file_type = file_path.split(".")[-1].upper()
items = []
if file_type == "JSON":
with open(file_path, "r", encoding="utf-8") as f:
data = json.load(f)
if not self.validate_json(data):
return items
items = data
elif file_type == "JSONL":
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
data = json.loads(line)
if not self.validate_json(data):
continue
items.append(data)
return items
class QAItemHandler(StructuredFileItemHandler):
def __init__(self):
@@ -51,32 +67,44 @@ class QAItemHandler(StructuredFileItemHandler):
except Exception as e:
return False
def get_items_from_file(self, file_path: str) -> list[dict]:
file_type = file_path.split(".")[-1].upper()
items = []
if file_type == "JSON":
with open(file_path, "r", encoding="utf-8") as f:
data = json.load(f)
if not self.validate_json(data):
return items
items = data
elif file_type == "JSONL":
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
data = json.loads(line)
if not self.validate_json(data):
continue
items.append(data)
return items
def check_file(self) -> bool:
pass
class COTItemHandler(StructuredFileItemHandler):
def __init__(self):
self.schema = {
"type": "object",
"properties": {
"question": {"type": "string"},
"conclusion": {"type": "string"},
"chain_of_thought": {"type": "string"}
},
"required": ["question", "conclusion", "chain_of_thought"],
}
self.schema_list = {
"type": "array",
"items": self.schema,
}
super().__init__()
def get_item_type(self):
return ItemTypes.COT
def validate_json(self, data):
try:
validate(instance=data, schema=self.schema)
return True
except Exception as e:
try:
validate(instance=data, schema=self.schema_list)
return True
except Exception as e:
return False
class StructuredFileHandlerFactory:
def __init__(self):
self.handlers: list[StructuredFileItemHandler] = []
self.handlers.append(QAItemHandler())
self.handlers.append(COTItemHandler())
def get_handler(self, item_type: str) -> StructuredFileItemHandler:
for handler in self.handlers: