You've already forked DataMate
fix: 修复知识库同步的并发控制、数据清理、文件事务和COCO导出问题
问题1 - 并发控制缺失: - 在 _ensure_knowledge_set 方法中添加数据库行锁(with_for_update) - 修改 _update_project_config 方法,使用行锁保护配置更新 问题3 - 数据清理机制缺失: - 添加 _cleanup_knowledge_set_for_project 方法,项目删除时清理知识集 - 添加 _cleanup_knowledge_item_for_file 方法,文件删除时清理知识条目 - 在 delete_mapping 接口中调用清理方法 问题4 - 文件操作事务问题: - 修改 uploadKnowledgeItems,添加事务失败后的文件清理逻辑 - 修改 deleteKnowledgeItem,删除记录前先删除关联文件 - 新增 deleteKnowledgeItemFile 辅助方法 问题5 - COCO导出格式问题: - 添加 _get_image_dimensions 方法读取图片实际宽高 - 将百分比坐标转换为像素坐标 - 在 AnnotationExportItem 中添加 file_path 字段 涉及文件: - knowledge_sync.py - project.py - KnowledgeItemApplicationService.java - export.py - export schema.py
This commit is contained in:
@@ -126,41 +126,53 @@ public class KnowledgeItemApplicationService {
|
|||||||
createDirectories(setDir);
|
createDirectories(setDir);
|
||||||
|
|
||||||
List<KnowledgeItem> items = new ArrayList<>();
|
List<KnowledgeItem> items = new ArrayList<>();
|
||||||
|
List<Path> savedFilePaths = new ArrayList<>();
|
||||||
|
|
||||||
for (MultipartFile file : files) {
|
try {
|
||||||
BusinessAssert.notNull(file, CommonErrorCode.PARAM_ERROR);
|
for (MultipartFile file : files) {
|
||||||
BusinessAssert.isTrue(!file.isEmpty(), CommonErrorCode.PARAM_ERROR);
|
BusinessAssert.notNull(file, CommonErrorCode.PARAM_ERROR);
|
||||||
|
BusinessAssert.isTrue(!file.isEmpty(), CommonErrorCode.PARAM_ERROR);
|
||||||
|
|
||||||
String originalName = resolveOriginalFileName(file);
|
String originalName = resolveOriginalFileName(file);
|
||||||
String safeOriginalName = sanitizeFileName(originalName);
|
String safeOriginalName = sanitizeFileName(originalName);
|
||||||
if (StringUtils.isBlank(safeOriginalName)) {
|
if (StringUtils.isBlank(safeOriginalName)) {
|
||||||
safeOriginalName = "file";
|
safeOriginalName = "file";
|
||||||
|
}
|
||||||
|
|
||||||
|
String extension = getFileExtension(safeOriginalName);
|
||||||
|
String storedName = UUID.randomUUID().toString() +
|
||||||
|
(StringUtils.isBlank(extension) ? "" : "." + extension);
|
||||||
|
Path targetPath = setDir.resolve(storedName).normalize();
|
||||||
|
BusinessAssert.isTrue(targetPath.startsWith(setDir), CommonErrorCode.PARAM_ERROR);
|
||||||
|
|
||||||
|
saveMultipartFile(file, targetPath);
|
||||||
|
savedFilePaths.add(targetPath);
|
||||||
|
|
||||||
|
KnowledgeItem knowledgeItem = new KnowledgeItem();
|
||||||
|
knowledgeItem.setId(UUID.randomUUID().toString());
|
||||||
|
knowledgeItem.setSetId(setId);
|
||||||
|
knowledgeItem.setContent(buildRelativeFilePath(setId, storedName));
|
||||||
|
knowledgeItem.setContentType(KnowledgeContentType.FILE);
|
||||||
|
knowledgeItem.setSourceType(KnowledgeSourceType.FILE_UPLOAD);
|
||||||
|
knowledgeItem.setSourceFileId(trimToLength(safeOriginalName, MAX_TITLE_LENGTH));
|
||||||
|
knowledgeItem.setRelativePath(buildRelativePath(parentPrefix, safeOriginalName));
|
||||||
|
|
||||||
|
items.add(knowledgeItem);
|
||||||
}
|
}
|
||||||
|
|
||||||
String extension = getFileExtension(safeOriginalName);
|
if (CollectionUtils.isNotEmpty(items)) {
|
||||||
String storedName = UUID.randomUUID().toString() +
|
knowledgeItemRepository.saveBatch(items, items.size());
|
||||||
(StringUtils.isBlank(extension) ? "" : "." + extension);
|
}
|
||||||
Path targetPath = setDir.resolve(storedName).normalize();
|
return items;
|
||||||
BusinessAssert.isTrue(targetPath.startsWith(setDir), CommonErrorCode.PARAM_ERROR);
|
} catch (Exception e) {
|
||||||
|
for (Path filePath : savedFilePaths) {
|
||||||
saveMultipartFile(file, targetPath);
|
deleteFileQuietly(filePath);
|
||||||
|
}
|
||||||
KnowledgeItem knowledgeItem = new KnowledgeItem();
|
if (e instanceof BusinessException) {
|
||||||
knowledgeItem.setId(UUID.randomUUID().toString());
|
throw (BusinessException) e;
|
||||||
knowledgeItem.setSetId(setId);
|
}
|
||||||
knowledgeItem.setContent(buildRelativeFilePath(setId, storedName));
|
throw BusinessException.of(SystemErrorCode.FILE_SYSTEM_ERROR);
|
||||||
knowledgeItem.setContentType(KnowledgeContentType.FILE);
|
|
||||||
knowledgeItem.setSourceType(KnowledgeSourceType.FILE_UPLOAD);
|
|
||||||
knowledgeItem.setSourceFileId(trimToLength(safeOriginalName, MAX_TITLE_LENGTH));
|
|
||||||
knowledgeItem.setRelativePath(buildRelativePath(parentPrefix, safeOriginalName));
|
|
||||||
|
|
||||||
items.add(knowledgeItem);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (CollectionUtils.isNotEmpty(items)) {
|
|
||||||
knowledgeItemRepository.saveBatch(items, items.size());
|
|
||||||
}
|
|
||||||
return items;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public KnowledgeItem updateKnowledgeItem(String setId, String itemId, UpdateKnowledgeItemRequest request) {
|
public KnowledgeItem updateKnowledgeItem(String setId, String itemId, UpdateKnowledgeItemRequest request) {
|
||||||
@@ -190,6 +202,9 @@ public class KnowledgeItemApplicationService {
|
|||||||
KnowledgeItem knowledgeItem = knowledgeItemRepository.getById(itemId);
|
KnowledgeItem knowledgeItem = knowledgeItemRepository.getById(itemId);
|
||||||
BusinessAssert.notNull(knowledgeItem, DataManagementErrorCode.KNOWLEDGE_ITEM_NOT_FOUND);
|
BusinessAssert.notNull(knowledgeItem, DataManagementErrorCode.KNOWLEDGE_ITEM_NOT_FOUND);
|
||||||
BusinessAssert.isTrue(Objects.equals(knowledgeItem.getSetId(), setId), CommonErrorCode.PARAM_ERROR);
|
BusinessAssert.isTrue(Objects.equals(knowledgeItem.getSetId(), setId), CommonErrorCode.PARAM_ERROR);
|
||||||
|
|
||||||
|
deleteKnowledgeItemFile(knowledgeItem);
|
||||||
|
knowledgeItemPreviewService.deletePreviewFileQuietly(setId, itemId);
|
||||||
knowledgeItemRepository.removeById(itemId);
|
knowledgeItemRepository.removeById(itemId);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -205,6 +220,11 @@ public class KnowledgeItemApplicationService {
|
|||||||
boolean allMatch = items.stream().allMatch(item -> Objects.equals(item.getSetId(), setId));
|
boolean allMatch = items.stream().allMatch(item -> Objects.equals(item.getSetId(), setId));
|
||||||
BusinessAssert.isTrue(allMatch, CommonErrorCode.PARAM_ERROR);
|
BusinessAssert.isTrue(allMatch, CommonErrorCode.PARAM_ERROR);
|
||||||
|
|
||||||
|
for (KnowledgeItem item : items) {
|
||||||
|
deleteKnowledgeItemFile(item);
|
||||||
|
knowledgeItemPreviewService.deletePreviewFileQuietly(setId, item.getId());
|
||||||
|
}
|
||||||
|
|
||||||
List<String> deleteIds = items.stream().map(KnowledgeItem::getId).toList();
|
List<String> deleteIds = items.stream().map(KnowledgeItem::getId).toList();
|
||||||
knowledgeItemRepository.removeByIds(deleteIds);
|
knowledgeItemRepository.removeByIds(deleteIds);
|
||||||
}
|
}
|
||||||
@@ -785,6 +805,24 @@ public class KnowledgeItemApplicationService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void deleteKnowledgeItemFile(KnowledgeItem knowledgeItem) {
|
||||||
|
if (knowledgeItem == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (knowledgeItem.getSourceType() == KnowledgeSourceType.FILE_UPLOAD
|
||||||
|
|| knowledgeItem.getContentType() == KnowledgeContentType.FILE) {
|
||||||
|
String relativePath = knowledgeItem.getContent();
|
||||||
|
if (StringUtils.isNotBlank(relativePath)) {
|
||||||
|
try {
|
||||||
|
Path filePath = resolveKnowledgeItemStoragePath(relativePath);
|
||||||
|
deleteFileQuietly(filePath);
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.warn("delete knowledge item file error, itemId: {}, path: {}", knowledgeItem.getId(), relativePath, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private String resolveOriginalFileName(MultipartFile file) {
|
private String resolveOriginalFileName(MultipartFile file) {
|
||||||
String originalName = file.getOriginalFilename();
|
String originalName = file.getOriginalFilename();
|
||||||
if (StringUtils.isBlank(originalName)) {
|
if (StringUtils.isBlank(originalName)) {
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from app.core.logging import get_logger
|
|||||||
|
|
||||||
from ..service.mapping import DatasetMappingService
|
from ..service.mapping import DatasetMappingService
|
||||||
from ..service.template import AnnotationTemplateService
|
from ..service.template import AnnotationTemplateService
|
||||||
|
from ..service.knowledge_sync import KnowledgeSyncService
|
||||||
from ..schema import (
|
from ..schema import (
|
||||||
DatasetMappingCreateRequest,
|
DatasetMappingCreateRequest,
|
||||||
DatasetMappingCreateResponse,
|
DatasetMappingCreateResponse,
|
||||||
@@ -22,26 +23,26 @@ from ..schema import (
|
|||||||
DatasetMappingResponse,
|
DatasetMappingResponse,
|
||||||
)
|
)
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(prefix="/project", tags=["annotation/project"])
|
||||||
prefix="/project",
|
|
||||||
tags=["annotation/project"]
|
|
||||||
)
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
TEXT_DATASET_TYPE = "TEXT"
|
TEXT_DATASET_TYPE = "TEXT"
|
||||||
SOURCE_DOCUMENT_FILE_TYPES = {"pdf", "doc", "docx", "xls", "xlsx"}
|
SOURCE_DOCUMENT_FILE_TYPES = {"pdf", "doc", "docx", "xls", "xlsx"}
|
||||||
LABELING_TYPE_CONFIG_KEY = "labeling_type"
|
LABELING_TYPE_CONFIG_KEY = "labeling_type"
|
||||||
|
|
||||||
@router.get("/{mapping_id}/login")
|
|
||||||
async def login_label_studio(
|
|
||||||
mapping_id: str,
|
|
||||||
db: AsyncSession = Depends(get_db)
|
|
||||||
):
|
|
||||||
raise HTTPException(status_code=410, detail="当前为内嵌编辑器模式,不再支持 Label Studio 登录代理接口")
|
|
||||||
|
|
||||||
@router.post("", response_model=StandardResponse[DatasetMappingCreateResponse], status_code=201)
|
@router.get("/{mapping_id}/login")
|
||||||
|
async def login_label_studio(mapping_id: str, db: AsyncSession = Depends(get_db)):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=410,
|
||||||
|
detail="当前为内嵌编辑器模式,不再支持 Label Studio 登录代理接口",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"", response_model=StandardResponse[DatasetMappingCreateResponse], status_code=201
|
||||||
|
)
|
||||||
async def create_mapping(
|
async def create_mapping(
|
||||||
request: DatasetMappingCreateRequest,
|
request: DatasetMappingCreateRequest, db: AsyncSession = Depends(get_db)
|
||||||
db: AsyncSession = Depends(get_db)
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
创建数据集映射
|
创建数据集映射
|
||||||
@@ -64,7 +65,7 @@ async def create_mapping(
|
|||||||
if not dataset_info:
|
if not dataset_info:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404,
|
status_code=404,
|
||||||
detail=f"Dataset not found in DM service: {request.dataset_id}"
|
detail=f"Dataset not found in DM service: {request.dataset_id}",
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset_type = (
|
dataset_type = (
|
||||||
@@ -73,13 +74,15 @@ async def create_mapping(
|
|||||||
or ""
|
or ""
|
||||||
).upper()
|
).upper()
|
||||||
|
|
||||||
project_name = request.name or \
|
project_name = (
|
||||||
dataset_info.name or \
|
request.name or dataset_info.name or "A new project from DataMate"
|
||||||
"A new project from DataMate"
|
)
|
||||||
|
|
||||||
project_description = request.description or \
|
project_description = (
|
||||||
dataset_info.description or \
|
request.description
|
||||||
f"Imported from DM dataset {dataset_info.name} ({dataset_info.id})"
|
or dataset_info.description
|
||||||
|
or f"Imported from DM dataset {dataset_info.name} ({dataset_info.id})"
|
||||||
|
)
|
||||||
|
|
||||||
# 如果提供了模板ID,获取模板配置
|
# 如果提供了模板ID,获取模板配置
|
||||||
label_config = None
|
label_config = None
|
||||||
@@ -89,8 +92,7 @@ async def create_mapping(
|
|||||||
template = await template_service.get_template(db, request.template_id)
|
template = await template_service.get_template(db, request.template_id)
|
||||||
if not template:
|
if not template:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404,
|
status_code=404, detail=f"Template not found: {request.template_id}"
|
||||||
detail=f"Template not found: {request.template_id}"
|
|
||||||
)
|
)
|
||||||
label_config = template.label_config
|
label_config = template.label_config
|
||||||
template_labeling_type = getattr(template, "labeling_type", None)
|
template_labeling_type = getattr(template, "labeling_type", None)
|
||||||
@@ -110,19 +112,24 @@ async def create_mapping(
|
|||||||
project_configuration["label_config"] = label_config
|
project_configuration["label_config"] = label_config
|
||||||
if project_description:
|
if project_description:
|
||||||
project_configuration["description"] = project_description
|
project_configuration["description"] = project_description
|
||||||
if dataset_type == TEXT_DATASET_TYPE and request.segmentation_enabled is not None:
|
if (
|
||||||
project_configuration["segmentation_enabled"] = bool(request.segmentation_enabled)
|
dataset_type == TEXT_DATASET_TYPE
|
||||||
|
and request.segmentation_enabled is not None
|
||||||
|
):
|
||||||
|
project_configuration["segmentation_enabled"] = bool(
|
||||||
|
request.segmentation_enabled
|
||||||
|
)
|
||||||
if template_labeling_type:
|
if template_labeling_type:
|
||||||
project_configuration[LABELING_TYPE_CONFIG_KEY] = template_labeling_type
|
project_configuration[LABELING_TYPE_CONFIG_KEY] = template_labeling_type
|
||||||
|
|
||||||
labeling_project = LabelingProject(
|
labeling_project = LabelingProject(
|
||||||
id=str(uuid.uuid4()), # Generate UUID here
|
id=str(uuid.uuid4()), # Generate UUID here
|
||||||
dataset_id=request.dataset_id,
|
dataset_id=request.dataset_id,
|
||||||
labeling_project_id=labeling_project_id,
|
labeling_project_id=labeling_project_id,
|
||||||
name=project_name,
|
name=project_name,
|
||||||
template_id=request.template_id, # Save template_id to database
|
template_id=request.template_id, # Save template_id to database
|
||||||
configuration=project_configuration or None,
|
configuration=project_configuration or None,
|
||||||
)
|
)
|
||||||
|
|
||||||
file_result = await db.execute(
|
file_result = await db.execute(
|
||||||
select(DatasetFiles).where(
|
select(DatasetFiles).where(
|
||||||
@@ -143,9 +150,7 @@ async def create_mapping(
|
|||||||
snapshot_file_ids.append(str(file_record.id))
|
snapshot_file_ids.append(str(file_record.id))
|
||||||
else:
|
else:
|
||||||
snapshot_file_ids = [
|
snapshot_file_ids = [
|
||||||
str(file_record.id)
|
str(file_record.id) for file_record in file_records if file_record.id
|
||||||
for file_record in file_records
|
|
||||||
if file_record.id
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# 创建映射关系并写入快照
|
# 创建映射关系并写入快照
|
||||||
@@ -157,25 +162,30 @@ async def create_mapping(
|
|||||||
if dataset_type == TEXT_DATASET_TYPE and request.segmentation_enabled:
|
if dataset_type == TEXT_DATASET_TYPE and request.segmentation_enabled:
|
||||||
try:
|
try:
|
||||||
from ..service.editor import AnnotationEditorService
|
from ..service.editor import AnnotationEditorService
|
||||||
|
|
||||||
editor_service = AnnotationEditorService(db)
|
editor_service = AnnotationEditorService(db)
|
||||||
# 异步预计算切片(不阻塞创建响应)
|
# 异步预计算切片(不阻塞创建响应)
|
||||||
segmentation_result = await editor_service.precompute_segmentation_for_project(labeling_project.id)
|
segmentation_result = (
|
||||||
logger.info(f"Precomputed segmentation for project {labeling_project.id}: {segmentation_result}")
|
await editor_service.precompute_segmentation_for_project(
|
||||||
|
labeling_project.id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Precomputed segmentation for project {labeling_project.id}: {segmentation_result}"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to precompute segmentation for project {labeling_project.id}: {e}")
|
logger.warning(
|
||||||
|
f"Failed to precompute segmentation for project {labeling_project.id}: {e}"
|
||||||
|
)
|
||||||
# 不影响项目创建,只记录警告
|
# 不影响项目创建,只记录警告
|
||||||
|
|
||||||
response_data = DatasetMappingCreateResponse(
|
response_data = DatasetMappingCreateResponse(
|
||||||
id=mapping.id,
|
id=mapping.id,
|
||||||
labeling_project_id=str(mapping.labeling_project_id),
|
labeling_project_id=str(mapping.labeling_project_id),
|
||||||
labeling_project_name=mapping.name or project_name
|
labeling_project_name=mapping.name or project_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
return StandardResponse(
|
return StandardResponse(code=201, message="success", data=response_data)
|
||||||
code=201,
|
|
||||||
message="success",
|
|
||||||
data=response_data
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -183,12 +193,15 @@ async def create_mapping(
|
|||||||
logger.error(f"Error while creating dataset mapping: {e}")
|
logger.error(f"Error while creating dataset mapping: {e}")
|
||||||
raise HTTPException(status_code=500, detail="Internal server error")
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=StandardResponse[PaginatedData[DatasetMappingResponse]])
|
@router.get("", response_model=StandardResponse[PaginatedData[DatasetMappingResponse]])
|
||||||
async def list_mappings(
|
async def list_mappings(
|
||||||
page: int = Query(1, ge=1, description="页码(从1开始)"),
|
page: int = Query(1, ge=1, description="页码(从1开始)"),
|
||||||
size: int = Query(20, ge=1, le=100, description="每页记录数"),
|
size: int = Query(20, ge=1, le=100, description="每页记录数"),
|
||||||
include_template: bool = Query(False, description="是否包含模板详情", alias="includeTemplate"),
|
include_template: bool = Query(
|
||||||
db: AsyncSession = Depends(get_db)
|
False, description="是否包含模板详情", alias="includeTemplate"
|
||||||
|
),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
查询所有映射关系(分页)
|
查询所有映射关系(分页)
|
||||||
@@ -207,14 +220,16 @@ async def list_mappings(
|
|||||||
# 计算 skip
|
# 计算 skip
|
||||||
skip = (page - 1) * size
|
skip = (page - 1) * size
|
||||||
|
|
||||||
logger.info(f"List mappings: page={page}, size={size}, include_template={include_template}")
|
logger.info(
|
||||||
|
f"List mappings: page={page}, size={size}, include_template={include_template}"
|
||||||
|
)
|
||||||
|
|
||||||
# 获取数据和总数
|
# 获取数据和总数
|
||||||
mappings, total = await service.get_all_mappings_with_count(
|
mappings, total = await service.get_all_mappings_with_count(
|
||||||
skip=skip,
|
skip=skip,
|
||||||
limit=size,
|
limit=size,
|
||||||
include_deleted=False,
|
include_deleted=False,
|
||||||
include_template=include_template
|
include_template=include_template,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 计算总页数
|
# 计算总页数
|
||||||
@@ -226,26 +241,22 @@ async def list_mappings(
|
|||||||
size=size,
|
size=size,
|
||||||
total_elements=total,
|
total_elements=total,
|
||||||
total_pages=total_pages,
|
total_pages=total_pages,
|
||||||
content=mappings
|
content=mappings,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"List mappings: page={page}, returned {len(mappings)}/{total}, templates_included: {include_template}")
|
logger.info(
|
||||||
|
f"List mappings: page={page}, returned {len(mappings)}/{total}, templates_included: {include_template}"
|
||||||
return StandardResponse(
|
|
||||||
code=200,
|
|
||||||
message="success",
|
|
||||||
data=paginated_data
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return StandardResponse(code=200, message="success", data=paginated_data)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error listing mappings: {e}")
|
logger.error(f"Error listing mappings: {e}")
|
||||||
raise HTTPException(status_code=500, detail="Internal server error")
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{mapping_id}", response_model=StandardResponse[DatasetMappingResponse])
|
@router.get("/{mapping_id}", response_model=StandardResponse[DatasetMappingResponse])
|
||||||
async def get_mapping(
|
async def get_mapping(mapping_id: str, db: AsyncSession = Depends(get_db)):
|
||||||
mapping_id: str,
|
|
||||||
db: AsyncSession = Depends(get_db)
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
根据 UUID 查询单个映射关系(包含关联的标注模板详情)
|
根据 UUID 查询单个映射关系(包含关联的标注模板详情)
|
||||||
|
|
||||||
@@ -265,31 +276,34 @@ async def get_mapping(
|
|||||||
|
|
||||||
if not mapping:
|
if not mapping:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404,
|
status_code=404, detail=f"Mapping not found: {mapping_id}"
|
||||||
detail=f"Mapping not found: {mapping_id}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Found mapping: {mapping.id}, template_included: {mapping.template is not None}")
|
logger.info(
|
||||||
|
f"Found mapping: {mapping.id}, template_included: {mapping.template is not None}"
|
||||||
return StandardResponse(
|
|
||||||
code=200,
|
|
||||||
message="success",
|
|
||||||
data=mapping
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return StandardResponse(code=200, message="success", data=mapping)
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting mapping: {e}")
|
logger.error(f"Error getting mapping: {e}")
|
||||||
raise HTTPException(status_code=500, detail="Internal server error")
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
|
|
||||||
@router.get("/by-source/{dataset_id}", response_model=StandardResponse[PaginatedData[DatasetMappingResponse]])
|
|
||||||
|
@router.get(
|
||||||
|
"/by-source/{dataset_id}",
|
||||||
|
response_model=StandardResponse[PaginatedData[DatasetMappingResponse]],
|
||||||
|
)
|
||||||
async def get_mappings_by_source(
|
async def get_mappings_by_source(
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
page: int = Query(1, ge=1, description="页码(从1开始)"),
|
page: int = Query(1, ge=1, description="页码(从1开始)"),
|
||||||
size: int = Query(20, ge=1, le=100, description="每页记录数"),
|
size: int = Query(20, ge=1, le=100, description="每页记录数"),
|
||||||
include_template: bool = Query(True, description="是否包含模板详情", alias="includeTemplate"),
|
include_template: bool = Query(
|
||||||
db: AsyncSession = Depends(get_db)
|
True, description="是否包含模板详情", alias="includeTemplate"
|
||||||
|
),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
根据源数据集 ID 查询所有映射关系(分页,包含模板详情)
|
根据源数据集 ID 查询所有映射关系(分页,包含模板详情)
|
||||||
@@ -309,14 +323,16 @@ async def get_mappings_by_source(
|
|||||||
# 计算 skip
|
# 计算 skip
|
||||||
skip = (page - 1) * size
|
skip = (page - 1) * size
|
||||||
|
|
||||||
logger.info(f"Get mappings by source dataset id: {dataset_id}, page={page}, size={size}, include_template={include_template}")
|
logger.info(
|
||||||
|
f"Get mappings by source dataset id: {dataset_id}, page={page}, size={size}, include_template={include_template}"
|
||||||
|
)
|
||||||
|
|
||||||
# 获取数据和总数(包含模板信息)
|
# 获取数据和总数(包含模板信息)
|
||||||
mappings, total = await service.get_mappings_by_source_with_count(
|
mappings, total = await service.get_mappings_by_source_with_count(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
skip=skip,
|
skip=skip,
|
||||||
limit=size,
|
limit=size,
|
||||||
include_template=include_template
|
include_template=include_template,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 计算总页数
|
# 计算总页数
|
||||||
@@ -328,27 +344,26 @@ async def get_mappings_by_source(
|
|||||||
size=size,
|
size=size,
|
||||||
total_elements=total,
|
total_elements=total,
|
||||||
total_pages=total_pages,
|
total_pages=total_pages,
|
||||||
content=mappings
|
content=mappings,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Found {len(mappings)} mappings on page {page}, total: {total}, templates_included: {include_template}")
|
logger.info(
|
||||||
|
f"Found {len(mappings)} mappings on page {page}, total: {total}, templates_included: {include_template}"
|
||||||
return StandardResponse(
|
|
||||||
code=200,
|
|
||||||
message="success",
|
|
||||||
data=paginated_data
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return StandardResponse(code=200, message="success", data=paginated_data)
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting mappings: {e}")
|
logger.error(f"Error getting mappings: {e}")
|
||||||
raise HTTPException(status_code=500, detail="Internal server error")
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{project_id}", response_model=StandardResponse[DeleteDatasetResponse])
|
@router.delete("/{project_id}", response_model=StandardResponse[DeleteDatasetResponse])
|
||||||
async def delete_mapping(
|
async def delete_mapping(
|
||||||
project_id: str = Path(..., description="映射UUID(path param)"),
|
project_id: str = Path(..., description="映射UUID(path param)"),
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
删除映射关系(软删除)
|
删除映射关系(软删除)
|
||||||
@@ -370,12 +385,12 @@ async def delete_mapping(
|
|||||||
|
|
||||||
if not mapping:
|
if not mapping:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404,
|
status_code=404, detail=f"Mapping either not found or not specified."
|
||||||
detail=f"Mapping either not found or not specified."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
id = mapping.id
|
id = mapping.id
|
||||||
logger.debug(f"Found mapping: {id}")
|
dataset_id = mapping.dataset_id
|
||||||
|
logger.debug(f"Found mapping: {id}, dataset_id: {dataset_id}")
|
||||||
|
|
||||||
# 软删除映射记录
|
# 软删除映射记录
|
||||||
soft_delete_success = await service.soft_delete_mapping(id)
|
soft_delete_success = await service.soft_delete_mapping(id)
|
||||||
@@ -383,19 +398,22 @@ async def delete_mapping(
|
|||||||
|
|
||||||
if not soft_delete_success:
|
if not soft_delete_success:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500, detail="Failed to delete mapping record"
|
||||||
detail="Failed to delete mapping record"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 清理知识集中的关联数据
|
||||||
|
try:
|
||||||
|
knowledge_sync = KnowledgeSyncService(db)
|
||||||
|
await knowledge_sync._cleanup_knowledge_set_for_project(id)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"清理知识集失败:project_id={id} error={exc}")
|
||||||
|
|
||||||
logger.info(f"Successfully deleted mapping: {id}")
|
logger.info(f"Successfully deleted mapping: {id}")
|
||||||
|
|
||||||
return StandardResponse(
|
return StandardResponse(
|
||||||
code=200,
|
code=200,
|
||||||
message="success",
|
message="success",
|
||||||
data=DeleteDatasetResponse(
|
data=DeleteDatasetResponse(id=id, status="success"),
|
||||||
id=id,
|
|
||||||
status="success"
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
@@ -409,7 +427,7 @@ async def delete_mapping(
|
|||||||
async def update_mapping(
|
async def update_mapping(
|
||||||
project_id: str = Path(..., description="映射UUID(path param)"),
|
project_id: str = Path(..., description="映射UUID(path param)"),
|
||||||
request: DatasetMappingUpdateRequest = None,
|
request: DatasetMappingUpdateRequest = None,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
更新标注项目信息
|
更新标注项目信息
|
||||||
@@ -429,16 +447,14 @@ async def update_mapping(
|
|||||||
# 直接查询 ORM 模型获取原始数据
|
# 直接查询 ORM 模型获取原始数据
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(LabelingProject).where(
|
select(LabelingProject).where(
|
||||||
LabelingProject.id == project_id,
|
LabelingProject.id == project_id, LabelingProject.deleted_at.is_(None)
|
||||||
LabelingProject.deleted_at.is_(None)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
mapping_orm = result.scalar_one_or_none()
|
mapping_orm = result.scalar_one_or_none()
|
||||||
|
|
||||||
if not mapping_orm:
|
if not mapping_orm:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404,
|
status_code=404, detail=f"Mapping not found: {project_id}"
|
||||||
detail=f"Mapping not found: {project_id}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 构建更新数据
|
# 构建更新数据
|
||||||
@@ -449,7 +465,11 @@ async def update_mapping(
|
|||||||
# 从 configuration 字段中读取和更新 description 和 label_config
|
# 从 configuration 字段中读取和更新 description 和 label_config
|
||||||
configuration = {}
|
configuration = {}
|
||||||
if mapping_orm.configuration:
|
if mapping_orm.configuration:
|
||||||
configuration = mapping_orm.configuration.copy() if isinstance(mapping_orm.configuration, dict) else {}
|
configuration = (
|
||||||
|
mapping_orm.configuration.copy()
|
||||||
|
if isinstance(mapping_orm.configuration, dict)
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
|
||||||
if request.description is not None:
|
if request.description is not None:
|
||||||
configuration["description"] = request.description
|
configuration["description"] = request.description
|
||||||
@@ -468,7 +488,7 @@ async def update_mapping(
|
|||||||
if not template:
|
if not template:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404,
|
status_code=404,
|
||||||
detail=f"Template not found: {request.template_id}"
|
detail=f"Template not found: {request.template_id}",
|
||||||
)
|
)
|
||||||
template_labeling_type = getattr(template, "labeling_type", None)
|
template_labeling_type = getattr(template, "labeling_type", None)
|
||||||
if template_labeling_type:
|
if template_labeling_type:
|
||||||
@@ -477,14 +497,11 @@ async def update_mapping(
|
|||||||
if not update_values:
|
if not update_values:
|
||||||
# 没有要更新的字段,直接返回当前数据
|
# 没有要更新的字段,直接返回当前数据
|
||||||
response_data = await service.get_mapping_by_uuid(project_id)
|
response_data = await service.get_mapping_by_uuid(project_id)
|
||||||
return StandardResponse(
|
return StandardResponse(code=200, message="success", data=response_data)
|
||||||
code=200,
|
|
||||||
message="success",
|
|
||||||
data=response_data
|
|
||||||
)
|
|
||||||
|
|
||||||
# 执行更新
|
# 执行更新
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
update_values["updated_at"] = datetime.now()
|
update_values["updated_at"] = datetime.now()
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
@@ -495,21 +512,14 @@ async def update_mapping(
|
|||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
if result.rowcount == 0:
|
if result.rowcount == 0:
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=500, detail="Failed to update mapping")
|
||||||
status_code=500,
|
|
||||||
detail="Failed to update mapping"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 重新获取更新后的数据
|
# 重新获取更新后的数据
|
||||||
updated_mapping = await service.get_mapping_by_uuid(project_id)
|
updated_mapping = await service.get_mapping_by_uuid(project_id)
|
||||||
|
|
||||||
logger.info(f"Successfully updated mapping: {project_id}")
|
logger.info(f"Successfully updated mapping: {project_id}")
|
||||||
|
|
||||||
return StandardResponse(
|
return StandardResponse(code=200, message="success", data=updated_mapping)
|
||||||
code=200,
|
|
||||||
message="success",
|
|
||||||
data=updated_mapping
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -13,17 +13,21 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
class ExportFormat(str, Enum):
|
class ExportFormat(str, Enum):
|
||||||
"""导出格式枚举"""
|
"""导出格式枚举"""
|
||||||
JSON = "json" # Label Studio 原生 JSON 格式
|
|
||||||
JSONL = "jsonl" # JSON Lines 格式(每行一条记录)
|
JSON = "json" # Label Studio 原生 JSON 格式
|
||||||
CSV = "csv" # CSV 表格格式
|
JSONL = "jsonl" # JSON Lines 格式(每行一条记录)
|
||||||
COCO = "coco" # COCO 目标检测格式
|
CSV = "csv" # CSV 表格格式
|
||||||
YOLO = "yolo" # YOLO 格式
|
COCO = "coco" # COCO 目标检测格式
|
||||||
|
YOLO = "yolo" # YOLO 格式
|
||||||
|
|
||||||
|
|
||||||
class ExportAnnotationsRequest(BaseModel):
|
class ExportAnnotationsRequest(BaseModel):
|
||||||
"""导出标注数据请求"""
|
"""导出标注数据请求"""
|
||||||
|
|
||||||
format: ExportFormat = Field(default=ExportFormat.JSON, description="导出格式")
|
format: ExportFormat = Field(default=ExportFormat.JSON, description="导出格式")
|
||||||
include_data: bool = Field(default=True, description="是否包含原始数据(如文本内容)")
|
include_data: bool = Field(
|
||||||
|
default=True, description="是否包含原始数据(如文本内容)"
|
||||||
|
)
|
||||||
only_annotated: bool = Field(default=True, description="是否只导出已标注的数据")
|
only_annotated: bool = Field(default=True, description="是否只导出已标注的数据")
|
||||||
|
|
||||||
model_config = {"use_enum_values": True}
|
model_config = {"use_enum_values": True}
|
||||||
@@ -31,6 +35,7 @@ class ExportAnnotationsRequest(BaseModel):
|
|||||||
|
|
||||||
class ExportAnnotationsResponse(BaseModel):
|
class ExportAnnotationsResponse(BaseModel):
|
||||||
"""导出标注数据响应(用于预览/统计)"""
|
"""导出标注数据响应(用于预览/统计)"""
|
||||||
|
|
||||||
project_id: str = Field(..., description="项目ID")
|
project_id: str = Field(..., description="项目ID")
|
||||||
project_name: str = Field(..., description="项目名称")
|
project_name: str = Field(..., description="项目名称")
|
||||||
total_files: int = Field(..., description="总文件数")
|
total_files: int = Field(..., description="总文件数")
|
||||||
@@ -42,16 +47,21 @@ class ExportAnnotationsResponse(BaseModel):
|
|||||||
|
|
||||||
class AnnotationExportItem(BaseModel):
|
class AnnotationExportItem(BaseModel):
|
||||||
"""单条导出记录"""
|
"""单条导出记录"""
|
||||||
|
|
||||||
file_id: str = Field(..., description="文件ID")
|
file_id: str = Field(..., description="文件ID")
|
||||||
file_name: str = Field(..., description="文件名")
|
file_name: str = Field(..., description="文件名")
|
||||||
|
file_path: Optional[str] = Field(default=None, description="文件路径")
|
||||||
data: Optional[Dict[str, Any]] = Field(default=None, description="原始数据")
|
data: Optional[Dict[str, Any]] = Field(default=None, description="原始数据")
|
||||||
annotations: List[Dict[str, Any]] = Field(default_factory=list, description="标注结果")
|
annotations: List[Dict[str, Any]] = Field(
|
||||||
|
default_factory=list, description="标注结果"
|
||||||
|
)
|
||||||
created_at: Optional[datetime] = Field(default=None, description="创建时间")
|
created_at: Optional[datetime] = Field(default=None, description="创建时间")
|
||||||
updated_at: Optional[datetime] = Field(default=None, description="更新时间")
|
updated_at: Optional[datetime] = Field(default=None, description="更新时间")
|
||||||
|
|
||||||
|
|
||||||
class COCOExportFormat(BaseModel):
|
class COCOExportFormat(BaseModel):
|
||||||
"""COCO 格式导出结构"""
|
"""COCO 格式导出结构"""
|
||||||
|
|
||||||
info: Dict[str, Any] = Field(default_factory=dict)
|
info: Dict[str, Any] = Field(default_factory=dict)
|
||||||
licenses: List[Dict[str, Any]] = Field(default_factory=list)
|
licenses: List[Dict[str, Any]] = Field(default_factory=list)
|
||||||
images: List[Dict[str, Any]] = Field(default_factory=list)
|
images: List[Dict[str, Any]] = Field(default_factory=list)
|
||||||
|
|||||||
@@ -21,14 +21,23 @@ from datetime import datetime
|
|||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
from PIL import Image
|
||||||
from sqlalchemy import func, select
|
from sqlalchemy import func, select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
from app.db.models import AnnotationResult, Dataset, DatasetFiles, LabelingProject, LabelingProjectFile
|
from app.db.models import (
|
||||||
|
AnnotationResult,
|
||||||
|
Dataset,
|
||||||
|
DatasetFiles,
|
||||||
|
LabelingProject,
|
||||||
|
LabelingProjectFile,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _read_file_content(file_path: str, max_size: int = 10 * 1024 * 1024) -> Optional[str]:
|
async def _read_file_content(
|
||||||
|
file_path: str, max_size: int = 10 * 1024 * 1024
|
||||||
|
) -> Optional[str]:
|
||||||
"""读取文件内容,仅适用于文本文件
|
"""读取文件内容,仅适用于文本文件
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -48,11 +57,12 @@ async def _read_file_content(file_path: str, max_size: int = 10 * 1024 * 1024) -
|
|||||||
return f"[File too large: {file_size} bytes]"
|
return f"[File too large: {file_size} bytes]"
|
||||||
|
|
||||||
# 尝试以文本方式读取
|
# 尝试以文本方式读取
|
||||||
with open(file_path, 'r', encoding='utf-8', errors='replace') as f:
|
with open(file_path, "r", encoding="utf-8", errors="replace") as f:
|
||||||
return f.read()
|
return f.read()
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
from ..schema.export import (
|
from ..schema.export import (
|
||||||
AnnotationExportItem,
|
AnnotationExportItem,
|
||||||
COCOExportFormat,
|
COCOExportFormat,
|
||||||
@@ -79,7 +89,9 @@ class AnnotationExportService:
|
|||||||
async def get_export_stats(self, project_id: str) -> ExportAnnotationsResponse:
|
async def get_export_stats(self, project_id: str) -> ExportAnnotationsResponse:
|
||||||
"""获取导出统计信息"""
|
"""获取导出统计信息"""
|
||||||
project = await self._get_project_or_404(project_id)
|
project = await self._get_project_or_404(project_id)
|
||||||
logger.info(f"Export stats for project: id={project_id}, dataset_id={project.dataset_id}, name={project.name}")
|
logger.info(
|
||||||
|
f"Export stats for project: id={project_id}, dataset_id={project.dataset_id}, name={project.name}"
|
||||||
|
)
|
||||||
|
|
||||||
# 获取总文件数(标注项目快照内的文件)
|
# 获取总文件数(标注项目快照内的文件)
|
||||||
total_result = await self.db.execute(
|
total_result = await self.db.execute(
|
||||||
@@ -92,7 +104,9 @@ class AnnotationExportService:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
total_files = int(total_result.scalar() or 0)
|
total_files = int(total_result.scalar() or 0)
|
||||||
logger.info(f"Total files (snapshot): {total_files} for project_id={project_id}")
|
logger.info(
|
||||||
|
f"Total files (snapshot): {total_files} for project_id={project_id}"
|
||||||
|
)
|
||||||
|
|
||||||
# 获取已标注文件数(统计不同的 file_id 数量)
|
# 获取已标注文件数(统计不同的 file_id 数量)
|
||||||
annotated_result = await self.db.execute(
|
annotated_result = await self.db.execute(
|
||||||
@@ -132,7 +146,11 @@ class AnnotationExportService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 根据格式导出
|
# 根据格式导出
|
||||||
format_type = ExportFormat(request.format) if isinstance(request.format, str) else request.format
|
format_type = (
|
||||||
|
ExportFormat(request.format)
|
||||||
|
if isinstance(request.format, str)
|
||||||
|
else request.format
|
||||||
|
)
|
||||||
|
|
||||||
if format_type == ExportFormat.JSON:
|
if format_type == ExportFormat.JSON:
|
||||||
return self._export_json(items, project.name)
|
return self._export_json(items, project.name)
|
||||||
@@ -145,7 +163,9 @@ class AnnotationExportService:
|
|||||||
elif format_type == ExportFormat.YOLO:
|
elif format_type == ExportFormat.YOLO:
|
||||||
return self._export_yolo(items, project.name)
|
return self._export_yolo(items, project.name)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=400, detail=f"不支持的导出格式: {request.format}")
|
raise HTTPException(
|
||||||
|
status_code=400, detail=f"不支持的导出格式: {request.format}"
|
||||||
|
)
|
||||||
|
|
||||||
async def _get_project_or_404(self, project_id: str) -> LabelingProject:
|
async def _get_project_or_404(self, project_id: str) -> LabelingProject:
|
||||||
"""获取标注项目,不存在则抛出 404"""
|
"""获取标注项目,不存在则抛出 404"""
|
||||||
@@ -174,7 +194,10 @@ class AnnotationExportService:
|
|||||||
# 只获取已标注的数据
|
# 只获取已标注的数据
|
||||||
result = await self.db.execute(
|
result = await self.db.execute(
|
||||||
select(AnnotationResult, DatasetFiles)
|
select(AnnotationResult, DatasetFiles)
|
||||||
.join(LabelingProjectFile, LabelingProjectFile.file_id == AnnotationResult.file_id)
|
.join(
|
||||||
|
LabelingProjectFile,
|
||||||
|
LabelingProjectFile.file_id == AnnotationResult.file_id,
|
||||||
|
)
|
||||||
.join(DatasetFiles, AnnotationResult.file_id == DatasetFiles.id)
|
.join(DatasetFiles, AnnotationResult.file_id == DatasetFiles.id)
|
||||||
.where(
|
.where(
|
||||||
AnnotationResult.project_id == project_id,
|
AnnotationResult.project_id == project_id,
|
||||||
@@ -197,6 +220,7 @@ class AnnotationExportService:
|
|||||||
AnnotationExportItem(
|
AnnotationExportItem(
|
||||||
file_id=str(file.id),
|
file_id=str(file.id),
|
||||||
file_name=str(getattr(file, "file_name", "")),
|
file_name=str(getattr(file, "file_name", "")),
|
||||||
|
file_path=str(getattr(file, "file_path", "")),
|
||||||
data={"text": file_content} if include_data else None,
|
data={"text": file_content} if include_data else None,
|
||||||
annotations=[annotation_data] if annotation_data else [],
|
annotations=[annotation_data] if annotation_data else [],
|
||||||
created_at=ann.created_at,
|
created_at=ann.created_at,
|
||||||
@@ -207,7 +231,9 @@ class AnnotationExportService:
|
|||||||
# 获取所有文件(基于标注项目快照)
|
# 获取所有文件(基于标注项目快照)
|
||||||
files_result = await self.db.execute(
|
files_result = await self.db.execute(
|
||||||
select(DatasetFiles)
|
select(DatasetFiles)
|
||||||
.join(LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id)
|
.join(
|
||||||
|
LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id
|
||||||
|
)
|
||||||
.where(
|
.where(
|
||||||
LabelingProjectFile.project_id == project_id,
|
LabelingProjectFile.project_id == project_id,
|
||||||
DatasetFiles.dataset_id == dataset_id,
|
DatasetFiles.dataset_id == dataset_id,
|
||||||
@@ -217,7 +243,9 @@ class AnnotationExportService:
|
|||||||
|
|
||||||
# 获取已有的标注
|
# 获取已有的标注
|
||||||
ann_result = await self.db.execute(
|
ann_result = await self.db.execute(
|
||||||
select(AnnotationResult).where(AnnotationResult.project_id == project_id)
|
select(AnnotationResult).where(
|
||||||
|
AnnotationResult.project_id == project_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
annotations = {str(a.file_id): a for a in ann_result.scalars().all()}
|
annotations = {str(a.file_id): a for a in ann_result.scalars().all()}
|
||||||
|
|
||||||
@@ -236,6 +264,7 @@ class AnnotationExportService:
|
|||||||
AnnotationExportItem(
|
AnnotationExportItem(
|
||||||
file_id=file_id,
|
file_id=file_id,
|
||||||
file_name=str(getattr(file, "file_name", "")),
|
file_name=str(getattr(file, "file_name", "")),
|
||||||
|
file_path=str(getattr(file, "file_path", "")),
|
||||||
data={"text": file_content} if include_data else None,
|
data={"text": file_content} if include_data else None,
|
||||||
annotations=[annotation_data] if annotation_data else [],
|
annotations=[annotation_data] if annotation_data else [],
|
||||||
created_at=ann.created_at if ann else None,
|
created_at=ann.created_at if ann else None,
|
||||||
@@ -262,8 +291,13 @@ class AnnotationExportService:
|
|||||||
for item in segment_results:
|
for item in segment_results:
|
||||||
if isinstance(item, dict):
|
if isinstance(item, dict):
|
||||||
normalized = dict(item)
|
normalized = dict(item)
|
||||||
if SEGMENT_INDEX_KEY not in normalized and SEGMENT_INDEX_FALLBACK_KEY not in normalized:
|
if (
|
||||||
normalized[SEGMENT_INDEX_KEY] = int(key) if str(key).isdigit() else key
|
SEGMENT_INDEX_KEY not in normalized
|
||||||
|
and SEGMENT_INDEX_FALLBACK_KEY not in normalized
|
||||||
|
):
|
||||||
|
normalized[SEGMENT_INDEX_KEY] = (
|
||||||
|
int(key) if str(key).isdigit() else key
|
||||||
|
)
|
||||||
results.append(normalized)
|
results.append(normalized)
|
||||||
elif isinstance(segments, list):
|
elif isinstance(segments, list):
|
||||||
for idx, segment in enumerate(segments):
|
for idx, segment in enumerate(segments):
|
||||||
@@ -272,11 +306,16 @@ class AnnotationExportService:
|
|||||||
segment_results = segment.get(SEGMENT_RESULT_KEY)
|
segment_results = segment.get(SEGMENT_RESULT_KEY)
|
||||||
if not isinstance(segment_results, list):
|
if not isinstance(segment_results, list):
|
||||||
continue
|
continue
|
||||||
segment_index = segment.get(SEGMENT_INDEX_KEY, segment.get(SEGMENT_INDEX_FALLBACK_KEY, idx))
|
segment_index = segment.get(
|
||||||
|
SEGMENT_INDEX_KEY, segment.get(SEGMENT_INDEX_FALLBACK_KEY, idx)
|
||||||
|
)
|
||||||
for item in segment_results:
|
for item in segment_results:
|
||||||
if isinstance(item, dict):
|
if isinstance(item, dict):
|
||||||
normalized = dict(item)
|
normalized = dict(item)
|
||||||
if SEGMENT_INDEX_KEY not in normalized and SEGMENT_INDEX_FALLBACK_KEY not in normalized:
|
if (
|
||||||
|
SEGMENT_INDEX_KEY not in normalized
|
||||||
|
and SEGMENT_INDEX_FALLBACK_KEY not in normalized
|
||||||
|
):
|
||||||
normalized[SEGMENT_INDEX_KEY] = segment_index
|
normalized[SEGMENT_INDEX_KEY] = segment_index
|
||||||
results.append(normalized)
|
results.append(normalized)
|
||||||
return results
|
return results
|
||||||
@@ -284,18 +323,43 @@ class AnnotationExportService:
|
|||||||
return result if isinstance(result, list) else []
|
return result if isinstance(result, list) else []
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _normalize_annotation_for_export(cls, annotation: Dict[str, Any]) -> Dict[str, Any]:
|
def _normalize_annotation_for_export(
|
||||||
|
cls, annotation: Dict[str, Any]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
if not annotation or not isinstance(annotation, dict):
|
if not annotation or not isinstance(annotation, dict):
|
||||||
return {}
|
return {}
|
||||||
segments = annotation.get(SEGMENTS_KEY)
|
segments = annotation.get(SEGMENTS_KEY)
|
||||||
if annotation.get(SEGMENTED_KEY) or isinstance(segments, (dict, list)):
|
if annotation.get(SEGMENTED_KEY) or isinstance(segments, (dict, list)):
|
||||||
normalized = dict(annotation)
|
normalized = dict(annotation)
|
||||||
normalized_result = cls._flatten_annotation_results(annotation)
|
normalized_result = cls._flatten_annotation_results(annotation)
|
||||||
if SEGMENT_RESULT_KEY not in normalized or not isinstance(normalized.get(SEGMENT_RESULT_KEY), list):
|
if SEGMENT_RESULT_KEY not in normalized or not isinstance(
|
||||||
|
normalized.get(SEGMENT_RESULT_KEY), list
|
||||||
|
):
|
||||||
normalized[SEGMENT_RESULT_KEY] = normalized_result
|
normalized[SEGMENT_RESULT_KEY] = normalized_result
|
||||||
return normalized
|
return normalized
|
||||||
return annotation
|
return annotation
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_image_dimensions(file_path: str) -> Tuple[int, int]:
|
||||||
|
"""获取图片文件的宽度和高度
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: 图片文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(width, height) 元组,如果读取失败则返回 (1920, 1080) 作为默认值
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if os.path.exists(file_path):
|
||||||
|
with Image.open(file_path) as img:
|
||||||
|
width, height = img.size
|
||||||
|
return width, height
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to read image dimensions from {file_path}: {e}")
|
||||||
|
|
||||||
|
# 使用合理的默认值
|
||||||
|
return 1920, 1080
|
||||||
|
|
||||||
def _export_json(
|
def _export_json(
|
||||||
self, items: List[AnnotationExportItem], project_name: str
|
self, items: List[AnnotationExportItem], project_name: str
|
||||||
) -> Tuple[bytes, str, str]:
|
) -> Tuple[bytes, str, str]:
|
||||||
@@ -309,9 +373,16 @@ class AnnotationExportService:
|
|||||||
"file_id": item.file_id,
|
"file_id": item.file_id,
|
||||||
"file_name": item.file_name,
|
"file_name": item.file_name,
|
||||||
"data": item.data,
|
"data": item.data,
|
||||||
"annotations": [self._normalize_annotation_for_export(ann) for ann in item.annotations],
|
"annotations": [
|
||||||
"created_at": item.created_at.isoformat() if item.created_at else None,
|
self._normalize_annotation_for_export(ann)
|
||||||
"updated_at": item.updated_at.isoformat() if item.updated_at else None,
|
for ann in item.annotations
|
||||||
|
],
|
||||||
|
"created_at": item.created_at.isoformat()
|
||||||
|
if item.created_at
|
||||||
|
else None,
|
||||||
|
"updated_at": item.updated_at.isoformat()
|
||||||
|
if item.updated_at
|
||||||
|
else None,
|
||||||
}
|
}
|
||||||
for item in items
|
for item in items
|
||||||
],
|
],
|
||||||
@@ -331,7 +402,10 @@ class AnnotationExportService:
|
|||||||
"file_id": item.file_id,
|
"file_id": item.file_id,
|
||||||
"file_name": item.file_name,
|
"file_name": item.file_name,
|
||||||
"data": item.data,
|
"data": item.data,
|
||||||
"annotations": [self._normalize_annotation_for_export(ann) for ann in item.annotations],
|
"annotations": [
|
||||||
|
self._normalize_annotation_for_export(ann)
|
||||||
|
for ann in item.annotations
|
||||||
|
],
|
||||||
"created_at": item.created_at.isoformat() if item.created_at else None,
|
"created_at": item.created_at.isoformat() if item.created_at else None,
|
||||||
"updated_at": item.updated_at.isoformat() if item.updated_at else None,
|
"updated_at": item.updated_at.isoformat() if item.updated_at else None,
|
||||||
}
|
}
|
||||||
@@ -389,14 +463,22 @@ class AnnotationExportService:
|
|||||||
elif "transcription" in value:
|
elif "transcription" in value:
|
||||||
labels.append(value["transcription"])
|
labels.append(value["transcription"])
|
||||||
|
|
||||||
writer.writerow({
|
writer.writerow(
|
||||||
"file_id": item.file_id,
|
{
|
||||||
"file_name": item.file_name,
|
"file_id": item.file_id,
|
||||||
"annotation_result": json.dumps(item.annotations, ensure_ascii=False),
|
"file_name": item.file_name,
|
||||||
"labels": "|".join(labels),
|
"annotation_result": json.dumps(
|
||||||
"created_at": item.created_at.isoformat() if item.created_at else "",
|
item.annotations, ensure_ascii=False
|
||||||
"updated_at": item.updated_at.isoformat() if item.updated_at else "",
|
),
|
||||||
})
|
"labels": "|".join(labels),
|
||||||
|
"created_at": item.created_at.isoformat()
|
||||||
|
if item.created_at
|
||||||
|
else "",
|
||||||
|
"updated_at": item.updated_at.isoformat()
|
||||||
|
if item.updated_at
|
||||||
|
else "",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
content = output.getvalue().encode("utf-8-sig") # BOM for Excel compatibility
|
content = output.getvalue().encode("utf-8-sig") # BOM for Excel compatibility
|
||||||
filename = f"{project_name}_annotations.csv"
|
filename = f"{project_name}_annotations.csv"
|
||||||
@@ -405,11 +487,7 @@ class AnnotationExportService:
|
|||||||
def _export_coco(
|
def _export_coco(
|
||||||
self, items: List[AnnotationExportItem], project_name: str
|
self, items: List[AnnotationExportItem], project_name: str
|
||||||
) -> Tuple[bytes, str, str]:
|
) -> Tuple[bytes, str, str]:
|
||||||
"""导出为 COCO 格式(适用于目标检测标注)
|
"""导出为 COCO 格式(适用于目标检测标注)"""
|
||||||
|
|
||||||
注意:当前实现中图片宽高被设置为0,因为需要读取实际图片文件获取尺寸。
|
|
||||||
bbox 坐标使用 Label Studio 的百分比值(0-100),使用时需要转换为像素坐标。
|
|
||||||
"""
|
|
||||||
coco_format = COCOExportFormat(
|
coco_format = COCOExportFormat(
|
||||||
info={
|
info={
|
||||||
"description": f"Exported from DataMate project: {project_name}",
|
"description": f"Exported from DataMate project: {project_name}",
|
||||||
@@ -429,13 +507,18 @@ class AnnotationExportService:
|
|||||||
for idx, item in enumerate(items):
|
for idx, item in enumerate(items):
|
||||||
image_id = idx + 1
|
image_id = idx + 1
|
||||||
|
|
||||||
|
# 获取图片实际尺寸
|
||||||
|
img_width, img_height = self._get_image_dimensions(item.file_path or "")
|
||||||
|
|
||||||
# 添加图片信息
|
# 添加图片信息
|
||||||
coco_format.images.append({
|
coco_format.images.append(
|
||||||
"id": image_id,
|
{
|
||||||
"file_name": item.file_name,
|
"id": image_id,
|
||||||
"width": 0, # 需要实际图片尺寸
|
"file_name": item.file_name,
|
||||||
"height": 0,
|
"width": img_width,
|
||||||
})
|
"height": img_height,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# 处理标注
|
# 处理标注
|
||||||
for ann in item.annotations:
|
for ann in item.annotations:
|
||||||
@@ -449,29 +532,41 @@ class AnnotationExportService:
|
|||||||
for label in labels:
|
for label in labels:
|
||||||
if label not in category_map:
|
if label not in category_map:
|
||||||
category_map[label] = len(category_map) + 1
|
category_map[label] = len(category_map) + 1
|
||||||
coco_format.categories.append({
|
coco_format.categories.append(
|
||||||
"id": category_map[label],
|
{
|
||||||
"name": label,
|
"id": category_map[label],
|
||||||
"supercategory": "",
|
"name": label,
|
||||||
})
|
"supercategory": "",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# 转换坐标(Label Studio 使用百分比)
|
# 转换坐标:Label Studio 使用百分比(0-100)转换为像素坐标
|
||||||
x = value.get("x", 0)
|
x_percent = value.get("x", 0)
|
||||||
y = value.get("y", 0)
|
y_percent = value.get("y", 0)
|
||||||
width = value.get("width", 0)
|
width_percent = value.get("width", 0)
|
||||||
height = value.get("height", 0)
|
height_percent = value.get("height", 0)
|
||||||
|
|
||||||
coco_format.annotations.append({
|
# 转换为像素坐标
|
||||||
"id": annotation_id,
|
x = x_percent * img_width / 100.0
|
||||||
"image_id": image_id,
|
y = y_percent * img_height / 100.0
|
||||||
"category_id": category_map[label],
|
width = width_percent * img_width / 100.0
|
||||||
"bbox": [x, y, width, height],
|
height = height_percent * img_height / 100.0
|
||||||
"area": width * height,
|
|
||||||
"iscrowd": 0,
|
coco_format.annotations.append(
|
||||||
})
|
{
|
||||||
|
"id": annotation_id,
|
||||||
|
"image_id": image_id,
|
||||||
|
"category_id": category_map[label],
|
||||||
|
"bbox": [x, y, width, height],
|
||||||
|
"area": width * height,
|
||||||
|
"iscrowd": 0,
|
||||||
|
}
|
||||||
|
)
|
||||||
annotation_id += 1
|
annotation_id += 1
|
||||||
|
|
||||||
content = json.dumps(coco_format.model_dump(), ensure_ascii=False, indent=2).encode("utf-8")
|
content = json.dumps(
|
||||||
|
coco_format.model_dump(), ensure_ascii=False, indent=2
|
||||||
|
).encode("utf-8")
|
||||||
filename = f"{project_name}_coco.json"
|
filename = f"{project_name}_coco.json"
|
||||||
return content, filename, "application/json"
|
return content, filename, "application/json"
|
||||||
|
|
||||||
@@ -510,7 +605,9 @@ class AnnotationExportService:
|
|||||||
x_center = x + w / 2
|
x_center = x + w / 2
|
||||||
y_center = y + h / 2
|
y_center = y + h / 2
|
||||||
|
|
||||||
lines.append(f"{label} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}")
|
lines.append(
|
||||||
|
f"{label} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}"
|
||||||
|
)
|
||||||
|
|
||||||
if lines:
|
if lines:
|
||||||
# 生成对应的 txt 文件名
|
# 生成对应的 txt 文件名
|
||||||
|
|||||||
@@ -43,7 +43,9 @@ class KnowledgeSyncService:
|
|||||||
logger.warning("标注同步失败:无法获取知识集")
|
logger.warning("标注同步失败:无法获取知识集")
|
||||||
return
|
return
|
||||||
|
|
||||||
item = await self._get_item_by_source(set_id, project.dataset_id, str(file_record.id))
|
item = await self._get_item_by_source(
|
||||||
|
set_id, project.dataset_id, str(file_record.id)
|
||||||
|
)
|
||||||
if item and item.get("status") in {"PUBLISHED", "ARCHIVED", "DEPRECATED"}:
|
if item and item.get("status") in {"PUBLISHED", "ARCHIVED", "DEPRECATED"}:
|
||||||
logger.info(
|
logger.info(
|
||||||
"知识条目为只读状态,跳过同步:item_id=%s status=%s",
|
"知识条目为只读状态,跳过同步:item_id=%s status=%s",
|
||||||
@@ -71,26 +73,46 @@ class KnowledgeSyncService:
|
|||||||
logger.warning("标注同步到知识管理失败:%s", exc)
|
logger.warning("标注同步到知识管理失败:%s", exc)
|
||||||
|
|
||||||
async def _ensure_knowledge_set(self, project: LabelingProject) -> Optional[str]:
|
async def _ensure_knowledge_set(self, project: LabelingProject) -> Optional[str]:
|
||||||
config = project.configuration if isinstance(project.configuration, dict) else {}
|
result = await self.db.execute(
|
||||||
|
select(LabelingProject)
|
||||||
|
.where(LabelingProject.id == project.id)
|
||||||
|
.with_for_update()
|
||||||
|
)
|
||||||
|
locked_project = result.scalar_one_or_none()
|
||||||
|
if not locked_project:
|
||||||
|
logger.warning("标注同步失败:无法锁定项目:project_id=%s", project.id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
config = (
|
||||||
|
locked_project.configuration
|
||||||
|
if isinstance(locked_project.configuration, dict)
|
||||||
|
else {}
|
||||||
|
)
|
||||||
set_id = config.get(self.CONFIG_KEY_SET_ID)
|
set_id = config.get(self.CONFIG_KEY_SET_ID)
|
||||||
|
|
||||||
if set_id:
|
if set_id:
|
||||||
exists = await self._get_knowledge_set(set_id)
|
exists = await self._get_knowledge_set(set_id)
|
||||||
if exists and self._metadata_matches_project(exists.get("metadata"), project.id):
|
if exists and self._metadata_matches_project(
|
||||||
|
exists.get("metadata"), locked_project.id
|
||||||
|
):
|
||||||
return set_id
|
return set_id
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"知识集不存在或归属不匹配,准备重建:set_id=%s project_id=%s",
|
"知识集不存在或归属不匹配,准备重建:set_id=%s project_id=%s",
|
||||||
set_id,
|
set_id,
|
||||||
project.id,
|
locked_project.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
project_name = (project.name or "annotation-project").strip() or "annotation-project"
|
project_name = (
|
||||||
metadata = self._build_set_metadata(project)
|
locked_project.name or "annotation-project"
|
||||||
|
).strip() or "annotation-project"
|
||||||
|
metadata = self._build_set_metadata(locked_project)
|
||||||
|
|
||||||
existing = await self._find_knowledge_set_by_name_and_project(project_name, project.id)
|
existing = await self._find_knowledge_set_by_name_and_project(
|
||||||
|
project_name, locked_project.id
|
||||||
|
)
|
||||||
if existing:
|
if existing:
|
||||||
await self._update_project_config(
|
await self._update_project_config(
|
||||||
project,
|
locked_project,
|
||||||
{
|
{
|
||||||
self.CONFIG_KEY_SET_ID: existing.get("id"),
|
self.CONFIG_KEY_SET_ID: existing.get("id"),
|
||||||
self.CONFIG_KEY_SET_NAME: existing.get("name"),
|
self.CONFIG_KEY_SET_NAME: existing.get("name"),
|
||||||
@@ -100,23 +122,31 @@ class KnowledgeSyncService:
|
|||||||
|
|
||||||
created = await self._create_knowledge_set(project_name, metadata)
|
created = await self._create_knowledge_set(project_name, metadata)
|
||||||
if not created:
|
if not created:
|
||||||
created = await self._find_knowledge_set_by_name_and_project(project_name, project.id)
|
created = await self._find_knowledge_set_by_name_and_project(
|
||||||
|
project_name, locked_project.id
|
||||||
|
)
|
||||||
|
|
||||||
if not created:
|
if not created:
|
||||||
fallback_name = self._build_fallback_set_name(project_name, project.id)
|
fallback_name = self._build_fallback_set_name(
|
||||||
existing = await self._find_knowledge_set_by_name_and_project(fallback_name, project.id)
|
project_name, locked_project.id
|
||||||
|
)
|
||||||
|
existing = await self._find_knowledge_set_by_name_and_project(
|
||||||
|
fallback_name, locked_project.id
|
||||||
|
)
|
||||||
if existing:
|
if existing:
|
||||||
created = existing
|
created = existing
|
||||||
else:
|
else:
|
||||||
created = await self._create_knowledge_set(fallback_name, metadata)
|
created = await self._create_knowledge_set(fallback_name, metadata)
|
||||||
if not created:
|
if not created:
|
||||||
created = await self._find_knowledge_set_by_name_and_project(fallback_name, project.id)
|
created = await self._find_knowledge_set_by_name_and_project(
|
||||||
|
fallback_name, locked_project.id
|
||||||
|
)
|
||||||
|
|
||||||
if not created:
|
if not created:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
await self._update_project_config(
|
await self._update_project_config(
|
||||||
project,
|
locked_project,
|
||||||
{
|
{
|
||||||
self.CONFIG_KEY_SET_ID: created.get("id"),
|
self.CONFIG_KEY_SET_ID: created.get("id"),
|
||||||
self.CONFIG_KEY_SET_NAME: created.get("name"),
|
self.CONFIG_KEY_SET_NAME: created.get("name"),
|
||||||
@@ -126,13 +156,17 @@ class KnowledgeSyncService:
|
|||||||
|
|
||||||
async def _get_knowledge_set(self, set_id: str) -> Optional[Dict[str, Any]]:
|
async def _get_knowledge_set(self, set_id: str) -> Optional[Dict[str, Any]]:
|
||||||
try:
|
try:
|
||||||
return await self._request("GET", f"/data-management/knowledge-sets/{set_id}")
|
return await self._request(
|
||||||
|
"GET", f"/data-management/knowledge-sets/{set_id}"
|
||||||
|
)
|
||||||
except httpx.HTTPStatusError as exc:
|
except httpx.HTTPStatusError as exc:
|
||||||
if exc.response.status_code == 404:
|
if exc.response.status_code == 404:
|
||||||
return None
|
return None
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _list_knowledge_sets(self, keyword: Optional[str]) -> list[Dict[str, Any]]:
|
async def _list_knowledge_sets(
|
||||||
|
self, keyword: Optional[str]
|
||||||
|
) -> list[Dict[str, Any]]:
|
||||||
params: Dict[str, Any] = {
|
params: Dict[str, Any] = {
|
||||||
"page": 1,
|
"page": 1,
|
||||||
"size": self.KNOWLEDGE_SET_LIST_SIZE,
|
"size": self.KNOWLEDGE_SET_LIST_SIZE,
|
||||||
@@ -140,7 +174,9 @@ class KnowledgeSyncService:
|
|||||||
if keyword:
|
if keyword:
|
||||||
params["keyword"] = keyword
|
params["keyword"] = keyword
|
||||||
try:
|
try:
|
||||||
data = await self._request("GET", "/data-management/knowledge-sets", params=params)
|
data = await self._request(
|
||||||
|
"GET", "/data-management/knowledge-sets", params=params
|
||||||
|
)
|
||||||
except httpx.HTTPStatusError as exc:
|
except httpx.HTTPStatusError as exc:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"查询知识集失败:keyword=%s status=%s",
|
"查询知识集失败:keyword=%s status=%s",
|
||||||
@@ -155,7 +191,9 @@ class KnowledgeSyncService:
|
|||||||
return []
|
return []
|
||||||
return [item for item in content if isinstance(item, dict)]
|
return [item for item in content if isinstance(item, dict)]
|
||||||
|
|
||||||
async def _find_knowledge_set_by_name_and_project(self, name: str, project_id: str) -> Optional[Dict[str, Any]]:
|
async def _find_knowledge_set_by_name_and_project(
|
||||||
|
self, name: str, project_id: str
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
if not name:
|
if not name:
|
||||||
return None
|
return None
|
||||||
items = await self._list_knowledge_sets(name)
|
items = await self._list_knowledge_sets(name)
|
||||||
@@ -168,7 +206,9 @@ class KnowledgeSyncService:
|
|||||||
return item
|
return item
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _create_knowledge_set(self, name: str, metadata: str) -> Optional[Dict[str, Any]]:
|
async def _create_knowledge_set(
|
||||||
|
self, name: str, metadata: str
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
payload = {
|
payload = {
|
||||||
"name": name,
|
"name": name,
|
||||||
"description": "标注项目自动创建的知识集",
|
"description": "标注项目自动创建的知识集",
|
||||||
@@ -176,7 +216,9 @@ class KnowledgeSyncService:
|
|||||||
"metadata": metadata,
|
"metadata": metadata,
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
return await self._request("POST", "/data-management/knowledge-sets", json=payload)
|
return await self._request(
|
||||||
|
"POST", "/data-management/knowledge-sets", json=payload
|
||||||
|
)
|
||||||
except httpx.HTTPStatusError as exc:
|
except httpx.HTTPStatusError as exc:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"创建知识集失败:name=%s status=%s detail=%s",
|
"创建知识集失败:name=%s status=%s detail=%s",
|
||||||
@@ -199,7 +241,9 @@ class KnowledgeSyncService:
|
|||||||
"sourceFileId": file_id,
|
"sourceFileId": file_id,
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
data = await self._request("GET", f"/data-management/knowledge-sets/{set_id}/items", params=params)
|
data = await self._request(
|
||||||
|
"GET", f"/data-management/knowledge-sets/{set_id}/items", params=params
|
||||||
|
)
|
||||||
except httpx.HTTPStatusError as exc:
|
except httpx.HTTPStatusError as exc:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"查询知识条目失败:set_id=%s status=%s",
|
"查询知识条目失败:set_id=%s status=%s",
|
||||||
@@ -216,9 +260,13 @@ class KnowledgeSyncService:
|
|||||||
return content[0]
|
return content[0]
|
||||||
|
|
||||||
async def _create_item(self, set_id: str, payload: Dict[str, Any]) -> None:
|
async def _create_item(self, set_id: str, payload: Dict[str, Any]) -> None:
|
||||||
await self._request("POST", f"/data-management/knowledge-sets/{set_id}/items", json=payload)
|
await self._request(
|
||||||
|
"POST", f"/data-management/knowledge-sets/{set_id}/items", json=payload
|
||||||
|
)
|
||||||
|
|
||||||
async def _update_item(self, set_id: str, item_id: str, payload: Dict[str, Any]) -> None:
|
async def _update_item(
|
||||||
|
self, set_id: str, item_id: str, payload: Dict[str, Any]
|
||||||
|
) -> None:
|
||||||
update_payload = dict(payload)
|
update_payload = dict(payload)
|
||||||
update_payload.pop("sourceDatasetId", None)
|
update_payload.pop("sourceDatasetId", None)
|
||||||
update_payload.pop("sourceFileId", None)
|
update_payload.pop("sourceFileId", None)
|
||||||
@@ -228,6 +276,62 @@ class KnowledgeSyncService:
|
|||||||
json=update_payload,
|
json=update_payload,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _cleanup_knowledge_set_for_project(self, project_id: str) -> None:
|
||||||
|
"""清理项目关联的知识集及其所有知识条目"""
|
||||||
|
items = await self._list_knowledge_sets(None)
|
||||||
|
for item in items:
|
||||||
|
if self._metadata_matches_project(item.get("metadata"), project_id):
|
||||||
|
set_id = item.get("id")
|
||||||
|
if not set_id:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
await self._request(
|
||||||
|
"DELETE", f"/data-management/knowledge-sets/{set_id}"
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"已删除知识集:set_id=%s project_id=%s", set_id, project_id
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"删除知识集失败:set_id=%s project_id=%s error=%s",
|
||||||
|
set_id,
|
||||||
|
project_id,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _cleanup_knowledge_item_for_file(
|
||||||
|
self, dataset_id: str, file_id: str
|
||||||
|
) -> None:
|
||||||
|
"""清理文件的知识条目"""
|
||||||
|
items = await self._list_knowledge_sets(None)
|
||||||
|
for set_item in items:
|
||||||
|
set_id = set_item.get("id")
|
||||||
|
if not set_id:
|
||||||
|
continue
|
||||||
|
item = await self._get_item_by_source(set_id, dataset_id, file_id)
|
||||||
|
if item and item.get("id"):
|
||||||
|
try:
|
||||||
|
await self._request(
|
||||||
|
"DELETE",
|
||||||
|
f"/data-management/knowledge-sets/{set_id}/items/{item['id']}",
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"已删除知识条目:item_id=%s set_id=%s dataset_id=%s file_id=%s",
|
||||||
|
item.get("id"),
|
||||||
|
set_id,
|
||||||
|
dataset_id,
|
||||||
|
file_id,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"删除知识条目失败:item_id=%s set_id=%s dataset_id=%s file_id=%s error=%s",
|
||||||
|
item.get("id"),
|
||||||
|
set_id,
|
||||||
|
dataset_id,
|
||||||
|
file_id,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
async def _build_item_payload(
|
async def _build_item_payload(
|
||||||
self,
|
self,
|
||||||
project: LabelingProject,
|
project: LabelingProject,
|
||||||
@@ -323,12 +427,28 @@ class KnowledgeSyncService:
|
|||||||
short_id = project_id.replace("-", "")[:8]
|
short_id = project_id.replace("-", "")[:8]
|
||||||
return f"{base_name}-annotation-{short_id}"
|
return f"{base_name}-annotation-{short_id}"
|
||||||
|
|
||||||
async def _update_project_config(self, project: LabelingProject, updates: Dict[str, Any]) -> None:
|
async def _update_project_config(
|
||||||
config = project.configuration if isinstance(project.configuration, dict) else {}
|
self, project: LabelingProject, updates: Dict[str, Any]
|
||||||
|
) -> None:
|
||||||
|
result = await self.db.execute(
|
||||||
|
select(LabelingProject)
|
||||||
|
.where(LabelingProject.id == project.id)
|
||||||
|
.with_for_update()
|
||||||
|
)
|
||||||
|
locked_project = result.scalar_one_or_none()
|
||||||
|
if not locked_project:
|
||||||
|
logger.warning("更新项目配置失败:无法锁定项目:project_id=%s", project.id)
|
||||||
|
return
|
||||||
|
|
||||||
|
config = (
|
||||||
|
locked_project.configuration
|
||||||
|
if isinstance(locked_project.configuration, dict)
|
||||||
|
else {}
|
||||||
|
)
|
||||||
config.update(updates)
|
config.update(updates)
|
||||||
project.configuration = config
|
locked_project.configuration = config
|
||||||
await self.db.commit()
|
await self.db.commit()
|
||||||
await self.db.refresh(project)
|
await self.db.refresh(locked_project)
|
||||||
|
|
||||||
async def _request(self, method: str, path: str, **kwargs) -> Any:
|
async def _request(self, method: str, path: str, **kwargs) -> Any:
|
||||||
url = f"{self.base_url}{path}"
|
url = f"{self.base_url}{path}"
|
||||||
|
|||||||
Reference in New Issue
Block a user