You've already forked DataMate
feature: add data-evaluation
* feature: add evaluation task management function * feature: add evaluation task detail page * fix: delete duplicate definition for table t_model_config * refactor: rename package synthesis to ratio * refactor: add eval file table and refactor related code * fix: calling large models in parallel during evaluation
This commit is contained in:
@@ -0,0 +1,11 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/evaluation",
|
||||
tags = ["evaluation"]
|
||||
)
|
||||
|
||||
# Include sub-routers
|
||||
from .evaluation import router as evaluation_router
|
||||
|
||||
router.include_router(evaluation_router)
|
||||
@@ -0,0 +1,429 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
import json
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, or_, text, and_
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.db.models.data_evaluation import EvaluationFile
|
||||
from app.db.session import get_db
|
||||
from app.db.models import EvaluationTask, EvaluationItem, DatasetFiles
|
||||
from app.module.evaluation.schema.evaluation import (
|
||||
CreateEvaluationTaskRequest,
|
||||
PagedEvaluationTaskResponse,
|
||||
EvaluationTaskDetailResponse,
|
||||
PagedEvaluationItemsResponse,
|
||||
EvaluationItemResponse, PagedEvaluationFilesResponse, EvaluationFileResponse
|
||||
)
|
||||
from app.module.evaluation.schema.prompt import get_prompt
|
||||
from app.module.evaluation.schema.prompt_template import PromptTemplateResponse
|
||||
from app.module.evaluation.service.prompt_template_service import PromptTemplateService
|
||||
from app.module.evaluation.service.evaluation import EvaluationTaskService
|
||||
from app.module.shared.schema.common import StandardResponse, TaskStatus
|
||||
|
||||
router = APIRouter(
|
||||
prefix="",
|
||||
tags=["evaluation"],
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@router.get("/prompt-templates", response_model=StandardResponse[PromptTemplateResponse])
|
||||
async def get_prompt_templates():
|
||||
"""
|
||||
Get all available evaluation prompt templates
|
||||
|
||||
Returns:
|
||||
StandardResponse with list of prompt templates
|
||||
"""
|
||||
try:
|
||||
templates = PromptTemplateService.get_prompt_templates()
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="Success",
|
||||
data=templates
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get prompt templates: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to retrieve prompt templates"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/tasks", response_model=StandardResponse[EvaluationTaskDetailResponse], status_code=201)
|
||||
async def create_evaluation_task(
|
||||
request: CreateEvaluationTaskRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
创建评估任务
|
||||
|
||||
Args:
|
||||
request: 创建评估任务请求
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
StandardResponse[EvaluationTaskDetailResponse]: 创建的任务详情
|
||||
"""
|
||||
try:
|
||||
# 检查任务名称是否已存在
|
||||
existing_task = await db.execute(
|
||||
select(EvaluationTask).where(EvaluationTask.name == request.name)
|
||||
)
|
||||
if existing_task.scalar_one_or_none():
|
||||
raise HTTPException(status_code=400, detail=f"Evaluation task with name '{request.name}' already exists")
|
||||
|
||||
# 创建评估任务
|
||||
task = EvaluationTask(
|
||||
id=str(uuid.uuid4()),
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
task_type=request.task_type,
|
||||
source_type=request.source_type,
|
||||
source_id=request.source_id,
|
||||
source_name=request.source_name,
|
||||
eval_prompt=request.eval_prompt,
|
||||
eval_config=json.dumps({
|
||||
"model_id": request.eval_config.model_id,
|
||||
"dimensions": request.eval_config.dimensions,
|
||||
}),
|
||||
status=TaskStatus.PENDING.value,
|
||||
eval_process=0.0,
|
||||
)
|
||||
|
||||
db.add(task)
|
||||
# Commit first to persist the task before scheduling background work
|
||||
await db.commit()
|
||||
# Schedule background execution without blocking the current request
|
||||
asyncio.create_task(EvaluationTaskService.run_evaluation_task(task.id))
|
||||
|
||||
# Refresh the task to return latest state
|
||||
await db.refresh(task)
|
||||
|
||||
# 转换响应模型
|
||||
response = _map_to_task_detail_response(task)
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="Evaluation task created successfully",
|
||||
data=response
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Validation error: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Failed to create evaluation task: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("/tasks", response_model=StandardResponse[PagedEvaluationTaskResponse])
|
||||
async def list_evaluation_tasks(
|
||||
page: int = Query(1, ge=1, description="页码,从1开始"),
|
||||
size: int = Query(10, ge=1, le=100, description="每页数量"),
|
||||
name: Optional[str] = Query(None, description="任务名称模糊查询"),
|
||||
status: Optional[str] = Query(None, description="任务状态过滤"),
|
||||
task_type: Optional[str] = Query(None, description="任务类型过滤"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
分页查询评估任务
|
||||
|
||||
Args:
|
||||
page: 页码,从1开始
|
||||
size: 每页数量
|
||||
name: 任务名称模糊查询
|
||||
status: 任务状态过滤
|
||||
task_type: 任务类型过滤
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
StandardResponse[PagedEvaluationTaskResponse]: 分页的评估任务列表
|
||||
"""
|
||||
try:
|
||||
# 构建查询条件
|
||||
query = select(EvaluationTask)
|
||||
|
||||
if name:
|
||||
query = query.where(EvaluationTask.name.ilike(f"%{name}%"))
|
||||
if status:
|
||||
query = query.where(EvaluationTask.status == status)
|
||||
if task_type:
|
||||
query = query.where(EvaluationTask.task_type == task_type)
|
||||
|
||||
# 获取总数
|
||||
count_query = select(func.count()).select_from(query.subquery())
|
||||
total = (await db.execute(count_query)).scalar_one()
|
||||
|
||||
# 分页查询
|
||||
offset = (page - 1) * size
|
||||
tasks = (await db.execute(
|
||||
query.order_by(EvaluationTask.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(size)
|
||||
)).scalars().all()
|
||||
|
||||
# 转换为响应模型
|
||||
items = [_map_to_task_detail_response(task) for task in tasks]
|
||||
total_pages = (total + size - 1) // size if size > 0 else 0
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="Success",
|
||||
data=PagedEvaluationTaskResponse(
|
||||
content=items,
|
||||
totalElements=total,
|
||||
totalPages=total_pages,
|
||||
page=page,
|
||||
size=size,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list evaluation tasks: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.get("/tasks/{task_id}/files", response_model=StandardResponse[PagedEvaluationFilesResponse])
|
||||
async def list_evaluation_items(
|
||||
task_id: str,
|
||||
page: int = Query(1, ge=1, description="页码,从1开始"),
|
||||
size: int = Query(10, ge=1, le=100, description="每页数量"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
分页查询评估文件
|
||||
|
||||
Args:
|
||||
task_id: 评估任务ID
|
||||
page: 页码,从1开始
|
||||
size: 每页数量
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
StandardResponse[PagedEvaluationFilesResponse]: 分页的评估文件列表
|
||||
"""
|
||||
try:
|
||||
task = await db.get(EvaluationTask, task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Evaluation task not found")
|
||||
offset = (page - 1) * size
|
||||
query = select(EvaluationFile).where(EvaluationFile.task_id == task_id)
|
||||
count_query = select(func.count()).select_from(query.subquery())
|
||||
total = (await db.execute(count_query)).scalar_one()
|
||||
files = (await db.execute(query.offset(offset).limit(size))).scalars().all()
|
||||
total_pages = (total + size - 1) // size if size > 0 else 0
|
||||
file_responses = [
|
||||
EvaluationFileResponse(
|
||||
taskId=file.task_id,
|
||||
fileId=file.file_id,
|
||||
fileName=file.file_name,
|
||||
totalCount=file.total_count,
|
||||
evaluatedCount=file.evaluated_count,
|
||||
pendingCount=file.total_count - file.evaluated_count
|
||||
)
|
||||
for file in files
|
||||
]
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="Success",
|
||||
data=PagedEvaluationFilesResponse(
|
||||
content=file_responses,
|
||||
totalElements=total,
|
||||
totalPages=total_pages,
|
||||
page=page,
|
||||
size=size,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list evaluation items: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("/tasks/{task_id}/items", response_model=StandardResponse[PagedEvaluationItemsResponse])
|
||||
async def list_evaluation_items(
|
||||
task_id: str,
|
||||
page: int = Query(1, ge=1, description="页码,从1开始"),
|
||||
size: int = Query(10, ge=1, le=100, description="每页数量"),
|
||||
status: Optional[str] = Query(None, description="状态过滤"),
|
||||
file_id: Optional[str] = Query(None, description="文件过滤"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
分页查询评估条目
|
||||
|
||||
Args:
|
||||
task_id: 评估任务ID
|
||||
page: 页码,从1开始
|
||||
size: 每页数量
|
||||
status: 状态过滤
|
||||
file_id: 文件过滤
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
StandardResponse[PagedEvaluationItemsResponse]: 分页的评估条目列表
|
||||
"""
|
||||
try:
|
||||
# 检查任务是否存在
|
||||
task = await db.get(EvaluationTask, task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Evaluation task not found")
|
||||
|
||||
# 构建查询条件
|
||||
query = select(EvaluationItem).where(EvaluationItem.task_id == task_id)
|
||||
|
||||
if status:
|
||||
query = query.where(EvaluationItem.status == status)
|
||||
|
||||
if file_id:
|
||||
query = query.where(EvaluationItem.file_id == file_id)
|
||||
|
||||
# 获取总数
|
||||
count_query = select(func.count()).select_from(query.subquery())
|
||||
total = (await db.execute(count_query)).scalar_one()
|
||||
|
||||
# 分页查询
|
||||
offset = (page - 1) * size
|
||||
items = (await db.execute(query.offset(offset).limit(size))).scalars().all()
|
||||
|
||||
# 转换为响应模型
|
||||
item_responses = [
|
||||
EvaluationItemResponse(
|
||||
id=item.id,
|
||||
taskId=item.task_id,
|
||||
itemId=item.item_id,
|
||||
fileId=item.file_id,
|
||||
evalContent=json.loads(item.eval_content),
|
||||
evalScore=float(item.eval_score) if item.eval_score else None,
|
||||
evalResult=json.loads(item.eval_result),
|
||||
status=item.status
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
|
||||
total_pages = (total + size - 1) // size if size > 0 else 0
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="Success",
|
||||
data=PagedEvaluationItemsResponse(
|
||||
content=item_responses,
|
||||
totalElements=total,
|
||||
totalPages=total_pages,
|
||||
page=page,
|
||||
size=size,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list evaluation items: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("/tasks/{task_id}", response_model=StandardResponse[EvaluationTaskDetailResponse])
|
||||
async def get_evaluation_task(
|
||||
task_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取评估任务详情
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
StandardResponse[EvaluationTaskDetailResponse]: 评估任务详情
|
||||
"""
|
||||
try:
|
||||
task = await db.get(EvaluationTask, task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Evaluation task not found")
|
||||
|
||||
# 转换为响应模型
|
||||
response = _map_to_task_detail_response(task)
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="Success",
|
||||
data=response
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get evaluation task: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.delete("/tasks", response_model=StandardResponse[str], status_code=200)
|
||||
async def delete_eval_tasks(
|
||||
ids: list[str] = Query(..., description="要删除的评估任务ID列表"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
删除评估任务
|
||||
|
||||
Args:
|
||||
ids: 任务ID
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
StandardResponse[str]: 删除结果
|
||||
"""
|
||||
try:
|
||||
# 检查任务是否存在
|
||||
task_id = ids[0]
|
||||
task = await db.get(EvaluationTask, task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Evaluation task not found")
|
||||
|
||||
# 删除评估项
|
||||
await db.execute(
|
||||
EvaluationItem.__table__.delete()
|
||||
.where(EvaluationItem.task_id == task_id)
|
||||
)
|
||||
|
||||
# 删除任务
|
||||
await db.delete(task)
|
||||
await db.commit()
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="Evaluation task deleted successfully",
|
||||
data="success"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
await db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Failed to delete evaluation task: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
def _map_to_task_detail_response(
|
||||
task: EvaluationTask
|
||||
) -> EvaluationTaskDetailResponse:
|
||||
"""将数据库模型转换为任务详情响应模型"""
|
||||
task_response = EvaluationTaskDetailResponse(
|
||||
id=task.id,
|
||||
name=task.name,
|
||||
description=task.description,
|
||||
taskType=task.task_type,
|
||||
sourceType=task.source_type,
|
||||
sourceId=task.source_id,
|
||||
sourceName=task.source_name,
|
||||
status=task.status,
|
||||
evalProcess=task.eval_process,
|
||||
evalPrompt=task.eval_prompt,
|
||||
evalConfig=json.loads(task.eval_config),
|
||||
createdAt=task.created_at.isoformat() if task.created_at else None,
|
||||
updatedAt=task.updated_at.isoformat() if task.updated_at else None,
|
||||
)
|
||||
task_response.eval_prompt = get_prompt(task_response.task_type, task_response.eval_config.get("dimensions"))
|
||||
return task_response
|
||||
@@ -0,0 +1,101 @@
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from enum import Enum
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.module.shared.schema.common import TaskStatus
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class EvaluationConfig(BaseModel):
|
||||
"""评估配置项"""
|
||||
model_id: str = Field(..., alias="modelId", description="模型id")
|
||||
dimensions: list[dict] = Field(..., alias="dimensions", description="评估维度")
|
||||
|
||||
|
||||
class CreateEvaluationTaskRequest(BaseModel):
|
||||
"""创建评估任务请求"""
|
||||
name: str = Field(..., description="评估任务名称")
|
||||
description: Optional[str] = Field(None, description="评估任务描述")
|
||||
task_type: str = Field(..., alias="taskType", description="评估任务类型:QA/QUALITY/COMPATIBILITY/VALUE")
|
||||
source_type: str = Field(..., alias="sourceType", description="待评估对象类型:DATASET/SYNTHESIS")
|
||||
source_id: str = Field(..., alias="sourceId", description="待评估对象ID")
|
||||
source_name: str = Field(..., alias="sourceName", description="待评估对象名称")
|
||||
eval_method: str = Field("AUTO", alias="evalMethod", description="评估提示词")
|
||||
eval_prompt: Optional[str] = Field(None, alias="evalPrompt", description="评估提示词")
|
||||
eval_config: EvaluationConfig = Field(..., alias="evalConfig", description="评估配置项列表")
|
||||
|
||||
|
||||
class EvaluationTaskItem(BaseModel):
|
||||
"""评估任务列表项"""
|
||||
id: str = Field(..., description="任务ID")
|
||||
name: str = Field(..., description="任务名称")
|
||||
description: Optional[str] = Field(None, description="任务描述")
|
||||
task_type: Optional[str] = Field(..., alias="taskType", description="任务类型")
|
||||
source_type: Optional[str] = Field(..., alias="sourceType", description="数据源类型")
|
||||
source_id: Optional[str] = Field(..., alias="sourceId", description="数据源ID")
|
||||
source_name: Optional[str] = Field(None, alias="sourceName", description="数据源名称")
|
||||
status: TaskStatus = Field(..., description="任务状态")
|
||||
eval_process: Optional[float] = Field(0, alias="evalProcess", description="评估进度")
|
||||
created_at: Optional[str] = Field(None, alias="createdAt", description="创建时间")
|
||||
updated_at: Optional[str] = Field(None, alias="updatedAt", description="更新时间")
|
||||
|
||||
|
||||
class PagedEvaluationTaskResponse(BaseModel):
|
||||
"""分页评估任务响应"""
|
||||
content: List[EvaluationTaskItem]
|
||||
total_elements: int = Field(..., alias="totalElements")
|
||||
total_pages: int = Field(..., alias="totalPages")
|
||||
page: int
|
||||
size: int
|
||||
|
||||
|
||||
class EvaluationTaskDetailResponse(EvaluationTaskItem):
|
||||
"""评估任务详情响应"""
|
||||
eval_prompt: Optional[str] = Field(None, alias="evalPrompt", description="评估提示词")
|
||||
eval_config: Optional[Dict[str, Any]] = Field(None, alias="evalConfig", description="评估配置")
|
||||
eval_result: Optional[Dict[str, Any]] = Field(None, alias="evalResult", description="评估结果")
|
||||
|
||||
|
||||
class EvaluationItemResponse(BaseModel):
|
||||
"""评估条目响应"""
|
||||
id: str = Field(..., description="条目ID")
|
||||
task_id: str = Field(..., alias="taskId", description="任务ID")
|
||||
file_id: str = Field(..., alias="fileId", description="文件ID")
|
||||
item_id: str = Field(..., alias="itemId", description="评估项ID")
|
||||
eval_content: Optional[Dict[str, Any]] = Field(None, alias="evalContent", description="评估内容")
|
||||
eval_score: Optional[float] = Field(None, alias="evalScore", description="评估分数")
|
||||
eval_result: Optional[Dict[str, Any]] = Field(None, alias="evalResult", description="评估结果详情")
|
||||
status: str = Field(..., description="评估状态")
|
||||
|
||||
class EvaluationFileResponse(BaseModel):
|
||||
"""评估文件响应"""
|
||||
task_id: str = Field(..., alias="taskId", description="任务ID")
|
||||
file_id: str = Field(..., alias="fileId", description="文件ID")
|
||||
file_name: str = Field(..., alias="fileName", description="文件名")
|
||||
total_count: int = Field(..., alias="totalCount", description="总数")
|
||||
evaluated_count: int = Field(..., alias="evaluatedCount", description="已评估数")
|
||||
pending_count: int = Field(..., alias="pendingCount", description="待评估数")
|
||||
|
||||
|
||||
|
||||
class PagedEvaluationItemsResponse(BaseModel):
|
||||
"""分页评估任务响应"""
|
||||
content: List[EvaluationItemResponse]
|
||||
total_elements: int = Field(..., alias="totalElements")
|
||||
total_pages: int = Field(..., alias="totalPages")
|
||||
page: int
|
||||
size: int
|
||||
|
||||
class PagedEvaluationFilesResponse(BaseModel):
|
||||
"""分页评估任务响应"""
|
||||
content: List[EvaluationFileResponse]
|
||||
total_elements: int = Field(..., alias="totalElements")
|
||||
total_pages: int = Field(..., alias="totalPages")
|
||||
page: int
|
||||
size: int
|
||||
|
||||
|
||||
class SourceType(Enum):
|
||||
DATASET = "DATASET"
|
||||
SYNTHESIS = "SYNTHESIS"
|
||||
@@ -0,0 +1,87 @@
|
||||
EVALUATION_PROMPT_TEMPLATE = [
|
||||
{
|
||||
"evalType": "QA",
|
||||
"defaultDimensions": [
|
||||
{
|
||||
"dimension": "问题是否独立",
|
||||
"description": "仅分析问题,问题的主体和客体都比较明确,即使有省略,也符合语言习惯。在不需要补充其他信息的情况下不会引起疑惑。"
|
||||
},
|
||||
{
|
||||
"dimension": "语法是否错误",
|
||||
"description": "问题为疑问句,答案为陈述句; 不存在词语搭配不当的情况;连接词和标点符号不存在错用情况;逻辑混乱的情况不存在;语法结构都正确且完整。"
|
||||
},
|
||||
{
|
||||
"dimension": "回答是否有针对性",
|
||||
"description": "回答应对问题中的所有疑问点提供正面、直接的回答,不应引起疑惑。同时,答案不应有任何内容的遗漏,需构成一个完整的陈述。"
|
||||
}
|
||||
],
|
||||
"prompt": """
|
||||
# Role: 问答对质量评估专家
|
||||
## Profile:
|
||||
- Description: 你是一名专业的对话文本质量评估专家,擅长从多个维度对问答对进行质量评估,为机器学习模型训练提供高质量的数据筛选建议。具备深度学习、自然语言处理和数据科学的专业背景。
|
||||
|
||||
## Skills:
|
||||
1. 能够从多个维度对问答对进行综合评估
|
||||
2. 擅长识别问答对中的潜在问题,如答案不准确、问题模糊、文本不匹配、逻辑错误等
|
||||
3. 能够给出具体的改进建议和质量评分,并提供可操作的优化方案
|
||||
4. 熟悉机器学习训练数据的质量标准和最佳实践
|
||||
5. 能够区分不同类型的问题(事实性、推理性、创造性)并采用相应的评估标准
|
||||
|
||||
## 评估维度:
|
||||
{dimensions}
|
||||
|
||||
## 原始文本块内容:
|
||||
{content}
|
||||
|
||||
## 问题:
|
||||
{question}
|
||||
|
||||
## 答案:
|
||||
{answer}
|
||||
|
||||
## 评估说明:
|
||||
1. **数据集类型识别**:如果原始文本块内容为空或显示"Distilled Content",说明这是一个蒸馏数据集,没有原始文本参考。请重点评估问题的质量、答案的合理性和逻辑性,以及问答的一致性。
|
||||
2. **评估原则**:采用严格的评估标准,确保筛选出的数据集能够有效提升模型性能。
|
||||
|
||||
## 注意事项:
|
||||
- 评估结论要具体指出优点和不足,提供可操作的改进建议
|
||||
- 评估结论控制在150字以内,简洁明了但要涵盖关键信息
|
||||
|
||||
## 输出要求:
|
||||
请按照以下JSON格式输出评估结果,评估结果为Y/N,符合标注输出Y,不符合标准输出N:
|
||||
|
||||
{
|
||||
"result": {{result_example}
|
||||
},
|
||||
"evaluation": "这是一个高质量的问答数据集。问题表述清晰具体,答案准确完整且逻辑性强,与原始文本高度相关。建议:可以进一步丰富答案的细节描述。"
|
||||
}
|
||||
"""
|
||||
}
|
||||
]
|
||||
|
||||
def get_dimensions_for_qa(dimensions: list[dict]) -> str:
|
||||
dimensions_str = "\n"
|
||||
index = 1
|
||||
for dimension in dimensions:
|
||||
dimensions_str += f"### {index}. {dimension.get("dimension")}\n**评估标准:**\n{dimension.get("description")}\n\n"
|
||||
index += 1
|
||||
return dimensions_str
|
||||
|
||||
def get_result_example_for_qa(dimensions: list[dict]) -> str:
|
||||
result_example = ""
|
||||
for dimension in dimensions:
|
||||
result_example += f'\n "{dimension.get("dimension")}": "Y",'
|
||||
return result_example
|
||||
|
||||
def get_prompt(task_type: str, dimensions: list[dict]) -> str:
|
||||
template = None
|
||||
for t in EVALUATION_PROMPT_TEMPLATE:
|
||||
if t.get("evalType") == task_type:
|
||||
template = t.get("prompt")
|
||||
break
|
||||
if not template:
|
||||
template = EVALUATION_PROMPT_TEMPLATE[0].get("prompt", "")
|
||||
if not dimensions or len(dimensions) == 0:
|
||||
return template
|
||||
return (template.replace("{dimensions}", get_dimensions_for_qa(dimensions))
|
||||
.replace("{result_example}", get_result_example_for_qa(dimensions)))
|
||||
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
Schema for evaluation prompt templates.
|
||||
"""
|
||||
from typing import List, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PromptTemplateDimension(BaseModel):
|
||||
"""A single dimension in the prompt template"""
|
||||
dimension: str = Field(..., description="Dimension name")
|
||||
description: str = Field(..., description="Description of the dimension")
|
||||
|
||||
|
||||
class PromptTemplateItem(BaseModel):
|
||||
"""A single prompt template item"""
|
||||
evalType: str = Field(..., description="Evaluation type")
|
||||
defaultDimensions: List[PromptTemplateDimension] = Field(
|
||||
default_factory=list,
|
||||
description="List of default dimensions for this evaluation type"
|
||||
)
|
||||
prompt: str = Field(..., description="The prompt template string")
|
||||
|
||||
|
||||
class PromptTemplateResponse(BaseModel):
|
||||
"""Response model for getting prompt templates"""
|
||||
templates: List[PromptTemplateItem] = Field(
|
||||
...,
|
||||
description="List of available prompt templates"
|
||||
)
|
||||
@@ -0,0 +1,207 @@
|
||||
import json
|
||||
import uuid
|
||||
import asyncio
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.exception import BusinessErrorCodeEnum, BusinessException
|
||||
from app.core.logging import get_logger
|
||||
from app.db.models import EvaluationItem, EvaluationTask, DatasetFiles
|
||||
from app.db.models.data_evaluation import EvaluationFile
|
||||
from app.db.models.data_synthesis import DataSynthesisFileInstance, SynthesisData
|
||||
from app.db.session import AsyncSessionLocal
|
||||
from app.module.evaluation.schema.evaluation import SourceType
|
||||
from app.module.shared.schema import TaskStatus
|
||||
from app.module.shared.util.model_chat import call_openai_style_model
|
||||
from app.module.evaluation.schema.prompt import get_prompt
|
||||
from app.module.shared.util.structured_file import StructuredFileHandlerFactory
|
||||
from app.module.system.service.common_service import get_model_by_id
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class EvaluationExecutor:
|
||||
def __init__(self, db: AsyncSession, task: EvaluationTask):
|
||||
self.db = db
|
||||
self.task = task
|
||||
|
||||
async def save_eval_items(self):
|
||||
pass
|
||||
|
||||
def get_eval_prompt(self, item: EvaluationItem) -> str:
|
||||
prompt_text = get_prompt(self.task.task_type, json.loads(self.task.eval_config).get("dimensions"))
|
||||
eval_content = json.loads(item.eval_content)
|
||||
if self.task.task_type == "QA":
|
||||
prompt_text = ((prompt_text.replace("{content}", eval_content.get("input"))
|
||||
.replace("{question}", eval_content.get("instruction")))
|
||||
.replace("{answer}", eval_content.get("output")))
|
||||
return prompt_text
|
||||
|
||||
async def execute(self):
|
||||
eval_config = json.loads(self.task.eval_config)
|
||||
model_config = await get_model_by_id(self.db, eval_config.get("model_id"))
|
||||
semaphore = asyncio.Semaphore(10)
|
||||
files = (await self.db.execute(
|
||||
select(EvaluationFile).where(EvaluationFile.task_id == self.task.id)
|
||||
)).scalars().all()
|
||||
for file in files:
|
||||
items = (await self.db.execute(
|
||||
select(EvaluationItem).where(EvaluationItem.task_id == self.task.id)
|
||||
.where(EvaluationItem.file_id == file.file_id)
|
||||
)).scalars().all()
|
||||
tasks = [
|
||||
self.evaluate_item(model_config, item, semaphore)
|
||||
for item in items
|
||||
]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
file.evaluated_count = len(items)
|
||||
await self.db.commit()
|
||||
|
||||
async def evaluate_item(self, model_config, item: EvaluationItem, semaphore: asyncio.Semaphore):
|
||||
async with semaphore:
|
||||
prompt_text = self.get_eval_prompt(item)
|
||||
resp_text = await asyncio.to_thread(
|
||||
call_openai_style_model, model_config.base_url, model_config.api_key, model_config.model_name,
|
||||
prompt_text,
|
||||
)
|
||||
item.eval_result = resp_text
|
||||
item.status = TaskStatus.COMPLETED.value
|
||||
await self.db.commit()
|
||||
|
||||
|
||||
def get_source_type(self) -> SourceType:
|
||||
pass
|
||||
|
||||
|
||||
class DatasetEvaluationExecutor(EvaluationExecutor):
|
||||
def __init__(self, db: AsyncSession, task: EvaluationTask):
|
||||
super().__init__(db, task)
|
||||
|
||||
async def save_eval_items(self):
|
||||
dataset_files = ((await self.db.execute(select(DatasetFiles)
|
||||
.where(DatasetFiles.dataset_id == self.task.source_id)))
|
||||
.scalars().all())
|
||||
handler = StructuredFileHandlerFactory().get_handler(self.task.task_type)
|
||||
for dataset_file in dataset_files:
|
||||
if dataset_file.file_type.upper() != "JSON" and dataset_file.file_type.upper() != "JSONL":
|
||||
continue
|
||||
items = handler.get_items_from_file(dataset_file.file_path)
|
||||
logger.info(f"parse {len(items)} items from file {dataset_file.file_name}")
|
||||
for item in items:
|
||||
self.db.add(EvaluationItem(
|
||||
id=str(uuid.uuid4()),
|
||||
task_id=self.task.id,
|
||||
file_id=dataset_file.id,
|
||||
item_id=item.get("id") if item.get("id") else str(uuid.uuid4()),
|
||||
eval_content=json.dumps(item, ensure_ascii=False),
|
||||
status=TaskStatus.PENDING.value,
|
||||
created_by=self.task.created_by,
|
||||
updated_by=self.task.updated_by,
|
||||
))
|
||||
self.db.add(EvaluationFile(
|
||||
id=str(uuid.uuid4()),
|
||||
task_id=self.task.id,
|
||||
file_id=dataset_file.id,
|
||||
file_name=dataset_file.file_name,
|
||||
total_count=len(items),
|
||||
evaluated_count=0,
|
||||
created_by=self.task.created_by,
|
||||
updated_by=self.task.updated_by,
|
||||
))
|
||||
|
||||
def get_source_type(self) -> SourceType:
|
||||
return SourceType.DATASET
|
||||
|
||||
|
||||
class SynthesisEvaluationExecutor(EvaluationExecutor):
|
||||
def __init__(self, db: AsyncSession, task: EvaluationTask):
|
||||
super().__init__(db, task)
|
||||
|
||||
async def save_eval_items(self):
|
||||
synthesis_files = ((await self.db.execute(select(DataSynthesisFileInstance)
|
||||
.where(DataSynthesisFileInstance.task_id == self.task.source_id)))
|
||||
.scalars().all())
|
||||
for synthesis_file in synthesis_files:
|
||||
synthesis_datas = ((await self.db.execute(select(SynthesisData)
|
||||
.where(SynthesisData.synthesis_file_instance_id == synthesis_file.id)))
|
||||
.scalars().all())
|
||||
logger.info(f"get {len(synthesis_datas)} items from file {synthesis_file.file_name}")
|
||||
for synthesis_data in synthesis_datas:
|
||||
self.db.add(EvaluationItem(
|
||||
id=str(uuid.uuid4()),
|
||||
task_id=self.task.id,
|
||||
file_id=synthesis_file.id,
|
||||
item_id=synthesis_data.id,
|
||||
eval_content=synthesis_data.data,
|
||||
status=TaskStatus.PENDING.value,
|
||||
created_by=self.task.created_by,
|
||||
updated_by=self.task.updated_by,
|
||||
))
|
||||
self.db.add(EvaluationFile(
|
||||
id=str(uuid.uuid4()),
|
||||
task_id=self.task.id,
|
||||
file_id=synthesis_file.id,
|
||||
file_name=synthesis_file.file_name,
|
||||
total_count=len(synthesis_datas),
|
||||
evaluated_count=0,
|
||||
created_by=self.task.created_by,
|
||||
updated_by=self.task.updated_by,
|
||||
))
|
||||
pass
|
||||
|
||||
def get_source_type(self) -> SourceType:
|
||||
return SourceType.SYNTHESIS
|
||||
|
||||
|
||||
class EvaluationExecutorFactory:
|
||||
def __init__(self, db: AsyncSession, task: EvaluationTask):
|
||||
self.db = db
|
||||
self.executors: list[EvaluationExecutor] = []
|
||||
self.executors.append(DatasetEvaluationExecutor(db, task))
|
||||
self.executors.append(SynthesisEvaluationExecutor(db, task))
|
||||
|
||||
def get_executor(self, source_type: str) -> EvaluationExecutor:
|
||||
for executor in self.executors:
|
||||
if executor.get_source_type().value == source_type:
|
||||
return executor
|
||||
raise BusinessException(BusinessErrorCodeEnum.TASK_TYPE_ERROR.value)
|
||||
|
||||
|
||||
class EvaluationTaskService:
|
||||
|
||||
@staticmethod
|
||||
async def run_evaluation_task(task_id: str):
|
||||
"""
|
||||
Background worker to run evaluations.
|
||||
- task_id: id of EvaluationTaskModel
|
||||
"""
|
||||
logger.info(f"Background evaluation worker started add items for task {task_id}")
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
task = await session.execute(select(EvaluationTask).where(EvaluationTask.id == task_id))
|
||||
task = task.scalar_one_or_none()
|
||||
factory = EvaluationExecutorFactory(session, task)
|
||||
executor = factory.get_executor(task.source_type)
|
||||
await executor.save_eval_items()
|
||||
task.status = TaskStatus.RUNNING.value
|
||||
except Exception as e:
|
||||
logger.error(f"Background worker encountered error for task {task_id}: {e}")
|
||||
task.status = TaskStatus.FAILED.value
|
||||
finally:
|
||||
await session.commit()
|
||||
|
||||
logger.info(f"Background evaluation worker started for task {task_id}")
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
task = await session.execute(select(EvaluationTask).where(EvaluationTask.id == task_id))
|
||||
task = task.scalar_one_or_none()
|
||||
factory = EvaluationExecutorFactory(session, task)
|
||||
executor = factory.get_executor(task.source_type)
|
||||
await executor.execute()
|
||||
logger.info(f"Background evaluation worker finished for task {task_id}")
|
||||
task.status = TaskStatus.COMPLETED.value
|
||||
except Exception as e:
|
||||
logger.error(f"Background worker encountered error for task {task_id}: {e}")
|
||||
task.status = TaskStatus.FAILED.value
|
||||
finally:
|
||||
await session.commit()
|
||||
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
Service for managing evaluation prompt templates.
|
||||
"""
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from app.module.evaluation.schema.prompt import EVALUATION_PROMPT_TEMPLATE
|
||||
from app.module.evaluation.schema.prompt_template import (
|
||||
PromptTemplateItem,
|
||||
PromptTemplateDimension,
|
||||
PromptTemplateResponse
|
||||
)
|
||||
|
||||
|
||||
class PromptTemplateService:
|
||||
"""Service for managing evaluation prompt templates"""
|
||||
|
||||
@staticmethod
|
||||
def get_prompt_templates() -> PromptTemplateResponse:
|
||||
"""
|
||||
Get all available prompt templates
|
||||
|
||||
Returns:
|
||||
PromptTemplateResponse containing all prompt templates
|
||||
"""
|
||||
templates = []
|
||||
|
||||
for template in EVALUATION_PROMPT_TEMPLATE:
|
||||
# Convert dimensions to the proper schema
|
||||
dimensions = [
|
||||
PromptTemplateDimension(
|
||||
dimension=dim.get("dimension"),
|
||||
description=dim.get("description", "")
|
||||
)
|
||||
for dim in template.get("defaultDimensions", [])
|
||||
]
|
||||
|
||||
# Create template item
|
||||
template_item = PromptTemplateItem(
|
||||
evalType=template.get("evalType", ""),
|
||||
defaultDimensions=dimensions,
|
||||
prompt=template.get("prompt", "")
|
||||
)
|
||||
templates.append(template_item)
|
||||
|
||||
return PromptTemplateResponse(templates=templates)
|
||||
Reference in New Issue
Block a user