fix: 修复评估时模型输出json格式不对导致读取错误的问题 (#133)

* feature: add cot data evaluation function

* fix: added verification to evaluation results

* fix: fix the prompt for evaluating

* fix: 修复当评估结果为空导致读取失败的问题
This commit is contained in:
hefanli
2025-12-04 18:49:50 +08:00
committed by GitHub
parent 31c4966608
commit 744d15ba24
14 changed files with 373 additions and 219 deletions

View File

@@ -2,7 +2,7 @@ import json
import uuid
import asyncio
from sqlalchemy import select
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exception import BusinessErrorCodeEnum, BusinessException
@@ -13,7 +13,7 @@ from app.db.models.data_synthesis import DataSynthesisFileInstance, SynthesisDat
from app.db.session import AsyncSessionLocal
from app.module.evaluation.schema.evaluation import SourceType
from app.module.shared.schema import TaskStatus
from app.module.shared.util.model_chat import call_openai_style_model
from app.module.shared.util.model_chat import call_openai_style_model, _extract_json_substring
from app.module.evaluation.schema.prompt import get_prompt
from app.module.shared.util.structured_file import StructuredFileHandlerFactory
from app.module.system.service.common_service import get_model_by_id
@@ -35,6 +35,10 @@ class EvaluationExecutor:
prompt_text = ((prompt_text.replace("{content}", eval_content.get("input"))
.replace("{question}", eval_content.get("instruction")))
.replace("{answer}", eval_content.get("output")))
if self.task.task_type == "COT":
prompt_text = ((prompt_text.replace("{question}", eval_content.get("question"))
.replace("{conclusion}", eval_content.get("conclusion")))
.replace("{chain_of_thought}", eval_content.get("chain_of_thought")))
return prompt_text
async def execute(self):
@@ -44,29 +48,44 @@ class EvaluationExecutor:
files = (await self.db.execute(
select(EvaluationFile).where(EvaluationFile.task_id == self.task.id)
)).scalars().all()
query = select(EvaluationItem).where(EvaluationItem.task_id == self.task.id)
count_query = select(func.count()).select_from(query.subquery())
total = (await self.db.execute(count_query)).scalar_one()
evaluated_count = 0
for file in files:
items = (await self.db.execute(
select(EvaluationItem).where(EvaluationItem.task_id == self.task.id)
.where(EvaluationItem.file_id == file.file_id)
)).scalars().all()
items = (await self.db.execute(query.where(EvaluationItem.file_id == file.file_id))).scalars().all()
tasks = [
self.evaluate_item(model_config, item, semaphore)
for item in items
]
await asyncio.gather(*tasks, return_exceptions=True)
file.evaluated_count = len(items)
evaluated_count += file.evaluated_count
self.task.eval_process = evaluated_count / total
await self.db.commit()
async def evaluate_item(self, model_config, item: EvaluationItem, semaphore: asyncio.Semaphore):
async with semaphore:
prompt_text = self.get_eval_prompt(item)
resp_text = await asyncio.to_thread(
call_openai_style_model, model_config.base_url, model_config.api_key, model_config.model_name,
prompt_text,
)
item.eval_result = resp_text
item.status = TaskStatus.COMPLETED.value
await self.db.commit()
max_try = 3
while max_try > 0:
prompt_text = self.get_eval_prompt(item)
resp_text = await asyncio.to_thread(
call_openai_style_model, model_config.base_url, model_config.api_key, model_config.model_name,
prompt_text,
)
resp_text = _extract_json_substring(resp_text)
try:
json.loads(resp_text)
except Exception as e:
logger.error(
f"Failed to parse LLM answer as JSON for task={self.task.id}, file={item.file_id}: {e}. Raw answer: {resp_text!r}"
)
max_try -= 1
continue
item.eval_result = resp_text
item.status = TaskStatus.COMPLETED.value
await self.db.commit()
return
def get_source_type(self) -> SourceType:
@@ -119,7 +138,7 @@ class SynthesisEvaluationExecutor(EvaluationExecutor):
async def save_eval_items(self):
synthesis_files = ((await self.db.execute(select(DataSynthesisFileInstance)
.where(DataSynthesisFileInstance.task_id == self.task.source_id)))
.where(DataSynthesisFileInstance.synthesis_instance_id == self.task.source_id)))
.scalars().all())
for synthesis_file in synthesis_files:
synthesis_datas = ((await self.db.execute(select(SynthesisData)
@@ -132,7 +151,7 @@ class SynthesisEvaluationExecutor(EvaluationExecutor):
task_id=self.task.id,
file_id=synthesis_file.id,
item_id=synthesis_data.id,
eval_content=synthesis_data.data,
eval_content=json.dumps(synthesis_data.data),
status=TaskStatus.PENDING.value,
created_by=self.task.created_by,
updated_by=self.task.updated_by,