You've already forked DataMate
fix: 修复知识库同步的并发控制、数据清理、文件事务和COCO导出问题
问题1 - 并发控制缺失: - 在 _ensure_knowledge_set 方法中添加数据库行锁(with_for_update) - 修改 _update_project_config 方法,使用行锁保护配置更新 问题3 - 数据清理机制缺失: - 添加 _cleanup_knowledge_set_for_project 方法,项目删除时清理知识集 - 添加 _cleanup_knowledge_item_for_file 方法,文件删除时清理知识条目 - 在 delete_mapping 接口中调用清理方法 问题4 - 文件操作事务问题: - 修改 uploadKnowledgeItems,添加事务失败后的文件清理逻辑 - 修改 deleteKnowledgeItem,删除记录前先删除关联文件 - 新增 deleteKnowledgeItemFile 辅助方法 问题5 - COCO导出格式问题: - 添加 _get_image_dimensions 方法读取图片实际宽高 - 将百分比坐标转换为像素坐标 - 在 AnnotationExportItem 中添加 file_path 字段 涉及文件: - knowledge_sync.py - project.py - KnowledgeItemApplicationService.java - export.py - export schema.py
This commit is contained in:
@@ -14,6 +14,7 @@ from app.core.logging import get_logger
|
||||
|
||||
from ..service.mapping import DatasetMappingService
|
||||
from ..service.template import AnnotationTemplateService
|
||||
from ..service.knowledge_sync import KnowledgeSyncService
|
||||
from ..schema import (
|
||||
DatasetMappingCreateRequest,
|
||||
DatasetMappingCreateResponse,
|
||||
@@ -22,26 +23,26 @@ from ..schema import (
|
||||
DatasetMappingResponse,
|
||||
)
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/project",
|
||||
tags=["annotation/project"]
|
||||
)
|
||||
router = APIRouter(prefix="/project", tags=["annotation/project"])
|
||||
logger = get_logger(__name__)
|
||||
TEXT_DATASET_TYPE = "TEXT"
|
||||
SOURCE_DOCUMENT_FILE_TYPES = {"pdf", "doc", "docx", "xls", "xlsx"}
|
||||
LABELING_TYPE_CONFIG_KEY = "labeling_type"
|
||||
|
||||
@router.get("/{mapping_id}/login")
|
||||
async def login_label_studio(
|
||||
mapping_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
raise HTTPException(status_code=410, detail="当前为内嵌编辑器模式,不再支持 Label Studio 登录代理接口")
|
||||
|
||||
@router.post("", response_model=StandardResponse[DatasetMappingCreateResponse], status_code=201)
|
||||
@router.get("/{mapping_id}/login")
|
||||
async def login_label_studio(mapping_id: str, db: AsyncSession = Depends(get_db)):
|
||||
raise HTTPException(
|
||||
status_code=410,
|
||||
detail="当前为内嵌编辑器模式,不再支持 Label Studio 登录代理接口",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"", response_model=StandardResponse[DatasetMappingCreateResponse], status_code=201
|
||||
)
|
||||
async def create_mapping(
|
||||
request: DatasetMappingCreateRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
request: DatasetMappingCreateRequest, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
创建数据集映射
|
||||
@@ -64,7 +65,7 @@ async def create_mapping(
|
||||
if not dataset_info:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Dataset not found in DM service: {request.dataset_id}"
|
||||
detail=f"Dataset not found in DM service: {request.dataset_id}",
|
||||
)
|
||||
|
||||
dataset_type = (
|
||||
@@ -73,13 +74,15 @@ async def create_mapping(
|
||||
or ""
|
||||
).upper()
|
||||
|
||||
project_name = request.name or \
|
||||
dataset_info.name or \
|
||||
"A new project from DataMate"
|
||||
project_name = (
|
||||
request.name or dataset_info.name or "A new project from DataMate"
|
||||
)
|
||||
|
||||
project_description = request.description or \
|
||||
dataset_info.description or \
|
||||
f"Imported from DM dataset {dataset_info.name} ({dataset_info.id})"
|
||||
project_description = (
|
||||
request.description
|
||||
or dataset_info.description
|
||||
or f"Imported from DM dataset {dataset_info.name} ({dataset_info.id})"
|
||||
)
|
||||
|
||||
# 如果提供了模板ID,获取模板配置
|
||||
label_config = None
|
||||
@@ -89,8 +92,7 @@ async def create_mapping(
|
||||
template = await template_service.get_template(db, request.template_id)
|
||||
if not template:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Template not found: {request.template_id}"
|
||||
status_code=404, detail=f"Template not found: {request.template_id}"
|
||||
)
|
||||
label_config = template.label_config
|
||||
template_labeling_type = getattr(template, "labeling_type", None)
|
||||
@@ -110,19 +112,24 @@ async def create_mapping(
|
||||
project_configuration["label_config"] = label_config
|
||||
if project_description:
|
||||
project_configuration["description"] = project_description
|
||||
if dataset_type == TEXT_DATASET_TYPE and request.segmentation_enabled is not None:
|
||||
project_configuration["segmentation_enabled"] = bool(request.segmentation_enabled)
|
||||
if (
|
||||
dataset_type == TEXT_DATASET_TYPE
|
||||
and request.segmentation_enabled is not None
|
||||
):
|
||||
project_configuration["segmentation_enabled"] = bool(
|
||||
request.segmentation_enabled
|
||||
)
|
||||
if template_labeling_type:
|
||||
project_configuration[LABELING_TYPE_CONFIG_KEY] = template_labeling_type
|
||||
|
||||
labeling_project = LabelingProject(
|
||||
id=str(uuid.uuid4()), # Generate UUID here
|
||||
dataset_id=request.dataset_id,
|
||||
labeling_project_id=labeling_project_id,
|
||||
name=project_name,
|
||||
template_id=request.template_id, # Save template_id to database
|
||||
configuration=project_configuration or None,
|
||||
)
|
||||
id=str(uuid.uuid4()), # Generate UUID here
|
||||
dataset_id=request.dataset_id,
|
||||
labeling_project_id=labeling_project_id,
|
||||
name=project_name,
|
||||
template_id=request.template_id, # Save template_id to database
|
||||
configuration=project_configuration or None,
|
||||
)
|
||||
|
||||
file_result = await db.execute(
|
||||
select(DatasetFiles).where(
|
||||
@@ -143,9 +150,7 @@ async def create_mapping(
|
||||
snapshot_file_ids.append(str(file_record.id))
|
||||
else:
|
||||
snapshot_file_ids = [
|
||||
str(file_record.id)
|
||||
for file_record in file_records
|
||||
if file_record.id
|
||||
str(file_record.id) for file_record in file_records if file_record.id
|
||||
]
|
||||
|
||||
# 创建映射关系并写入快照
|
||||
@@ -157,25 +162,30 @@ async def create_mapping(
|
||||
if dataset_type == TEXT_DATASET_TYPE and request.segmentation_enabled:
|
||||
try:
|
||||
from ..service.editor import AnnotationEditorService
|
||||
|
||||
editor_service = AnnotationEditorService(db)
|
||||
# 异步预计算切片(不阻塞创建响应)
|
||||
segmentation_result = await editor_service.precompute_segmentation_for_project(labeling_project.id)
|
||||
logger.info(f"Precomputed segmentation for project {labeling_project.id}: {segmentation_result}")
|
||||
segmentation_result = (
|
||||
await editor_service.precompute_segmentation_for_project(
|
||||
labeling_project.id
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
f"Precomputed segmentation for project {labeling_project.id}: {segmentation_result}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to precompute segmentation for project {labeling_project.id}: {e}")
|
||||
logger.warning(
|
||||
f"Failed to precompute segmentation for project {labeling_project.id}: {e}"
|
||||
)
|
||||
# 不影响项目创建,只记录警告
|
||||
|
||||
response_data = DatasetMappingCreateResponse(
|
||||
id=mapping.id,
|
||||
labeling_project_id=str(mapping.labeling_project_id),
|
||||
labeling_project_name=mapping.name or project_name
|
||||
labeling_project_name=mapping.name or project_name,
|
||||
)
|
||||
|
||||
return StandardResponse(
|
||||
code=201,
|
||||
message="success",
|
||||
data=response_data
|
||||
)
|
||||
return StandardResponse(code=201, message="success", data=response_data)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -183,12 +193,15 @@ async def create_mapping(
|
||||
logger.error(f"Error while creating dataset mapping: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("", response_model=StandardResponse[PaginatedData[DatasetMappingResponse]])
|
||||
async def list_mappings(
|
||||
page: int = Query(1, ge=1, description="页码(从1开始)"),
|
||||
size: int = Query(20, ge=1, le=100, description="每页记录数"),
|
||||
include_template: bool = Query(False, description="是否包含模板详情", alias="includeTemplate"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
include_template: bool = Query(
|
||||
False, description="是否包含模板详情", alias="includeTemplate"
|
||||
),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
查询所有映射关系(分页)
|
||||
@@ -207,14 +220,16 @@ async def list_mappings(
|
||||
# 计算 skip
|
||||
skip = (page - 1) * size
|
||||
|
||||
logger.info(f"List mappings: page={page}, size={size}, include_template={include_template}")
|
||||
logger.info(
|
||||
f"List mappings: page={page}, size={size}, include_template={include_template}"
|
||||
)
|
||||
|
||||
# 获取数据和总数
|
||||
mappings, total = await service.get_all_mappings_with_count(
|
||||
skip=skip,
|
||||
limit=size,
|
||||
include_deleted=False,
|
||||
include_template=include_template
|
||||
include_template=include_template,
|
||||
)
|
||||
|
||||
# 计算总页数
|
||||
@@ -226,26 +241,22 @@ async def list_mappings(
|
||||
size=size,
|
||||
total_elements=total,
|
||||
total_pages=total_pages,
|
||||
content=mappings
|
||||
content=mappings,
|
||||
)
|
||||
|
||||
logger.info(f"List mappings: page={page}, returned {len(mappings)}/{total}, templates_included: {include_template}")
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=paginated_data
|
||||
logger.info(
|
||||
f"List mappings: page={page}, returned {len(mappings)}/{total}, templates_included: {include_template}"
|
||||
)
|
||||
|
||||
return StandardResponse(code=200, message="success", data=paginated_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing mappings: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("/{mapping_id}", response_model=StandardResponse[DatasetMappingResponse])
|
||||
async def get_mapping(
|
||||
mapping_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
async def get_mapping(mapping_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""
|
||||
根据 UUID 查询单个映射关系(包含关联的标注模板详情)
|
||||
|
||||
@@ -265,31 +276,34 @@ async def get_mapping(
|
||||
|
||||
if not mapping:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Mapping not found: {mapping_id}"
|
||||
status_code=404, detail=f"Mapping not found: {mapping_id}"
|
||||
)
|
||||
|
||||
logger.info(f"Found mapping: {mapping.id}, template_included: {mapping.template is not None}")
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=mapping
|
||||
logger.info(
|
||||
f"Found mapping: {mapping.id}, template_included: {mapping.template is not None}"
|
||||
)
|
||||
|
||||
return StandardResponse(code=200, message="success", data=mapping)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting mapping: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.get("/by-source/{dataset_id}", response_model=StandardResponse[PaginatedData[DatasetMappingResponse]])
|
||||
|
||||
@router.get(
|
||||
"/by-source/{dataset_id}",
|
||||
response_model=StandardResponse[PaginatedData[DatasetMappingResponse]],
|
||||
)
|
||||
async def get_mappings_by_source(
|
||||
dataset_id: str,
|
||||
page: int = Query(1, ge=1, description="页码(从1开始)"),
|
||||
size: int = Query(20, ge=1, le=100, description="每页记录数"),
|
||||
include_template: bool = Query(True, description="是否包含模板详情", alias="includeTemplate"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
include_template: bool = Query(
|
||||
True, description="是否包含模板详情", alias="includeTemplate"
|
||||
),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
根据源数据集 ID 查询所有映射关系(分页,包含模板详情)
|
||||
@@ -309,14 +323,16 @@ async def get_mappings_by_source(
|
||||
# 计算 skip
|
||||
skip = (page - 1) * size
|
||||
|
||||
logger.info(f"Get mappings by source dataset id: {dataset_id}, page={page}, size={size}, include_template={include_template}")
|
||||
logger.info(
|
||||
f"Get mappings by source dataset id: {dataset_id}, page={page}, size={size}, include_template={include_template}"
|
||||
)
|
||||
|
||||
# 获取数据和总数(包含模板信息)
|
||||
mappings, total = await service.get_mappings_by_source_with_count(
|
||||
dataset_id=dataset_id,
|
||||
skip=skip,
|
||||
limit=size,
|
||||
include_template=include_template
|
||||
include_template=include_template,
|
||||
)
|
||||
|
||||
# 计算总页数
|
||||
@@ -328,27 +344,26 @@ async def get_mappings_by_source(
|
||||
size=size,
|
||||
total_elements=total,
|
||||
total_pages=total_pages,
|
||||
content=mappings
|
||||
content=mappings,
|
||||
)
|
||||
|
||||
logger.info(f"Found {len(mappings)} mappings on page {page}, total: {total}, templates_included: {include_template}")
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=paginated_data
|
||||
logger.info(
|
||||
f"Found {len(mappings)} mappings on page {page}, total: {total}, templates_included: {include_template}"
|
||||
)
|
||||
|
||||
return StandardResponse(code=200, message="success", data=paginated_data)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting mappings: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.delete("/{project_id}", response_model=StandardResponse[DeleteDatasetResponse])
|
||||
async def delete_mapping(
|
||||
project_id: str = Path(..., description="映射UUID(path param)"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
删除映射关系(软删除)
|
||||
@@ -370,12 +385,12 @@ async def delete_mapping(
|
||||
|
||||
if not mapping:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Mapping either not found or not specified."
|
||||
status_code=404, detail=f"Mapping either not found or not specified."
|
||||
)
|
||||
|
||||
id = mapping.id
|
||||
logger.debug(f"Found mapping: {id}")
|
||||
dataset_id = mapping.dataset_id
|
||||
logger.debug(f"Found mapping: {id}, dataset_id: {dataset_id}")
|
||||
|
||||
# 软删除映射记录
|
||||
soft_delete_success = await service.soft_delete_mapping(id)
|
||||
@@ -383,19 +398,22 @@ async def delete_mapping(
|
||||
|
||||
if not soft_delete_success:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to delete mapping record"
|
||||
status_code=500, detail="Failed to delete mapping record"
|
||||
)
|
||||
|
||||
# 清理知识集中的关联数据
|
||||
try:
|
||||
knowledge_sync = KnowledgeSyncService(db)
|
||||
await knowledge_sync._cleanup_knowledge_set_for_project(id)
|
||||
except Exception as exc:
|
||||
logger.warning(f"清理知识集失败:project_id={id} error={exc}")
|
||||
|
||||
logger.info(f"Successfully deleted mapping: {id}")
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=DeleteDatasetResponse(
|
||||
id=id,
|
||||
status="success"
|
||||
)
|
||||
data=DeleteDatasetResponse(id=id, status="success"),
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
@@ -409,7 +427,7 @@ async def delete_mapping(
|
||||
async def update_mapping(
|
||||
project_id: str = Path(..., description="映射UUID(path param)"),
|
||||
request: DatasetMappingUpdateRequest = None,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
更新标注项目信息
|
||||
@@ -429,16 +447,14 @@ async def update_mapping(
|
||||
# 直接查询 ORM 模型获取原始数据
|
||||
result = await db.execute(
|
||||
select(LabelingProject).where(
|
||||
LabelingProject.id == project_id,
|
||||
LabelingProject.deleted_at.is_(None)
|
||||
LabelingProject.id == project_id, LabelingProject.deleted_at.is_(None)
|
||||
)
|
||||
)
|
||||
mapping_orm = result.scalar_one_or_none()
|
||||
|
||||
if not mapping_orm:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Mapping not found: {project_id}"
|
||||
status_code=404, detail=f"Mapping not found: {project_id}"
|
||||
)
|
||||
|
||||
# 构建更新数据
|
||||
@@ -449,7 +465,11 @@ async def update_mapping(
|
||||
# 从 configuration 字段中读取和更新 description 和 label_config
|
||||
configuration = {}
|
||||
if mapping_orm.configuration:
|
||||
configuration = mapping_orm.configuration.copy() if isinstance(mapping_orm.configuration, dict) else {}
|
||||
configuration = (
|
||||
mapping_orm.configuration.copy()
|
||||
if isinstance(mapping_orm.configuration, dict)
|
||||
else {}
|
||||
)
|
||||
|
||||
if request.description is not None:
|
||||
configuration["description"] = request.description
|
||||
@@ -468,7 +488,7 @@ async def update_mapping(
|
||||
if not template:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Template not found: {request.template_id}"
|
||||
detail=f"Template not found: {request.template_id}",
|
||||
)
|
||||
template_labeling_type = getattr(template, "labeling_type", None)
|
||||
if template_labeling_type:
|
||||
@@ -477,14 +497,11 @@ async def update_mapping(
|
||||
if not update_values:
|
||||
# 没有要更新的字段,直接返回当前数据
|
||||
response_data = await service.get_mapping_by_uuid(project_id)
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=response_data
|
||||
)
|
||||
return StandardResponse(code=200, message="success", data=response_data)
|
||||
|
||||
# 执行更新
|
||||
from datetime import datetime
|
||||
|
||||
update_values["updated_at"] = datetime.now()
|
||||
|
||||
result = await db.execute(
|
||||
@@ -495,21 +512,14 @@ async def update_mapping(
|
||||
await db.commit()
|
||||
|
||||
if result.rowcount == 0:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to update mapping"
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Failed to update mapping")
|
||||
|
||||
# 重新获取更新后的数据
|
||||
updated_mapping = await service.get_mapping_by_uuid(project_id)
|
||||
|
||||
logger.info(f"Successfully updated mapping: {project_id}")
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=updated_mapping
|
||||
)
|
||||
return StandardResponse(code=200, message="success", data=updated_mapping)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user