You've already forked DataMate
fix: 修复评估时模型输出json格式不对导致读取错误的问题 (#133)
* feature: add cot data evaluation function * fix: added verification to evaluation results * fix: fix the prompt for evaluating * fix: 修复当评估结果为空导致读取失败的问题
This commit is contained in:
@@ -2,7 +2,7 @@ import json
|
||||
import uuid
|
||||
import asyncio
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.exception import BusinessErrorCodeEnum, BusinessException
|
||||
@@ -13,7 +13,7 @@ from app.db.models.data_synthesis import DataSynthesisFileInstance, SynthesisDat
|
||||
from app.db.session import AsyncSessionLocal
|
||||
from app.module.evaluation.schema.evaluation import SourceType
|
||||
from app.module.shared.schema import TaskStatus
|
||||
from app.module.shared.util.model_chat import call_openai_style_model
|
||||
from app.module.shared.util.model_chat import call_openai_style_model, _extract_json_substring
|
||||
from app.module.evaluation.schema.prompt import get_prompt
|
||||
from app.module.shared.util.structured_file import StructuredFileHandlerFactory
|
||||
from app.module.system.service.common_service import get_model_by_id
|
||||
@@ -35,6 +35,10 @@ class EvaluationExecutor:
|
||||
prompt_text = ((prompt_text.replace("{content}", eval_content.get("input"))
|
||||
.replace("{question}", eval_content.get("instruction")))
|
||||
.replace("{answer}", eval_content.get("output")))
|
||||
if self.task.task_type == "COT":
|
||||
prompt_text = ((prompt_text.replace("{question}", eval_content.get("question"))
|
||||
.replace("{conclusion}", eval_content.get("conclusion")))
|
||||
.replace("{chain_of_thought}", eval_content.get("chain_of_thought")))
|
||||
return prompt_text
|
||||
|
||||
async def execute(self):
|
||||
@@ -44,29 +48,44 @@ class EvaluationExecutor:
|
||||
files = (await self.db.execute(
|
||||
select(EvaluationFile).where(EvaluationFile.task_id == self.task.id)
|
||||
)).scalars().all()
|
||||
query = select(EvaluationItem).where(EvaluationItem.task_id == self.task.id)
|
||||
count_query = select(func.count()).select_from(query.subquery())
|
||||
total = (await self.db.execute(count_query)).scalar_one()
|
||||
evaluated_count = 0
|
||||
for file in files:
|
||||
items = (await self.db.execute(
|
||||
select(EvaluationItem).where(EvaluationItem.task_id == self.task.id)
|
||||
.where(EvaluationItem.file_id == file.file_id)
|
||||
)).scalars().all()
|
||||
items = (await self.db.execute(query.where(EvaluationItem.file_id == file.file_id))).scalars().all()
|
||||
tasks = [
|
||||
self.evaluate_item(model_config, item, semaphore)
|
||||
for item in items
|
||||
]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
file.evaluated_count = len(items)
|
||||
evaluated_count += file.evaluated_count
|
||||
self.task.eval_process = evaluated_count / total
|
||||
await self.db.commit()
|
||||
|
||||
async def evaluate_item(self, model_config, item: EvaluationItem, semaphore: asyncio.Semaphore):
|
||||
async with semaphore:
|
||||
prompt_text = self.get_eval_prompt(item)
|
||||
resp_text = await asyncio.to_thread(
|
||||
call_openai_style_model, model_config.base_url, model_config.api_key, model_config.model_name,
|
||||
prompt_text,
|
||||
)
|
||||
item.eval_result = resp_text
|
||||
item.status = TaskStatus.COMPLETED.value
|
||||
await self.db.commit()
|
||||
max_try = 3
|
||||
while max_try > 0:
|
||||
prompt_text = self.get_eval_prompt(item)
|
||||
resp_text = await asyncio.to_thread(
|
||||
call_openai_style_model, model_config.base_url, model_config.api_key, model_config.model_name,
|
||||
prompt_text,
|
||||
)
|
||||
resp_text = _extract_json_substring(resp_text)
|
||||
try:
|
||||
json.loads(resp_text)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to parse LLM answer as JSON for task={self.task.id}, file={item.file_id}: {e}. Raw answer: {resp_text!r}"
|
||||
)
|
||||
max_try -= 1
|
||||
continue
|
||||
item.eval_result = resp_text
|
||||
item.status = TaskStatus.COMPLETED.value
|
||||
await self.db.commit()
|
||||
return
|
||||
|
||||
|
||||
def get_source_type(self) -> SourceType:
|
||||
@@ -119,7 +138,7 @@ class SynthesisEvaluationExecutor(EvaluationExecutor):
|
||||
|
||||
async def save_eval_items(self):
|
||||
synthesis_files = ((await self.db.execute(select(DataSynthesisFileInstance)
|
||||
.where(DataSynthesisFileInstance.task_id == self.task.source_id)))
|
||||
.where(DataSynthesisFileInstance.synthesis_instance_id == self.task.source_id)))
|
||||
.scalars().all())
|
||||
for synthesis_file in synthesis_files:
|
||||
synthesis_datas = ((await self.db.execute(select(SynthesisData)
|
||||
@@ -132,7 +151,7 @@ class SynthesisEvaluationExecutor(EvaluationExecutor):
|
||||
task_id=self.task.id,
|
||||
file_id=synthesis_file.id,
|
||||
item_id=synthesis_data.id,
|
||||
eval_content=synthesis_data.data,
|
||||
eval_content=json.dumps(synthesis_data.data),
|
||||
status=TaskStatus.PENDING.value,
|
||||
created_by=self.task.created_by,
|
||||
updated_by=self.task.updated_by,
|
||||
|
||||
Reference in New Issue
Block a user