You've already forked DataMate
fix: 修复评估时模型输出json格式不对导致读取错误的问题 (#133)
* feature: add cot data evaluation function * fix: added verification to evaluation results * fix: fix the prompt for evaluating * fix: 修复当评估结果为空导致读取失败的问题
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import math
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy import func
|
||||
@@ -14,11 +15,11 @@ logger = get_logger(__name__)
|
||||
|
||||
class Service:
|
||||
"""数据管理服务客户端 - 直接访问数据库"""
|
||||
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
"""
|
||||
初始化 DM 客户端
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
"""
|
||||
@@ -29,16 +30,16 @@ class Service:
|
||||
"""获取数据集详情"""
|
||||
try:
|
||||
logger.debug(f"Getting dataset detail: {dataset_id} ...")
|
||||
|
||||
|
||||
result = await self.db.execute(
|
||||
select(Dataset).where(Dataset.id == dataset_id)
|
||||
)
|
||||
dataset = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not dataset:
|
||||
logger.error(f"Dataset not found: {dataset_id}")
|
||||
return None
|
||||
|
||||
|
||||
# 将数据库模型转换为响应模型
|
||||
# type: ignore 用于忽略 SQLAlchemy 的类型检查问题
|
||||
return DatasetResponse(
|
||||
@@ -56,11 +57,11 @@ class Service:
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get dataset {dataset_id}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def get_dataset_files(
|
||||
self,
|
||||
dataset_id: str,
|
||||
page: int = 0,
|
||||
self,
|
||||
dataset_id: str,
|
||||
page: int = 0,
|
||||
size: int = 100,
|
||||
file_type: Optional[str] = None,
|
||||
status: Optional[str] = None
|
||||
@@ -68,16 +69,16 @@ class Service:
|
||||
"""获取数据集文件列表"""
|
||||
try:
|
||||
logger.debug(f"Get dataset files: dataset={dataset_id}, page={page}, size={size}")
|
||||
|
||||
|
||||
# 构建查询
|
||||
query = select(DatasetFiles).where(DatasetFiles.dataset_id == dataset_id)
|
||||
|
||||
|
||||
# 添加可选过滤条件
|
||||
if file_type:
|
||||
query = query.where(DatasetFiles.file_type == file_type)
|
||||
if status:
|
||||
query = query.where(DatasetFiles.status == status)
|
||||
|
||||
|
||||
# 获取总数
|
||||
count_query = select(func.count()).select_from(DatasetFiles).where(
|
||||
DatasetFiles.dataset_id == dataset_id
|
||||
@@ -86,15 +87,15 @@ class Service:
|
||||
count_query = count_query.where(DatasetFiles.file_type == file_type)
|
||||
if status:
|
||||
count_query = count_query.where(DatasetFiles.status == status)
|
||||
|
||||
|
||||
count_result = await self.db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
|
||||
# 分页查询
|
||||
query = query.offset(page * size).limit(size).order_by(DatasetFiles.created_at.desc())
|
||||
result = await self.db.execute(query)
|
||||
files = result.scalars().all()
|
||||
|
||||
|
||||
# 转换为响应模型
|
||||
# type: ignore 用于忽略 SQLAlchemy 的类型检查问题
|
||||
content = [
|
||||
@@ -115,9 +116,9 @@ class Service:
|
||||
)
|
||||
for f in files
|
||||
]
|
||||
|
||||
total_pages = (total + size - 1) // size if size > 0 else 0
|
||||
|
||||
|
||||
total_pages = math.ceil(total / size) if total > 0 else 0
|
||||
|
||||
return PagedDatasetFileResponse(
|
||||
content=content,
|
||||
totalElements=total,
|
||||
@@ -128,7 +129,7 @@ class Service:
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get dataset files for {dataset_id}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def download_file(self, dataset_id: str, file_id: str) -> Optional[bytes]:
|
||||
"""
|
||||
下载文件内容
|
||||
@@ -136,7 +137,7 @@ class Service:
|
||||
"""
|
||||
logger.warning(f"download_file is deprecated when using database mode. Use get_file_download_url instead.")
|
||||
return None
|
||||
|
||||
|
||||
async def get_file_download_url(self, dataset_id: str, file_id: str) -> Optional[str]:
|
||||
"""获取文件下载URL(或文件路径)"""
|
||||
try:
|
||||
@@ -147,60 +148,60 @@ class Service:
|
||||
)
|
||||
)
|
||||
file = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not file:
|
||||
logger.error(f"File not found: {file_id} in dataset {dataset_id}")
|
||||
return None
|
||||
|
||||
|
||||
# 返回文件路径(可以是本地路径或对象存储URL)
|
||||
return file.file_path # type: ignore
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get file path for {file_id}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def close(self):
|
||||
"""关闭客户端连接(数据库模式下无需操作)"""
|
||||
logger.info("DM service client closed (Database mode)")
|
||||
|
||||
|
||||
async def update_file_tags_partial(
|
||||
self,
|
||||
file_id: str,
|
||||
self,
|
||||
file_id: str,
|
||||
new_tags: List[Dict[str, Any]],
|
||||
template_id: Optional[str] = None
|
||||
) -> tuple[bool, Optional[str], Optional[datetime]]:
|
||||
"""
|
||||
部分更新文件标签,支持自动格式转换
|
||||
|
||||
|
||||
如果提供了 template_id,会自动将简化格式的标签转换为完整格式。
|
||||
简化格式: {"from_name": "x", "to_name": "y", "values": [...]}
|
||||
完整格式: {"id": "...", "from_name": "x", "to_name": "y", "type": "...", "value": {"type": [...]}}
|
||||
|
||||
|
||||
Args:
|
||||
file_id: 文件ID
|
||||
new_tags: 新的标签列表(部分更新),可以是简化格式或完整格式
|
||||
template_id: 可选的模板ID,用于格式转换
|
||||
|
||||
|
||||
Returns:
|
||||
(成功标志, 错误信息, 更新时间)
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Partial updating tags for file: {file_id}")
|
||||
|
||||
|
||||
# 获取文件记录
|
||||
result = await self.db.execute(
|
||||
select(DatasetFiles).where(DatasetFiles.id == file_id)
|
||||
)
|
||||
file_record = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not file_record:
|
||||
logger.error(f"File not found: {file_id}")
|
||||
return False, f"File not found: {file_id}", None
|
||||
|
||||
|
||||
# 如果提供了 template_id,尝试进行格式转换
|
||||
processed_tags = new_tags
|
||||
if template_id:
|
||||
logger.debug(f"Converting tags using template: {template_id}")
|
||||
|
||||
|
||||
try:
|
||||
# 获取模板配置
|
||||
from app.db.models import AnnotationTemplate
|
||||
@@ -211,29 +212,29 @@ class Service:
|
||||
)
|
||||
)
|
||||
template = template_result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not template:
|
||||
logger.warning(f"Template {template_id} not found, skipping conversion")
|
||||
else:
|
||||
# 使用 converter 转换标签格式
|
||||
from app.module.annotation.utils import create_converter_from_template_config
|
||||
|
||||
|
||||
converter = create_converter_from_template_config(template.configuration) # type: ignore
|
||||
processed_tags = converter.convert_if_needed(new_tags)
|
||||
|
||||
|
||||
logger.info(f"Converted {len(new_tags)} tags to full format")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert tags using template: {e}")
|
||||
# 继续使用原始标签格式
|
||||
logger.warning("Continuing with original tag format")
|
||||
|
||||
|
||||
# 获取现有标签
|
||||
existing_tags: List[Dict[str, Any]] = file_record.tags or [] # type: ignore
|
||||
|
||||
|
||||
# 创建标签ID到索引的映射
|
||||
tag_id_map = {tag.get('id'): idx for idx, tag in enumerate(existing_tags) if tag.get('id')}
|
||||
|
||||
|
||||
# 更新或追加标签
|
||||
for new_tag in processed_tags:
|
||||
tag_id = new_tag.get('id')
|
||||
@@ -246,19 +247,19 @@ class Service:
|
||||
# 追加新标签
|
||||
existing_tags.append(new_tag)
|
||||
logger.debug(f"Added new tag with id: {tag_id}")
|
||||
|
||||
|
||||
# 更新数据库
|
||||
update_time = datetime.utcnow()
|
||||
file_record.tags = existing_tags # type: ignore
|
||||
file_record.tags_updated_at = update_time # type: ignore
|
||||
|
||||
|
||||
await self.db.commit()
|
||||
await self.db.refresh(file_record)
|
||||
|
||||
|
||||
logger.info(f"Successfully updated tags for file: {file_id}")
|
||||
return True, None, update_time
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update tags for file {file_id}: {e}")
|
||||
await self.db.rollback()
|
||||
return False, str(e), None
|
||||
return False, str(e), None
|
||||
|
||||
Reference in New Issue
Block a user