Files
DataMate/runtime/datamate-python/app/module/evaluation/service/evaluation.py
hefanli 744d15ba24 fix: 修复评估时模型输出json格式不对导致读取错误的问题 (#133)
* feature: add cot data evaluation function

* fix: added verification to evaluation results

* fix: fix the prompt for evaluating

* fix: 修复当评估结果为空导致读取失败的问题
2025-12-04 18:49:50 +08:00

227 lines
10 KiB
Python

import json
import uuid
import asyncio
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exception import BusinessErrorCodeEnum, BusinessException
from app.core.logging import get_logger
from app.db.models import EvaluationItem, EvaluationTask, DatasetFiles
from app.db.models.data_evaluation import EvaluationFile
from app.db.models.data_synthesis import DataSynthesisFileInstance, SynthesisData
from app.db.session import AsyncSessionLocal
from app.module.evaluation.schema.evaluation import SourceType
from app.module.shared.schema import TaskStatus
from app.module.shared.util.model_chat import call_openai_style_model, _extract_json_substring
from app.module.evaluation.schema.prompt import get_prompt
from app.module.shared.util.structured_file import StructuredFileHandlerFactory
from app.module.system.service.common_service import get_model_by_id
logger = get_logger(__name__)
class EvaluationExecutor:
def __init__(self, db: AsyncSession, task: EvaluationTask):
self.db = db
self.task = task
async def save_eval_items(self):
pass
def get_eval_prompt(self, item: EvaluationItem) -> str:
prompt_text = get_prompt(self.task.task_type, json.loads(self.task.eval_config).get("dimensions"))
eval_content = json.loads(item.eval_content)
if self.task.task_type == "QA":
prompt_text = ((prompt_text.replace("{content}", eval_content.get("input"))
.replace("{question}", eval_content.get("instruction")))
.replace("{answer}", eval_content.get("output")))
if self.task.task_type == "COT":
prompt_text = ((prompt_text.replace("{question}", eval_content.get("question"))
.replace("{conclusion}", eval_content.get("conclusion")))
.replace("{chain_of_thought}", eval_content.get("chain_of_thought")))
return prompt_text
async def execute(self):
eval_config = json.loads(self.task.eval_config)
model_config = await get_model_by_id(self.db, eval_config.get("model_id"))
semaphore = asyncio.Semaphore(10)
files = (await self.db.execute(
select(EvaluationFile).where(EvaluationFile.task_id == self.task.id)
)).scalars().all()
query = select(EvaluationItem).where(EvaluationItem.task_id == self.task.id)
count_query = select(func.count()).select_from(query.subquery())
total = (await self.db.execute(count_query)).scalar_one()
evaluated_count = 0
for file in files:
items = (await self.db.execute(query.where(EvaluationItem.file_id == file.file_id))).scalars().all()
tasks = [
self.evaluate_item(model_config, item, semaphore)
for item in items
]
await asyncio.gather(*tasks, return_exceptions=True)
file.evaluated_count = len(items)
evaluated_count += file.evaluated_count
self.task.eval_process = evaluated_count / total
await self.db.commit()
async def evaluate_item(self, model_config, item: EvaluationItem, semaphore: asyncio.Semaphore):
async with semaphore:
max_try = 3
while max_try > 0:
prompt_text = self.get_eval_prompt(item)
resp_text = await asyncio.to_thread(
call_openai_style_model, model_config.base_url, model_config.api_key, model_config.model_name,
prompt_text,
)
resp_text = _extract_json_substring(resp_text)
try:
json.loads(resp_text)
except Exception as e:
logger.error(
f"Failed to parse LLM answer as JSON for task={self.task.id}, file={item.file_id}: {e}. Raw answer: {resp_text!r}"
)
max_try -= 1
continue
item.eval_result = resp_text
item.status = TaskStatus.COMPLETED.value
await self.db.commit()
return
def get_source_type(self) -> SourceType:
pass
class DatasetEvaluationExecutor(EvaluationExecutor):
def __init__(self, db: AsyncSession, task: EvaluationTask):
super().__init__(db, task)
async def save_eval_items(self):
dataset_files = ((await self.db.execute(select(DatasetFiles)
.where(DatasetFiles.dataset_id == self.task.source_id)))
.scalars().all())
handler = StructuredFileHandlerFactory().get_handler(self.task.task_type)
for dataset_file in dataset_files:
if dataset_file.file_type.upper() != "JSON" and dataset_file.file_type.upper() != "JSONL":
continue
items = handler.get_items_from_file(dataset_file.file_path)
logger.info(f"parse {len(items)} items from file {dataset_file.file_name}")
for item in items:
self.db.add(EvaluationItem(
id=str(uuid.uuid4()),
task_id=self.task.id,
file_id=dataset_file.id,
item_id=item.get("id") if item.get("id") else str(uuid.uuid4()),
eval_content=json.dumps(item, ensure_ascii=False),
status=TaskStatus.PENDING.value,
created_by=self.task.created_by,
updated_by=self.task.updated_by,
))
self.db.add(EvaluationFile(
id=str(uuid.uuid4()),
task_id=self.task.id,
file_id=dataset_file.id,
file_name=dataset_file.file_name,
total_count=len(items),
evaluated_count=0,
created_by=self.task.created_by,
updated_by=self.task.updated_by,
))
def get_source_type(self) -> SourceType:
return SourceType.DATASET
class SynthesisEvaluationExecutor(EvaluationExecutor):
def __init__(self, db: AsyncSession, task: EvaluationTask):
super().__init__(db, task)
async def save_eval_items(self):
synthesis_files = ((await self.db.execute(select(DataSynthesisFileInstance)
.where(DataSynthesisFileInstance.synthesis_instance_id == self.task.source_id)))
.scalars().all())
for synthesis_file in synthesis_files:
synthesis_datas = ((await self.db.execute(select(SynthesisData)
.where(SynthesisData.synthesis_file_instance_id == synthesis_file.id)))
.scalars().all())
logger.info(f"get {len(synthesis_datas)} items from file {synthesis_file.file_name}")
for synthesis_data in synthesis_datas:
self.db.add(EvaluationItem(
id=str(uuid.uuid4()),
task_id=self.task.id,
file_id=synthesis_file.id,
item_id=synthesis_data.id,
eval_content=json.dumps(synthesis_data.data),
status=TaskStatus.PENDING.value,
created_by=self.task.created_by,
updated_by=self.task.updated_by,
))
self.db.add(EvaluationFile(
id=str(uuid.uuid4()),
task_id=self.task.id,
file_id=synthesis_file.id,
file_name=synthesis_file.file_name,
total_count=len(synthesis_datas),
evaluated_count=0,
created_by=self.task.created_by,
updated_by=self.task.updated_by,
))
pass
def get_source_type(self) -> SourceType:
return SourceType.SYNTHESIS
class EvaluationExecutorFactory:
def __init__(self, db: AsyncSession, task: EvaluationTask):
self.db = db
self.executors: list[EvaluationExecutor] = []
self.executors.append(DatasetEvaluationExecutor(db, task))
self.executors.append(SynthesisEvaluationExecutor(db, task))
def get_executor(self, source_type: str) -> EvaluationExecutor:
for executor in self.executors:
if executor.get_source_type().value == source_type:
return executor
raise BusinessException(BusinessErrorCodeEnum.TASK_TYPE_ERROR.value)
class EvaluationTaskService:
@staticmethod
async def run_evaluation_task(task_id: str):
"""
Background worker to run evaluations.
- task_id: id of EvaluationTaskModel
"""
logger.info(f"Background evaluation worker started add items for task {task_id}")
async with AsyncSessionLocal() as session:
try:
task = await session.execute(select(EvaluationTask).where(EvaluationTask.id == task_id))
task = task.scalar_one_or_none()
factory = EvaluationExecutorFactory(session, task)
executor = factory.get_executor(task.source_type)
await executor.save_eval_items()
task.status = TaskStatus.RUNNING.value
except Exception as e:
logger.error(f"Background worker encountered error for task {task_id}: {e}")
task.status = TaskStatus.FAILED.value
finally:
await session.commit()
logger.info(f"Background evaluation worker started for task {task_id}")
async with AsyncSessionLocal() as session:
try:
task = await session.execute(select(EvaluationTask).where(EvaluationTask.id == task_id))
task = task.scalar_one_or_none()
factory = EvaluationExecutorFactory(session, task)
executor = factory.get_executor(task.source_type)
await executor.execute()
logger.info(f"Background evaluation worker finished for task {task_id}")
task.status = TaskStatus.COMPLETED.value
except Exception as e:
logger.error(f"Background worker encountered error for task {task_id}: {e}")
task.status = TaskStatus.FAILED.value
finally:
await session.commit()