You've already forked DataMate
feature: add data-evaluation
* feature: add evaluation task management function * feature: add evaluation task detail page * fix: delete duplicate definition for table t_model_config * refactor: rename package synthesis to ratio * refactor: add eval file table and refactor related code * fix: calling large models in parallel during evaluation
This commit is contained in:
@@ -0,0 +1,207 @@
|
||||
import json
|
||||
import uuid
|
||||
import asyncio
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.exception import BusinessErrorCodeEnum, BusinessException
|
||||
from app.core.logging import get_logger
|
||||
from app.db.models import EvaluationItem, EvaluationTask, DatasetFiles
|
||||
from app.db.models.data_evaluation import EvaluationFile
|
||||
from app.db.models.data_synthesis import DataSynthesisFileInstance, SynthesisData
|
||||
from app.db.session import AsyncSessionLocal
|
||||
from app.module.evaluation.schema.evaluation import SourceType
|
||||
from app.module.shared.schema import TaskStatus
|
||||
from app.module.shared.util.model_chat import call_openai_style_model
|
||||
from app.module.evaluation.schema.prompt import get_prompt
|
||||
from app.module.shared.util.structured_file import StructuredFileHandlerFactory
|
||||
from app.module.system.service.common_service import get_model_by_id
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class EvaluationExecutor:
|
||||
def __init__(self, db: AsyncSession, task: EvaluationTask):
|
||||
self.db = db
|
||||
self.task = task
|
||||
|
||||
async def save_eval_items(self):
|
||||
pass
|
||||
|
||||
def get_eval_prompt(self, item: EvaluationItem) -> str:
|
||||
prompt_text = get_prompt(self.task.task_type, json.loads(self.task.eval_config).get("dimensions"))
|
||||
eval_content = json.loads(item.eval_content)
|
||||
if self.task.task_type == "QA":
|
||||
prompt_text = ((prompt_text.replace("{content}", eval_content.get("input"))
|
||||
.replace("{question}", eval_content.get("instruction")))
|
||||
.replace("{answer}", eval_content.get("output")))
|
||||
return prompt_text
|
||||
|
||||
async def execute(self):
|
||||
eval_config = json.loads(self.task.eval_config)
|
||||
model_config = await get_model_by_id(self.db, eval_config.get("model_id"))
|
||||
semaphore = asyncio.Semaphore(10)
|
||||
files = (await self.db.execute(
|
||||
select(EvaluationFile).where(EvaluationFile.task_id == self.task.id)
|
||||
)).scalars().all()
|
||||
for file in files:
|
||||
items = (await self.db.execute(
|
||||
select(EvaluationItem).where(EvaluationItem.task_id == self.task.id)
|
||||
.where(EvaluationItem.file_id == file.file_id)
|
||||
)).scalars().all()
|
||||
tasks = [
|
||||
self.evaluate_item(model_config, item, semaphore)
|
||||
for item in items
|
||||
]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
file.evaluated_count = len(items)
|
||||
await self.db.commit()
|
||||
|
||||
async def evaluate_item(self, model_config, item: EvaluationItem, semaphore: asyncio.Semaphore):
|
||||
async with semaphore:
|
||||
prompt_text = self.get_eval_prompt(item)
|
||||
resp_text = await asyncio.to_thread(
|
||||
call_openai_style_model, model_config.base_url, model_config.api_key, model_config.model_name,
|
||||
prompt_text,
|
||||
)
|
||||
item.eval_result = resp_text
|
||||
item.status = TaskStatus.COMPLETED.value
|
||||
await self.db.commit()
|
||||
|
||||
|
||||
def get_source_type(self) -> SourceType:
|
||||
pass
|
||||
|
||||
|
||||
class DatasetEvaluationExecutor(EvaluationExecutor):
|
||||
def __init__(self, db: AsyncSession, task: EvaluationTask):
|
||||
super().__init__(db, task)
|
||||
|
||||
async def save_eval_items(self):
|
||||
dataset_files = ((await self.db.execute(select(DatasetFiles)
|
||||
.where(DatasetFiles.dataset_id == self.task.source_id)))
|
||||
.scalars().all())
|
||||
handler = StructuredFileHandlerFactory().get_handler(self.task.task_type)
|
||||
for dataset_file in dataset_files:
|
||||
if dataset_file.file_type.upper() != "JSON" and dataset_file.file_type.upper() != "JSONL":
|
||||
continue
|
||||
items = handler.get_items_from_file(dataset_file.file_path)
|
||||
logger.info(f"parse {len(items)} items from file {dataset_file.file_name}")
|
||||
for item in items:
|
||||
self.db.add(EvaluationItem(
|
||||
id=str(uuid.uuid4()),
|
||||
task_id=self.task.id,
|
||||
file_id=dataset_file.id,
|
||||
item_id=item.get("id") if item.get("id") else str(uuid.uuid4()),
|
||||
eval_content=json.dumps(item, ensure_ascii=False),
|
||||
status=TaskStatus.PENDING.value,
|
||||
created_by=self.task.created_by,
|
||||
updated_by=self.task.updated_by,
|
||||
))
|
||||
self.db.add(EvaluationFile(
|
||||
id=str(uuid.uuid4()),
|
||||
task_id=self.task.id,
|
||||
file_id=dataset_file.id,
|
||||
file_name=dataset_file.file_name,
|
||||
total_count=len(items),
|
||||
evaluated_count=0,
|
||||
created_by=self.task.created_by,
|
||||
updated_by=self.task.updated_by,
|
||||
))
|
||||
|
||||
def get_source_type(self) -> SourceType:
|
||||
return SourceType.DATASET
|
||||
|
||||
|
||||
class SynthesisEvaluationExecutor(EvaluationExecutor):
|
||||
def __init__(self, db: AsyncSession, task: EvaluationTask):
|
||||
super().__init__(db, task)
|
||||
|
||||
async def save_eval_items(self):
|
||||
synthesis_files = ((await self.db.execute(select(DataSynthesisFileInstance)
|
||||
.where(DataSynthesisFileInstance.task_id == self.task.source_id)))
|
||||
.scalars().all())
|
||||
for synthesis_file in synthesis_files:
|
||||
synthesis_datas = ((await self.db.execute(select(SynthesisData)
|
||||
.where(SynthesisData.synthesis_file_instance_id == synthesis_file.id)))
|
||||
.scalars().all())
|
||||
logger.info(f"get {len(synthesis_datas)} items from file {synthesis_file.file_name}")
|
||||
for synthesis_data in synthesis_datas:
|
||||
self.db.add(EvaluationItem(
|
||||
id=str(uuid.uuid4()),
|
||||
task_id=self.task.id,
|
||||
file_id=synthesis_file.id,
|
||||
item_id=synthesis_data.id,
|
||||
eval_content=synthesis_data.data,
|
||||
status=TaskStatus.PENDING.value,
|
||||
created_by=self.task.created_by,
|
||||
updated_by=self.task.updated_by,
|
||||
))
|
||||
self.db.add(EvaluationFile(
|
||||
id=str(uuid.uuid4()),
|
||||
task_id=self.task.id,
|
||||
file_id=synthesis_file.id,
|
||||
file_name=synthesis_file.file_name,
|
||||
total_count=len(synthesis_datas),
|
||||
evaluated_count=0,
|
||||
created_by=self.task.created_by,
|
||||
updated_by=self.task.updated_by,
|
||||
))
|
||||
pass
|
||||
|
||||
def get_source_type(self) -> SourceType:
|
||||
return SourceType.SYNTHESIS
|
||||
|
||||
|
||||
class EvaluationExecutorFactory:
|
||||
def __init__(self, db: AsyncSession, task: EvaluationTask):
|
||||
self.db = db
|
||||
self.executors: list[EvaluationExecutor] = []
|
||||
self.executors.append(DatasetEvaluationExecutor(db, task))
|
||||
self.executors.append(SynthesisEvaluationExecutor(db, task))
|
||||
|
||||
def get_executor(self, source_type: str) -> EvaluationExecutor:
|
||||
for executor in self.executors:
|
||||
if executor.get_source_type().value == source_type:
|
||||
return executor
|
||||
raise BusinessException(BusinessErrorCodeEnum.TASK_TYPE_ERROR.value)
|
||||
|
||||
|
||||
class EvaluationTaskService:
|
||||
|
||||
@staticmethod
|
||||
async def run_evaluation_task(task_id: str):
|
||||
"""
|
||||
Background worker to run evaluations.
|
||||
- task_id: id of EvaluationTaskModel
|
||||
"""
|
||||
logger.info(f"Background evaluation worker started add items for task {task_id}")
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
task = await session.execute(select(EvaluationTask).where(EvaluationTask.id == task_id))
|
||||
task = task.scalar_one_or_none()
|
||||
factory = EvaluationExecutorFactory(session, task)
|
||||
executor = factory.get_executor(task.source_type)
|
||||
await executor.save_eval_items()
|
||||
task.status = TaskStatus.RUNNING.value
|
||||
except Exception as e:
|
||||
logger.error(f"Background worker encountered error for task {task_id}: {e}")
|
||||
task.status = TaskStatus.FAILED.value
|
||||
finally:
|
||||
await session.commit()
|
||||
|
||||
logger.info(f"Background evaluation worker started for task {task_id}")
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
task = await session.execute(select(EvaluationTask).where(EvaluationTask.id == task_id))
|
||||
task = task.scalar_one_or_none()
|
||||
factory = EvaluationExecutorFactory(session, task)
|
||||
executor = factory.get_executor(task.source_type)
|
||||
await executor.execute()
|
||||
logger.info(f"Background evaluation worker finished for task {task_id}")
|
||||
task.status = TaskStatus.COMPLETED.value
|
||||
except Exception as e:
|
||||
logger.error(f"Background worker encountered error for task {task_id}: {e}")
|
||||
task.status = TaskStatus.FAILED.value
|
||||
finally:
|
||||
await session.commit()
|
||||
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
Service for managing evaluation prompt templates.
|
||||
"""
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from app.module.evaluation.schema.prompt import EVALUATION_PROMPT_TEMPLATE
|
||||
from app.module.evaluation.schema.prompt_template import (
|
||||
PromptTemplateItem,
|
||||
PromptTemplateDimension,
|
||||
PromptTemplateResponse
|
||||
)
|
||||
|
||||
|
||||
class PromptTemplateService:
|
||||
"""Service for managing evaluation prompt templates"""
|
||||
|
||||
@staticmethod
|
||||
def get_prompt_templates() -> PromptTemplateResponse:
|
||||
"""
|
||||
Get all available prompt templates
|
||||
|
||||
Returns:
|
||||
PromptTemplateResponse containing all prompt templates
|
||||
"""
|
||||
templates = []
|
||||
|
||||
for template in EVALUATION_PROMPT_TEMPLATE:
|
||||
# Convert dimensions to the proper schema
|
||||
dimensions = [
|
||||
PromptTemplateDimension(
|
||||
dimension=dim.get("dimension"),
|
||||
description=dim.get("description", "")
|
||||
)
|
||||
for dim in template.get("defaultDimensions", [])
|
||||
]
|
||||
|
||||
# Create template item
|
||||
template_item = PromptTemplateItem(
|
||||
evalType=template.get("evalType", ""),
|
||||
defaultDimensions=dimensions,
|
||||
prompt=template.get("prompt", "")
|
||||
)
|
||||
templates.append(template_item)
|
||||
|
||||
return PromptTemplateResponse(templates=templates)
|
||||
Reference in New Issue
Block a user