Files
DataMate/runtime/datamate-python/app/module/evaluation/service/evaluation.py
Dallas98 e0e9b1d94d feat:问题生成过程优化及COT数据生成优化 (#169)
* fix(chart): update Helm chart helpers and values for improved configuration

* feat(SynthesisTaskTab): enhance task table with tooltip support and improved column widths

* feat(CreateTask, SynthFileTask): improve task creation and detail view with enhanced payload handling and UI updates

* feat(SynthFileTask): enhance file display with progress tracking and delete action

* feat(SynthFileTask): enhance file display with progress tracking and delete action

* feat(SynthDataDetail): add delete action for chunks with confirmation prompt

* feat(SynthDataDetail): update edit and delete buttons to icon-only format

* feat(SynthDataDetail): add confirmation modals for chunk and synthesis data deletion

* feat(DocumentSplitter): add enhanced document splitting functionality with CJK support and metadata detection

* feat(DataSynthesis): refactor data synthesis models and update task handling logic

* feat(DataSynthesis): streamline synthesis task handling and enhance chunk processing logic

* feat(DataSynthesis): refactor data synthesis models and update task handling logic

* fix(generation_service): ensure processed chunks are incremented regardless of question generation success

* feat(CreateTask): enhance task creation with new synthesis templates and improved configuration options

* feat(CreateTask): enhance task creation with new synthesis templates and improved configuration options

* feat(CreateTask): enhance task creation with new synthesis templates and improved configuration options

* feat(CreateTask): enhance task creation with new synthesis templates and improved configuration options
2025-12-18 16:51:18 +08:00

229 lines
10 KiB
Python

import json
import uuid
import asyncio
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exception import BusinessErrorCodeEnum, BusinessException
from app.core.logging import get_logger
from app.db.models import EvaluationItem, EvaluationTask, DatasetFiles
from app.db.models.data_evaluation import EvaluationFile
from app.db.models.data_synthesis import DataSynthesisFileInstance, SynthesisData
from app.db.session import AsyncSessionLocal
from app.module.evaluation.schema.evaluation import SourceType
from app.module.shared.schema import TaskStatus
from app.module.shared.util.model_chat import call_openai_style_model, extract_json_substring
from app.module.evaluation.schema.prompt import get_prompt
from app.module.shared.util.structured_file import StructuredFileHandlerFactory
from app.module.system.service.common_service import get_model_by_id
logger = get_logger(__name__)
class EvaluationExecutor:
def __init__(self, db: AsyncSession, task: EvaluationTask):
self.db = db
self.task = task
async def save_eval_items(self):
pass
def get_eval_prompt(self, item: EvaluationItem) -> str:
prompt_text = get_prompt(self.task.task_type, json.loads(self.task.eval_config).get("dimensions"))
eval_content = json.loads(item.eval_content)
if self.task.task_type == "QA":
prompt_text = ((prompt_text.replace("{content}", eval_content.get("input"))
.replace("{question}", eval_content.get("instruction")))
.replace("{answer}", eval_content.get("output")))
if self.task.task_type == "COT":
prompt_text = ((prompt_text.replace("{question}", eval_content.get("instruction"))
.replace("{conclusion}", eval_content.get("output")))
.replace("{chain_of_thought}", eval_content.get("chain_of_thought")))
return prompt_text
async def execute(self):
eval_config = json.loads(self.task.eval_config)
model_config = await get_model_by_id(self.db, eval_config.get("modelId"))
semaphore = asyncio.Semaphore(10)
files = (await self.db.execute(
select(EvaluationFile).where(EvaluationFile.task_id == self.task.id)
)).scalars().all()
query = select(EvaluationItem).where(EvaluationItem.task_id == self.task.id)
count_query = select(func.count()).select_from(query.subquery())
total = (await self.db.execute(count_query)).scalar_one()
evaluated_count = 0
for file in files:
items = (await self.db.execute(query.where(EvaluationItem.file_id == file.file_id))).scalars().all()
tasks = [
self.evaluate_item(model_config, item, semaphore)
for item in items
]
await asyncio.gather(*tasks, return_exceptions=True)
file.evaluated_count = len(items)
evaluated_count += file.evaluated_count
self.task.eval_process = evaluated_count / total
await self.db.commit()
async def evaluate_item(self, model_config, item: EvaluationItem, semaphore: asyncio.Semaphore):
async with semaphore:
max_try = 3
while max_try > 0:
prompt_text = self.get_eval_prompt(item)
resp_text = await asyncio.to_thread(
call_openai_style_model, model_config.base_url, model_config.api_key, model_config.model_name,
prompt_text,
)
resp_text = extract_json_substring(resp_text)
try:
json.loads(resp_text)
except Exception as e:
logger.error(
f"Failed to parse LLM answer as JSON for task={self.task.id}, file={item.file_id}: {e}. Raw answer: {resp_text!r}"
)
max_try -= 1
continue
item.eval_result = resp_text
item.status = TaskStatus.COMPLETED.value
await self.db.commit()
return
def get_source_type(self) -> SourceType:
pass
class DatasetEvaluationExecutor(EvaluationExecutor):
def __init__(self, db: AsyncSession, task: EvaluationTask):
super().__init__(db, task)
async def save_eval_items(self):
dataset_files = ((await self.db.execute(select(DatasetFiles)
.where(DatasetFiles.dataset_id == self.task.source_id)))
.scalars().all())
handler = StructuredFileHandlerFactory().get_handler(self.task.task_type)
for dataset_file in dataset_files:
if dataset_file.file_type.upper() != "JSON" and dataset_file.file_type.upper() != "JSONL":
continue
items = handler.get_items_from_file(dataset_file.file_path)
logger.info(f"parse {len(items)} items from file {dataset_file.file_name}")
for item in items:
self.db.add(EvaluationItem(
id=str(uuid.uuid4()),
task_id=self.task.id,
file_id=dataset_file.id,
item_id=item.get("id") if item.get("id") else str(uuid.uuid4()),
eval_content=json.dumps(item, ensure_ascii=False),
eval_result="{}",
status=TaskStatus.PENDING.value,
created_by=self.task.created_by,
updated_by=self.task.updated_by,
))
self.db.add(EvaluationFile(
id=str(uuid.uuid4()),
task_id=self.task.id,
file_id=dataset_file.id,
file_name=dataset_file.file_name,
total_count=len(items),
evaluated_count=0,
created_by=self.task.created_by,
updated_by=self.task.updated_by,
))
def get_source_type(self) -> SourceType:
return SourceType.DATASET
class SynthesisEvaluationExecutor(EvaluationExecutor):
def __init__(self, db: AsyncSession, task: EvaluationTask):
super().__init__(db, task)
async def save_eval_items(self):
synthesis_files = ((await self.db.execute(select(DataSynthesisFileInstance)
.where(DataSynthesisFileInstance.synthesis_instance_id == self.task.source_id)))
.scalars().all())
for synthesis_file in synthesis_files:
synthesis_datas = ((await self.db.execute(select(SynthesisData)
.where(SynthesisData.synthesis_file_instance_id == synthesis_file.id)))
.scalars().all())
logger.info(f"get {len(synthesis_datas)} items from file {synthesis_file.file_name}")
for synthesis_data in synthesis_datas:
self.db.add(EvaluationItem(
id=str(uuid.uuid4()),
task_id=self.task.id,
file_id=synthesis_file.id,
item_id=synthesis_data.id,
eval_content=json.dumps(synthesis_data.data),
eval_result="{}",
status=TaskStatus.PENDING.value,
created_by=self.task.created_by,
updated_by=self.task.updated_by,
))
self.db.add(EvaluationFile(
id=str(uuid.uuid4()),
task_id=self.task.id,
file_id=synthesis_file.id,
file_name=synthesis_file.file_name,
total_count=len(synthesis_datas),
evaluated_count=0,
created_by=self.task.created_by,
updated_by=self.task.updated_by,
))
pass
def get_source_type(self) -> SourceType:
return SourceType.SYNTHESIS
class EvaluationExecutorFactory:
def __init__(self, db: AsyncSession, task: EvaluationTask):
self.db = db
self.executors: list[EvaluationExecutor] = []
self.executors.append(DatasetEvaluationExecutor(db, task))
self.executors.append(SynthesisEvaluationExecutor(db, task))
def get_executor(self, source_type: str) -> EvaluationExecutor:
for executor in self.executors:
if executor.get_source_type().value == source_type:
return executor
raise BusinessException(BusinessErrorCodeEnum.TASK_TYPE_ERROR.value)
class EvaluationTaskService:
@staticmethod
async def run_evaluation_task(task_id: str):
"""
Background worker to run evaluations.
- task_id: id of EvaluationTaskModel
"""
logger.info(f"Background evaluation worker started add items for task {task_id}")
async with AsyncSessionLocal() as session:
try:
task = await session.execute(select(EvaluationTask).where(EvaluationTask.id == task_id))
task = task.scalar_one_or_none()
factory = EvaluationExecutorFactory(session, task)
executor = factory.get_executor(task.source_type)
await executor.save_eval_items()
task.status = TaskStatus.RUNNING.value
except Exception as e:
logger.error(f"Background worker encountered error for task {task_id}: {e}")
task.status = TaskStatus.FAILED.value
finally:
await session.commit()
logger.info(f"Background evaluation worker started for task {task_id}")
async with AsyncSessionLocal() as session:
try:
task = await session.execute(select(EvaluationTask).where(EvaluationTask.id == task_id))
task = task.scalar_one_or_none()
factory = EvaluationExecutorFactory(session, task)
executor = factory.get_executor(task.source_type)
await executor.execute()
logger.info(f"Background evaluation worker finished for task {task_id}")
task.status = TaskStatus.COMPLETED.value
except Exception as e:
logger.error(f"Background worker encountered error for task {task_id}: {e}")
task.status = TaskStatus.FAILED.value
finally:
await session.commit()