fix: 修复评估时模型输出json格式不对导致读取错误的问题 (#133)

* feature: add cot data evaluation function

* fix: added verification to evaluation results

* fix: fix the prompt for evaluating

* fix: 修复当评估结果为空导致读取失败的问题
This commit is contained in:
hefanli
2025-12-04 18:49:50 +08:00
committed by GitHub
parent 31c4966608
commit 744d15ba24
14 changed files with 373 additions and 219 deletions

View File

@@ -1,3 +1,4 @@
import math
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import func
@@ -14,11 +15,11 @@ logger = get_logger(__name__)
class Service:
"""数据管理服务客户端 - 直接访问数据库"""
def __init__(self, db: AsyncSession):
"""
初始化 DM 客户端
Args:
db: 数据库会话
"""
@@ -29,16 +30,16 @@ class Service:
"""获取数据集详情"""
try:
logger.debug(f"Getting dataset detail: {dataset_id} ...")
result = await self.db.execute(
select(Dataset).where(Dataset.id == dataset_id)
)
dataset = result.scalar_one_or_none()
if not dataset:
logger.error(f"Dataset not found: {dataset_id}")
return None
# 将数据库模型转换为响应模型
# type: ignore 用于忽略 SQLAlchemy 的类型检查问题
return DatasetResponse(
@@ -56,11 +57,11 @@ class Service:
except Exception as e:
logger.error(f"Failed to get dataset {dataset_id}: {e}")
return None
async def get_dataset_files(
self,
dataset_id: str,
page: int = 0,
self,
dataset_id: str,
page: int = 0,
size: int = 100,
file_type: Optional[str] = None,
status: Optional[str] = None
@@ -68,16 +69,16 @@ class Service:
"""获取数据集文件列表"""
try:
logger.debug(f"Get dataset files: dataset={dataset_id}, page={page}, size={size}")
# 构建查询
query = select(DatasetFiles).where(DatasetFiles.dataset_id == dataset_id)
# 添加可选过滤条件
if file_type:
query = query.where(DatasetFiles.file_type == file_type)
if status:
query = query.where(DatasetFiles.status == status)
# 获取总数
count_query = select(func.count()).select_from(DatasetFiles).where(
DatasetFiles.dataset_id == dataset_id
@@ -86,15 +87,15 @@ class Service:
count_query = count_query.where(DatasetFiles.file_type == file_type)
if status:
count_query = count_query.where(DatasetFiles.status == status)
count_result = await self.db.execute(count_query)
total = count_result.scalar_one()
# 分页查询
query = query.offset(page * size).limit(size).order_by(DatasetFiles.created_at.desc())
result = await self.db.execute(query)
files = result.scalars().all()
# 转换为响应模型
# type: ignore 用于忽略 SQLAlchemy 的类型检查问题
content = [
@@ -115,9 +116,9 @@ class Service:
)
for f in files
]
total_pages = (total + size - 1) // size if size > 0 else 0
total_pages = math.ceil(total / size) if total > 0 else 0
return PagedDatasetFileResponse(
content=content,
totalElements=total,
@@ -128,7 +129,7 @@ class Service:
except Exception as e:
logger.error(f"Failed to get dataset files for {dataset_id}: {e}")
return None
async def download_file(self, dataset_id: str, file_id: str) -> Optional[bytes]:
"""
下载文件内容
@@ -136,7 +137,7 @@ class Service:
"""
logger.warning(f"download_file is deprecated when using database mode. Use get_file_download_url instead.")
return None
async def get_file_download_url(self, dataset_id: str, file_id: str) -> Optional[str]:
"""获取文件下载URL(或文件路径)"""
try:
@@ -147,60 +148,60 @@ class Service:
)
)
file = result.scalar_one_or_none()
if not file:
logger.error(f"File not found: {file_id} in dataset {dataset_id}")
return None
# 返回文件路径(可以是本地路径或对象存储URL)
return file.file_path # type: ignore
except Exception as e:
logger.error(f"Failed to get file path for {file_id}: {e}")
return None
async def close(self):
"""关闭客户端连接(数据库模式下无需操作)"""
logger.info("DM service client closed (Database mode)")
async def update_file_tags_partial(
self,
file_id: str,
self,
file_id: str,
new_tags: List[Dict[str, Any]],
template_id: Optional[str] = None
) -> tuple[bool, Optional[str], Optional[datetime]]:
"""
部分更新文件标签,支持自动格式转换
如果提供了 template_id,会自动将简化格式的标签转换为完整格式。
简化格式: {"from_name": "x", "to_name": "y", "values": [...]}
完整格式: {"id": "...", "from_name": "x", "to_name": "y", "type": "...", "value": {"type": [...]}}
Args:
file_id: 文件ID
new_tags: 新的标签列表(部分更新),可以是简化格式或完整格式
template_id: 可选的模板ID,用于格式转换
Returns:
(成功标志, 错误信息, 更新时间)
"""
try:
logger.info(f"Partial updating tags for file: {file_id}")
# 获取文件记录
result = await self.db.execute(
select(DatasetFiles).where(DatasetFiles.id == file_id)
)
file_record = result.scalar_one_or_none()
if not file_record:
logger.error(f"File not found: {file_id}")
return False, f"File not found: {file_id}", None
# 如果提供了 template_id,尝试进行格式转换
processed_tags = new_tags
if template_id:
logger.debug(f"Converting tags using template: {template_id}")
try:
# 获取模板配置
from app.db.models import AnnotationTemplate
@@ -211,29 +212,29 @@ class Service:
)
)
template = template_result.scalar_one_or_none()
if not template:
logger.warning(f"Template {template_id} not found, skipping conversion")
else:
# 使用 converter 转换标签格式
from app.module.annotation.utils import create_converter_from_template_config
converter = create_converter_from_template_config(template.configuration) # type: ignore
processed_tags = converter.convert_if_needed(new_tags)
logger.info(f"Converted {len(new_tags)} tags to full format")
except Exception as e:
logger.error(f"Failed to convert tags using template: {e}")
# 继续使用原始标签格式
logger.warning("Continuing with original tag format")
# 获取现有标签
existing_tags: List[Dict[str, Any]] = file_record.tags or [] # type: ignore
# 创建标签ID到索引的映射
tag_id_map = {tag.get('id'): idx for idx, tag in enumerate(existing_tags) if tag.get('id')}
# 更新或追加标签
for new_tag in processed_tags:
tag_id = new_tag.get('id')
@@ -246,19 +247,19 @@ class Service:
# 追加新标签
existing_tags.append(new_tag)
logger.debug(f"Added new tag with id: {tag_id}")
# 更新数据库
update_time = datetime.utcnow()
file_record.tags = existing_tags # type: ignore
file_record.tags_updated_at = update_time # type: ignore
await self.db.commit()
await self.db.refresh(file_record)
logger.info(f"Successfully updated tags for file: {file_id}")
return True, None, update_time
except Exception as e:
logger.error(f"Failed to update tags for file {file_id}: {e}")
await self.db.rollback()
return False, str(e), None
return False, str(e), None