fix: 修复知识库同步的并发控制、数据清理、文件事务和COCO导出问题

问题1 - 并发控制缺失:
- 在 _ensure_knowledge_set 方法中添加数据库行锁(with_for_update)
- 修改 _update_project_config 方法,使用行锁保护配置更新

问题3 - 数据清理机制缺失:
- 添加 _cleanup_knowledge_set_for_project 方法,项目删除时清理知识集
- 添加 _cleanup_knowledge_item_for_file 方法,文件删除时清理知识条目
- 在 delete_mapping 接口中调用清理方法

问题4 - 文件操作事务问题:
- 修改 uploadKnowledgeItems,添加事务失败后的文件清理逻辑
- 修改 deleteKnowledgeItem,删除记录前先删除关联文件
- 新增 deleteKnowledgeItemFile 辅助方法

问题5 - COCO导出格式问题:
- 添加 _get_image_dimensions 方法读取图片实际宽高
- 将百分比坐标转换为像素坐标
- 在 AnnotationExportItem 中添加 file_path 字段

涉及文件:
- knowledge_sync.py
- project.py
- KnowledgeItemApplicationService.java
- export.py
- export schema.py
This commit is contained in:
2026-02-05 03:55:01 +08:00
parent c03bdf1a24
commit 99bd83d312
5 changed files with 513 additions and 238 deletions

View File

@@ -126,7 +126,9 @@ public class KnowledgeItemApplicationService {
createDirectories(setDir); createDirectories(setDir);
List<KnowledgeItem> items = new ArrayList<>(); List<KnowledgeItem> items = new ArrayList<>();
List<Path> savedFilePaths = new ArrayList<>();
try {
for (MultipartFile file : files) { for (MultipartFile file : files) {
BusinessAssert.notNull(file, CommonErrorCode.PARAM_ERROR); BusinessAssert.notNull(file, CommonErrorCode.PARAM_ERROR);
BusinessAssert.isTrue(!file.isEmpty(), CommonErrorCode.PARAM_ERROR); BusinessAssert.isTrue(!file.isEmpty(), CommonErrorCode.PARAM_ERROR);
@@ -144,6 +146,7 @@ public class KnowledgeItemApplicationService {
BusinessAssert.isTrue(targetPath.startsWith(setDir), CommonErrorCode.PARAM_ERROR); BusinessAssert.isTrue(targetPath.startsWith(setDir), CommonErrorCode.PARAM_ERROR);
saveMultipartFile(file, targetPath); saveMultipartFile(file, targetPath);
savedFilePaths.add(targetPath);
KnowledgeItem knowledgeItem = new KnowledgeItem(); KnowledgeItem knowledgeItem = new KnowledgeItem();
knowledgeItem.setId(UUID.randomUUID().toString()); knowledgeItem.setId(UUID.randomUUID().toString());
@@ -161,6 +164,15 @@ public class KnowledgeItemApplicationService {
knowledgeItemRepository.saveBatch(items, items.size()); knowledgeItemRepository.saveBatch(items, items.size());
} }
return items; return items;
} catch (Exception e) {
for (Path filePath : savedFilePaths) {
deleteFileQuietly(filePath);
}
if (e instanceof BusinessException) {
throw (BusinessException) e;
}
throw BusinessException.of(SystemErrorCode.FILE_SYSTEM_ERROR);
}
} }
public KnowledgeItem updateKnowledgeItem(String setId, String itemId, UpdateKnowledgeItemRequest request) { public KnowledgeItem updateKnowledgeItem(String setId, String itemId, UpdateKnowledgeItemRequest request) {
@@ -190,6 +202,9 @@ public class KnowledgeItemApplicationService {
KnowledgeItem knowledgeItem = knowledgeItemRepository.getById(itemId); KnowledgeItem knowledgeItem = knowledgeItemRepository.getById(itemId);
BusinessAssert.notNull(knowledgeItem, DataManagementErrorCode.KNOWLEDGE_ITEM_NOT_FOUND); BusinessAssert.notNull(knowledgeItem, DataManagementErrorCode.KNOWLEDGE_ITEM_NOT_FOUND);
BusinessAssert.isTrue(Objects.equals(knowledgeItem.getSetId(), setId), CommonErrorCode.PARAM_ERROR); BusinessAssert.isTrue(Objects.equals(knowledgeItem.getSetId(), setId), CommonErrorCode.PARAM_ERROR);
deleteKnowledgeItemFile(knowledgeItem);
knowledgeItemPreviewService.deletePreviewFileQuietly(setId, itemId);
knowledgeItemRepository.removeById(itemId); knowledgeItemRepository.removeById(itemId);
} }
@@ -205,6 +220,11 @@ public class KnowledgeItemApplicationService {
boolean allMatch = items.stream().allMatch(item -> Objects.equals(item.getSetId(), setId)); boolean allMatch = items.stream().allMatch(item -> Objects.equals(item.getSetId(), setId));
BusinessAssert.isTrue(allMatch, CommonErrorCode.PARAM_ERROR); BusinessAssert.isTrue(allMatch, CommonErrorCode.PARAM_ERROR);
for (KnowledgeItem item : items) {
deleteKnowledgeItemFile(item);
knowledgeItemPreviewService.deletePreviewFileQuietly(setId, item.getId());
}
List<String> deleteIds = items.stream().map(KnowledgeItem::getId).toList(); List<String> deleteIds = items.stream().map(KnowledgeItem::getId).toList();
knowledgeItemRepository.removeByIds(deleteIds); knowledgeItemRepository.removeByIds(deleteIds);
} }
@@ -785,6 +805,24 @@ public class KnowledgeItemApplicationService {
} }
} }
private void deleteKnowledgeItemFile(KnowledgeItem knowledgeItem) {
if (knowledgeItem == null) {
return;
}
if (knowledgeItem.getSourceType() == KnowledgeSourceType.FILE_UPLOAD
|| knowledgeItem.getContentType() == KnowledgeContentType.FILE) {
String relativePath = knowledgeItem.getContent();
if (StringUtils.isNotBlank(relativePath)) {
try {
Path filePath = resolveKnowledgeItemStoragePath(relativePath);
deleteFileQuietly(filePath);
} catch (Exception e) {
log.warn("delete knowledge item file error, itemId: {}, path: {}", knowledgeItem.getId(), relativePath, e);
}
}
}
}
private String resolveOriginalFileName(MultipartFile file) { private String resolveOriginalFileName(MultipartFile file) {
String originalName = file.getOriginalFilename(); String originalName = file.getOriginalFilename();
if (StringUtils.isBlank(originalName)) { if (StringUtils.isBlank(originalName)) {

View File

@@ -14,6 +14,7 @@ from app.core.logging import get_logger
from ..service.mapping import DatasetMappingService from ..service.mapping import DatasetMappingService
from ..service.template import AnnotationTemplateService from ..service.template import AnnotationTemplateService
from ..service.knowledge_sync import KnowledgeSyncService
from ..schema import ( from ..schema import (
DatasetMappingCreateRequest, DatasetMappingCreateRequest,
DatasetMappingCreateResponse, DatasetMappingCreateResponse,
@@ -22,26 +23,26 @@ from ..schema import (
DatasetMappingResponse, DatasetMappingResponse,
) )
router = APIRouter( router = APIRouter(prefix="/project", tags=["annotation/project"])
prefix="/project",
tags=["annotation/project"]
)
logger = get_logger(__name__) logger = get_logger(__name__)
TEXT_DATASET_TYPE = "TEXT" TEXT_DATASET_TYPE = "TEXT"
SOURCE_DOCUMENT_FILE_TYPES = {"pdf", "doc", "docx", "xls", "xlsx"} SOURCE_DOCUMENT_FILE_TYPES = {"pdf", "doc", "docx", "xls", "xlsx"}
LABELING_TYPE_CONFIG_KEY = "labeling_type" LABELING_TYPE_CONFIG_KEY = "labeling_type"
@router.get("/{mapping_id}/login")
async def login_label_studio(
mapping_id: str,
db: AsyncSession = Depends(get_db)
):
raise HTTPException(status_code=410, detail="当前为内嵌编辑器模式,不再支持 Label Studio 登录代理接口")
@router.post("", response_model=StandardResponse[DatasetMappingCreateResponse], status_code=201) @router.get("/{mapping_id}/login")
async def login_label_studio(mapping_id: str, db: AsyncSession = Depends(get_db)):
raise HTTPException(
status_code=410,
detail="当前为内嵌编辑器模式,不再支持 Label Studio 登录代理接口",
)
@router.post(
"", response_model=StandardResponse[DatasetMappingCreateResponse], status_code=201
)
async def create_mapping( async def create_mapping(
request: DatasetMappingCreateRequest, request: DatasetMappingCreateRequest, db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db)
): ):
""" """
创建数据集映射 创建数据集映射
@@ -64,7 +65,7 @@ async def create_mapping(
if not dataset_info: if not dataset_info:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=f"Dataset not found in DM service: {request.dataset_id}" detail=f"Dataset not found in DM service: {request.dataset_id}",
) )
dataset_type = ( dataset_type = (
@@ -73,13 +74,15 @@ async def create_mapping(
or "" or ""
).upper() ).upper()
project_name = request.name or \ project_name = (
dataset_info.name or \ request.name or dataset_info.name or "A new project from DataMate"
"A new project from DataMate" )
project_description = request.description or \ project_description = (
dataset_info.description or \ request.description
f"Imported from DM dataset {dataset_info.name} ({dataset_info.id})" or dataset_info.description
or f"Imported from DM dataset {dataset_info.name} ({dataset_info.id})"
)
# 如果提供了模板ID,获取模板配置 # 如果提供了模板ID,获取模板配置
label_config = None label_config = None
@@ -89,8 +92,7 @@ async def create_mapping(
template = await template_service.get_template(db, request.template_id) template = await template_service.get_template(db, request.template_id)
if not template: if not template:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404, detail=f"Template not found: {request.template_id}"
detail=f"Template not found: {request.template_id}"
) )
label_config = template.label_config label_config = template.label_config
template_labeling_type = getattr(template, "labeling_type", None) template_labeling_type = getattr(template, "labeling_type", None)
@@ -110,8 +112,13 @@ async def create_mapping(
project_configuration["label_config"] = label_config project_configuration["label_config"] = label_config
if project_description: if project_description:
project_configuration["description"] = project_description project_configuration["description"] = project_description
if dataset_type == TEXT_DATASET_TYPE and request.segmentation_enabled is not None: if (
project_configuration["segmentation_enabled"] = bool(request.segmentation_enabled) dataset_type == TEXT_DATASET_TYPE
and request.segmentation_enabled is not None
):
project_configuration["segmentation_enabled"] = bool(
request.segmentation_enabled
)
if template_labeling_type: if template_labeling_type:
project_configuration[LABELING_TYPE_CONFIG_KEY] = template_labeling_type project_configuration[LABELING_TYPE_CONFIG_KEY] = template_labeling_type
@@ -143,9 +150,7 @@ async def create_mapping(
snapshot_file_ids.append(str(file_record.id)) snapshot_file_ids.append(str(file_record.id))
else: else:
snapshot_file_ids = [ snapshot_file_ids = [
str(file_record.id) str(file_record.id) for file_record in file_records if file_record.id
for file_record in file_records
if file_record.id
] ]
# 创建映射关系并写入快照 # 创建映射关系并写入快照
@@ -157,25 +162,30 @@ async def create_mapping(
if dataset_type == TEXT_DATASET_TYPE and request.segmentation_enabled: if dataset_type == TEXT_DATASET_TYPE and request.segmentation_enabled:
try: try:
from ..service.editor import AnnotationEditorService from ..service.editor import AnnotationEditorService
editor_service = AnnotationEditorService(db) editor_service = AnnotationEditorService(db)
# 异步预计算切片(不阻塞创建响应) # 异步预计算切片(不阻塞创建响应)
segmentation_result = await editor_service.precompute_segmentation_for_project(labeling_project.id) segmentation_result = (
logger.info(f"Precomputed segmentation for project {labeling_project.id}: {segmentation_result}") await editor_service.precompute_segmentation_for_project(
labeling_project.id
)
)
logger.info(
f"Precomputed segmentation for project {labeling_project.id}: {segmentation_result}"
)
except Exception as e: except Exception as e:
logger.warning(f"Failed to precompute segmentation for project {labeling_project.id}: {e}") logger.warning(
f"Failed to precompute segmentation for project {labeling_project.id}: {e}"
)
# 不影响项目创建,只记录警告 # 不影响项目创建,只记录警告
response_data = DatasetMappingCreateResponse( response_data = DatasetMappingCreateResponse(
id=mapping.id, id=mapping.id,
labeling_project_id=str(mapping.labeling_project_id), labeling_project_id=str(mapping.labeling_project_id),
labeling_project_name=mapping.name or project_name labeling_project_name=mapping.name or project_name,
) )
return StandardResponse( return StandardResponse(code=201, message="success", data=response_data)
code=201,
message="success",
data=response_data
)
except HTTPException: except HTTPException:
raise raise
@@ -183,12 +193,15 @@ async def create_mapping(
logger.error(f"Error while creating dataset mapping: {e}") logger.error(f"Error while creating dataset mapping: {e}")
raise HTTPException(status_code=500, detail="Internal server error") raise HTTPException(status_code=500, detail="Internal server error")
@router.get("", response_model=StandardResponse[PaginatedData[DatasetMappingResponse]]) @router.get("", response_model=StandardResponse[PaginatedData[DatasetMappingResponse]])
async def list_mappings( async def list_mappings(
page: int = Query(1, ge=1, description="页码(从1开始)"), page: int = Query(1, ge=1, description="页码(从1开始)"),
size: int = Query(20, ge=1, le=100, description="每页记录数"), size: int = Query(20, ge=1, le=100, description="每页记录数"),
include_template: bool = Query(False, description="是否包含模板详情", alias="includeTemplate"), include_template: bool = Query(
db: AsyncSession = Depends(get_db) False, description="是否包含模板详情", alias="includeTemplate"
),
db: AsyncSession = Depends(get_db),
): ):
""" """
查询所有映射关系(分页) 查询所有映射关系(分页)
@@ -207,14 +220,16 @@ async def list_mappings(
# 计算 skip # 计算 skip
skip = (page - 1) * size skip = (page - 1) * size
logger.info(f"List mappings: page={page}, size={size}, include_template={include_template}") logger.info(
f"List mappings: page={page}, size={size}, include_template={include_template}"
)
# 获取数据和总数 # 获取数据和总数
mappings, total = await service.get_all_mappings_with_count( mappings, total = await service.get_all_mappings_with_count(
skip=skip, skip=skip,
limit=size, limit=size,
include_deleted=False, include_deleted=False,
include_template=include_template include_template=include_template,
) )
# 计算总页数 # 计算总页数
@@ -226,26 +241,22 @@ async def list_mappings(
size=size, size=size,
total_elements=total, total_elements=total,
total_pages=total_pages, total_pages=total_pages,
content=mappings content=mappings,
) )
logger.info(f"List mappings: page={page}, returned {len(mappings)}/{total}, templates_included: {include_template}") logger.info(
f"List mappings: page={page}, returned {len(mappings)}/{total}, templates_included: {include_template}"
return StandardResponse(
code=200,
message="success",
data=paginated_data
) )
return StandardResponse(code=200, message="success", data=paginated_data)
except Exception as e: except Exception as e:
logger.error(f"Error listing mappings: {e}") logger.error(f"Error listing mappings: {e}")
raise HTTPException(status_code=500, detail="Internal server error") raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/{mapping_id}", response_model=StandardResponse[DatasetMappingResponse]) @router.get("/{mapping_id}", response_model=StandardResponse[DatasetMappingResponse])
async def get_mapping( async def get_mapping(mapping_id: str, db: AsyncSession = Depends(get_db)):
mapping_id: str,
db: AsyncSession = Depends(get_db)
):
""" """
根据 UUID 查询单个映射关系(包含关联的标注模板详情) 根据 UUID 查询单个映射关系(包含关联的标注模板详情)
@@ -265,31 +276,34 @@ async def get_mapping(
if not mapping: if not mapping:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404, detail=f"Mapping not found: {mapping_id}"
detail=f"Mapping not found: {mapping_id}"
) )
logger.info(f"Found mapping: {mapping.id}, template_included: {mapping.template is not None}") logger.info(
f"Found mapping: {mapping.id}, template_included: {mapping.template is not None}"
return StandardResponse(
code=200,
message="success",
data=mapping
) )
return StandardResponse(code=200, message="success", data=mapping)
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error getting mapping: {e}") logger.error(f"Error getting mapping: {e}")
raise HTTPException(status_code=500, detail="Internal server error") raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/by-source/{dataset_id}", response_model=StandardResponse[PaginatedData[DatasetMappingResponse]])
@router.get(
"/by-source/{dataset_id}",
response_model=StandardResponse[PaginatedData[DatasetMappingResponse]],
)
async def get_mappings_by_source( async def get_mappings_by_source(
dataset_id: str, dataset_id: str,
page: int = Query(1, ge=1, description="页码(从1开始)"), page: int = Query(1, ge=1, description="页码(从1开始)"),
size: int = Query(20, ge=1, le=100, description="每页记录数"), size: int = Query(20, ge=1, le=100, description="每页记录数"),
include_template: bool = Query(True, description="是否包含模板详情", alias="includeTemplate"), include_template: bool = Query(
db: AsyncSession = Depends(get_db) True, description="是否包含模板详情", alias="includeTemplate"
),
db: AsyncSession = Depends(get_db),
): ):
""" """
根据源数据集 ID 查询所有映射关系(分页,包含模板详情) 根据源数据集 ID 查询所有映射关系(分页,包含模板详情)
@@ -309,14 +323,16 @@ async def get_mappings_by_source(
# 计算 skip # 计算 skip
skip = (page - 1) * size skip = (page - 1) * size
logger.info(f"Get mappings by source dataset id: {dataset_id}, page={page}, size={size}, include_template={include_template}") logger.info(
f"Get mappings by source dataset id: {dataset_id}, page={page}, size={size}, include_template={include_template}"
)
# 获取数据和总数(包含模板信息) # 获取数据和总数(包含模板信息)
mappings, total = await service.get_mappings_by_source_with_count( mappings, total = await service.get_mappings_by_source_with_count(
dataset_id=dataset_id, dataset_id=dataset_id,
skip=skip, skip=skip,
limit=size, limit=size,
include_template=include_template include_template=include_template,
) )
# 计算总页数 # 计算总页数
@@ -328,27 +344,26 @@ async def get_mappings_by_source(
size=size, size=size,
total_elements=total, total_elements=total,
total_pages=total_pages, total_pages=total_pages,
content=mappings content=mappings,
) )
logger.info(f"Found {len(mappings)} mappings on page {page}, total: {total}, templates_included: {include_template}") logger.info(
f"Found {len(mappings)} mappings on page {page}, total: {total}, templates_included: {include_template}"
return StandardResponse(
code=200,
message="success",
data=paginated_data
) )
return StandardResponse(code=200, message="success", data=paginated_data)
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error getting mappings: {e}") logger.error(f"Error getting mappings: {e}")
raise HTTPException(status_code=500, detail="Internal server error") raise HTTPException(status_code=500, detail="Internal server error")
@router.delete("/{project_id}", response_model=StandardResponse[DeleteDatasetResponse]) @router.delete("/{project_id}", response_model=StandardResponse[DeleteDatasetResponse])
async def delete_mapping( async def delete_mapping(
project_id: str = Path(..., description="映射UUID(path param)"), project_id: str = Path(..., description="映射UUID(path param)"),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
): ):
""" """
删除映射关系(软删除) 删除映射关系(软删除)
@@ -370,12 +385,12 @@ async def delete_mapping(
if not mapping: if not mapping:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404, detail=f"Mapping either not found or not specified."
detail=f"Mapping either not found or not specified."
) )
id = mapping.id id = mapping.id
logger.debug(f"Found mapping: {id}") dataset_id = mapping.dataset_id
logger.debug(f"Found mapping: {id}, dataset_id: {dataset_id}")
# 软删除映射记录 # 软删除映射记录
soft_delete_success = await service.soft_delete_mapping(id) soft_delete_success = await service.soft_delete_mapping(id)
@@ -383,19 +398,22 @@ async def delete_mapping(
if not soft_delete_success: if not soft_delete_success:
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail="Failed to delete mapping record"
detail="Failed to delete mapping record"
) )
# 清理知识集中的关联数据
try:
knowledge_sync = KnowledgeSyncService(db)
await knowledge_sync._cleanup_knowledge_set_for_project(id)
except Exception as exc:
logger.warning(f"清理知识集失败:project_id={id} error={exc}")
logger.info(f"Successfully deleted mapping: {id}") logger.info(f"Successfully deleted mapping: {id}")
return StandardResponse( return StandardResponse(
code=200, code=200,
message="success", message="success",
data=DeleteDatasetResponse( data=DeleteDatasetResponse(id=id, status="success"),
id=id,
status="success"
)
) )
except HTTPException: except HTTPException:
@@ -409,7 +427,7 @@ async def delete_mapping(
async def update_mapping( async def update_mapping(
project_id: str = Path(..., description="映射UUID(path param)"), project_id: str = Path(..., description="映射UUID(path param)"),
request: DatasetMappingUpdateRequest = None, request: DatasetMappingUpdateRequest = None,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
): ):
""" """
更新标注项目信息 更新标注项目信息
@@ -429,16 +447,14 @@ async def update_mapping(
# 直接查询 ORM 模型获取原始数据 # 直接查询 ORM 模型获取原始数据
result = await db.execute( result = await db.execute(
select(LabelingProject).where( select(LabelingProject).where(
LabelingProject.id == project_id, LabelingProject.id == project_id, LabelingProject.deleted_at.is_(None)
LabelingProject.deleted_at.is_(None)
) )
) )
mapping_orm = result.scalar_one_or_none() mapping_orm = result.scalar_one_or_none()
if not mapping_orm: if not mapping_orm:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404, detail=f"Mapping not found: {project_id}"
detail=f"Mapping not found: {project_id}"
) )
# 构建更新数据 # 构建更新数据
@@ -449,7 +465,11 @@ async def update_mapping(
# 从 configuration 字段中读取和更新 description 和 label_config # 从 configuration 字段中读取和更新 description 和 label_config
configuration = {} configuration = {}
if mapping_orm.configuration: if mapping_orm.configuration:
configuration = mapping_orm.configuration.copy() if isinstance(mapping_orm.configuration, dict) else {} configuration = (
mapping_orm.configuration.copy()
if isinstance(mapping_orm.configuration, dict)
else {}
)
if request.description is not None: if request.description is not None:
configuration["description"] = request.description configuration["description"] = request.description
@@ -468,7 +488,7 @@ async def update_mapping(
if not template: if not template:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=f"Template not found: {request.template_id}" detail=f"Template not found: {request.template_id}",
) )
template_labeling_type = getattr(template, "labeling_type", None) template_labeling_type = getattr(template, "labeling_type", None)
if template_labeling_type: if template_labeling_type:
@@ -477,14 +497,11 @@ async def update_mapping(
if not update_values: if not update_values:
# 没有要更新的字段,直接返回当前数据 # 没有要更新的字段,直接返回当前数据
response_data = await service.get_mapping_by_uuid(project_id) response_data = await service.get_mapping_by_uuid(project_id)
return StandardResponse( return StandardResponse(code=200, message="success", data=response_data)
code=200,
message="success",
data=response_data
)
# 执行更新 # 执行更新
from datetime import datetime from datetime import datetime
update_values["updated_at"] = datetime.now() update_values["updated_at"] = datetime.now()
result = await db.execute( result = await db.execute(
@@ -495,21 +512,14 @@ async def update_mapping(
await db.commit() await db.commit()
if result.rowcount == 0: if result.rowcount == 0:
raise HTTPException( raise HTTPException(status_code=500, detail="Failed to update mapping")
status_code=500,
detail="Failed to update mapping"
)
# 重新获取更新后的数据 # 重新获取更新后的数据
updated_mapping = await service.get_mapping_by_uuid(project_id) updated_mapping = await service.get_mapping_by_uuid(project_id)
logger.info(f"Successfully updated mapping: {project_id}") logger.info(f"Successfully updated mapping: {project_id}")
return StandardResponse( return StandardResponse(code=200, message="success", data=updated_mapping)
code=200,
message="success",
data=updated_mapping
)
except HTTPException: except HTTPException:
raise raise

View File

@@ -13,6 +13,7 @@ from pydantic import BaseModel, Field
class ExportFormat(str, Enum): class ExportFormat(str, Enum):
"""导出格式枚举""" """导出格式枚举"""
JSON = "json" # Label Studio 原生 JSON 格式 JSON = "json" # Label Studio 原生 JSON 格式
JSONL = "jsonl" # JSON Lines 格式(每行一条记录) JSONL = "jsonl" # JSON Lines 格式(每行一条记录)
CSV = "csv" # CSV 表格格式 CSV = "csv" # CSV 表格格式
@@ -22,8 +23,11 @@ class ExportFormat(str, Enum):
class ExportAnnotationsRequest(BaseModel): class ExportAnnotationsRequest(BaseModel):
"""导出标注数据请求""" """导出标注数据请求"""
format: ExportFormat = Field(default=ExportFormat.JSON, description="导出格式") format: ExportFormat = Field(default=ExportFormat.JSON, description="导出格式")
include_data: bool = Field(default=True, description="是否包含原始数据(如文本内容)") include_data: bool = Field(
default=True, description="是否包含原始数据(如文本内容)"
)
only_annotated: bool = Field(default=True, description="是否只导出已标注的数据") only_annotated: bool = Field(default=True, description="是否只导出已标注的数据")
model_config = {"use_enum_values": True} model_config = {"use_enum_values": True}
@@ -31,6 +35,7 @@ class ExportAnnotationsRequest(BaseModel):
class ExportAnnotationsResponse(BaseModel): class ExportAnnotationsResponse(BaseModel):
"""导出标注数据响应(用于预览/统计)""" """导出标注数据响应(用于预览/统计)"""
project_id: str = Field(..., description="项目ID") project_id: str = Field(..., description="项目ID")
project_name: str = Field(..., description="项目名称") project_name: str = Field(..., description="项目名称")
total_files: int = Field(..., description="总文件数") total_files: int = Field(..., description="总文件数")
@@ -42,16 +47,21 @@ class ExportAnnotationsResponse(BaseModel):
class AnnotationExportItem(BaseModel): class AnnotationExportItem(BaseModel):
"""单条导出记录""" """单条导出记录"""
file_id: str = Field(..., description="文件ID") file_id: str = Field(..., description="文件ID")
file_name: str = Field(..., description="文件名") file_name: str = Field(..., description="文件名")
file_path: Optional[str] = Field(default=None, description="文件路径")
data: Optional[Dict[str, Any]] = Field(default=None, description="原始数据") data: Optional[Dict[str, Any]] = Field(default=None, description="原始数据")
annotations: List[Dict[str, Any]] = Field(default_factory=list, description="标注结果") annotations: List[Dict[str, Any]] = Field(
default_factory=list, description="标注结果"
)
created_at: Optional[datetime] = Field(default=None, description="创建时间") created_at: Optional[datetime] = Field(default=None, description="创建时间")
updated_at: Optional[datetime] = Field(default=None, description="更新时间") updated_at: Optional[datetime] = Field(default=None, description="更新时间")
class COCOExportFormat(BaseModel): class COCOExportFormat(BaseModel):
"""COCO 格式导出结构""" """COCO 格式导出结构"""
info: Dict[str, Any] = Field(default_factory=dict) info: Dict[str, Any] = Field(default_factory=dict)
licenses: List[Dict[str, Any]] = Field(default_factory=list) licenses: List[Dict[str, Any]] = Field(default_factory=list)
images: List[Dict[str, Any]] = Field(default_factory=list) images: List[Dict[str, Any]] = Field(default_factory=list)

View File

@@ -21,14 +21,23 @@ from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from fastapi import HTTPException from fastapi import HTTPException
from PIL import Image
from sqlalchemy import func, select from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger from app.core.logging import get_logger
from app.db.models import AnnotationResult, Dataset, DatasetFiles, LabelingProject, LabelingProjectFile from app.db.models import (
AnnotationResult,
Dataset,
DatasetFiles,
LabelingProject,
LabelingProjectFile,
)
async def _read_file_content(file_path: str, max_size: int = 10 * 1024 * 1024) -> Optional[str]: async def _read_file_content(
file_path: str, max_size: int = 10 * 1024 * 1024
) -> Optional[str]:
"""读取文件内容,仅适用于文本文件 """读取文件内容,仅适用于文本文件
Args: Args:
@@ -48,11 +57,12 @@ async def _read_file_content(file_path: str, max_size: int = 10 * 1024 * 1024) -
return f"[File too large: {file_size} bytes]" return f"[File too large: {file_size} bytes]"
# 尝试以文本方式读取 # 尝试以文本方式读取
with open(file_path, 'r', encoding='utf-8', errors='replace') as f: with open(file_path, "r", encoding="utf-8", errors="replace") as f:
return f.read() return f.read()
except Exception: except Exception:
return None return None
from ..schema.export import ( from ..schema.export import (
AnnotationExportItem, AnnotationExportItem,
COCOExportFormat, COCOExportFormat,
@@ -79,7 +89,9 @@ class AnnotationExportService:
async def get_export_stats(self, project_id: str) -> ExportAnnotationsResponse: async def get_export_stats(self, project_id: str) -> ExportAnnotationsResponse:
"""获取导出统计信息""" """获取导出统计信息"""
project = await self._get_project_or_404(project_id) project = await self._get_project_or_404(project_id)
logger.info(f"Export stats for project: id={project_id}, dataset_id={project.dataset_id}, name={project.name}") logger.info(
f"Export stats for project: id={project_id}, dataset_id={project.dataset_id}, name={project.name}"
)
# 获取总文件数(标注项目快照内的文件) # 获取总文件数(标注项目快照内的文件)
total_result = await self.db.execute( total_result = await self.db.execute(
@@ -92,7 +104,9 @@ class AnnotationExportService:
) )
) )
total_files = int(total_result.scalar() or 0) total_files = int(total_result.scalar() or 0)
logger.info(f"Total files (snapshot): {total_files} for project_id={project_id}") logger.info(
f"Total files (snapshot): {total_files} for project_id={project_id}"
)
# 获取已标注文件数(统计不同的 file_id 数量) # 获取已标注文件数(统计不同的 file_id 数量)
annotated_result = await self.db.execute( annotated_result = await self.db.execute(
@@ -132,7 +146,11 @@ class AnnotationExportService:
) )
# 根据格式导出 # 根据格式导出
format_type = ExportFormat(request.format) if isinstance(request.format, str) else request.format format_type = (
ExportFormat(request.format)
if isinstance(request.format, str)
else request.format
)
if format_type == ExportFormat.JSON: if format_type == ExportFormat.JSON:
return self._export_json(items, project.name) return self._export_json(items, project.name)
@@ -145,7 +163,9 @@ class AnnotationExportService:
elif format_type == ExportFormat.YOLO: elif format_type == ExportFormat.YOLO:
return self._export_yolo(items, project.name) return self._export_yolo(items, project.name)
else: else:
raise HTTPException(status_code=400, detail=f"不支持的导出格式: {request.format}") raise HTTPException(
status_code=400, detail=f"不支持的导出格式: {request.format}"
)
async def _get_project_or_404(self, project_id: str) -> LabelingProject: async def _get_project_or_404(self, project_id: str) -> LabelingProject:
"""获取标注项目,不存在则抛出 404""" """获取标注项目,不存在则抛出 404"""
@@ -174,7 +194,10 @@ class AnnotationExportService:
# 只获取已标注的数据 # 只获取已标注的数据
result = await self.db.execute( result = await self.db.execute(
select(AnnotationResult, DatasetFiles) select(AnnotationResult, DatasetFiles)
.join(LabelingProjectFile, LabelingProjectFile.file_id == AnnotationResult.file_id) .join(
LabelingProjectFile,
LabelingProjectFile.file_id == AnnotationResult.file_id,
)
.join(DatasetFiles, AnnotationResult.file_id == DatasetFiles.id) .join(DatasetFiles, AnnotationResult.file_id == DatasetFiles.id)
.where( .where(
AnnotationResult.project_id == project_id, AnnotationResult.project_id == project_id,
@@ -197,6 +220,7 @@ class AnnotationExportService:
AnnotationExportItem( AnnotationExportItem(
file_id=str(file.id), file_id=str(file.id),
file_name=str(getattr(file, "file_name", "")), file_name=str(getattr(file, "file_name", "")),
file_path=str(getattr(file, "file_path", "")),
data={"text": file_content} if include_data else None, data={"text": file_content} if include_data else None,
annotations=[annotation_data] if annotation_data else [], annotations=[annotation_data] if annotation_data else [],
created_at=ann.created_at, created_at=ann.created_at,
@@ -207,7 +231,9 @@ class AnnotationExportService:
# 获取所有文件(基于标注项目快照) # 获取所有文件(基于标注项目快照)
files_result = await self.db.execute( files_result = await self.db.execute(
select(DatasetFiles) select(DatasetFiles)
.join(LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id) .join(
LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id
)
.where( .where(
LabelingProjectFile.project_id == project_id, LabelingProjectFile.project_id == project_id,
DatasetFiles.dataset_id == dataset_id, DatasetFiles.dataset_id == dataset_id,
@@ -217,7 +243,9 @@ class AnnotationExportService:
# 获取已有的标注 # 获取已有的标注
ann_result = await self.db.execute( ann_result = await self.db.execute(
select(AnnotationResult).where(AnnotationResult.project_id == project_id) select(AnnotationResult).where(
AnnotationResult.project_id == project_id
)
) )
annotations = {str(a.file_id): a for a in ann_result.scalars().all()} annotations = {str(a.file_id): a for a in ann_result.scalars().all()}
@@ -236,6 +264,7 @@ class AnnotationExportService:
AnnotationExportItem( AnnotationExportItem(
file_id=file_id, file_id=file_id,
file_name=str(getattr(file, "file_name", "")), file_name=str(getattr(file, "file_name", "")),
file_path=str(getattr(file, "file_path", "")),
data={"text": file_content} if include_data else None, data={"text": file_content} if include_data else None,
annotations=[annotation_data] if annotation_data else [], annotations=[annotation_data] if annotation_data else [],
created_at=ann.created_at if ann else None, created_at=ann.created_at if ann else None,
@@ -262,8 +291,13 @@ class AnnotationExportService:
for item in segment_results: for item in segment_results:
if isinstance(item, dict): if isinstance(item, dict):
normalized = dict(item) normalized = dict(item)
if SEGMENT_INDEX_KEY not in normalized and SEGMENT_INDEX_FALLBACK_KEY not in normalized: if (
normalized[SEGMENT_INDEX_KEY] = int(key) if str(key).isdigit() else key SEGMENT_INDEX_KEY not in normalized
and SEGMENT_INDEX_FALLBACK_KEY not in normalized
):
normalized[SEGMENT_INDEX_KEY] = (
int(key) if str(key).isdigit() else key
)
results.append(normalized) results.append(normalized)
elif isinstance(segments, list): elif isinstance(segments, list):
for idx, segment in enumerate(segments): for idx, segment in enumerate(segments):
@@ -272,11 +306,16 @@ class AnnotationExportService:
segment_results = segment.get(SEGMENT_RESULT_KEY) segment_results = segment.get(SEGMENT_RESULT_KEY)
if not isinstance(segment_results, list): if not isinstance(segment_results, list):
continue continue
segment_index = segment.get(SEGMENT_INDEX_KEY, segment.get(SEGMENT_INDEX_FALLBACK_KEY, idx)) segment_index = segment.get(
SEGMENT_INDEX_KEY, segment.get(SEGMENT_INDEX_FALLBACK_KEY, idx)
)
for item in segment_results: for item in segment_results:
if isinstance(item, dict): if isinstance(item, dict):
normalized = dict(item) normalized = dict(item)
if SEGMENT_INDEX_KEY not in normalized and SEGMENT_INDEX_FALLBACK_KEY not in normalized: if (
SEGMENT_INDEX_KEY not in normalized
and SEGMENT_INDEX_FALLBACK_KEY not in normalized
):
normalized[SEGMENT_INDEX_KEY] = segment_index normalized[SEGMENT_INDEX_KEY] = segment_index
results.append(normalized) results.append(normalized)
return results return results
@@ -284,18 +323,43 @@ class AnnotationExportService:
return result if isinstance(result, list) else [] return result if isinstance(result, list) else []
@classmethod @classmethod
def _normalize_annotation_for_export(cls, annotation: Dict[str, Any]) -> Dict[str, Any]: def _normalize_annotation_for_export(
cls, annotation: Dict[str, Any]
) -> Dict[str, Any]:
if not annotation or not isinstance(annotation, dict): if not annotation or not isinstance(annotation, dict):
return {} return {}
segments = annotation.get(SEGMENTS_KEY) segments = annotation.get(SEGMENTS_KEY)
if annotation.get(SEGMENTED_KEY) or isinstance(segments, (dict, list)): if annotation.get(SEGMENTED_KEY) or isinstance(segments, (dict, list)):
normalized = dict(annotation) normalized = dict(annotation)
normalized_result = cls._flatten_annotation_results(annotation) normalized_result = cls._flatten_annotation_results(annotation)
if SEGMENT_RESULT_KEY not in normalized or not isinstance(normalized.get(SEGMENT_RESULT_KEY), list): if SEGMENT_RESULT_KEY not in normalized or not isinstance(
normalized.get(SEGMENT_RESULT_KEY), list
):
normalized[SEGMENT_RESULT_KEY] = normalized_result normalized[SEGMENT_RESULT_KEY] = normalized_result
return normalized return normalized
return annotation return annotation
@staticmethod
def _get_image_dimensions(file_path: str) -> Tuple[int, int]:
"""获取图片文件的宽度和高度
Args:
file_path: 图片文件路径
Returns:
(width, height) 元组,如果读取失败则返回 (1920, 1080) 作为默认值
"""
try:
if os.path.exists(file_path):
with Image.open(file_path) as img:
width, height = img.size
return width, height
except Exception as e:
logger.warning(f"Failed to read image dimensions from {file_path}: {e}")
# 使用合理的默认值
return 1920, 1080
def _export_json( def _export_json(
self, items: List[AnnotationExportItem], project_name: str self, items: List[AnnotationExportItem], project_name: str
) -> Tuple[bytes, str, str]: ) -> Tuple[bytes, str, str]:
@@ -309,9 +373,16 @@ class AnnotationExportService:
"file_id": item.file_id, "file_id": item.file_id,
"file_name": item.file_name, "file_name": item.file_name,
"data": item.data, "data": item.data,
"annotations": [self._normalize_annotation_for_export(ann) for ann in item.annotations], "annotations": [
"created_at": item.created_at.isoformat() if item.created_at else None, self._normalize_annotation_for_export(ann)
"updated_at": item.updated_at.isoformat() if item.updated_at else None, for ann in item.annotations
],
"created_at": item.created_at.isoformat()
if item.created_at
else None,
"updated_at": item.updated_at.isoformat()
if item.updated_at
else None,
} }
for item in items for item in items
], ],
@@ -331,7 +402,10 @@ class AnnotationExportService:
"file_id": item.file_id, "file_id": item.file_id,
"file_name": item.file_name, "file_name": item.file_name,
"data": item.data, "data": item.data,
"annotations": [self._normalize_annotation_for_export(ann) for ann in item.annotations], "annotations": [
self._normalize_annotation_for_export(ann)
for ann in item.annotations
],
"created_at": item.created_at.isoformat() if item.created_at else None, "created_at": item.created_at.isoformat() if item.created_at else None,
"updated_at": item.updated_at.isoformat() if item.updated_at else None, "updated_at": item.updated_at.isoformat() if item.updated_at else None,
} }
@@ -389,14 +463,22 @@ class AnnotationExportService:
elif "transcription" in value: elif "transcription" in value:
labels.append(value["transcription"]) labels.append(value["transcription"])
writer.writerow({ writer.writerow(
{
"file_id": item.file_id, "file_id": item.file_id,
"file_name": item.file_name, "file_name": item.file_name,
"annotation_result": json.dumps(item.annotations, ensure_ascii=False), "annotation_result": json.dumps(
item.annotations, ensure_ascii=False
),
"labels": "|".join(labels), "labels": "|".join(labels),
"created_at": item.created_at.isoformat() if item.created_at else "", "created_at": item.created_at.isoformat()
"updated_at": item.updated_at.isoformat() if item.updated_at else "", if item.created_at
}) else "",
"updated_at": item.updated_at.isoformat()
if item.updated_at
else "",
}
)
content = output.getvalue().encode("utf-8-sig") # BOM for Excel compatibility content = output.getvalue().encode("utf-8-sig") # BOM for Excel compatibility
filename = f"{project_name}_annotations.csv" filename = f"{project_name}_annotations.csv"
@@ -405,11 +487,7 @@ class AnnotationExportService:
def _export_coco( def _export_coco(
self, items: List[AnnotationExportItem], project_name: str self, items: List[AnnotationExportItem], project_name: str
) -> Tuple[bytes, str, str]: ) -> Tuple[bytes, str, str]:
"""导出为 COCO 格式(适用于目标检测标注) """导出为 COCO 格式(适用于目标检测标注)"""
注意:当前实现中图片宽高被设置为0,因为需要读取实际图片文件获取尺寸。
bbox 坐标使用 Label Studio 的百分比值(0-100),使用时需要转换为像素坐标。
"""
coco_format = COCOExportFormat( coco_format = COCOExportFormat(
info={ info={
"description": f"Exported from DataMate project: {project_name}", "description": f"Exported from DataMate project: {project_name}",
@@ -429,13 +507,18 @@ class AnnotationExportService:
for idx, item in enumerate(items): for idx, item in enumerate(items):
image_id = idx + 1 image_id = idx + 1
# 获取图片实际尺寸
img_width, img_height = self._get_image_dimensions(item.file_path or "")
# 添加图片信息 # 添加图片信息
coco_format.images.append({ coco_format.images.append(
{
"id": image_id, "id": image_id,
"file_name": item.file_name, "file_name": item.file_name,
"width": 0, # 需要实际图片尺寸 "width": img_width,
"height": 0, "height": img_height,
}) }
)
# 处理标注 # 处理标注
for ann in item.annotations: for ann in item.annotations:
@@ -449,29 +532,41 @@ class AnnotationExportService:
for label in labels: for label in labels:
if label not in category_map: if label not in category_map:
category_map[label] = len(category_map) + 1 category_map[label] = len(category_map) + 1
coco_format.categories.append({ coco_format.categories.append(
{
"id": category_map[label], "id": category_map[label],
"name": label, "name": label,
"supercategory": "", "supercategory": "",
}) }
)
# 转换坐标Label Studio 使用百分比 # 转换坐标Label Studio 使用百分比(0-100)转换为像素坐标
x = value.get("x", 0) x_percent = value.get("x", 0)
y = value.get("y", 0) y_percent = value.get("y", 0)
width = value.get("width", 0) width_percent = value.get("width", 0)
height = value.get("height", 0) height_percent = value.get("height", 0)
coco_format.annotations.append({ # 转换为像素坐标
x = x_percent * img_width / 100.0
y = y_percent * img_height / 100.0
width = width_percent * img_width / 100.0
height = height_percent * img_height / 100.0
coco_format.annotations.append(
{
"id": annotation_id, "id": annotation_id,
"image_id": image_id, "image_id": image_id,
"category_id": category_map[label], "category_id": category_map[label],
"bbox": [x, y, width, height], "bbox": [x, y, width, height],
"area": width * height, "area": width * height,
"iscrowd": 0, "iscrowd": 0,
}) }
)
annotation_id += 1 annotation_id += 1
content = json.dumps(coco_format.model_dump(), ensure_ascii=False, indent=2).encode("utf-8") content = json.dumps(
coco_format.model_dump(), ensure_ascii=False, indent=2
).encode("utf-8")
filename = f"{project_name}_coco.json" filename = f"{project_name}_coco.json"
return content, filename, "application/json" return content, filename, "application/json"
@@ -510,7 +605,9 @@ class AnnotationExportService:
x_center = x + w / 2 x_center = x + w / 2
y_center = y + h / 2 y_center = y + h / 2
lines.append(f"{label} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}") lines.append(
f"{label} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}"
)
if lines: if lines:
# 生成对应的 txt 文件名 # 生成对应的 txt 文件名

View File

@@ -43,7 +43,9 @@ class KnowledgeSyncService:
logger.warning("标注同步失败:无法获取知识集") logger.warning("标注同步失败:无法获取知识集")
return return
item = await self._get_item_by_source(set_id, project.dataset_id, str(file_record.id)) item = await self._get_item_by_source(
set_id, project.dataset_id, str(file_record.id)
)
if item and item.get("status") in {"PUBLISHED", "ARCHIVED", "DEPRECATED"}: if item and item.get("status") in {"PUBLISHED", "ARCHIVED", "DEPRECATED"}:
logger.info( logger.info(
"知识条目为只读状态,跳过同步:item_id=%s status=%s", "知识条目为只读状态,跳过同步:item_id=%s status=%s",
@@ -71,26 +73,46 @@ class KnowledgeSyncService:
logger.warning("标注同步到知识管理失败:%s", exc) logger.warning("标注同步到知识管理失败:%s", exc)
async def _ensure_knowledge_set(self, project: LabelingProject) -> Optional[str]: async def _ensure_knowledge_set(self, project: LabelingProject) -> Optional[str]:
config = project.configuration if isinstance(project.configuration, dict) else {} result = await self.db.execute(
select(LabelingProject)
.where(LabelingProject.id == project.id)
.with_for_update()
)
locked_project = result.scalar_one_or_none()
if not locked_project:
logger.warning("标注同步失败:无法锁定项目:project_id=%s", project.id)
return None
config = (
locked_project.configuration
if isinstance(locked_project.configuration, dict)
else {}
)
set_id = config.get(self.CONFIG_KEY_SET_ID) set_id = config.get(self.CONFIG_KEY_SET_ID)
if set_id: if set_id:
exists = await self._get_knowledge_set(set_id) exists = await self._get_knowledge_set(set_id)
if exists and self._metadata_matches_project(exists.get("metadata"), project.id): if exists and self._metadata_matches_project(
exists.get("metadata"), locked_project.id
):
return set_id return set_id
logger.warning( logger.warning(
"知识集不存在或归属不匹配,准备重建:set_id=%s project_id=%s", "知识集不存在或归属不匹配,准备重建:set_id=%s project_id=%s",
set_id, set_id,
project.id, locked_project.id,
) )
project_name = (project.name or "annotation-project").strip() or "annotation-project" project_name = (
metadata = self._build_set_metadata(project) locked_project.name or "annotation-project"
).strip() or "annotation-project"
metadata = self._build_set_metadata(locked_project)
existing = await self._find_knowledge_set_by_name_and_project(project_name, project.id) existing = await self._find_knowledge_set_by_name_and_project(
project_name, locked_project.id
)
if existing: if existing:
await self._update_project_config( await self._update_project_config(
project, locked_project,
{ {
self.CONFIG_KEY_SET_ID: existing.get("id"), self.CONFIG_KEY_SET_ID: existing.get("id"),
self.CONFIG_KEY_SET_NAME: existing.get("name"), self.CONFIG_KEY_SET_NAME: existing.get("name"),
@@ -100,23 +122,31 @@ class KnowledgeSyncService:
created = await self._create_knowledge_set(project_name, metadata) created = await self._create_knowledge_set(project_name, metadata)
if not created: if not created:
created = await self._find_knowledge_set_by_name_and_project(project_name, project.id) created = await self._find_knowledge_set_by_name_and_project(
project_name, locked_project.id
)
if not created: if not created:
fallback_name = self._build_fallback_set_name(project_name, project.id) fallback_name = self._build_fallback_set_name(
existing = await self._find_knowledge_set_by_name_and_project(fallback_name, project.id) project_name, locked_project.id
)
existing = await self._find_knowledge_set_by_name_and_project(
fallback_name, locked_project.id
)
if existing: if existing:
created = existing created = existing
else: else:
created = await self._create_knowledge_set(fallback_name, metadata) created = await self._create_knowledge_set(fallback_name, metadata)
if not created: if not created:
created = await self._find_knowledge_set_by_name_and_project(fallback_name, project.id) created = await self._find_knowledge_set_by_name_and_project(
fallback_name, locked_project.id
)
if not created: if not created:
return None return None
await self._update_project_config( await self._update_project_config(
project, locked_project,
{ {
self.CONFIG_KEY_SET_ID: created.get("id"), self.CONFIG_KEY_SET_ID: created.get("id"),
self.CONFIG_KEY_SET_NAME: created.get("name"), self.CONFIG_KEY_SET_NAME: created.get("name"),
@@ -126,13 +156,17 @@ class KnowledgeSyncService:
async def _get_knowledge_set(self, set_id: str) -> Optional[Dict[str, Any]]: async def _get_knowledge_set(self, set_id: str) -> Optional[Dict[str, Any]]:
try: try:
return await self._request("GET", f"/data-management/knowledge-sets/{set_id}") return await self._request(
"GET", f"/data-management/knowledge-sets/{set_id}"
)
except httpx.HTTPStatusError as exc: except httpx.HTTPStatusError as exc:
if exc.response.status_code == 404: if exc.response.status_code == 404:
return None return None
raise raise
async def _list_knowledge_sets(self, keyword: Optional[str]) -> list[Dict[str, Any]]: async def _list_knowledge_sets(
self, keyword: Optional[str]
) -> list[Dict[str, Any]]:
params: Dict[str, Any] = { params: Dict[str, Any] = {
"page": 1, "page": 1,
"size": self.KNOWLEDGE_SET_LIST_SIZE, "size": self.KNOWLEDGE_SET_LIST_SIZE,
@@ -140,7 +174,9 @@ class KnowledgeSyncService:
if keyword: if keyword:
params["keyword"] = keyword params["keyword"] = keyword
try: try:
data = await self._request("GET", "/data-management/knowledge-sets", params=params) data = await self._request(
"GET", "/data-management/knowledge-sets", params=params
)
except httpx.HTTPStatusError as exc: except httpx.HTTPStatusError as exc:
logger.warning( logger.warning(
"查询知识集失败:keyword=%s status=%s", "查询知识集失败:keyword=%s status=%s",
@@ -155,7 +191,9 @@ class KnowledgeSyncService:
return [] return []
return [item for item in content if isinstance(item, dict)] return [item for item in content if isinstance(item, dict)]
async def _find_knowledge_set_by_name_and_project(self, name: str, project_id: str) -> Optional[Dict[str, Any]]: async def _find_knowledge_set_by_name_and_project(
self, name: str, project_id: str
) -> Optional[Dict[str, Any]]:
if not name: if not name:
return None return None
items = await self._list_knowledge_sets(name) items = await self._list_knowledge_sets(name)
@@ -168,7 +206,9 @@ class KnowledgeSyncService:
return item return item
return None return None
async def _create_knowledge_set(self, name: str, metadata: str) -> Optional[Dict[str, Any]]: async def _create_knowledge_set(
self, name: str, metadata: str
) -> Optional[Dict[str, Any]]:
payload = { payload = {
"name": name, "name": name,
"description": "标注项目自动创建的知识集", "description": "标注项目自动创建的知识集",
@@ -176,7 +216,9 @@ class KnowledgeSyncService:
"metadata": metadata, "metadata": metadata,
} }
try: try:
return await self._request("POST", "/data-management/knowledge-sets", json=payload) return await self._request(
"POST", "/data-management/knowledge-sets", json=payload
)
except httpx.HTTPStatusError as exc: except httpx.HTTPStatusError as exc:
logger.warning( logger.warning(
"创建知识集失败:name=%s status=%s detail=%s", "创建知识集失败:name=%s status=%s detail=%s",
@@ -199,7 +241,9 @@ class KnowledgeSyncService:
"sourceFileId": file_id, "sourceFileId": file_id,
} }
try: try:
data = await self._request("GET", f"/data-management/knowledge-sets/{set_id}/items", params=params) data = await self._request(
"GET", f"/data-management/knowledge-sets/{set_id}/items", params=params
)
except httpx.HTTPStatusError as exc: except httpx.HTTPStatusError as exc:
logger.warning( logger.warning(
"查询知识条目失败:set_id=%s status=%s", "查询知识条目失败:set_id=%s status=%s",
@@ -216,9 +260,13 @@ class KnowledgeSyncService:
return content[0] return content[0]
async def _create_item(self, set_id: str, payload: Dict[str, Any]) -> None: async def _create_item(self, set_id: str, payload: Dict[str, Any]) -> None:
await self._request("POST", f"/data-management/knowledge-sets/{set_id}/items", json=payload) await self._request(
"POST", f"/data-management/knowledge-sets/{set_id}/items", json=payload
)
async def _update_item(self, set_id: str, item_id: str, payload: Dict[str, Any]) -> None: async def _update_item(
self, set_id: str, item_id: str, payload: Dict[str, Any]
) -> None:
update_payload = dict(payload) update_payload = dict(payload)
update_payload.pop("sourceDatasetId", None) update_payload.pop("sourceDatasetId", None)
update_payload.pop("sourceFileId", None) update_payload.pop("sourceFileId", None)
@@ -228,6 +276,62 @@ class KnowledgeSyncService:
json=update_payload, json=update_payload,
) )
async def _cleanup_knowledge_set_for_project(self, project_id: str) -> None:
"""清理项目关联的知识集及其所有知识条目"""
items = await self._list_knowledge_sets(None)
for item in items:
if self._metadata_matches_project(item.get("metadata"), project_id):
set_id = item.get("id")
if not set_id:
continue
try:
await self._request(
"DELETE", f"/data-management/knowledge-sets/{set_id}"
)
logger.info(
"已删除知识集:set_id=%s project_id=%s", set_id, project_id
)
except Exception as exc:
logger.warning(
"删除知识集失败:set_id=%s project_id=%s error=%s",
set_id,
project_id,
exc,
)
async def _cleanup_knowledge_item_for_file(
self, dataset_id: str, file_id: str
) -> None:
"""清理文件的知识条目"""
items = await self._list_knowledge_sets(None)
for set_item in items:
set_id = set_item.get("id")
if not set_id:
continue
item = await self._get_item_by_source(set_id, dataset_id, file_id)
if item and item.get("id"):
try:
await self._request(
"DELETE",
f"/data-management/knowledge-sets/{set_id}/items/{item['id']}",
)
logger.info(
"已删除知识条目:item_id=%s set_id=%s dataset_id=%s file_id=%s",
item.get("id"),
set_id,
dataset_id,
file_id,
)
except Exception as exc:
logger.warning(
"删除知识条目失败:item_id=%s set_id=%s dataset_id=%s file_id=%s error=%s",
item.get("id"),
set_id,
dataset_id,
file_id,
exc,
)
async def _build_item_payload( async def _build_item_payload(
self, self,
project: LabelingProject, project: LabelingProject,
@@ -323,12 +427,28 @@ class KnowledgeSyncService:
short_id = project_id.replace("-", "")[:8] short_id = project_id.replace("-", "")[:8]
return f"{base_name}-annotation-{short_id}" return f"{base_name}-annotation-{short_id}"
async def _update_project_config(self, project: LabelingProject, updates: Dict[str, Any]) -> None: async def _update_project_config(
config = project.configuration if isinstance(project.configuration, dict) else {} self, project: LabelingProject, updates: Dict[str, Any]
) -> None:
result = await self.db.execute(
select(LabelingProject)
.where(LabelingProject.id == project.id)
.with_for_update()
)
locked_project = result.scalar_one_or_none()
if not locked_project:
logger.warning("更新项目配置失败:无法锁定项目:project_id=%s", project.id)
return
config = (
locked_project.configuration
if isinstance(locked_project.configuration, dict)
else {}
)
config.update(updates) config.update(updates)
project.configuration = config locked_project.configuration = config
await self.db.commit() await self.db.commit()
await self.db.refresh(project) await self.db.refresh(locked_project)
async def _request(self, method: str, path: str, **kwargs) -> Any: async def _request(self, method: str, path: str, **kwargs) -> Any:
url = f"{self.base_url}{path}" url = f"{self.base_url}{path}"