feature: add data-evaluation

* feature: add evaluation task management function

* feature: add evaluation task detail page

* fix: delete duplicate definition for table t_model_config

* refactor: rename package synthesis to ratio

* refactor: add eval file table and  refactor related code

* fix: calling large models in parallel during evaluation
This commit is contained in:
hefanli
2025-12-04 09:23:54 +08:00
committed by GitHub
parent 265e284fb8
commit 1d19cd3a62
52 changed files with 2882 additions and 1244 deletions

View File

@@ -2,8 +2,9 @@ from fastapi import APIRouter
from .system.interface import router as system_router
from .annotation.interface import router as annotation_router
from .synthesis.interface import router as ratio_router
from .ratio.interface import router as ratio_router
from .generation.interface import router as generation_router
from .evaluation.interface import router as evaluation_router
router = APIRouter(
prefix="/api"
@@ -13,5 +14,6 @@ router.include_router(system_router)
router.include_router(annotation_router)
router.include_router(ratio_router)
router.include_router(generation_router)
router.include_router(evaluation_router)
__all__ = ["router"]

View File

@@ -46,8 +46,7 @@ class DatasetFileTag(BaseModel):
tags.append(tag_values)
# 如果 from_name 不为空,添加前缀
if self.from_name:
tags = [f"{self.from_name} {tag}" for tag in tags]
tags = [f"{self.from_name}@{tag}" for tag in tags]
return tags

View File

@@ -0,0 +1,11 @@
from fastapi import APIRouter
router = APIRouter(
prefix="/evaluation",
tags = ["evaluation"]
)
# Include sub-routers
from .evaluation import router as evaluation_router
router.include_router(evaluation_router)

View File

@@ -0,0 +1,429 @@
import asyncio
import uuid
import json
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, or_, text, and_
from pydantic import ValidationError
from app.core.logging import get_logger
from app.db.models.data_evaluation import EvaluationFile
from app.db.session import get_db
from app.db.models import EvaluationTask, EvaluationItem, DatasetFiles
from app.module.evaluation.schema.evaluation import (
CreateEvaluationTaskRequest,
PagedEvaluationTaskResponse,
EvaluationTaskDetailResponse,
PagedEvaluationItemsResponse,
EvaluationItemResponse, PagedEvaluationFilesResponse, EvaluationFileResponse
)
from app.module.evaluation.schema.prompt import get_prompt
from app.module.evaluation.schema.prompt_template import PromptTemplateResponse
from app.module.evaluation.service.prompt_template_service import PromptTemplateService
from app.module.evaluation.service.evaluation import EvaluationTaskService
from app.module.shared.schema.common import StandardResponse, TaskStatus
router = APIRouter(
prefix="",
tags=["evaluation"],
)
logger = get_logger(__name__)
@router.get("/prompt-templates", response_model=StandardResponse[PromptTemplateResponse])
async def get_prompt_templates():
"""
Get all available evaluation prompt templates
Returns:
StandardResponse with list of prompt templates
"""
try:
templates = PromptTemplateService.get_prompt_templates()
return StandardResponse(
code=200,
message="Success",
data=templates
)
except Exception as e:
logger.error(f"Failed to get prompt templates: {str(e)}")
raise HTTPException(
status_code=500,
detail="Failed to retrieve prompt templates"
)
@router.post("/tasks", response_model=StandardResponse[EvaluationTaskDetailResponse], status_code=201)
async def create_evaluation_task(
request: CreateEvaluationTaskRequest,
db: AsyncSession = Depends(get_db)
):
"""
创建评估任务
Args:
request: 创建评估任务请求
db: 数据库会话
Returns:
StandardResponse[EvaluationTaskDetailResponse]: 创建的任务详情
"""
try:
# 检查任务名称是否已存在
existing_task = await db.execute(
select(EvaluationTask).where(EvaluationTask.name == request.name)
)
if existing_task.scalar_one_or_none():
raise HTTPException(status_code=400, detail=f"Evaluation task with name '{request.name}' already exists")
# 创建评估任务
task = EvaluationTask(
id=str(uuid.uuid4()),
name=request.name,
description=request.description,
task_type=request.task_type,
source_type=request.source_type,
source_id=request.source_id,
source_name=request.source_name,
eval_prompt=request.eval_prompt,
eval_config=json.dumps({
"model_id": request.eval_config.model_id,
"dimensions": request.eval_config.dimensions,
}),
status=TaskStatus.PENDING.value,
eval_process=0.0,
)
db.add(task)
# Commit first to persist the task before scheduling background work
await db.commit()
# Schedule background execution without blocking the current request
asyncio.create_task(EvaluationTaskService.run_evaluation_task(task.id))
# Refresh the task to return latest state
await db.refresh(task)
# 转换响应模型
response = _map_to_task_detail_response(task)
return StandardResponse(
code=200,
message="Evaluation task created successfully",
data=response
)
except ValidationError as e:
await db.rollback()
logger.error(f"Validation error: {str(e)}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
await db.rollback()
logger.error(f"Failed to create evaluation task: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/tasks", response_model=StandardResponse[PagedEvaluationTaskResponse])
async def list_evaluation_tasks(
page: int = Query(1, ge=1, description="页码,从1开始"),
size: int = Query(10, ge=1, le=100, description="每页数量"),
name: Optional[str] = Query(None, description="任务名称模糊查询"),
status: Optional[str] = Query(None, description="任务状态过滤"),
task_type: Optional[str] = Query(None, description="任务类型过滤"),
db: AsyncSession = Depends(get_db),
):
"""
分页查询评估任务
Args:
page: 页码,从1开始
size: 每页数量
name: 任务名称模糊查询
status: 任务状态过滤
task_type: 任务类型过滤
db: 数据库会话
Returns:
StandardResponse[PagedEvaluationTaskResponse]: 分页的评估任务列表
"""
try:
# 构建查询条件
query = select(EvaluationTask)
if name:
query = query.where(EvaluationTask.name.ilike(f"%{name}%"))
if status:
query = query.where(EvaluationTask.status == status)
if task_type:
query = query.where(EvaluationTask.task_type == task_type)
# 获取总数
count_query = select(func.count()).select_from(query.subquery())
total = (await db.execute(count_query)).scalar_one()
# 分页查询
offset = (page - 1) * size
tasks = (await db.execute(
query.order_by(EvaluationTask.created_at.desc())
.offset(offset)
.limit(size)
)).scalars().all()
# 转换为响应模型
items = [_map_to_task_detail_response(task) for task in tasks]
total_pages = (total + size - 1) // size if size > 0 else 0
return StandardResponse(
code=200,
message="Success",
data=PagedEvaluationTaskResponse(
content=items,
totalElements=total,
totalPages=total_pages,
page=page,
size=size,
)
)
except Exception as e:
logger.error(f"Failed to list evaluation tasks: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/tasks/{task_id}/files", response_model=StandardResponse[PagedEvaluationFilesResponse])
async def list_evaluation_items(
task_id: str,
page: int = Query(1, ge=1, description="页码,从1开始"),
size: int = Query(10, ge=1, le=100, description="每页数量"),
db: AsyncSession = Depends(get_db),
):
"""
分页查询评估文件
Args:
task_id: 评估任务ID
page: 页码,从1开始
size: 每页数量
db: 数据库会话
Returns:
StandardResponse[PagedEvaluationFilesResponse]: 分页的评估文件列表
"""
try:
task = await db.get(EvaluationTask, task_id)
if not task:
raise HTTPException(status_code=404, detail="Evaluation task not found")
offset = (page - 1) * size
query = select(EvaluationFile).where(EvaluationFile.task_id == task_id)
count_query = select(func.count()).select_from(query.subquery())
total = (await db.execute(count_query)).scalar_one()
files = (await db.execute(query.offset(offset).limit(size))).scalars().all()
total_pages = (total + size - 1) // size if size > 0 else 0
file_responses = [
EvaluationFileResponse(
taskId=file.task_id,
fileId=file.file_id,
fileName=file.file_name,
totalCount=file.total_count,
evaluatedCount=file.evaluated_count,
pendingCount=file.total_count - file.evaluated_count
)
for file in files
]
return StandardResponse(
code=200,
message="Success",
data=PagedEvaluationFilesResponse(
content=file_responses,
totalElements=total,
totalPages=total_pages,
page=page,
size=size,
)
)
except Exception as e:
logger.error(f"Failed to list evaluation items: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/tasks/{task_id}/items", response_model=StandardResponse[PagedEvaluationItemsResponse])
async def list_evaluation_items(
task_id: str,
page: int = Query(1, ge=1, description="页码,从1开始"),
size: int = Query(10, ge=1, le=100, description="每页数量"),
status: Optional[str] = Query(None, description="状态过滤"),
file_id: Optional[str] = Query(None, description="文件过滤"),
db: AsyncSession = Depends(get_db),
):
"""
分页查询评估条目
Args:
task_id: 评估任务ID
page: 页码,从1开始
size: 每页数量
status: 状态过滤
file_id: 文件过滤
db: 数据库会话
Returns:
StandardResponse[PagedEvaluationItemsResponse]: 分页的评估条目列表
"""
try:
# 检查任务是否存在
task = await db.get(EvaluationTask, task_id)
if not task:
raise HTTPException(status_code=404, detail="Evaluation task not found")
# 构建查询条件
query = select(EvaluationItem).where(EvaluationItem.task_id == task_id)
if status:
query = query.where(EvaluationItem.status == status)
if file_id:
query = query.where(EvaluationItem.file_id == file_id)
# 获取总数
count_query = select(func.count()).select_from(query.subquery())
total = (await db.execute(count_query)).scalar_one()
# 分页查询
offset = (page - 1) * size
items = (await db.execute(query.offset(offset).limit(size))).scalars().all()
# 转换为响应模型
item_responses = [
EvaluationItemResponse(
id=item.id,
taskId=item.task_id,
itemId=item.item_id,
fileId=item.file_id,
evalContent=json.loads(item.eval_content),
evalScore=float(item.eval_score) if item.eval_score else None,
evalResult=json.loads(item.eval_result),
status=item.status
)
for item in items
]
total_pages = (total + size - 1) // size if size > 0 else 0
return StandardResponse(
code=200,
message="Success",
data=PagedEvaluationItemsResponse(
content=item_responses,
totalElements=total,
totalPages=total_pages,
page=page,
size=size,
)
)
except Exception as e:
logger.error(f"Failed to list evaluation items: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/tasks/{task_id}", response_model=StandardResponse[EvaluationTaskDetailResponse])
async def get_evaluation_task(
task_id: str,
db: AsyncSession = Depends(get_db),
):
"""
获取评估任务详情
Args:
task_id: 任务ID
db: 数据库会话
Returns:
StandardResponse[EvaluationTaskDetailResponse]: 评估任务详情
"""
try:
task = await db.get(EvaluationTask, task_id)
if not task:
raise HTTPException(status_code=404, detail="Evaluation task not found")
# 转换为响应模型
response = _map_to_task_detail_response(task)
return StandardResponse(
code=200,
message="Success",
data=response
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to get evaluation task: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.delete("/tasks", response_model=StandardResponse[str], status_code=200)
async def delete_eval_tasks(
ids: list[str] = Query(..., description="要删除的评估任务ID列表"),
db: AsyncSession = Depends(get_db),
):
"""
删除评估任务
Args:
ids: 任务ID
db: 数据库会话
Returns:
StandardResponse[str]: 删除结果
"""
try:
# 检查任务是否存在
task_id = ids[0]
task = await db.get(EvaluationTask, task_id)
if not task:
raise HTTPException(status_code=404, detail="Evaluation task not found")
# 删除评估项
await db.execute(
EvaluationItem.__table__.delete()
.where(EvaluationItem.task_id == task_id)
)
# 删除任务
await db.delete(task)
await db.commit()
return StandardResponse(
code=200,
message="Evaluation task deleted successfully",
data="success"
)
except HTTPException:
await db.rollback()
raise
except Exception as e:
await db.rollback()
logger.error(f"Failed to delete evaluation task: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
def _map_to_task_detail_response(
task: EvaluationTask
) -> EvaluationTaskDetailResponse:
"""将数据库模型转换为任务详情响应模型"""
task_response = EvaluationTaskDetailResponse(
id=task.id,
name=task.name,
description=task.description,
taskType=task.task_type,
sourceType=task.source_type,
sourceId=task.source_id,
sourceName=task.source_name,
status=task.status,
evalProcess=task.eval_process,
evalPrompt=task.eval_prompt,
evalConfig=json.loads(task.eval_config),
createdAt=task.created_at.isoformat() if task.created_at else None,
updatedAt=task.updated_at.isoformat() if task.updated_at else None,
)
task_response.eval_prompt = get_prompt(task_response.task_type, task_response.eval_config.get("dimensions"))
return task_response

View File

@@ -0,0 +1,101 @@
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field, field_validator
from enum import Enum
from app.core.logging import get_logger
from app.module.shared.schema.common import TaskStatus
logger = get_logger(__name__)
class EvaluationConfig(BaseModel):
"""评估配置项"""
model_id: str = Field(..., alias="modelId", description="模型id")
dimensions: list[dict] = Field(..., alias="dimensions", description="评估维度")
class CreateEvaluationTaskRequest(BaseModel):
"""创建评估任务请求"""
name: str = Field(..., description="评估任务名称")
description: Optional[str] = Field(None, description="评估任务描述")
task_type: str = Field(..., alias="taskType", description="评估任务类型:QA/QUALITY/COMPATIBILITY/VALUE")
source_type: str = Field(..., alias="sourceType", description="待评估对象类型:DATASET/SYNTHESIS")
source_id: str = Field(..., alias="sourceId", description="待评估对象ID")
source_name: str = Field(..., alias="sourceName", description="待评估对象名称")
eval_method: str = Field("AUTO", alias="evalMethod", description="评估提示词")
eval_prompt: Optional[str] = Field(None, alias="evalPrompt", description="评估提示词")
eval_config: EvaluationConfig = Field(..., alias="evalConfig", description="评估配置项列表")
class EvaluationTaskItem(BaseModel):
"""评估任务列表项"""
id: str = Field(..., description="任务ID")
name: str = Field(..., description="任务名称")
description: Optional[str] = Field(None, description="任务描述")
task_type: Optional[str] = Field(..., alias="taskType", description="任务类型")
source_type: Optional[str] = Field(..., alias="sourceType", description="数据源类型")
source_id: Optional[str] = Field(..., alias="sourceId", description="数据源ID")
source_name: Optional[str] = Field(None, alias="sourceName", description="数据源名称")
status: TaskStatus = Field(..., description="任务状态")
eval_process: Optional[float] = Field(0, alias="evalProcess", description="评估进度")
created_at: Optional[str] = Field(None, alias="createdAt", description="创建时间")
updated_at: Optional[str] = Field(None, alias="updatedAt", description="更新时间")
class PagedEvaluationTaskResponse(BaseModel):
"""分页评估任务响应"""
content: List[EvaluationTaskItem]
total_elements: int = Field(..., alias="totalElements")
total_pages: int = Field(..., alias="totalPages")
page: int
size: int
class EvaluationTaskDetailResponse(EvaluationTaskItem):
"""评估任务详情响应"""
eval_prompt: Optional[str] = Field(None, alias="evalPrompt", description="评估提示词")
eval_config: Optional[Dict[str, Any]] = Field(None, alias="evalConfig", description="评估配置")
eval_result: Optional[Dict[str, Any]] = Field(None, alias="evalResult", description="评估结果")
class EvaluationItemResponse(BaseModel):
"""评估条目响应"""
id: str = Field(..., description="条目ID")
task_id: str = Field(..., alias="taskId", description="任务ID")
file_id: str = Field(..., alias="fileId", description="文件ID")
item_id: str = Field(..., alias="itemId", description="评估项ID")
eval_content: Optional[Dict[str, Any]] = Field(None, alias="evalContent", description="评估内容")
eval_score: Optional[float] = Field(None, alias="evalScore", description="评估分数")
eval_result: Optional[Dict[str, Any]] = Field(None, alias="evalResult", description="评估结果详情")
status: str = Field(..., description="评估状态")
class EvaluationFileResponse(BaseModel):
"""评估文件响应"""
task_id: str = Field(..., alias="taskId", description="任务ID")
file_id: str = Field(..., alias="fileId", description="文件ID")
file_name: str = Field(..., alias="fileName", description="文件名")
total_count: int = Field(..., alias="totalCount", description="总数")
evaluated_count: int = Field(..., alias="evaluatedCount", description="已评估数")
pending_count: int = Field(..., alias="pendingCount", description="待评估数")
class PagedEvaluationItemsResponse(BaseModel):
"""分页评估任务响应"""
content: List[EvaluationItemResponse]
total_elements: int = Field(..., alias="totalElements")
total_pages: int = Field(..., alias="totalPages")
page: int
size: int
class PagedEvaluationFilesResponse(BaseModel):
"""分页评估任务响应"""
content: List[EvaluationFileResponse]
total_elements: int = Field(..., alias="totalElements")
total_pages: int = Field(..., alias="totalPages")
page: int
size: int
class SourceType(Enum):
DATASET = "DATASET"
SYNTHESIS = "SYNTHESIS"

View File

@@ -0,0 +1,87 @@
EVALUATION_PROMPT_TEMPLATE = [
{
"evalType": "QA",
"defaultDimensions": [
{
"dimension": "问题是否独立",
"description": "仅分析问题,问题的主体和客体都比较明确,即使有省略,也符合语言习惯。在不需要补充其他信息的情况下不会引起疑惑。"
},
{
"dimension": "语法是否错误",
"description": "问题为疑问句,答案为陈述句; 不存在词语搭配不当的情况;连接词和标点符号不存在错用情况;逻辑混乱的情况不存在;语法结构都正确且完整。"
},
{
"dimension": "回答是否有针对性",
"description": "回答应对问题中的所有疑问点提供正面、直接的回答,不应引起疑惑。同时,答案不应有任何内容的遗漏,需构成一个完整的陈述。"
}
],
"prompt": """
# Role: 问答对质量评估专家
## Profile:
- Description: 你是一名专业的对话文本质量评估专家,擅长从多个维度对问答对进行质量评估,为机器学习模型训练提供高质量的数据筛选建议。具备深度学习、自然语言处理和数据科学的专业背景。
## Skills:
1. 能够从多个维度对问答对进行综合评估
2. 擅长识别问答对中的潜在问题,如答案不准确、问题模糊、文本不匹配、逻辑错误等
3. 能够给出具体的改进建议和质量评分,并提供可操作的优化方案
4. 熟悉机器学习训练数据的质量标准和最佳实践
5. 能够区分不同类型的问题(事实性、推理性、创造性)并采用相应的评估标准
## 评估维度:
{dimensions}
## 原始文本块内容:
{content}
## 问题:
{question}
## 答案:
{answer}
## 评估说明:
1. **数据集类型识别**:如果原始文本块内容为空或显示"Distilled Content",说明这是一个蒸馏数据集,没有原始文本参考。请重点评估问题的质量、答案的合理性和逻辑性,以及问答的一致性。
2. **评估原则**:采用严格的评估标准,确保筛选出的数据集能够有效提升模型性能。
## 注意事项:
- 评估结论要具体指出优点和不足,提供可操作的改进建议
- 评估结论控制在150字以内,简洁明了但要涵盖关键信息
## 输出要求:
请按照以下JSON格式输出评估结果,评估结果为Y/N,符合标注输出Y,不符合标准输出N:
{
"result": {{result_example}
},
"evaluation": "这是一个高质量的问答数据集。问题表述清晰具体,答案准确完整且逻辑性强,与原始文本高度相关。建议:可以进一步丰富答案的细节描述。"
}
"""
}
]
def get_dimensions_for_qa(dimensions: list[dict]) -> str:
dimensions_str = "\n"
index = 1
for dimension in dimensions:
dimensions_str += f"### {index}. {dimension.get("dimension")}\n**评估标准:**\n{dimension.get("description")}\n\n"
index += 1
return dimensions_str
def get_result_example_for_qa(dimensions: list[dict]) -> str:
result_example = ""
for dimension in dimensions:
result_example += f'\n "{dimension.get("dimension")}": "Y",'
return result_example
def get_prompt(task_type: str, dimensions: list[dict]) -> str:
template = None
for t in EVALUATION_PROMPT_TEMPLATE:
if t.get("evalType") == task_type:
template = t.get("prompt")
break
if not template:
template = EVALUATION_PROMPT_TEMPLATE[0].get("prompt", "")
if not dimensions or len(dimensions) == 0:
return template
return (template.replace("{dimensions}", get_dimensions_for_qa(dimensions))
.replace("{result_example}", get_result_example_for_qa(dimensions)))

View File

@@ -0,0 +1,29 @@
"""
Schema for evaluation prompt templates.
"""
from typing import List, Dict, Any
from pydantic import BaseModel, Field
class PromptTemplateDimension(BaseModel):
"""A single dimension in the prompt template"""
dimension: str = Field(..., description="Dimension name")
description: str = Field(..., description="Description of the dimension")
class PromptTemplateItem(BaseModel):
"""A single prompt template item"""
evalType: str = Field(..., description="Evaluation type")
defaultDimensions: List[PromptTemplateDimension] = Field(
default_factory=list,
description="List of default dimensions for this evaluation type"
)
prompt: str = Field(..., description="The prompt template string")
class PromptTemplateResponse(BaseModel):
"""Response model for getting prompt templates"""
templates: List[PromptTemplateItem] = Field(
...,
description="List of available prompt templates"
)

View File

@@ -0,0 +1,207 @@
import json
import uuid
import asyncio
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exception import BusinessErrorCodeEnum, BusinessException
from app.core.logging import get_logger
from app.db.models import EvaluationItem, EvaluationTask, DatasetFiles
from app.db.models.data_evaluation import EvaluationFile
from app.db.models.data_synthesis import DataSynthesisFileInstance, SynthesisData
from app.db.session import AsyncSessionLocal
from app.module.evaluation.schema.evaluation import SourceType
from app.module.shared.schema import TaskStatus
from app.module.shared.util.model_chat import call_openai_style_model
from app.module.evaluation.schema.prompt import get_prompt
from app.module.shared.util.structured_file import StructuredFileHandlerFactory
from app.module.system.service.common_service import get_model_by_id
logger = get_logger(__name__)
class EvaluationExecutor:
def __init__(self, db: AsyncSession, task: EvaluationTask):
self.db = db
self.task = task
async def save_eval_items(self):
pass
def get_eval_prompt(self, item: EvaluationItem) -> str:
prompt_text = get_prompt(self.task.task_type, json.loads(self.task.eval_config).get("dimensions"))
eval_content = json.loads(item.eval_content)
if self.task.task_type == "QA":
prompt_text = ((prompt_text.replace("{content}", eval_content.get("input"))
.replace("{question}", eval_content.get("instruction")))
.replace("{answer}", eval_content.get("output")))
return prompt_text
async def execute(self):
eval_config = json.loads(self.task.eval_config)
model_config = await get_model_by_id(self.db, eval_config.get("model_id"))
semaphore = asyncio.Semaphore(10)
files = (await self.db.execute(
select(EvaluationFile).where(EvaluationFile.task_id == self.task.id)
)).scalars().all()
for file in files:
items = (await self.db.execute(
select(EvaluationItem).where(EvaluationItem.task_id == self.task.id)
.where(EvaluationItem.file_id == file.file_id)
)).scalars().all()
tasks = [
self.evaluate_item(model_config, item, semaphore)
for item in items
]
await asyncio.gather(*tasks, return_exceptions=True)
file.evaluated_count = len(items)
await self.db.commit()
async def evaluate_item(self, model_config, item: EvaluationItem, semaphore: asyncio.Semaphore):
async with semaphore:
prompt_text = self.get_eval_prompt(item)
resp_text = await asyncio.to_thread(
call_openai_style_model, model_config.base_url, model_config.api_key, model_config.model_name,
prompt_text,
)
item.eval_result = resp_text
item.status = TaskStatus.COMPLETED.value
await self.db.commit()
def get_source_type(self) -> SourceType:
pass
class DatasetEvaluationExecutor(EvaluationExecutor):
def __init__(self, db: AsyncSession, task: EvaluationTask):
super().__init__(db, task)
async def save_eval_items(self):
dataset_files = ((await self.db.execute(select(DatasetFiles)
.where(DatasetFiles.dataset_id == self.task.source_id)))
.scalars().all())
handler = StructuredFileHandlerFactory().get_handler(self.task.task_type)
for dataset_file in dataset_files:
if dataset_file.file_type.upper() != "JSON" and dataset_file.file_type.upper() != "JSONL":
continue
items = handler.get_items_from_file(dataset_file.file_path)
logger.info(f"parse {len(items)} items from file {dataset_file.file_name}")
for item in items:
self.db.add(EvaluationItem(
id=str(uuid.uuid4()),
task_id=self.task.id,
file_id=dataset_file.id,
item_id=item.get("id") if item.get("id") else str(uuid.uuid4()),
eval_content=json.dumps(item, ensure_ascii=False),
status=TaskStatus.PENDING.value,
created_by=self.task.created_by,
updated_by=self.task.updated_by,
))
self.db.add(EvaluationFile(
id=str(uuid.uuid4()),
task_id=self.task.id,
file_id=dataset_file.id,
file_name=dataset_file.file_name,
total_count=len(items),
evaluated_count=0,
created_by=self.task.created_by,
updated_by=self.task.updated_by,
))
def get_source_type(self) -> SourceType:
return SourceType.DATASET
class SynthesisEvaluationExecutor(EvaluationExecutor):
def __init__(self, db: AsyncSession, task: EvaluationTask):
super().__init__(db, task)
async def save_eval_items(self):
synthesis_files = ((await self.db.execute(select(DataSynthesisFileInstance)
.where(DataSynthesisFileInstance.task_id == self.task.source_id)))
.scalars().all())
for synthesis_file in synthesis_files:
synthesis_datas = ((await self.db.execute(select(SynthesisData)
.where(SynthesisData.synthesis_file_instance_id == synthesis_file.id)))
.scalars().all())
logger.info(f"get {len(synthesis_datas)} items from file {synthesis_file.file_name}")
for synthesis_data in synthesis_datas:
self.db.add(EvaluationItem(
id=str(uuid.uuid4()),
task_id=self.task.id,
file_id=synthesis_file.id,
item_id=synthesis_data.id,
eval_content=synthesis_data.data,
status=TaskStatus.PENDING.value,
created_by=self.task.created_by,
updated_by=self.task.updated_by,
))
self.db.add(EvaluationFile(
id=str(uuid.uuid4()),
task_id=self.task.id,
file_id=synthesis_file.id,
file_name=synthesis_file.file_name,
total_count=len(synthesis_datas),
evaluated_count=0,
created_by=self.task.created_by,
updated_by=self.task.updated_by,
))
pass
def get_source_type(self) -> SourceType:
return SourceType.SYNTHESIS
class EvaluationExecutorFactory:
def __init__(self, db: AsyncSession, task: EvaluationTask):
self.db = db
self.executors: list[EvaluationExecutor] = []
self.executors.append(DatasetEvaluationExecutor(db, task))
self.executors.append(SynthesisEvaluationExecutor(db, task))
def get_executor(self, source_type: str) -> EvaluationExecutor:
for executor in self.executors:
if executor.get_source_type().value == source_type:
return executor
raise BusinessException(BusinessErrorCodeEnum.TASK_TYPE_ERROR.value)
class EvaluationTaskService:
@staticmethod
async def run_evaluation_task(task_id: str):
"""
Background worker to run evaluations.
- task_id: id of EvaluationTaskModel
"""
logger.info(f"Background evaluation worker started add items for task {task_id}")
async with AsyncSessionLocal() as session:
try:
task = await session.execute(select(EvaluationTask).where(EvaluationTask.id == task_id))
task = task.scalar_one_or_none()
factory = EvaluationExecutorFactory(session, task)
executor = factory.get_executor(task.source_type)
await executor.save_eval_items()
task.status = TaskStatus.RUNNING.value
except Exception as e:
logger.error(f"Background worker encountered error for task {task_id}: {e}")
task.status = TaskStatus.FAILED.value
finally:
await session.commit()
logger.info(f"Background evaluation worker started for task {task_id}")
async with AsyncSessionLocal() as session:
try:
task = await session.execute(select(EvaluationTask).where(EvaluationTask.id == task_id))
task = task.scalar_one_or_none()
factory = EvaluationExecutorFactory(session, task)
executor = factory.get_executor(task.source_type)
await executor.execute()
logger.info(f"Background evaluation worker finished for task {task_id}")
task.status = TaskStatus.COMPLETED.value
except Exception as e:
logger.error(f"Background worker encountered error for task {task_id}: {e}")
task.status = TaskStatus.FAILED.value
finally:
await session.commit()

View File

@@ -0,0 +1,45 @@
"""
Service for managing evaluation prompt templates.
"""
from typing import List, Dict, Any
from app.module.evaluation.schema.prompt import EVALUATION_PROMPT_TEMPLATE
from app.module.evaluation.schema.prompt_template import (
PromptTemplateItem,
PromptTemplateDimension,
PromptTemplateResponse
)
class PromptTemplateService:
"""Service for managing evaluation prompt templates"""
@staticmethod
def get_prompt_templates() -> PromptTemplateResponse:
"""
Get all available prompt templates
Returns:
PromptTemplateResponse containing all prompt templates
"""
templates = []
for template in EVALUATION_PROMPT_TEMPLATE:
# Convert dimensions to the proper schema
dimensions = [
PromptTemplateDimension(
dimension=dim.get("dimension"),
description=dim.get("description", "")
)
for dim in template.get("defaultDimensions", [])
]
# Create template item
template_item = PromptTemplateItem(
evalType=template.get("evalType", ""),
defaultDimensions=dimensions,
prompt=template.get("prompt", "")
)
templates.append(template_item)
return PromptTemplateResponse(templates=templates)

View File

@@ -13,7 +13,7 @@ from app.db.models import Dataset
from app.db.session import get_db
from app.module.dataset import DatasetManagementService
from app.module.shared.schema import StandardResponse, TaskStatus
from app.module.synthesis.schema.ratio_task import (
from app.module.ratio.schema.ratio_task import (
CreateRatioTaskResponse,
CreateRatioTaskRequest,
PagedRatioTaskResponse,
@@ -21,7 +21,7 @@ from app.module.synthesis.schema.ratio_task import (
TargetDatasetInfo,
RatioTaskDetailResponse,
)
from app.module.synthesis.service.ratio_task import RatioTaskService
from app.module.ratio.service.ratio_task import RatioTaskService
from app.db.models.ratio_task import RatioInstance, RatioRelation, RatioRelation as RatioRelationModel
router = APIRouter(

View File

@@ -7,9 +7,13 @@ from app.module.shared.schema.common import TaskStatus
logger = get_logger(__name__)
class LabelFilter(BaseModel):
label: Optional[str] = Field(..., description="标签")
value: Optional[str] = Field(None, description="标签值")
class FilterCondition(BaseModel):
date_range: Optional[str] = Field(None, description="数据范围", alias="dateRange")
label: Optional[str] = Field(None, description="标签")
label: Optional[LabelFilter] = Field(None, description="标签")
@field_validator("date_range")
@classmethod

View File

@@ -7,7 +7,6 @@ import shutil
import asyncio
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
@@ -16,7 +15,7 @@ from app.db.models import Dataset, DatasetFiles
from app.db.session import AsyncSessionLocal
from app.module.dataset.schema.dataset_file import DatasetFileTag
from app.module.shared.schema import TaskStatus
from app.module.synthesis.schema.ratio_task import FilterCondition
from app.module.ratio.schema.ratio_task import FilterCondition
logger = get_logger(__name__)
@@ -59,7 +58,10 @@ class RatioTaskService:
counts=int(item.get("counts", 0)),
filter_conditions=json.dumps({
'date_range': item.get("filter_conditions").date_range,
'label': item.get("filter_conditions").label,
'label': {
"label":item.get("filter_conditions").label.label,
"value":item.get("filter_conditions").label.value,
},
})
)
logger.info(f"Relation created: {relation.id}, {relation}, {item}, {config}")
@@ -285,7 +287,7 @@ class RatioTaskService:
try:
# tags could be a list of strings or list of objects with 'name'
tag_names = RatioTaskService.get_all_tags(tags)
return conditions.label in tag_names
return f"{conditions.label.label}@{conditions.label.value}" in tag_names
except Exception as e:
logger.exception(f"Failed to get tags for {file}", e)
return False

View File

@@ -0,0 +1,15 @@
from openai import OpenAI
def call_openai_style_model(base_url, api_key, model_name, prompt, **kwargs):
client = OpenAI(
base_url=base_url,
api_key=api_key
)
response = client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": prompt}],
**kwargs
)
return response.choices[0].message.content

View File

@@ -0,0 +1,85 @@
import json
from enum import Enum
from jsonschema import validate
class ItemTypes(Enum):
QA = "QA"
class StructuredFileItemHandler:
def __init__(self):
pass
def get_item_type(self) -> ItemTypes:
pass
def get_items_from_file(self, file_path: str) -> list[dict]:
pass
def check_file(self) -> bool:
pass
class QAItemHandler(StructuredFileItemHandler):
def __init__(self):
self.schema_alpaca = {
"type": "object",
"properties": {
"instruction": {"type": "string"},
"input": {"type": "string"},
"output": {"type": "string"}
},
"required": ["instruction", "output"],
}
self.schema_alpaca_list = {
"type": "array",
"items": self.schema_alpaca,
}
super().__init__()
def get_item_type(self):
return ItemTypes.QA
def validate_json(self, data):
try:
validate(instance=data, schema=self.schema_alpaca)
return True
except Exception as e:
try:
validate(instance=data, schema=self.schema_alpaca_list)
return True
except Exception as e:
return False
def get_items_from_file(self, file_path: str) -> list[dict]:
file_type = file_path.split(".")[-1].upper()
items = []
if file_type == "JSON":
with open(file_path, "r", encoding="utf-8") as f:
data = json.load(f)
if not self.validate_json(data):
return items
items = data
elif file_type == "JSONL":
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
data = json.loads(line)
if not self.validate_json(data):
continue
items.append(data)
return items
def check_file(self) -> bool:
pass
class StructuredFileHandlerFactory:
def __init__(self):
self.handlers: list[StructuredFileItemHandler] = []
self.handlers.append(QAItemHandler())
def get_handler(self, item_type: str) -> StructuredFileItemHandler:
for handler in self.handlers:
if handler.get_item_type().value == item_type:
return handler
raise ValueError(f"Unsupported item type: {item_type}")