diff --git a/backend/services/data-management-service/src/main/java/com/datamate/datamanagement/application/KnowledgeItemApplicationService.java b/backend/services/data-management-service/src/main/java/com/datamate/datamanagement/application/KnowledgeItemApplicationService.java index 5e32193..919e535 100644 --- a/backend/services/data-management-service/src/main/java/com/datamate/datamanagement/application/KnowledgeItemApplicationService.java +++ b/backend/services/data-management-service/src/main/java/com/datamate/datamanagement/application/KnowledgeItemApplicationService.java @@ -126,41 +126,53 @@ public class KnowledgeItemApplicationService { createDirectories(setDir); List items = new ArrayList<>(); + List savedFilePaths = new ArrayList<>(); - for (MultipartFile file : files) { - BusinessAssert.notNull(file, CommonErrorCode.PARAM_ERROR); - BusinessAssert.isTrue(!file.isEmpty(), CommonErrorCode.PARAM_ERROR); + try { + for (MultipartFile file : files) { + BusinessAssert.notNull(file, CommonErrorCode.PARAM_ERROR); + BusinessAssert.isTrue(!file.isEmpty(), CommonErrorCode.PARAM_ERROR); - String originalName = resolveOriginalFileName(file); - String safeOriginalName = sanitizeFileName(originalName); - if (StringUtils.isBlank(safeOriginalName)) { - safeOriginalName = "file"; + String originalName = resolveOriginalFileName(file); + String safeOriginalName = sanitizeFileName(originalName); + if (StringUtils.isBlank(safeOriginalName)) { + safeOriginalName = "file"; + } + + String extension = getFileExtension(safeOriginalName); + String storedName = UUID.randomUUID().toString() + + (StringUtils.isBlank(extension) ? "" : "." + extension); + Path targetPath = setDir.resolve(storedName).normalize(); + BusinessAssert.isTrue(targetPath.startsWith(setDir), CommonErrorCode.PARAM_ERROR); + + saveMultipartFile(file, targetPath); + savedFilePaths.add(targetPath); + + KnowledgeItem knowledgeItem = new KnowledgeItem(); + knowledgeItem.setId(UUID.randomUUID().toString()); + knowledgeItem.setSetId(setId); + knowledgeItem.setContent(buildRelativeFilePath(setId, storedName)); + knowledgeItem.setContentType(KnowledgeContentType.FILE); + knowledgeItem.setSourceType(KnowledgeSourceType.FILE_UPLOAD); + knowledgeItem.setSourceFileId(trimToLength(safeOriginalName, MAX_TITLE_LENGTH)); + knowledgeItem.setRelativePath(buildRelativePath(parentPrefix, safeOriginalName)); + + items.add(knowledgeItem); } - String extension = getFileExtension(safeOriginalName); - String storedName = UUID.randomUUID().toString() + - (StringUtils.isBlank(extension) ? "" : "." + extension); - Path targetPath = setDir.resolve(storedName).normalize(); - BusinessAssert.isTrue(targetPath.startsWith(setDir), CommonErrorCode.PARAM_ERROR); - - saveMultipartFile(file, targetPath); - - KnowledgeItem knowledgeItem = new KnowledgeItem(); - knowledgeItem.setId(UUID.randomUUID().toString()); - knowledgeItem.setSetId(setId); - knowledgeItem.setContent(buildRelativeFilePath(setId, storedName)); - knowledgeItem.setContentType(KnowledgeContentType.FILE); - knowledgeItem.setSourceType(KnowledgeSourceType.FILE_UPLOAD); - knowledgeItem.setSourceFileId(trimToLength(safeOriginalName, MAX_TITLE_LENGTH)); - knowledgeItem.setRelativePath(buildRelativePath(parentPrefix, safeOriginalName)); - - items.add(knowledgeItem); + if (CollectionUtils.isNotEmpty(items)) { + knowledgeItemRepository.saveBatch(items, items.size()); + } + return items; + } catch (Exception e) { + for (Path filePath : savedFilePaths) { + deleteFileQuietly(filePath); + } + if (e instanceof BusinessException) { + throw (BusinessException) e; + } + throw BusinessException.of(SystemErrorCode.FILE_SYSTEM_ERROR); } - - if (CollectionUtils.isNotEmpty(items)) { - knowledgeItemRepository.saveBatch(items, items.size()); - } - return items; } public KnowledgeItem updateKnowledgeItem(String setId, String itemId, UpdateKnowledgeItemRequest request) { @@ -190,6 +202,9 @@ public class KnowledgeItemApplicationService { KnowledgeItem knowledgeItem = knowledgeItemRepository.getById(itemId); BusinessAssert.notNull(knowledgeItem, DataManagementErrorCode.KNOWLEDGE_ITEM_NOT_FOUND); BusinessAssert.isTrue(Objects.equals(knowledgeItem.getSetId(), setId), CommonErrorCode.PARAM_ERROR); + + deleteKnowledgeItemFile(knowledgeItem); + knowledgeItemPreviewService.deletePreviewFileQuietly(setId, itemId); knowledgeItemRepository.removeById(itemId); } @@ -205,6 +220,11 @@ public class KnowledgeItemApplicationService { boolean allMatch = items.stream().allMatch(item -> Objects.equals(item.getSetId(), setId)); BusinessAssert.isTrue(allMatch, CommonErrorCode.PARAM_ERROR); + for (KnowledgeItem item : items) { + deleteKnowledgeItemFile(item); + knowledgeItemPreviewService.deletePreviewFileQuietly(setId, item.getId()); + } + List deleteIds = items.stream().map(KnowledgeItem::getId).toList(); knowledgeItemRepository.removeByIds(deleteIds); } @@ -785,6 +805,24 @@ public class KnowledgeItemApplicationService { } } + private void deleteKnowledgeItemFile(KnowledgeItem knowledgeItem) { + if (knowledgeItem == null) { + return; + } + if (knowledgeItem.getSourceType() == KnowledgeSourceType.FILE_UPLOAD + || knowledgeItem.getContentType() == KnowledgeContentType.FILE) { + String relativePath = knowledgeItem.getContent(); + if (StringUtils.isNotBlank(relativePath)) { + try { + Path filePath = resolveKnowledgeItemStoragePath(relativePath); + deleteFileQuietly(filePath); + } catch (Exception e) { + log.warn("delete knowledge item file error, itemId: {}, path: {}", knowledgeItem.getId(), relativePath, e); + } + } + } + } + private String resolveOriginalFileName(MultipartFile file) { String originalName = file.getOriginalFilename(); if (StringUtils.isBlank(originalName)) { diff --git a/runtime/datamate-python/app/module/annotation/interface/project.py b/runtime/datamate-python/app/module/annotation/interface/project.py index 44f0159..99760bb 100644 --- a/runtime/datamate-python/app/module/annotation/interface/project.py +++ b/runtime/datamate-python/app/module/annotation/interface/project.py @@ -14,6 +14,7 @@ from app.core.logging import get_logger from ..service.mapping import DatasetMappingService from ..service.template import AnnotationTemplateService +from ..service.knowledge_sync import KnowledgeSyncService from ..schema import ( DatasetMappingCreateRequest, DatasetMappingCreateResponse, @@ -22,26 +23,26 @@ from ..schema import ( DatasetMappingResponse, ) -router = APIRouter( - prefix="/project", - tags=["annotation/project"] -) +router = APIRouter(prefix="/project", tags=["annotation/project"]) logger = get_logger(__name__) TEXT_DATASET_TYPE = "TEXT" SOURCE_DOCUMENT_FILE_TYPES = {"pdf", "doc", "docx", "xls", "xlsx"} LABELING_TYPE_CONFIG_KEY = "labeling_type" -@router.get("/{mapping_id}/login") -async def login_label_studio( - mapping_id: str, - db: AsyncSession = Depends(get_db) -): - raise HTTPException(status_code=410, detail="当前为内嵌编辑器模式,不再支持 Label Studio 登录代理接口") -@router.post("", response_model=StandardResponse[DatasetMappingCreateResponse], status_code=201) +@router.get("/{mapping_id}/login") +async def login_label_studio(mapping_id: str, db: AsyncSession = Depends(get_db)): + raise HTTPException( + status_code=410, + detail="当前为内嵌编辑器模式,不再支持 Label Studio 登录代理接口", + ) + + +@router.post( + "", response_model=StandardResponse[DatasetMappingCreateResponse], status_code=201 +) async def create_mapping( - request: DatasetMappingCreateRequest, - db: AsyncSession = Depends(get_db) + request: DatasetMappingCreateRequest, db: AsyncSession = Depends(get_db) ): """ 创建数据集映射 @@ -64,7 +65,7 @@ async def create_mapping( if not dataset_info: raise HTTPException( status_code=404, - detail=f"Dataset not found in DM service: {request.dataset_id}" + detail=f"Dataset not found in DM service: {request.dataset_id}", ) dataset_type = ( @@ -73,13 +74,15 @@ async def create_mapping( or "" ).upper() - project_name = request.name or \ - dataset_info.name or \ - "A new project from DataMate" + project_name = ( + request.name or dataset_info.name or "A new project from DataMate" + ) - project_description = request.description or \ - dataset_info.description or \ - f"Imported from DM dataset {dataset_info.name} ({dataset_info.id})" + project_description = ( + request.description + or dataset_info.description + or f"Imported from DM dataset {dataset_info.name} ({dataset_info.id})" + ) # 如果提供了模板ID,获取模板配置 label_config = None @@ -89,8 +92,7 @@ async def create_mapping( template = await template_service.get_template(db, request.template_id) if not template: raise HTTPException( - status_code=404, - detail=f"Template not found: {request.template_id}" + status_code=404, detail=f"Template not found: {request.template_id}" ) label_config = template.label_config template_labeling_type = getattr(template, "labeling_type", None) @@ -110,19 +112,24 @@ async def create_mapping( project_configuration["label_config"] = label_config if project_description: project_configuration["description"] = project_description - if dataset_type == TEXT_DATASET_TYPE and request.segmentation_enabled is not None: - project_configuration["segmentation_enabled"] = bool(request.segmentation_enabled) + if ( + dataset_type == TEXT_DATASET_TYPE + and request.segmentation_enabled is not None + ): + project_configuration["segmentation_enabled"] = bool( + request.segmentation_enabled + ) if template_labeling_type: project_configuration[LABELING_TYPE_CONFIG_KEY] = template_labeling_type labeling_project = LabelingProject( - id=str(uuid.uuid4()), # Generate UUID here - dataset_id=request.dataset_id, - labeling_project_id=labeling_project_id, - name=project_name, - template_id=request.template_id, # Save template_id to database - configuration=project_configuration or None, - ) + id=str(uuid.uuid4()), # Generate UUID here + dataset_id=request.dataset_id, + labeling_project_id=labeling_project_id, + name=project_name, + template_id=request.template_id, # Save template_id to database + configuration=project_configuration or None, + ) file_result = await db.execute( select(DatasetFiles).where( @@ -143,9 +150,7 @@ async def create_mapping( snapshot_file_ids.append(str(file_record.id)) else: snapshot_file_ids = [ - str(file_record.id) - for file_record in file_records - if file_record.id + str(file_record.id) for file_record in file_records if file_record.id ] # 创建映射关系并写入快照 @@ -157,25 +162,30 @@ async def create_mapping( if dataset_type == TEXT_DATASET_TYPE and request.segmentation_enabled: try: from ..service.editor import AnnotationEditorService + editor_service = AnnotationEditorService(db) # 异步预计算切片(不阻塞创建响应) - segmentation_result = await editor_service.precompute_segmentation_for_project(labeling_project.id) - logger.info(f"Precomputed segmentation for project {labeling_project.id}: {segmentation_result}") + segmentation_result = ( + await editor_service.precompute_segmentation_for_project( + labeling_project.id + ) + ) + logger.info( + f"Precomputed segmentation for project {labeling_project.id}: {segmentation_result}" + ) except Exception as e: - logger.warning(f"Failed to precompute segmentation for project {labeling_project.id}: {e}") + logger.warning( + f"Failed to precompute segmentation for project {labeling_project.id}: {e}" + ) # 不影响项目创建,只记录警告 response_data = DatasetMappingCreateResponse( id=mapping.id, labeling_project_id=str(mapping.labeling_project_id), - labeling_project_name=mapping.name or project_name + labeling_project_name=mapping.name or project_name, ) - return StandardResponse( - code=201, - message="success", - data=response_data - ) + return StandardResponse(code=201, message="success", data=response_data) except HTTPException: raise @@ -183,12 +193,15 @@ async def create_mapping( logger.error(f"Error while creating dataset mapping: {e}") raise HTTPException(status_code=500, detail="Internal server error") + @router.get("", response_model=StandardResponse[PaginatedData[DatasetMappingResponse]]) async def list_mappings( page: int = Query(1, ge=1, description="页码(从1开始)"), size: int = Query(20, ge=1, le=100, description="每页记录数"), - include_template: bool = Query(False, description="是否包含模板详情", alias="includeTemplate"), - db: AsyncSession = Depends(get_db) + include_template: bool = Query( + False, description="是否包含模板详情", alias="includeTemplate" + ), + db: AsyncSession = Depends(get_db), ): """ 查询所有映射关系(分页) @@ -207,14 +220,16 @@ async def list_mappings( # 计算 skip skip = (page - 1) * size - logger.info(f"List mappings: page={page}, size={size}, include_template={include_template}") + logger.info( + f"List mappings: page={page}, size={size}, include_template={include_template}" + ) # 获取数据和总数 mappings, total = await service.get_all_mappings_with_count( skip=skip, limit=size, include_deleted=False, - include_template=include_template + include_template=include_template, ) # 计算总页数 @@ -226,26 +241,22 @@ async def list_mappings( size=size, total_elements=total, total_pages=total_pages, - content=mappings + content=mappings, ) - logger.info(f"List mappings: page={page}, returned {len(mappings)}/{total}, templates_included: {include_template}") - - return StandardResponse( - code=200, - message="success", - data=paginated_data + logger.info( + f"List mappings: page={page}, returned {len(mappings)}/{total}, templates_included: {include_template}" ) + return StandardResponse(code=200, message="success", data=paginated_data) + except Exception as e: logger.error(f"Error listing mappings: {e}") raise HTTPException(status_code=500, detail="Internal server error") + @router.get("/{mapping_id}", response_model=StandardResponse[DatasetMappingResponse]) -async def get_mapping( - mapping_id: str, - db: AsyncSession = Depends(get_db) -): +async def get_mapping(mapping_id: str, db: AsyncSession = Depends(get_db)): """ 根据 UUID 查询单个映射关系(包含关联的标注模板详情) @@ -265,31 +276,34 @@ async def get_mapping( if not mapping: raise HTTPException( - status_code=404, - detail=f"Mapping not found: {mapping_id}" + status_code=404, detail=f"Mapping not found: {mapping_id}" ) - logger.info(f"Found mapping: {mapping.id}, template_included: {mapping.template is not None}") - - return StandardResponse( - code=200, - message="success", - data=mapping + logger.info( + f"Found mapping: {mapping.id}, template_included: {mapping.template is not None}" ) + return StandardResponse(code=200, message="success", data=mapping) + except HTTPException: raise except Exception as e: logger.error(f"Error getting mapping: {e}") raise HTTPException(status_code=500, detail="Internal server error") -@router.get("/by-source/{dataset_id}", response_model=StandardResponse[PaginatedData[DatasetMappingResponse]]) + +@router.get( + "/by-source/{dataset_id}", + response_model=StandardResponse[PaginatedData[DatasetMappingResponse]], +) async def get_mappings_by_source( dataset_id: str, page: int = Query(1, ge=1, description="页码(从1开始)"), size: int = Query(20, ge=1, le=100, description="每页记录数"), - include_template: bool = Query(True, description="是否包含模板详情", alias="includeTemplate"), - db: AsyncSession = Depends(get_db) + include_template: bool = Query( + True, description="是否包含模板详情", alias="includeTemplate" + ), + db: AsyncSession = Depends(get_db), ): """ 根据源数据集 ID 查询所有映射关系(分页,包含模板详情) @@ -309,14 +323,16 @@ async def get_mappings_by_source( # 计算 skip skip = (page - 1) * size - logger.info(f"Get mappings by source dataset id: {dataset_id}, page={page}, size={size}, include_template={include_template}") + logger.info( + f"Get mappings by source dataset id: {dataset_id}, page={page}, size={size}, include_template={include_template}" + ) # 获取数据和总数(包含模板信息) mappings, total = await service.get_mappings_by_source_with_count( dataset_id=dataset_id, skip=skip, limit=size, - include_template=include_template + include_template=include_template, ) # 计算总页数 @@ -328,27 +344,26 @@ async def get_mappings_by_source( size=size, total_elements=total, total_pages=total_pages, - content=mappings + content=mappings, ) - logger.info(f"Found {len(mappings)} mappings on page {page}, total: {total}, templates_included: {include_template}") - - return StandardResponse( - code=200, - message="success", - data=paginated_data + logger.info( + f"Found {len(mappings)} mappings on page {page}, total: {total}, templates_included: {include_template}" ) + return StandardResponse(code=200, message="success", data=paginated_data) + except HTTPException: raise except Exception as e: logger.error(f"Error getting mappings: {e}") raise HTTPException(status_code=500, detail="Internal server error") + @router.delete("/{project_id}", response_model=StandardResponse[DeleteDatasetResponse]) async def delete_mapping( project_id: str = Path(..., description="映射UUID(path param)"), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ): """ 删除映射关系(软删除) @@ -370,12 +385,12 @@ async def delete_mapping( if not mapping: raise HTTPException( - status_code=404, - detail=f"Mapping either not found or not specified." + status_code=404, detail=f"Mapping either not found or not specified." ) id = mapping.id - logger.debug(f"Found mapping: {id}") + dataset_id = mapping.dataset_id + logger.debug(f"Found mapping: {id}, dataset_id: {dataset_id}") # 软删除映射记录 soft_delete_success = await service.soft_delete_mapping(id) @@ -383,19 +398,22 @@ async def delete_mapping( if not soft_delete_success: raise HTTPException( - status_code=500, - detail="Failed to delete mapping record" + status_code=500, detail="Failed to delete mapping record" ) + # 清理知识集中的关联数据 + try: + knowledge_sync = KnowledgeSyncService(db) + await knowledge_sync._cleanup_knowledge_set_for_project(id) + except Exception as exc: + logger.warning(f"清理知识集失败:project_id={id} error={exc}") + logger.info(f"Successfully deleted mapping: {id}") return StandardResponse( code=200, message="success", - data=DeleteDatasetResponse( - id=id, - status="success" - ) + data=DeleteDatasetResponse(id=id, status="success"), ) except HTTPException: @@ -409,7 +427,7 @@ async def delete_mapping( async def update_mapping( project_id: str = Path(..., description="映射UUID(path param)"), request: DatasetMappingUpdateRequest = None, - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ): """ 更新标注项目信息 @@ -429,16 +447,14 @@ async def update_mapping( # 直接查询 ORM 模型获取原始数据 result = await db.execute( select(LabelingProject).where( - LabelingProject.id == project_id, - LabelingProject.deleted_at.is_(None) + LabelingProject.id == project_id, LabelingProject.deleted_at.is_(None) ) ) mapping_orm = result.scalar_one_or_none() if not mapping_orm: raise HTTPException( - status_code=404, - detail=f"Mapping not found: {project_id}" + status_code=404, detail=f"Mapping not found: {project_id}" ) # 构建更新数据 @@ -449,7 +465,11 @@ async def update_mapping( # 从 configuration 字段中读取和更新 description 和 label_config configuration = {} if mapping_orm.configuration: - configuration = mapping_orm.configuration.copy() if isinstance(mapping_orm.configuration, dict) else {} + configuration = ( + mapping_orm.configuration.copy() + if isinstance(mapping_orm.configuration, dict) + else {} + ) if request.description is not None: configuration["description"] = request.description @@ -468,7 +488,7 @@ async def update_mapping( if not template: raise HTTPException( status_code=404, - detail=f"Template not found: {request.template_id}" + detail=f"Template not found: {request.template_id}", ) template_labeling_type = getattr(template, "labeling_type", None) if template_labeling_type: @@ -477,14 +497,11 @@ async def update_mapping( if not update_values: # 没有要更新的字段,直接返回当前数据 response_data = await service.get_mapping_by_uuid(project_id) - return StandardResponse( - code=200, - message="success", - data=response_data - ) + return StandardResponse(code=200, message="success", data=response_data) # 执行更新 from datetime import datetime + update_values["updated_at"] = datetime.now() result = await db.execute( @@ -495,21 +512,14 @@ async def update_mapping( await db.commit() if result.rowcount == 0: - raise HTTPException( - status_code=500, - detail="Failed to update mapping" - ) + raise HTTPException(status_code=500, detail="Failed to update mapping") # 重新获取更新后的数据 updated_mapping = await service.get_mapping_by_uuid(project_id) logger.info(f"Successfully updated mapping: {project_id}") - return StandardResponse( - code=200, - message="success", - data=updated_mapping - ) + return StandardResponse(code=200, message="success", data=updated_mapping) except HTTPException: raise diff --git a/runtime/datamate-python/app/module/annotation/schema/export.py b/runtime/datamate-python/app/module/annotation/schema/export.py index 5887d4a..dd5661a 100644 --- a/runtime/datamate-python/app/module/annotation/schema/export.py +++ b/runtime/datamate-python/app/module/annotation/schema/export.py @@ -13,17 +13,21 @@ from pydantic import BaseModel, Field class ExportFormat(str, Enum): """导出格式枚举""" - JSON = "json" # Label Studio 原生 JSON 格式 - JSONL = "jsonl" # JSON Lines 格式(每行一条记录) - CSV = "csv" # CSV 表格格式 - COCO = "coco" # COCO 目标检测格式 - YOLO = "yolo" # YOLO 格式 + + JSON = "json" # Label Studio 原生 JSON 格式 + JSONL = "jsonl" # JSON Lines 格式(每行一条记录) + CSV = "csv" # CSV 表格格式 + COCO = "coco" # COCO 目标检测格式 + YOLO = "yolo" # YOLO 格式 class ExportAnnotationsRequest(BaseModel): """导出标注数据请求""" + format: ExportFormat = Field(default=ExportFormat.JSON, description="导出格式") - include_data: bool = Field(default=True, description="是否包含原始数据(如文本内容)") + include_data: bool = Field( + default=True, description="是否包含原始数据(如文本内容)" + ) only_annotated: bool = Field(default=True, description="是否只导出已标注的数据") model_config = {"use_enum_values": True} @@ -31,6 +35,7 @@ class ExportAnnotationsRequest(BaseModel): class ExportAnnotationsResponse(BaseModel): """导出标注数据响应(用于预览/统计)""" + project_id: str = Field(..., description="项目ID") project_name: str = Field(..., description="项目名称") total_files: int = Field(..., description="总文件数") @@ -42,16 +47,21 @@ class ExportAnnotationsResponse(BaseModel): class AnnotationExportItem(BaseModel): """单条导出记录""" + file_id: str = Field(..., description="文件ID") file_name: str = Field(..., description="文件名") + file_path: Optional[str] = Field(default=None, description="文件路径") data: Optional[Dict[str, Any]] = Field(default=None, description="原始数据") - annotations: List[Dict[str, Any]] = Field(default_factory=list, description="标注结果") + annotations: List[Dict[str, Any]] = Field( + default_factory=list, description="标注结果" + ) created_at: Optional[datetime] = Field(default=None, description="创建时间") updated_at: Optional[datetime] = Field(default=None, description="更新时间") class COCOExportFormat(BaseModel): """COCO 格式导出结构""" + info: Dict[str, Any] = Field(default_factory=dict) licenses: List[Dict[str, Any]] = Field(default_factory=list) images: List[Dict[str, Any]] = Field(default_factory=list) diff --git a/runtime/datamate-python/app/module/annotation/service/export.py b/runtime/datamate-python/app/module/annotation/service/export.py index dde4553..0b502b0 100644 --- a/runtime/datamate-python/app/module/annotation/service/export.py +++ b/runtime/datamate-python/app/module/annotation/service/export.py @@ -21,20 +21,29 @@ from datetime import datetime from typing import Any, Dict, List, Optional, Tuple from fastapi import HTTPException +from PIL import Image from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from app.core.logging import get_logger -from app.db.models import AnnotationResult, Dataset, DatasetFiles, LabelingProject, LabelingProjectFile +from app.db.models import ( + AnnotationResult, + Dataset, + DatasetFiles, + LabelingProject, + LabelingProjectFile, +) -async def _read_file_content(file_path: str, max_size: int = 10 * 1024 * 1024) -> Optional[str]: +async def _read_file_content( + file_path: str, max_size: int = 10 * 1024 * 1024 +) -> Optional[str]: """读取文件内容,仅适用于文本文件 - + Args: file_path: 文件路径 max_size: 最大读取字节数(默认10MB) - + Returns: 文件内容字符串,如果读取失败返回 None """ @@ -42,17 +51,18 @@ async def _read_file_content(file_path: str, max_size: int = 10 * 1024 * 1024) - # 检查文件是否存在且大小在限制内 if not os.path.exists(file_path): return None - + file_size = os.path.getsize(file_path) if file_size > max_size: return f"[File too large: {file_size} bytes]" - + # 尝试以文本方式读取 - with open(file_path, 'r', encoding='utf-8', errors='replace') as f: + with open(file_path, "r", encoding="utf-8", errors="replace") as f: return f.read() except Exception: return None + from ..schema.export import ( AnnotationExportItem, COCOExportFormat, @@ -79,7 +89,9 @@ class AnnotationExportService: async def get_export_stats(self, project_id: str) -> ExportAnnotationsResponse: """获取导出统计信息""" project = await self._get_project_or_404(project_id) - logger.info(f"Export stats for project: id={project_id}, dataset_id={project.dataset_id}, name={project.name}") + logger.info( + f"Export stats for project: id={project_id}, dataset_id={project.dataset_id}, name={project.name}" + ) # 获取总文件数(标注项目快照内的文件) total_result = await self.db.execute( @@ -92,7 +104,9 @@ class AnnotationExportService: ) ) total_files = int(total_result.scalar() or 0) - logger.info(f"Total files (snapshot): {total_files} for project_id={project_id}") + logger.info( + f"Total files (snapshot): {total_files} for project_id={project_id}" + ) # 获取已标注文件数(统计不同的 file_id 数量) annotated_result = await self.db.execute( @@ -132,7 +146,11 @@ class AnnotationExportService: ) # 根据格式导出 - format_type = ExportFormat(request.format) if isinstance(request.format, str) else request.format + format_type = ( + ExportFormat(request.format) + if isinstance(request.format, str) + else request.format + ) if format_type == ExportFormat.JSON: return self._export_json(items, project.name) @@ -145,7 +163,9 @@ class AnnotationExportService: elif format_type == ExportFormat.YOLO: return self._export_yolo(items, project.name) else: - raise HTTPException(status_code=400, detail=f"不支持的导出格式: {request.format}") + raise HTTPException( + status_code=400, detail=f"不支持的导出格式: {request.format}" + ) async def _get_project_or_404(self, project_id: str) -> LabelingProject: """获取标注项目,不存在则抛出 404""" @@ -174,7 +194,10 @@ class AnnotationExportService: # 只获取已标注的数据 result = await self.db.execute( select(AnnotationResult, DatasetFiles) - .join(LabelingProjectFile, LabelingProjectFile.file_id == AnnotationResult.file_id) + .join( + LabelingProjectFile, + LabelingProjectFile.file_id == AnnotationResult.file_id, + ) .join(DatasetFiles, AnnotationResult.file_id == DatasetFiles.id) .where( AnnotationResult.project_id == project_id, @@ -192,11 +215,12 @@ class AnnotationExportService: if include_data: file_path = getattr(file, "file_path", "") file_content = await _read_file_content(file_path) - + items.append( AnnotationExportItem( file_id=str(file.id), file_name=str(getattr(file, "file_name", "")), + file_path=str(getattr(file, "file_path", "")), data={"text": file_content} if include_data else None, annotations=[annotation_data] if annotation_data else [], created_at=ann.created_at, @@ -207,7 +231,9 @@ class AnnotationExportService: # 获取所有文件(基于标注项目快照) files_result = await self.db.execute( select(DatasetFiles) - .join(LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id) + .join( + LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id + ) .where( LabelingProjectFile.project_id == project_id, DatasetFiles.dataset_id == dataset_id, @@ -217,7 +243,9 @@ class AnnotationExportService: # 获取已有的标注 ann_result = await self.db.execute( - select(AnnotationResult).where(AnnotationResult.project_id == project_id) + select(AnnotationResult).where( + AnnotationResult.project_id == project_id + ) ) annotations = {str(a.file_id): a for a in ann_result.scalars().all()} @@ -225,7 +253,7 @@ class AnnotationExportService: file_id = str(file.id) ann = annotations.get(file_id) annotation_data = ann.annotation if ann else {} - + # 获取文件内容(如果是文本文件且用户要求包含数据) file_content = None if include_data: @@ -236,6 +264,7 @@ class AnnotationExportService: AnnotationExportItem( file_id=file_id, file_name=str(getattr(file, "file_name", "")), + file_path=str(getattr(file, "file_path", "")), data={"text": file_content} if include_data else None, annotations=[annotation_data] if annotation_data else [], created_at=ann.created_at if ann else None, @@ -262,8 +291,13 @@ class AnnotationExportService: for item in segment_results: if isinstance(item, dict): normalized = dict(item) - if SEGMENT_INDEX_KEY not in normalized and SEGMENT_INDEX_FALLBACK_KEY not in normalized: - normalized[SEGMENT_INDEX_KEY] = int(key) if str(key).isdigit() else key + if ( + SEGMENT_INDEX_KEY not in normalized + and SEGMENT_INDEX_FALLBACK_KEY not in normalized + ): + normalized[SEGMENT_INDEX_KEY] = ( + int(key) if str(key).isdigit() else key + ) results.append(normalized) elif isinstance(segments, list): for idx, segment in enumerate(segments): @@ -272,11 +306,16 @@ class AnnotationExportService: segment_results = segment.get(SEGMENT_RESULT_KEY) if not isinstance(segment_results, list): continue - segment_index = segment.get(SEGMENT_INDEX_KEY, segment.get(SEGMENT_INDEX_FALLBACK_KEY, idx)) + segment_index = segment.get( + SEGMENT_INDEX_KEY, segment.get(SEGMENT_INDEX_FALLBACK_KEY, idx) + ) for item in segment_results: if isinstance(item, dict): normalized = dict(item) - if SEGMENT_INDEX_KEY not in normalized and SEGMENT_INDEX_FALLBACK_KEY not in normalized: + if ( + SEGMENT_INDEX_KEY not in normalized + and SEGMENT_INDEX_FALLBACK_KEY not in normalized + ): normalized[SEGMENT_INDEX_KEY] = segment_index results.append(normalized) return results @@ -284,18 +323,43 @@ class AnnotationExportService: return result if isinstance(result, list) else [] @classmethod - def _normalize_annotation_for_export(cls, annotation: Dict[str, Any]) -> Dict[str, Any]: + def _normalize_annotation_for_export( + cls, annotation: Dict[str, Any] + ) -> Dict[str, Any]: if not annotation or not isinstance(annotation, dict): return {} segments = annotation.get(SEGMENTS_KEY) if annotation.get(SEGMENTED_KEY) or isinstance(segments, (dict, list)): normalized = dict(annotation) normalized_result = cls._flatten_annotation_results(annotation) - if SEGMENT_RESULT_KEY not in normalized or not isinstance(normalized.get(SEGMENT_RESULT_KEY), list): + if SEGMENT_RESULT_KEY not in normalized or not isinstance( + normalized.get(SEGMENT_RESULT_KEY), list + ): normalized[SEGMENT_RESULT_KEY] = normalized_result return normalized return annotation + @staticmethod + def _get_image_dimensions(file_path: str) -> Tuple[int, int]: + """获取图片文件的宽度和高度 + + Args: + file_path: 图片文件路径 + + Returns: + (width, height) 元组,如果读取失败则返回 (1920, 1080) 作为默认值 + """ + try: + if os.path.exists(file_path): + with Image.open(file_path) as img: + width, height = img.size + return width, height + except Exception as e: + logger.warning(f"Failed to read image dimensions from {file_path}: {e}") + + # 使用合理的默认值 + return 1920, 1080 + def _export_json( self, items: List[AnnotationExportItem], project_name: str ) -> Tuple[bytes, str, str]: @@ -309,9 +373,16 @@ class AnnotationExportService: "file_id": item.file_id, "file_name": item.file_name, "data": item.data, - "annotations": [self._normalize_annotation_for_export(ann) for ann in item.annotations], - "created_at": item.created_at.isoformat() if item.created_at else None, - "updated_at": item.updated_at.isoformat() if item.updated_at else None, + "annotations": [ + self._normalize_annotation_for_export(ann) + for ann in item.annotations + ], + "created_at": item.created_at.isoformat() + if item.created_at + else None, + "updated_at": item.updated_at.isoformat() + if item.updated_at + else None, } for item in items ], @@ -331,7 +402,10 @@ class AnnotationExportService: "file_id": item.file_id, "file_name": item.file_name, "data": item.data, - "annotations": [self._normalize_annotation_for_export(ann) for ann in item.annotations], + "annotations": [ + self._normalize_annotation_for_export(ann) + for ann in item.annotations + ], "created_at": item.created_at.isoformat() if item.created_at else None, "updated_at": item.updated_at.isoformat() if item.updated_at else None, } @@ -368,7 +442,7 @@ class AnnotationExportService: for r in results: value = r.get("value", {}) label_type = r.get("type", "") - + # 提取不同类型的标签值 if "choices" in value: labels.extend(value["choices"]) @@ -389,14 +463,22 @@ class AnnotationExportService: elif "transcription" in value: labels.append(value["transcription"]) - writer.writerow({ - "file_id": item.file_id, - "file_name": item.file_name, - "annotation_result": json.dumps(item.annotations, ensure_ascii=False), - "labels": "|".join(labels), - "created_at": item.created_at.isoformat() if item.created_at else "", - "updated_at": item.updated_at.isoformat() if item.updated_at else "", - }) + writer.writerow( + { + "file_id": item.file_id, + "file_name": item.file_name, + "annotation_result": json.dumps( + item.annotations, ensure_ascii=False + ), + "labels": "|".join(labels), + "created_at": item.created_at.isoformat() + if item.created_at + else "", + "updated_at": item.updated_at.isoformat() + if item.updated_at + else "", + } + ) content = output.getvalue().encode("utf-8-sig") # BOM for Excel compatibility filename = f"{project_name}_annotations.csv" @@ -405,11 +487,7 @@ class AnnotationExportService: def _export_coco( self, items: List[AnnotationExportItem], project_name: str ) -> Tuple[bytes, str, str]: - """导出为 COCO 格式(适用于目标检测标注) - - 注意:当前实现中图片宽高被设置为0,因为需要读取实际图片文件获取尺寸。 - bbox 坐标使用 Label Studio 的百分比值(0-100),使用时需要转换为像素坐标。 - """ + """导出为 COCO 格式(适用于目标检测标注)""" coco_format = COCOExportFormat( info={ "description": f"Exported from DataMate project: {project_name}", @@ -429,13 +507,18 @@ class AnnotationExportService: for idx, item in enumerate(items): image_id = idx + 1 + # 获取图片实际尺寸 + img_width, img_height = self._get_image_dimensions(item.file_path or "") + # 添加图片信息 - coco_format.images.append({ - "id": image_id, - "file_name": item.file_name, - "width": 0, # 需要实际图片尺寸 - "height": 0, - }) + coco_format.images.append( + { + "id": image_id, + "file_name": item.file_name, + "width": img_width, + "height": img_height, + } + ) # 处理标注 for ann in item.annotations: @@ -449,29 +532,41 @@ class AnnotationExportService: for label in labels: if label not in category_map: category_map[label] = len(category_map) + 1 - coco_format.categories.append({ - "id": category_map[label], - "name": label, - "supercategory": "", - }) + coco_format.categories.append( + { + "id": category_map[label], + "name": label, + "supercategory": "", + } + ) - # 转换坐标(Label Studio 使用百分比) - x = value.get("x", 0) - y = value.get("y", 0) - width = value.get("width", 0) - height = value.get("height", 0) + # 转换坐标:Label Studio 使用百分比(0-100)转换为像素坐标 + x_percent = value.get("x", 0) + y_percent = value.get("y", 0) + width_percent = value.get("width", 0) + height_percent = value.get("height", 0) - coco_format.annotations.append({ - "id": annotation_id, - "image_id": image_id, - "category_id": category_map[label], - "bbox": [x, y, width, height], - "area": width * height, - "iscrowd": 0, - }) + # 转换为像素坐标 + x = x_percent * img_width / 100.0 + y = y_percent * img_height / 100.0 + width = width_percent * img_width / 100.0 + height = height_percent * img_height / 100.0 + + coco_format.annotations.append( + { + "id": annotation_id, + "image_id": image_id, + "category_id": category_map[label], + "bbox": [x, y, width, height], + "area": width * height, + "iscrowd": 0, + } + ) annotation_id += 1 - content = json.dumps(coco_format.model_dump(), ensure_ascii=False, indent=2).encode("utf-8") + content = json.dumps( + coco_format.model_dump(), ensure_ascii=False, indent=2 + ).encode("utf-8") filename = f"{project_name}_coco.json" return content, filename, "application/json" @@ -510,7 +605,9 @@ class AnnotationExportService: x_center = x + w / 2 y_center = y + h / 2 - lines.append(f"{label} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}") + lines.append( + f"{label} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}" + ) if lines: # 生成对应的 txt 文件名 diff --git a/runtime/datamate-python/app/module/annotation/service/knowledge_sync.py b/runtime/datamate-python/app/module/annotation/service/knowledge_sync.py index ae345ea..0883f6e 100644 --- a/runtime/datamate-python/app/module/annotation/service/knowledge_sync.py +++ b/runtime/datamate-python/app/module/annotation/service/knowledge_sync.py @@ -43,7 +43,9 @@ class KnowledgeSyncService: logger.warning("标注同步失败:无法获取知识集") return - item = await self._get_item_by_source(set_id, project.dataset_id, str(file_record.id)) + item = await self._get_item_by_source( + set_id, project.dataset_id, str(file_record.id) + ) if item and item.get("status") in {"PUBLISHED", "ARCHIVED", "DEPRECATED"}: logger.info( "知识条目为只读状态,跳过同步:item_id=%s status=%s", @@ -71,26 +73,46 @@ class KnowledgeSyncService: logger.warning("标注同步到知识管理失败:%s", exc) async def _ensure_knowledge_set(self, project: LabelingProject) -> Optional[str]: - config = project.configuration if isinstance(project.configuration, dict) else {} + result = await self.db.execute( + select(LabelingProject) + .where(LabelingProject.id == project.id) + .with_for_update() + ) + locked_project = result.scalar_one_or_none() + if not locked_project: + logger.warning("标注同步失败:无法锁定项目:project_id=%s", project.id) + return None + + config = ( + locked_project.configuration + if isinstance(locked_project.configuration, dict) + else {} + ) set_id = config.get(self.CONFIG_KEY_SET_ID) if set_id: exists = await self._get_knowledge_set(set_id) - if exists and self._metadata_matches_project(exists.get("metadata"), project.id): + if exists and self._metadata_matches_project( + exists.get("metadata"), locked_project.id + ): return set_id logger.warning( "知识集不存在或归属不匹配,准备重建:set_id=%s project_id=%s", set_id, - project.id, + locked_project.id, ) - project_name = (project.name or "annotation-project").strip() or "annotation-project" - metadata = self._build_set_metadata(project) + project_name = ( + locked_project.name or "annotation-project" + ).strip() or "annotation-project" + metadata = self._build_set_metadata(locked_project) - existing = await self._find_knowledge_set_by_name_and_project(project_name, project.id) + existing = await self._find_knowledge_set_by_name_and_project( + project_name, locked_project.id + ) if existing: await self._update_project_config( - project, + locked_project, { self.CONFIG_KEY_SET_ID: existing.get("id"), self.CONFIG_KEY_SET_NAME: existing.get("name"), @@ -100,23 +122,31 @@ class KnowledgeSyncService: created = await self._create_knowledge_set(project_name, metadata) if not created: - created = await self._find_knowledge_set_by_name_and_project(project_name, project.id) + created = await self._find_knowledge_set_by_name_and_project( + project_name, locked_project.id + ) if not created: - fallback_name = self._build_fallback_set_name(project_name, project.id) - existing = await self._find_knowledge_set_by_name_and_project(fallback_name, project.id) + fallback_name = self._build_fallback_set_name( + project_name, locked_project.id + ) + existing = await self._find_knowledge_set_by_name_and_project( + fallback_name, locked_project.id + ) if existing: created = existing else: created = await self._create_knowledge_set(fallback_name, metadata) if not created: - created = await self._find_knowledge_set_by_name_and_project(fallback_name, project.id) + created = await self._find_knowledge_set_by_name_and_project( + fallback_name, locked_project.id + ) if not created: return None await self._update_project_config( - project, + locked_project, { self.CONFIG_KEY_SET_ID: created.get("id"), self.CONFIG_KEY_SET_NAME: created.get("name"), @@ -126,13 +156,17 @@ class KnowledgeSyncService: async def _get_knowledge_set(self, set_id: str) -> Optional[Dict[str, Any]]: try: - return await self._request("GET", f"/data-management/knowledge-sets/{set_id}") + return await self._request( + "GET", f"/data-management/knowledge-sets/{set_id}" + ) except httpx.HTTPStatusError as exc: if exc.response.status_code == 404: return None raise - async def _list_knowledge_sets(self, keyword: Optional[str]) -> list[Dict[str, Any]]: + async def _list_knowledge_sets( + self, keyword: Optional[str] + ) -> list[Dict[str, Any]]: params: Dict[str, Any] = { "page": 1, "size": self.KNOWLEDGE_SET_LIST_SIZE, @@ -140,7 +174,9 @@ class KnowledgeSyncService: if keyword: params["keyword"] = keyword try: - data = await self._request("GET", "/data-management/knowledge-sets", params=params) + data = await self._request( + "GET", "/data-management/knowledge-sets", params=params + ) except httpx.HTTPStatusError as exc: logger.warning( "查询知识集失败:keyword=%s status=%s", @@ -155,7 +191,9 @@ class KnowledgeSyncService: return [] return [item for item in content if isinstance(item, dict)] - async def _find_knowledge_set_by_name_and_project(self, name: str, project_id: str) -> Optional[Dict[str, Any]]: + async def _find_knowledge_set_by_name_and_project( + self, name: str, project_id: str + ) -> Optional[Dict[str, Any]]: if not name: return None items = await self._list_knowledge_sets(name) @@ -168,7 +206,9 @@ class KnowledgeSyncService: return item return None - async def _create_knowledge_set(self, name: str, metadata: str) -> Optional[Dict[str, Any]]: + async def _create_knowledge_set( + self, name: str, metadata: str + ) -> Optional[Dict[str, Any]]: payload = { "name": name, "description": "标注项目自动创建的知识集", @@ -176,7 +216,9 @@ class KnowledgeSyncService: "metadata": metadata, } try: - return await self._request("POST", "/data-management/knowledge-sets", json=payload) + return await self._request( + "POST", "/data-management/knowledge-sets", json=payload + ) except httpx.HTTPStatusError as exc: logger.warning( "创建知识集失败:name=%s status=%s detail=%s", @@ -199,7 +241,9 @@ class KnowledgeSyncService: "sourceFileId": file_id, } try: - data = await self._request("GET", f"/data-management/knowledge-sets/{set_id}/items", params=params) + data = await self._request( + "GET", f"/data-management/knowledge-sets/{set_id}/items", params=params + ) except httpx.HTTPStatusError as exc: logger.warning( "查询知识条目失败:set_id=%s status=%s", @@ -216,9 +260,13 @@ class KnowledgeSyncService: return content[0] async def _create_item(self, set_id: str, payload: Dict[str, Any]) -> None: - await self._request("POST", f"/data-management/knowledge-sets/{set_id}/items", json=payload) + await self._request( + "POST", f"/data-management/knowledge-sets/{set_id}/items", json=payload + ) - async def _update_item(self, set_id: str, item_id: str, payload: Dict[str, Any]) -> None: + async def _update_item( + self, set_id: str, item_id: str, payload: Dict[str, Any] + ) -> None: update_payload = dict(payload) update_payload.pop("sourceDatasetId", None) update_payload.pop("sourceFileId", None) @@ -228,6 +276,62 @@ class KnowledgeSyncService: json=update_payload, ) + async def _cleanup_knowledge_set_for_project(self, project_id: str) -> None: + """清理项目关联的知识集及其所有知识条目""" + items = await self._list_knowledge_sets(None) + for item in items: + if self._metadata_matches_project(item.get("metadata"), project_id): + set_id = item.get("id") + if not set_id: + continue + try: + await self._request( + "DELETE", f"/data-management/knowledge-sets/{set_id}" + ) + logger.info( + "已删除知识集:set_id=%s project_id=%s", set_id, project_id + ) + except Exception as exc: + logger.warning( + "删除知识集失败:set_id=%s project_id=%s error=%s", + set_id, + project_id, + exc, + ) + + async def _cleanup_knowledge_item_for_file( + self, dataset_id: str, file_id: str + ) -> None: + """清理文件的知识条目""" + items = await self._list_knowledge_sets(None) + for set_item in items: + set_id = set_item.get("id") + if not set_id: + continue + item = await self._get_item_by_source(set_id, dataset_id, file_id) + if item and item.get("id"): + try: + await self._request( + "DELETE", + f"/data-management/knowledge-sets/{set_id}/items/{item['id']}", + ) + logger.info( + "已删除知识条目:item_id=%s set_id=%s dataset_id=%s file_id=%s", + item.get("id"), + set_id, + dataset_id, + file_id, + ) + except Exception as exc: + logger.warning( + "删除知识条目失败:item_id=%s set_id=%s dataset_id=%s file_id=%s error=%s", + item.get("id"), + set_id, + dataset_id, + file_id, + exc, + ) + async def _build_item_payload( self, project: LabelingProject, @@ -323,12 +427,28 @@ class KnowledgeSyncService: short_id = project_id.replace("-", "")[:8] return f"{base_name}-annotation-{short_id}" - async def _update_project_config(self, project: LabelingProject, updates: Dict[str, Any]) -> None: - config = project.configuration if isinstance(project.configuration, dict) else {} + async def _update_project_config( + self, project: LabelingProject, updates: Dict[str, Any] + ) -> None: + result = await self.db.execute( + select(LabelingProject) + .where(LabelingProject.id == project.id) + .with_for_update() + ) + locked_project = result.scalar_one_or_none() + if not locked_project: + logger.warning("更新项目配置失败:无法锁定项目:project_id=%s", project.id) + return + + config = ( + locked_project.configuration + if isinstance(locked_project.configuration, dict) + else {} + ) config.update(updates) - project.configuration = config + locked_project.configuration = config await self.db.commit() - await self.db.refresh(project) + await self.db.refresh(locked_project) async def _request(self, method: str, path: str, **kwargs) -> Any: url = f"{self.base_url}{path}"