You've already forked DataMate
fix: 修复知识库同步的并发控制、数据清理、文件事务和COCO导出问题
问题1 - 并发控制缺失: - 在 _ensure_knowledge_set 方法中添加数据库行锁(with_for_update) - 修改 _update_project_config 方法,使用行锁保护配置更新 问题3 - 数据清理机制缺失: - 添加 _cleanup_knowledge_set_for_project 方法,项目删除时清理知识集 - 添加 _cleanup_knowledge_item_for_file 方法,文件删除时清理知识条目 - 在 delete_mapping 接口中调用清理方法 问题4 - 文件操作事务问题: - 修改 uploadKnowledgeItems,添加事务失败后的文件清理逻辑 - 修改 deleteKnowledgeItem,删除记录前先删除关联文件 - 新增 deleteKnowledgeItemFile 辅助方法 问题5 - COCO导出格式问题: - 添加 _get_image_dimensions 方法读取图片实际宽高 - 将百分比坐标转换为像素坐标 - 在 AnnotationExportItem 中添加 file_path 字段 涉及文件: - knowledge_sync.py - project.py - KnowledgeItemApplicationService.java - export.py - export schema.py
This commit is contained in:
@@ -21,20 +21,29 @@ from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from fastapi import HTTPException
|
||||
from PIL import Image
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.db.models import AnnotationResult, Dataset, DatasetFiles, LabelingProject, LabelingProjectFile
|
||||
from app.db.models import (
|
||||
AnnotationResult,
|
||||
Dataset,
|
||||
DatasetFiles,
|
||||
LabelingProject,
|
||||
LabelingProjectFile,
|
||||
)
|
||||
|
||||
|
||||
async def _read_file_content(file_path: str, max_size: int = 10 * 1024 * 1024) -> Optional[str]:
|
||||
async def _read_file_content(
|
||||
file_path: str, max_size: int = 10 * 1024 * 1024
|
||||
) -> Optional[str]:
|
||||
"""读取文件内容,仅适用于文本文件
|
||||
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
max_size: 最大读取字节数(默认10MB)
|
||||
|
||||
|
||||
Returns:
|
||||
文件内容字符串,如果读取失败返回 None
|
||||
"""
|
||||
@@ -42,17 +51,18 @@ async def _read_file_content(file_path: str, max_size: int = 10 * 1024 * 1024) -
|
||||
# 检查文件是否存在且大小在限制内
|
||||
if not os.path.exists(file_path):
|
||||
return None
|
||||
|
||||
|
||||
file_size = os.path.getsize(file_path)
|
||||
if file_size > max_size:
|
||||
return f"[File too large: {file_size} bytes]"
|
||||
|
||||
|
||||
# 尝试以文本方式读取
|
||||
with open(file_path, 'r', encoding='utf-8', errors='replace') as f:
|
||||
with open(file_path, "r", encoding="utf-8", errors="replace") as f:
|
||||
return f.read()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
from ..schema.export import (
|
||||
AnnotationExportItem,
|
||||
COCOExportFormat,
|
||||
@@ -79,7 +89,9 @@ class AnnotationExportService:
|
||||
async def get_export_stats(self, project_id: str) -> ExportAnnotationsResponse:
|
||||
"""获取导出统计信息"""
|
||||
project = await self._get_project_or_404(project_id)
|
||||
logger.info(f"Export stats for project: id={project_id}, dataset_id={project.dataset_id}, name={project.name}")
|
||||
logger.info(
|
||||
f"Export stats for project: id={project_id}, dataset_id={project.dataset_id}, name={project.name}"
|
||||
)
|
||||
|
||||
# 获取总文件数(标注项目快照内的文件)
|
||||
total_result = await self.db.execute(
|
||||
@@ -92,7 +104,9 @@ class AnnotationExportService:
|
||||
)
|
||||
)
|
||||
total_files = int(total_result.scalar() or 0)
|
||||
logger.info(f"Total files (snapshot): {total_files} for project_id={project_id}")
|
||||
logger.info(
|
||||
f"Total files (snapshot): {total_files} for project_id={project_id}"
|
||||
)
|
||||
|
||||
# 获取已标注文件数(统计不同的 file_id 数量)
|
||||
annotated_result = await self.db.execute(
|
||||
@@ -132,7 +146,11 @@ class AnnotationExportService:
|
||||
)
|
||||
|
||||
# 根据格式导出
|
||||
format_type = ExportFormat(request.format) if isinstance(request.format, str) else request.format
|
||||
format_type = (
|
||||
ExportFormat(request.format)
|
||||
if isinstance(request.format, str)
|
||||
else request.format
|
||||
)
|
||||
|
||||
if format_type == ExportFormat.JSON:
|
||||
return self._export_json(items, project.name)
|
||||
@@ -145,7 +163,9 @@ class AnnotationExportService:
|
||||
elif format_type == ExportFormat.YOLO:
|
||||
return self._export_yolo(items, project.name)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"不支持的导出格式: {request.format}")
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"不支持的导出格式: {request.format}"
|
||||
)
|
||||
|
||||
async def _get_project_or_404(self, project_id: str) -> LabelingProject:
|
||||
"""获取标注项目,不存在则抛出 404"""
|
||||
@@ -174,7 +194,10 @@ class AnnotationExportService:
|
||||
# 只获取已标注的数据
|
||||
result = await self.db.execute(
|
||||
select(AnnotationResult, DatasetFiles)
|
||||
.join(LabelingProjectFile, LabelingProjectFile.file_id == AnnotationResult.file_id)
|
||||
.join(
|
||||
LabelingProjectFile,
|
||||
LabelingProjectFile.file_id == AnnotationResult.file_id,
|
||||
)
|
||||
.join(DatasetFiles, AnnotationResult.file_id == DatasetFiles.id)
|
||||
.where(
|
||||
AnnotationResult.project_id == project_id,
|
||||
@@ -192,11 +215,12 @@ class AnnotationExportService:
|
||||
if include_data:
|
||||
file_path = getattr(file, "file_path", "")
|
||||
file_content = await _read_file_content(file_path)
|
||||
|
||||
|
||||
items.append(
|
||||
AnnotationExportItem(
|
||||
file_id=str(file.id),
|
||||
file_name=str(getattr(file, "file_name", "")),
|
||||
file_path=str(getattr(file, "file_path", "")),
|
||||
data={"text": file_content} if include_data else None,
|
||||
annotations=[annotation_data] if annotation_data else [],
|
||||
created_at=ann.created_at,
|
||||
@@ -207,7 +231,9 @@ class AnnotationExportService:
|
||||
# 获取所有文件(基于标注项目快照)
|
||||
files_result = await self.db.execute(
|
||||
select(DatasetFiles)
|
||||
.join(LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id)
|
||||
.join(
|
||||
LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id
|
||||
)
|
||||
.where(
|
||||
LabelingProjectFile.project_id == project_id,
|
||||
DatasetFiles.dataset_id == dataset_id,
|
||||
@@ -217,7 +243,9 @@ class AnnotationExportService:
|
||||
|
||||
# 获取已有的标注
|
||||
ann_result = await self.db.execute(
|
||||
select(AnnotationResult).where(AnnotationResult.project_id == project_id)
|
||||
select(AnnotationResult).where(
|
||||
AnnotationResult.project_id == project_id
|
||||
)
|
||||
)
|
||||
annotations = {str(a.file_id): a for a in ann_result.scalars().all()}
|
||||
|
||||
@@ -225,7 +253,7 @@ class AnnotationExportService:
|
||||
file_id = str(file.id)
|
||||
ann = annotations.get(file_id)
|
||||
annotation_data = ann.annotation if ann else {}
|
||||
|
||||
|
||||
# 获取文件内容(如果是文本文件且用户要求包含数据)
|
||||
file_content = None
|
||||
if include_data:
|
||||
@@ -236,6 +264,7 @@ class AnnotationExportService:
|
||||
AnnotationExportItem(
|
||||
file_id=file_id,
|
||||
file_name=str(getattr(file, "file_name", "")),
|
||||
file_path=str(getattr(file, "file_path", "")),
|
||||
data={"text": file_content} if include_data else None,
|
||||
annotations=[annotation_data] if annotation_data else [],
|
||||
created_at=ann.created_at if ann else None,
|
||||
@@ -262,8 +291,13 @@ class AnnotationExportService:
|
||||
for item in segment_results:
|
||||
if isinstance(item, dict):
|
||||
normalized = dict(item)
|
||||
if SEGMENT_INDEX_KEY not in normalized and SEGMENT_INDEX_FALLBACK_KEY not in normalized:
|
||||
normalized[SEGMENT_INDEX_KEY] = int(key) if str(key).isdigit() else key
|
||||
if (
|
||||
SEGMENT_INDEX_KEY not in normalized
|
||||
and SEGMENT_INDEX_FALLBACK_KEY not in normalized
|
||||
):
|
||||
normalized[SEGMENT_INDEX_KEY] = (
|
||||
int(key) if str(key).isdigit() else key
|
||||
)
|
||||
results.append(normalized)
|
||||
elif isinstance(segments, list):
|
||||
for idx, segment in enumerate(segments):
|
||||
@@ -272,11 +306,16 @@ class AnnotationExportService:
|
||||
segment_results = segment.get(SEGMENT_RESULT_KEY)
|
||||
if not isinstance(segment_results, list):
|
||||
continue
|
||||
segment_index = segment.get(SEGMENT_INDEX_KEY, segment.get(SEGMENT_INDEX_FALLBACK_KEY, idx))
|
||||
segment_index = segment.get(
|
||||
SEGMENT_INDEX_KEY, segment.get(SEGMENT_INDEX_FALLBACK_KEY, idx)
|
||||
)
|
||||
for item in segment_results:
|
||||
if isinstance(item, dict):
|
||||
normalized = dict(item)
|
||||
if SEGMENT_INDEX_KEY not in normalized and SEGMENT_INDEX_FALLBACK_KEY not in normalized:
|
||||
if (
|
||||
SEGMENT_INDEX_KEY not in normalized
|
||||
and SEGMENT_INDEX_FALLBACK_KEY not in normalized
|
||||
):
|
||||
normalized[SEGMENT_INDEX_KEY] = segment_index
|
||||
results.append(normalized)
|
||||
return results
|
||||
@@ -284,18 +323,43 @@ class AnnotationExportService:
|
||||
return result if isinstance(result, list) else []
|
||||
|
||||
@classmethod
|
||||
def _normalize_annotation_for_export(cls, annotation: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _normalize_annotation_for_export(
|
||||
cls, annotation: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
if not annotation or not isinstance(annotation, dict):
|
||||
return {}
|
||||
segments = annotation.get(SEGMENTS_KEY)
|
||||
if annotation.get(SEGMENTED_KEY) or isinstance(segments, (dict, list)):
|
||||
normalized = dict(annotation)
|
||||
normalized_result = cls._flatten_annotation_results(annotation)
|
||||
if SEGMENT_RESULT_KEY not in normalized or not isinstance(normalized.get(SEGMENT_RESULT_KEY), list):
|
||||
if SEGMENT_RESULT_KEY not in normalized or not isinstance(
|
||||
normalized.get(SEGMENT_RESULT_KEY), list
|
||||
):
|
||||
normalized[SEGMENT_RESULT_KEY] = normalized_result
|
||||
return normalized
|
||||
return annotation
|
||||
|
||||
@staticmethod
|
||||
def _get_image_dimensions(file_path: str) -> Tuple[int, int]:
|
||||
"""获取图片文件的宽度和高度
|
||||
|
||||
Args:
|
||||
file_path: 图片文件路径
|
||||
|
||||
Returns:
|
||||
(width, height) 元组,如果读取失败则返回 (1920, 1080) 作为默认值
|
||||
"""
|
||||
try:
|
||||
if os.path.exists(file_path):
|
||||
with Image.open(file_path) as img:
|
||||
width, height = img.size
|
||||
return width, height
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read image dimensions from {file_path}: {e}")
|
||||
|
||||
# 使用合理的默认值
|
||||
return 1920, 1080
|
||||
|
||||
def _export_json(
|
||||
self, items: List[AnnotationExportItem], project_name: str
|
||||
) -> Tuple[bytes, str, str]:
|
||||
@@ -309,9 +373,16 @@ class AnnotationExportService:
|
||||
"file_id": item.file_id,
|
||||
"file_name": item.file_name,
|
||||
"data": item.data,
|
||||
"annotations": [self._normalize_annotation_for_export(ann) for ann in item.annotations],
|
||||
"created_at": item.created_at.isoformat() if item.created_at else None,
|
||||
"updated_at": item.updated_at.isoformat() if item.updated_at else None,
|
||||
"annotations": [
|
||||
self._normalize_annotation_for_export(ann)
|
||||
for ann in item.annotations
|
||||
],
|
||||
"created_at": item.created_at.isoformat()
|
||||
if item.created_at
|
||||
else None,
|
||||
"updated_at": item.updated_at.isoformat()
|
||||
if item.updated_at
|
||||
else None,
|
||||
}
|
||||
for item in items
|
||||
],
|
||||
@@ -331,7 +402,10 @@ class AnnotationExportService:
|
||||
"file_id": item.file_id,
|
||||
"file_name": item.file_name,
|
||||
"data": item.data,
|
||||
"annotations": [self._normalize_annotation_for_export(ann) for ann in item.annotations],
|
||||
"annotations": [
|
||||
self._normalize_annotation_for_export(ann)
|
||||
for ann in item.annotations
|
||||
],
|
||||
"created_at": item.created_at.isoformat() if item.created_at else None,
|
||||
"updated_at": item.updated_at.isoformat() if item.updated_at else None,
|
||||
}
|
||||
@@ -368,7 +442,7 @@ class AnnotationExportService:
|
||||
for r in results:
|
||||
value = r.get("value", {})
|
||||
label_type = r.get("type", "")
|
||||
|
||||
|
||||
# 提取不同类型的标签值
|
||||
if "choices" in value:
|
||||
labels.extend(value["choices"])
|
||||
@@ -389,14 +463,22 @@ class AnnotationExportService:
|
||||
elif "transcription" in value:
|
||||
labels.append(value["transcription"])
|
||||
|
||||
writer.writerow({
|
||||
"file_id": item.file_id,
|
||||
"file_name": item.file_name,
|
||||
"annotation_result": json.dumps(item.annotations, ensure_ascii=False),
|
||||
"labels": "|".join(labels),
|
||||
"created_at": item.created_at.isoformat() if item.created_at else "",
|
||||
"updated_at": item.updated_at.isoformat() if item.updated_at else "",
|
||||
})
|
||||
writer.writerow(
|
||||
{
|
||||
"file_id": item.file_id,
|
||||
"file_name": item.file_name,
|
||||
"annotation_result": json.dumps(
|
||||
item.annotations, ensure_ascii=False
|
||||
),
|
||||
"labels": "|".join(labels),
|
||||
"created_at": item.created_at.isoformat()
|
||||
if item.created_at
|
||||
else "",
|
||||
"updated_at": item.updated_at.isoformat()
|
||||
if item.updated_at
|
||||
else "",
|
||||
}
|
||||
)
|
||||
|
||||
content = output.getvalue().encode("utf-8-sig") # BOM for Excel compatibility
|
||||
filename = f"{project_name}_annotations.csv"
|
||||
@@ -405,11 +487,7 @@ class AnnotationExportService:
|
||||
def _export_coco(
|
||||
self, items: List[AnnotationExportItem], project_name: str
|
||||
) -> Tuple[bytes, str, str]:
|
||||
"""导出为 COCO 格式(适用于目标检测标注)
|
||||
|
||||
注意:当前实现中图片宽高被设置为0,因为需要读取实际图片文件获取尺寸。
|
||||
bbox 坐标使用 Label Studio 的百分比值(0-100),使用时需要转换为像素坐标。
|
||||
"""
|
||||
"""导出为 COCO 格式(适用于目标检测标注)"""
|
||||
coco_format = COCOExportFormat(
|
||||
info={
|
||||
"description": f"Exported from DataMate project: {project_name}",
|
||||
@@ -429,13 +507,18 @@ class AnnotationExportService:
|
||||
for idx, item in enumerate(items):
|
||||
image_id = idx + 1
|
||||
|
||||
# 获取图片实际尺寸
|
||||
img_width, img_height = self._get_image_dimensions(item.file_path or "")
|
||||
|
||||
# 添加图片信息
|
||||
coco_format.images.append({
|
||||
"id": image_id,
|
||||
"file_name": item.file_name,
|
||||
"width": 0, # 需要实际图片尺寸
|
||||
"height": 0,
|
||||
})
|
||||
coco_format.images.append(
|
||||
{
|
||||
"id": image_id,
|
||||
"file_name": item.file_name,
|
||||
"width": img_width,
|
||||
"height": img_height,
|
||||
}
|
||||
)
|
||||
|
||||
# 处理标注
|
||||
for ann in item.annotations:
|
||||
@@ -449,29 +532,41 @@ class AnnotationExportService:
|
||||
for label in labels:
|
||||
if label not in category_map:
|
||||
category_map[label] = len(category_map) + 1
|
||||
coco_format.categories.append({
|
||||
"id": category_map[label],
|
||||
"name": label,
|
||||
"supercategory": "",
|
||||
})
|
||||
coco_format.categories.append(
|
||||
{
|
||||
"id": category_map[label],
|
||||
"name": label,
|
||||
"supercategory": "",
|
||||
}
|
||||
)
|
||||
|
||||
# 转换坐标(Label Studio 使用百分比)
|
||||
x = value.get("x", 0)
|
||||
y = value.get("y", 0)
|
||||
width = value.get("width", 0)
|
||||
height = value.get("height", 0)
|
||||
# 转换坐标:Label Studio 使用百分比(0-100)转换为像素坐标
|
||||
x_percent = value.get("x", 0)
|
||||
y_percent = value.get("y", 0)
|
||||
width_percent = value.get("width", 0)
|
||||
height_percent = value.get("height", 0)
|
||||
|
||||
coco_format.annotations.append({
|
||||
"id": annotation_id,
|
||||
"image_id": image_id,
|
||||
"category_id": category_map[label],
|
||||
"bbox": [x, y, width, height],
|
||||
"area": width * height,
|
||||
"iscrowd": 0,
|
||||
})
|
||||
# 转换为像素坐标
|
||||
x = x_percent * img_width / 100.0
|
||||
y = y_percent * img_height / 100.0
|
||||
width = width_percent * img_width / 100.0
|
||||
height = height_percent * img_height / 100.0
|
||||
|
||||
coco_format.annotations.append(
|
||||
{
|
||||
"id": annotation_id,
|
||||
"image_id": image_id,
|
||||
"category_id": category_map[label],
|
||||
"bbox": [x, y, width, height],
|
||||
"area": width * height,
|
||||
"iscrowd": 0,
|
||||
}
|
||||
)
|
||||
annotation_id += 1
|
||||
|
||||
content = json.dumps(coco_format.model_dump(), ensure_ascii=False, indent=2).encode("utf-8")
|
||||
content = json.dumps(
|
||||
coco_format.model_dump(), ensure_ascii=False, indent=2
|
||||
).encode("utf-8")
|
||||
filename = f"{project_name}_coco.json"
|
||||
return content, filename, "application/json"
|
||||
|
||||
@@ -510,7 +605,9 @@ class AnnotationExportService:
|
||||
x_center = x + w / 2
|
||||
y_center = y + h / 2
|
||||
|
||||
lines.append(f"{label} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}")
|
||||
lines.append(
|
||||
f"{label} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}"
|
||||
)
|
||||
|
||||
if lines:
|
||||
# 生成对应的 txt 文件名
|
||||
|
||||
@@ -43,7 +43,9 @@ class KnowledgeSyncService:
|
||||
logger.warning("标注同步失败:无法获取知识集")
|
||||
return
|
||||
|
||||
item = await self._get_item_by_source(set_id, project.dataset_id, str(file_record.id))
|
||||
item = await self._get_item_by_source(
|
||||
set_id, project.dataset_id, str(file_record.id)
|
||||
)
|
||||
if item and item.get("status") in {"PUBLISHED", "ARCHIVED", "DEPRECATED"}:
|
||||
logger.info(
|
||||
"知识条目为只读状态,跳过同步:item_id=%s status=%s",
|
||||
@@ -71,26 +73,46 @@ class KnowledgeSyncService:
|
||||
logger.warning("标注同步到知识管理失败:%s", exc)
|
||||
|
||||
async def _ensure_knowledge_set(self, project: LabelingProject) -> Optional[str]:
|
||||
config = project.configuration if isinstance(project.configuration, dict) else {}
|
||||
result = await self.db.execute(
|
||||
select(LabelingProject)
|
||||
.where(LabelingProject.id == project.id)
|
||||
.with_for_update()
|
||||
)
|
||||
locked_project = result.scalar_one_or_none()
|
||||
if not locked_project:
|
||||
logger.warning("标注同步失败:无法锁定项目:project_id=%s", project.id)
|
||||
return None
|
||||
|
||||
config = (
|
||||
locked_project.configuration
|
||||
if isinstance(locked_project.configuration, dict)
|
||||
else {}
|
||||
)
|
||||
set_id = config.get(self.CONFIG_KEY_SET_ID)
|
||||
|
||||
if set_id:
|
||||
exists = await self._get_knowledge_set(set_id)
|
||||
if exists and self._metadata_matches_project(exists.get("metadata"), project.id):
|
||||
if exists and self._metadata_matches_project(
|
||||
exists.get("metadata"), locked_project.id
|
||||
):
|
||||
return set_id
|
||||
logger.warning(
|
||||
"知识集不存在或归属不匹配,准备重建:set_id=%s project_id=%s",
|
||||
set_id,
|
||||
project.id,
|
||||
locked_project.id,
|
||||
)
|
||||
|
||||
project_name = (project.name or "annotation-project").strip() or "annotation-project"
|
||||
metadata = self._build_set_metadata(project)
|
||||
project_name = (
|
||||
locked_project.name or "annotation-project"
|
||||
).strip() or "annotation-project"
|
||||
metadata = self._build_set_metadata(locked_project)
|
||||
|
||||
existing = await self._find_knowledge_set_by_name_and_project(project_name, project.id)
|
||||
existing = await self._find_knowledge_set_by_name_and_project(
|
||||
project_name, locked_project.id
|
||||
)
|
||||
if existing:
|
||||
await self._update_project_config(
|
||||
project,
|
||||
locked_project,
|
||||
{
|
||||
self.CONFIG_KEY_SET_ID: existing.get("id"),
|
||||
self.CONFIG_KEY_SET_NAME: existing.get("name"),
|
||||
@@ -100,23 +122,31 @@ class KnowledgeSyncService:
|
||||
|
||||
created = await self._create_knowledge_set(project_name, metadata)
|
||||
if not created:
|
||||
created = await self._find_knowledge_set_by_name_and_project(project_name, project.id)
|
||||
created = await self._find_knowledge_set_by_name_and_project(
|
||||
project_name, locked_project.id
|
||||
)
|
||||
|
||||
if not created:
|
||||
fallback_name = self._build_fallback_set_name(project_name, project.id)
|
||||
existing = await self._find_knowledge_set_by_name_and_project(fallback_name, project.id)
|
||||
fallback_name = self._build_fallback_set_name(
|
||||
project_name, locked_project.id
|
||||
)
|
||||
existing = await self._find_knowledge_set_by_name_and_project(
|
||||
fallback_name, locked_project.id
|
||||
)
|
||||
if existing:
|
||||
created = existing
|
||||
else:
|
||||
created = await self._create_knowledge_set(fallback_name, metadata)
|
||||
if not created:
|
||||
created = await self._find_knowledge_set_by_name_and_project(fallback_name, project.id)
|
||||
created = await self._find_knowledge_set_by_name_and_project(
|
||||
fallback_name, locked_project.id
|
||||
)
|
||||
|
||||
if not created:
|
||||
return None
|
||||
|
||||
await self._update_project_config(
|
||||
project,
|
||||
locked_project,
|
||||
{
|
||||
self.CONFIG_KEY_SET_ID: created.get("id"),
|
||||
self.CONFIG_KEY_SET_NAME: created.get("name"),
|
||||
@@ -126,13 +156,17 @@ class KnowledgeSyncService:
|
||||
|
||||
async def _get_knowledge_set(self, set_id: str) -> Optional[Dict[str, Any]]:
|
||||
try:
|
||||
return await self._request("GET", f"/data-management/knowledge-sets/{set_id}")
|
||||
return await self._request(
|
||||
"GET", f"/data-management/knowledge-sets/{set_id}"
|
||||
)
|
||||
except httpx.HTTPStatusError as exc:
|
||||
if exc.response.status_code == 404:
|
||||
return None
|
||||
raise
|
||||
|
||||
async def _list_knowledge_sets(self, keyword: Optional[str]) -> list[Dict[str, Any]]:
|
||||
async def _list_knowledge_sets(
|
||||
self, keyword: Optional[str]
|
||||
) -> list[Dict[str, Any]]:
|
||||
params: Dict[str, Any] = {
|
||||
"page": 1,
|
||||
"size": self.KNOWLEDGE_SET_LIST_SIZE,
|
||||
@@ -140,7 +174,9 @@ class KnowledgeSyncService:
|
||||
if keyword:
|
||||
params["keyword"] = keyword
|
||||
try:
|
||||
data = await self._request("GET", "/data-management/knowledge-sets", params=params)
|
||||
data = await self._request(
|
||||
"GET", "/data-management/knowledge-sets", params=params
|
||||
)
|
||||
except httpx.HTTPStatusError as exc:
|
||||
logger.warning(
|
||||
"查询知识集失败:keyword=%s status=%s",
|
||||
@@ -155,7 +191,9 @@ class KnowledgeSyncService:
|
||||
return []
|
||||
return [item for item in content if isinstance(item, dict)]
|
||||
|
||||
async def _find_knowledge_set_by_name_and_project(self, name: str, project_id: str) -> Optional[Dict[str, Any]]:
|
||||
async def _find_knowledge_set_by_name_and_project(
|
||||
self, name: str, project_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
if not name:
|
||||
return None
|
||||
items = await self._list_knowledge_sets(name)
|
||||
@@ -168,7 +206,9 @@ class KnowledgeSyncService:
|
||||
return item
|
||||
return None
|
||||
|
||||
async def _create_knowledge_set(self, name: str, metadata: str) -> Optional[Dict[str, Any]]:
|
||||
async def _create_knowledge_set(
|
||||
self, name: str, metadata: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
payload = {
|
||||
"name": name,
|
||||
"description": "标注项目自动创建的知识集",
|
||||
@@ -176,7 +216,9 @@ class KnowledgeSyncService:
|
||||
"metadata": metadata,
|
||||
}
|
||||
try:
|
||||
return await self._request("POST", "/data-management/knowledge-sets", json=payload)
|
||||
return await self._request(
|
||||
"POST", "/data-management/knowledge-sets", json=payload
|
||||
)
|
||||
except httpx.HTTPStatusError as exc:
|
||||
logger.warning(
|
||||
"创建知识集失败:name=%s status=%s detail=%s",
|
||||
@@ -199,7 +241,9 @@ class KnowledgeSyncService:
|
||||
"sourceFileId": file_id,
|
||||
}
|
||||
try:
|
||||
data = await self._request("GET", f"/data-management/knowledge-sets/{set_id}/items", params=params)
|
||||
data = await self._request(
|
||||
"GET", f"/data-management/knowledge-sets/{set_id}/items", params=params
|
||||
)
|
||||
except httpx.HTTPStatusError as exc:
|
||||
logger.warning(
|
||||
"查询知识条目失败:set_id=%s status=%s",
|
||||
@@ -216,9 +260,13 @@ class KnowledgeSyncService:
|
||||
return content[0]
|
||||
|
||||
async def _create_item(self, set_id: str, payload: Dict[str, Any]) -> None:
|
||||
await self._request("POST", f"/data-management/knowledge-sets/{set_id}/items", json=payload)
|
||||
await self._request(
|
||||
"POST", f"/data-management/knowledge-sets/{set_id}/items", json=payload
|
||||
)
|
||||
|
||||
async def _update_item(self, set_id: str, item_id: str, payload: Dict[str, Any]) -> None:
|
||||
async def _update_item(
|
||||
self, set_id: str, item_id: str, payload: Dict[str, Any]
|
||||
) -> None:
|
||||
update_payload = dict(payload)
|
||||
update_payload.pop("sourceDatasetId", None)
|
||||
update_payload.pop("sourceFileId", None)
|
||||
@@ -228,6 +276,62 @@ class KnowledgeSyncService:
|
||||
json=update_payload,
|
||||
)
|
||||
|
||||
async def _cleanup_knowledge_set_for_project(self, project_id: str) -> None:
|
||||
"""清理项目关联的知识集及其所有知识条目"""
|
||||
items = await self._list_knowledge_sets(None)
|
||||
for item in items:
|
||||
if self._metadata_matches_project(item.get("metadata"), project_id):
|
||||
set_id = item.get("id")
|
||||
if not set_id:
|
||||
continue
|
||||
try:
|
||||
await self._request(
|
||||
"DELETE", f"/data-management/knowledge-sets/{set_id}"
|
||||
)
|
||||
logger.info(
|
||||
"已删除知识集:set_id=%s project_id=%s", set_id, project_id
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"删除知识集失败:set_id=%s project_id=%s error=%s",
|
||||
set_id,
|
||||
project_id,
|
||||
exc,
|
||||
)
|
||||
|
||||
async def _cleanup_knowledge_item_for_file(
|
||||
self, dataset_id: str, file_id: str
|
||||
) -> None:
|
||||
"""清理文件的知识条目"""
|
||||
items = await self._list_knowledge_sets(None)
|
||||
for set_item in items:
|
||||
set_id = set_item.get("id")
|
||||
if not set_id:
|
||||
continue
|
||||
item = await self._get_item_by_source(set_id, dataset_id, file_id)
|
||||
if item and item.get("id"):
|
||||
try:
|
||||
await self._request(
|
||||
"DELETE",
|
||||
f"/data-management/knowledge-sets/{set_id}/items/{item['id']}",
|
||||
)
|
||||
logger.info(
|
||||
"已删除知识条目:item_id=%s set_id=%s dataset_id=%s file_id=%s",
|
||||
item.get("id"),
|
||||
set_id,
|
||||
dataset_id,
|
||||
file_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"删除知识条目失败:item_id=%s set_id=%s dataset_id=%s file_id=%s error=%s",
|
||||
item.get("id"),
|
||||
set_id,
|
||||
dataset_id,
|
||||
file_id,
|
||||
exc,
|
||||
)
|
||||
|
||||
async def _build_item_payload(
|
||||
self,
|
||||
project: LabelingProject,
|
||||
@@ -323,12 +427,28 @@ class KnowledgeSyncService:
|
||||
short_id = project_id.replace("-", "")[:8]
|
||||
return f"{base_name}-annotation-{short_id}"
|
||||
|
||||
async def _update_project_config(self, project: LabelingProject, updates: Dict[str, Any]) -> None:
|
||||
config = project.configuration if isinstance(project.configuration, dict) else {}
|
||||
async def _update_project_config(
|
||||
self, project: LabelingProject, updates: Dict[str, Any]
|
||||
) -> None:
|
||||
result = await self.db.execute(
|
||||
select(LabelingProject)
|
||||
.where(LabelingProject.id == project.id)
|
||||
.with_for_update()
|
||||
)
|
||||
locked_project = result.scalar_one_or_none()
|
||||
if not locked_project:
|
||||
logger.warning("更新项目配置失败:无法锁定项目:project_id=%s", project.id)
|
||||
return
|
||||
|
||||
config = (
|
||||
locked_project.configuration
|
||||
if isinstance(locked_project.configuration, dict)
|
||||
else {}
|
||||
)
|
||||
config.update(updates)
|
||||
project.configuration = config
|
||||
locked_project.configuration = config
|
||||
await self.db.commit()
|
||||
await self.db.refresh(project)
|
||||
await self.db.refresh(locked_project)
|
||||
|
||||
async def _request(self, method: str, path: str, **kwargs) -> Any:
|
||||
url = f"{self.base_url}{path}"
|
||||
|
||||
Reference in New Issue
Block a user