fix: 修复知识库同步的并发控制、数据清理、文件事务和COCO导出问题

问题1 - 并发控制缺失:
- 在 _ensure_knowledge_set 方法中添加数据库行锁(with_for_update)
- 修改 _update_project_config 方法,使用行锁保护配置更新

问题3 - 数据清理机制缺失:
- 添加 _cleanup_knowledge_set_for_project 方法,项目删除时清理知识集
- 添加 _cleanup_knowledge_item_for_file 方法,文件删除时清理知识条目
- 在 delete_mapping 接口中调用清理方法

问题4 - 文件操作事务问题:
- 修改 uploadKnowledgeItems,添加事务失败后的文件清理逻辑
- 修改 deleteKnowledgeItem,删除记录前先删除关联文件
- 新增 deleteKnowledgeItemFile 辅助方法

问题5 - COCO导出格式问题:
- 添加 _get_image_dimensions 方法读取图片实际宽高
- 将百分比坐标转换为像素坐标
- 在 AnnotationExportItem 中添加 file_path 字段

涉及文件:
- knowledge_sync.py
- project.py
- KnowledgeItemApplicationService.java
- export.py
- export schema.py
This commit is contained in:
2026-02-05 03:55:01 +08:00
parent c03bdf1a24
commit 99bd83d312
5 changed files with 513 additions and 238 deletions

View File

@@ -21,20 +21,29 @@ from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from fastapi import HTTPException
from PIL import Image
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.models import AnnotationResult, Dataset, DatasetFiles, LabelingProject, LabelingProjectFile
from app.db.models import (
AnnotationResult,
Dataset,
DatasetFiles,
LabelingProject,
LabelingProjectFile,
)
async def _read_file_content(file_path: str, max_size: int = 10 * 1024 * 1024) -> Optional[str]:
async def _read_file_content(
file_path: str, max_size: int = 10 * 1024 * 1024
) -> Optional[str]:
"""读取文件内容,仅适用于文本文件
Args:
file_path: 文件路径
max_size: 最大读取字节数(默认10MB)
Returns:
文件内容字符串,如果读取失败返回 None
"""
@@ -42,17 +51,18 @@ async def _read_file_content(file_path: str, max_size: int = 10 * 1024 * 1024) -
# 检查文件是否存在且大小在限制内
if not os.path.exists(file_path):
return None
file_size = os.path.getsize(file_path)
if file_size > max_size:
return f"[File too large: {file_size} bytes]"
# 尝试以文本方式读取
with open(file_path, 'r', encoding='utf-8', errors='replace') as f:
with open(file_path, "r", encoding="utf-8", errors="replace") as f:
return f.read()
except Exception:
return None
from ..schema.export import (
AnnotationExportItem,
COCOExportFormat,
@@ -79,7 +89,9 @@ class AnnotationExportService:
async def get_export_stats(self, project_id: str) -> ExportAnnotationsResponse:
"""获取导出统计信息"""
project = await self._get_project_or_404(project_id)
logger.info(f"Export stats for project: id={project_id}, dataset_id={project.dataset_id}, name={project.name}")
logger.info(
f"Export stats for project: id={project_id}, dataset_id={project.dataset_id}, name={project.name}"
)
# 获取总文件数(标注项目快照内的文件)
total_result = await self.db.execute(
@@ -92,7 +104,9 @@ class AnnotationExportService:
)
)
total_files = int(total_result.scalar() or 0)
logger.info(f"Total files (snapshot): {total_files} for project_id={project_id}")
logger.info(
f"Total files (snapshot): {total_files} for project_id={project_id}"
)
# 获取已标注文件数(统计不同的 file_id 数量)
annotated_result = await self.db.execute(
@@ -132,7 +146,11 @@ class AnnotationExportService:
)
# 根据格式导出
format_type = ExportFormat(request.format) if isinstance(request.format, str) else request.format
format_type = (
ExportFormat(request.format)
if isinstance(request.format, str)
else request.format
)
if format_type == ExportFormat.JSON:
return self._export_json(items, project.name)
@@ -145,7 +163,9 @@ class AnnotationExportService:
elif format_type == ExportFormat.YOLO:
return self._export_yolo(items, project.name)
else:
raise HTTPException(status_code=400, detail=f"不支持的导出格式: {request.format}")
raise HTTPException(
status_code=400, detail=f"不支持的导出格式: {request.format}"
)
async def _get_project_or_404(self, project_id: str) -> LabelingProject:
"""获取标注项目,不存在则抛出 404"""
@@ -174,7 +194,10 @@ class AnnotationExportService:
# 只获取已标注的数据
result = await self.db.execute(
select(AnnotationResult, DatasetFiles)
.join(LabelingProjectFile, LabelingProjectFile.file_id == AnnotationResult.file_id)
.join(
LabelingProjectFile,
LabelingProjectFile.file_id == AnnotationResult.file_id,
)
.join(DatasetFiles, AnnotationResult.file_id == DatasetFiles.id)
.where(
AnnotationResult.project_id == project_id,
@@ -192,11 +215,12 @@ class AnnotationExportService:
if include_data:
file_path = getattr(file, "file_path", "")
file_content = await _read_file_content(file_path)
items.append(
AnnotationExportItem(
file_id=str(file.id),
file_name=str(getattr(file, "file_name", "")),
file_path=str(getattr(file, "file_path", "")),
data={"text": file_content} if include_data else None,
annotations=[annotation_data] if annotation_data else [],
created_at=ann.created_at,
@@ -207,7 +231,9 @@ class AnnotationExportService:
# 获取所有文件(基于标注项目快照)
files_result = await self.db.execute(
select(DatasetFiles)
.join(LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id)
.join(
LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id
)
.where(
LabelingProjectFile.project_id == project_id,
DatasetFiles.dataset_id == dataset_id,
@@ -217,7 +243,9 @@ class AnnotationExportService:
# 获取已有的标注
ann_result = await self.db.execute(
select(AnnotationResult).where(AnnotationResult.project_id == project_id)
select(AnnotationResult).where(
AnnotationResult.project_id == project_id
)
)
annotations = {str(a.file_id): a for a in ann_result.scalars().all()}
@@ -225,7 +253,7 @@ class AnnotationExportService:
file_id = str(file.id)
ann = annotations.get(file_id)
annotation_data = ann.annotation if ann else {}
# 获取文件内容(如果是文本文件且用户要求包含数据)
file_content = None
if include_data:
@@ -236,6 +264,7 @@ class AnnotationExportService:
AnnotationExportItem(
file_id=file_id,
file_name=str(getattr(file, "file_name", "")),
file_path=str(getattr(file, "file_path", "")),
data={"text": file_content} if include_data else None,
annotations=[annotation_data] if annotation_data else [],
created_at=ann.created_at if ann else None,
@@ -262,8 +291,13 @@ class AnnotationExportService:
for item in segment_results:
if isinstance(item, dict):
normalized = dict(item)
if SEGMENT_INDEX_KEY not in normalized and SEGMENT_INDEX_FALLBACK_KEY not in normalized:
normalized[SEGMENT_INDEX_KEY] = int(key) if str(key).isdigit() else key
if (
SEGMENT_INDEX_KEY not in normalized
and SEGMENT_INDEX_FALLBACK_KEY not in normalized
):
normalized[SEGMENT_INDEX_KEY] = (
int(key) if str(key).isdigit() else key
)
results.append(normalized)
elif isinstance(segments, list):
for idx, segment in enumerate(segments):
@@ -272,11 +306,16 @@ class AnnotationExportService:
segment_results = segment.get(SEGMENT_RESULT_KEY)
if not isinstance(segment_results, list):
continue
segment_index = segment.get(SEGMENT_INDEX_KEY, segment.get(SEGMENT_INDEX_FALLBACK_KEY, idx))
segment_index = segment.get(
SEGMENT_INDEX_KEY, segment.get(SEGMENT_INDEX_FALLBACK_KEY, idx)
)
for item in segment_results:
if isinstance(item, dict):
normalized = dict(item)
if SEGMENT_INDEX_KEY not in normalized and SEGMENT_INDEX_FALLBACK_KEY not in normalized:
if (
SEGMENT_INDEX_KEY not in normalized
and SEGMENT_INDEX_FALLBACK_KEY not in normalized
):
normalized[SEGMENT_INDEX_KEY] = segment_index
results.append(normalized)
return results
@@ -284,18 +323,43 @@ class AnnotationExportService:
return result if isinstance(result, list) else []
@classmethod
def _normalize_annotation_for_export(cls, annotation: Dict[str, Any]) -> Dict[str, Any]:
def _normalize_annotation_for_export(
cls, annotation: Dict[str, Any]
) -> Dict[str, Any]:
if not annotation or not isinstance(annotation, dict):
return {}
segments = annotation.get(SEGMENTS_KEY)
if annotation.get(SEGMENTED_KEY) or isinstance(segments, (dict, list)):
normalized = dict(annotation)
normalized_result = cls._flatten_annotation_results(annotation)
if SEGMENT_RESULT_KEY not in normalized or not isinstance(normalized.get(SEGMENT_RESULT_KEY), list):
if SEGMENT_RESULT_KEY not in normalized or not isinstance(
normalized.get(SEGMENT_RESULT_KEY), list
):
normalized[SEGMENT_RESULT_KEY] = normalized_result
return normalized
return annotation
@staticmethod
def _get_image_dimensions(file_path: str) -> Tuple[int, int]:
"""获取图片文件的宽度和高度
Args:
file_path: 图片文件路径
Returns:
(width, height) 元组,如果读取失败则返回 (1920, 1080) 作为默认值
"""
try:
if os.path.exists(file_path):
with Image.open(file_path) as img:
width, height = img.size
return width, height
except Exception as e:
logger.warning(f"Failed to read image dimensions from {file_path}: {e}")
# 使用合理的默认值
return 1920, 1080
def _export_json(
self, items: List[AnnotationExportItem], project_name: str
) -> Tuple[bytes, str, str]:
@@ -309,9 +373,16 @@ class AnnotationExportService:
"file_id": item.file_id,
"file_name": item.file_name,
"data": item.data,
"annotations": [self._normalize_annotation_for_export(ann) for ann in item.annotations],
"created_at": item.created_at.isoformat() if item.created_at else None,
"updated_at": item.updated_at.isoformat() if item.updated_at else None,
"annotations": [
self._normalize_annotation_for_export(ann)
for ann in item.annotations
],
"created_at": item.created_at.isoformat()
if item.created_at
else None,
"updated_at": item.updated_at.isoformat()
if item.updated_at
else None,
}
for item in items
],
@@ -331,7 +402,10 @@ class AnnotationExportService:
"file_id": item.file_id,
"file_name": item.file_name,
"data": item.data,
"annotations": [self._normalize_annotation_for_export(ann) for ann in item.annotations],
"annotations": [
self._normalize_annotation_for_export(ann)
for ann in item.annotations
],
"created_at": item.created_at.isoformat() if item.created_at else None,
"updated_at": item.updated_at.isoformat() if item.updated_at else None,
}
@@ -368,7 +442,7 @@ class AnnotationExportService:
for r in results:
value = r.get("value", {})
label_type = r.get("type", "")
# 提取不同类型的标签值
if "choices" in value:
labels.extend(value["choices"])
@@ -389,14 +463,22 @@ class AnnotationExportService:
elif "transcription" in value:
labels.append(value["transcription"])
writer.writerow({
"file_id": item.file_id,
"file_name": item.file_name,
"annotation_result": json.dumps(item.annotations, ensure_ascii=False),
"labels": "|".join(labels),
"created_at": item.created_at.isoformat() if item.created_at else "",
"updated_at": item.updated_at.isoformat() if item.updated_at else "",
})
writer.writerow(
{
"file_id": item.file_id,
"file_name": item.file_name,
"annotation_result": json.dumps(
item.annotations, ensure_ascii=False
),
"labels": "|".join(labels),
"created_at": item.created_at.isoformat()
if item.created_at
else "",
"updated_at": item.updated_at.isoformat()
if item.updated_at
else "",
}
)
content = output.getvalue().encode("utf-8-sig") # BOM for Excel compatibility
filename = f"{project_name}_annotations.csv"
@@ -405,11 +487,7 @@ class AnnotationExportService:
def _export_coco(
self, items: List[AnnotationExportItem], project_name: str
) -> Tuple[bytes, str, str]:
"""导出为 COCO 格式(适用于目标检测标注)
注意:当前实现中图片宽高被设置为0,因为需要读取实际图片文件获取尺寸。
bbox 坐标使用 Label Studio 的百分比值(0-100),使用时需要转换为像素坐标。
"""
"""导出为 COCO 格式(适用于目标检测标注)"""
coco_format = COCOExportFormat(
info={
"description": f"Exported from DataMate project: {project_name}",
@@ -429,13 +507,18 @@ class AnnotationExportService:
for idx, item in enumerate(items):
image_id = idx + 1
# 获取图片实际尺寸
img_width, img_height = self._get_image_dimensions(item.file_path or "")
# 添加图片信息
coco_format.images.append({
"id": image_id,
"file_name": item.file_name,
"width": 0, # 需要实际图片尺寸
"height": 0,
})
coco_format.images.append(
{
"id": image_id,
"file_name": item.file_name,
"width": img_width,
"height": img_height,
}
)
# 处理标注
for ann in item.annotations:
@@ -449,29 +532,41 @@ class AnnotationExportService:
for label in labels:
if label not in category_map:
category_map[label] = len(category_map) + 1
coco_format.categories.append({
"id": category_map[label],
"name": label,
"supercategory": "",
})
coco_format.categories.append(
{
"id": category_map[label],
"name": label,
"supercategory": "",
}
)
# 转换坐标Label Studio 使用百分比
x = value.get("x", 0)
y = value.get("y", 0)
width = value.get("width", 0)
height = value.get("height", 0)
# 转换坐标Label Studio 使用百分比(0-100)转换为像素坐标
x_percent = value.get("x", 0)
y_percent = value.get("y", 0)
width_percent = value.get("width", 0)
height_percent = value.get("height", 0)
coco_format.annotations.append({
"id": annotation_id,
"image_id": image_id,
"category_id": category_map[label],
"bbox": [x, y, width, height],
"area": width * height,
"iscrowd": 0,
})
# 转换为像素坐标
x = x_percent * img_width / 100.0
y = y_percent * img_height / 100.0
width = width_percent * img_width / 100.0
height = height_percent * img_height / 100.0
coco_format.annotations.append(
{
"id": annotation_id,
"image_id": image_id,
"category_id": category_map[label],
"bbox": [x, y, width, height],
"area": width * height,
"iscrowd": 0,
}
)
annotation_id += 1
content = json.dumps(coco_format.model_dump(), ensure_ascii=False, indent=2).encode("utf-8")
content = json.dumps(
coco_format.model_dump(), ensure_ascii=False, indent=2
).encode("utf-8")
filename = f"{project_name}_coco.json"
return content, filename, "application/json"
@@ -510,7 +605,9 @@ class AnnotationExportService:
x_center = x + w / 2
y_center = y + h / 2
lines.append(f"{label} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}")
lines.append(
f"{label} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}"
)
if lines:
# 生成对应的 txt 文件名