Files
DataMate/runtime/datamate-python/app/module/annotation/service/export.py
Jerry Yan c5c8e6c69e feat(annotation): 添加分段标注功能支持
- 定义分段标注相关常量(segmented、segments、result等键名)
- 实现分段标注提取方法_extract_segment_annotations处理字典和列表格式
- 添加分段标注判断方法_is_segmented_annotation检测标注状态
- 修改_has_annotation_result方法使用新的分段标注处理逻辑
- 在任务创建过程中集成分段标注数据处理
- 更新导出服务中的分段标注结果扁平化处理
- 实现标注归一化方法支持分段标注格式转换
- 调整JSON和CSV导出格式适配分段标注结构
2026-01-31 14:36:16 +08:00

555 lines
22 KiB
Python

"""
标注数据导出服务
支持的导出格式:
- JSON: Label Studio 原生 JSON 格式
- JSONL: JSON Lines 格式(每行一条记录)
- CSV: CSV 表格格式
- COCO: COCO 目标检测格式(适用于图像标注)
- YOLO: YOLO 格式(适用于图像标注)
"""
from __future__ import annotations
import csv
import io
import json
import os
import tempfile
import zipfile
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from fastapi import HTTPException
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.models import AnnotationResult, Dataset, DatasetFiles, LabelingProject, LabelingProjectFile
async def _read_file_content(file_path: str, max_size: int = 10 * 1024 * 1024) -> Optional[str]:
"""读取文件内容,仅适用于文本文件
Args:
file_path: 文件路径
max_size: 最大读取字节数(默认10MB)
Returns:
文件内容字符串,如果读取失败返回 None
"""
try:
# 检查文件是否存在且大小在限制内
if not os.path.exists(file_path):
return None
file_size = os.path.getsize(file_path)
if file_size > max_size:
return f"[File too large: {file_size} bytes]"
# 尝试以文本方式读取
with open(file_path, 'r', encoding='utf-8', errors='replace') as f:
return f.read()
except Exception:
return None
from ..schema.export import (
AnnotationExportItem,
COCOExportFormat,
ExportAnnotationsRequest,
ExportAnnotationsResponse,
ExportFormat,
)
logger = get_logger(__name__)
SEGMENTED_KEY = "segmented"
SEGMENTS_KEY = "segments"
SEGMENT_RESULT_KEY = "result"
SEGMENT_INDEX_KEY = "segmentIndex"
SEGMENT_INDEX_FALLBACK_KEY = "segment_index"
class AnnotationExportService:
"""标注数据导出服务"""
def __init__(self, db: AsyncSession):
self.db = db
async def get_export_stats(self, project_id: str) -> ExportAnnotationsResponse:
"""获取导出统计信息"""
project = await self._get_project_or_404(project_id)
logger.info(f"Export stats for project: id={project_id}, dataset_id={project.dataset_id}, name={project.name}")
# 获取总文件数(标注项目快照内的文件)
total_result = await self.db.execute(
select(func.count())
.select_from(LabelingProjectFile)
.join(DatasetFiles, LabelingProjectFile.file_id == DatasetFiles.id)
.where(
LabelingProjectFile.project_id == project_id,
DatasetFiles.dataset_id == project.dataset_id,
)
)
total_files = int(total_result.scalar() or 0)
logger.info(f"Total files (snapshot): {total_files} for project_id={project_id}")
# 获取已标注文件数(统计不同的 file_id 数量)
annotated_result = await self.db.execute(
select(func.count(func.distinct(AnnotationResult.file_id))).where(
AnnotationResult.project_id == project_id
)
)
annotated_files = int(annotated_result.scalar() or 0)
logger.info(f"Annotated files: {annotated_files} for project_id={project_id}")
return ExportAnnotationsResponse(
project_id=project_id,
project_name=project.name,
total_files=total_files,
annotated_files=annotated_files,
export_format="json",
)
async def export_annotations(
self,
project_id: str,
request: ExportAnnotationsRequest,
) -> Tuple[bytes, str, str]:
"""
导出标注数据
返回: (文件内容bytes, 文件名, content_type)
"""
project = await self._get_project_or_404(project_id)
# 获取标注数据
items = await self._fetch_annotation_data(
project_id=project_id,
dataset_id=project.dataset_id,
only_annotated=request.only_annotated,
include_data=request.include_data,
)
# 根据格式导出
format_type = ExportFormat(request.format) if isinstance(request.format, str) else request.format
if format_type == ExportFormat.JSON:
return self._export_json(items, project.name)
elif format_type == ExportFormat.JSONL:
return self._export_jsonl(items, project.name)
elif format_type == ExportFormat.CSV:
return self._export_csv(items, project.name)
elif format_type == ExportFormat.COCO:
return self._export_coco(items, project.name)
elif format_type == ExportFormat.YOLO:
return self._export_yolo(items, project.name)
else:
raise HTTPException(status_code=400, detail=f"不支持的导出格式: {request.format}")
async def _get_project_or_404(self, project_id: str) -> LabelingProject:
"""获取标注项目,不存在则抛出 404"""
result = await self.db.execute(
select(LabelingProject).where(
LabelingProject.id == project_id,
LabelingProject.deleted_at.is_(None),
)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail=f"标注项目不存在: {project_id}")
return project
async def _fetch_annotation_data(
self,
project_id: str,
dataset_id: str,
only_annotated: bool = True,
include_data: bool = False,
) -> List[AnnotationExportItem]:
"""获取标注数据列表"""
items: List[AnnotationExportItem] = []
if only_annotated:
# 只获取已标注的数据
result = await self.db.execute(
select(AnnotationResult, DatasetFiles)
.join(LabelingProjectFile, LabelingProjectFile.file_id == AnnotationResult.file_id)
.join(DatasetFiles, AnnotationResult.file_id == DatasetFiles.id)
.where(
AnnotationResult.project_id == project_id,
LabelingProjectFile.project_id == project_id,
DatasetFiles.dataset_id == dataset_id,
)
.order_by(AnnotationResult.updated_at.desc())
)
rows = result.all()
for ann, file in rows:
annotation_data = ann.annotation or {}
# 获取文件内容(如果是文本文件且用户要求包含数据)
file_content = None
if include_data:
file_path = getattr(file, "file_path", "")
file_content = await _read_file_content(file_path)
items.append(
AnnotationExportItem(
file_id=str(file.id),
file_name=str(getattr(file, "file_name", "")),
data={"text": file_content} if include_data else None,
annotations=[annotation_data] if annotation_data else [],
created_at=ann.created_at,
updated_at=ann.updated_at,
)
)
else:
# 获取所有文件(基于标注项目快照)
files_result = await self.db.execute(
select(DatasetFiles)
.join(LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id)
.where(
LabelingProjectFile.project_id == project_id,
DatasetFiles.dataset_id == dataset_id,
)
)
files = files_result.scalars().all()
# 获取已有的标注
ann_result = await self.db.execute(
select(AnnotationResult).where(AnnotationResult.project_id == project_id)
)
annotations = {str(a.file_id): a for a in ann_result.scalars().all()}
for file in files:
file_id = str(file.id)
ann = annotations.get(file_id)
annotation_data = ann.annotation if ann else {}
# 获取文件内容(如果是文本文件且用户要求包含数据)
file_content = None
if include_data:
file_path = getattr(file, "file_path", "")
file_content = await _read_file_content(file_path)
items.append(
AnnotationExportItem(
file_id=file_id,
file_name=str(getattr(file, "file_name", "")),
data={"text": file_content} if include_data else None,
annotations=[annotation_data] if annotation_data else [],
created_at=ann.created_at if ann else None,
updated_at=ann.updated_at if ann else None,
)
)
return items
@staticmethod
def _flatten_annotation_results(annotation: Dict[str, Any]) -> List[Dict[str, Any]]:
if not annotation or not isinstance(annotation, dict):
return []
segments = annotation.get(SEGMENTS_KEY)
if annotation.get(SEGMENTED_KEY) or isinstance(segments, (dict, list)):
results: List[Dict[str, Any]] = []
if isinstance(segments, dict):
for key, segment in segments.items():
if not isinstance(segment, dict):
continue
segment_results = segment.get(SEGMENT_RESULT_KEY)
if not isinstance(segment_results, list):
continue
for item in segment_results:
if isinstance(item, dict):
normalized = dict(item)
if SEGMENT_INDEX_KEY not in normalized and SEGMENT_INDEX_FALLBACK_KEY not in normalized:
normalized[SEGMENT_INDEX_KEY] = int(key) if str(key).isdigit() else key
results.append(normalized)
else:
results.append({"value": item, SEGMENT_INDEX_KEY: key})
elif isinstance(segments, list):
for idx, segment in enumerate(segments):
if not isinstance(segment, dict):
continue
segment_results = segment.get(SEGMENT_RESULT_KEY)
if not isinstance(segment_results, list):
continue
segment_index = segment.get(SEGMENT_INDEX_KEY, segment.get(SEGMENT_INDEX_FALLBACK_KEY, idx))
for item in segment_results:
if isinstance(item, dict):
normalized = dict(item)
if SEGMENT_INDEX_KEY not in normalized and SEGMENT_INDEX_FALLBACK_KEY not in normalized:
normalized[SEGMENT_INDEX_KEY] = segment_index
results.append(normalized)
else:
results.append({"value": item, SEGMENT_INDEX_KEY: segment_index})
return results
result = annotation.get(SEGMENT_RESULT_KEY)
return result if isinstance(result, list) else []
@classmethod
def _normalize_annotation_for_export(cls, annotation: Dict[str, Any]) -> Dict[str, Any]:
if not annotation or not isinstance(annotation, dict):
return {}
segments = annotation.get(SEGMENTS_KEY)
if annotation.get(SEGMENTED_KEY) or isinstance(segments, (dict, list)):
normalized = dict(annotation)
normalized_result = cls._flatten_annotation_results(annotation)
if SEGMENT_RESULT_KEY not in normalized or not isinstance(normalized.get(SEGMENT_RESULT_KEY), list):
normalized[SEGMENT_RESULT_KEY] = normalized_result
return normalized
return annotation
def _export_json(
self, items: List[AnnotationExportItem], project_name: str
) -> Tuple[bytes, str, str]:
"""导出为 JSON 格式"""
export_data = {
"project_name": project_name,
"export_time": datetime.utcnow().isoformat() + "Z",
"total_items": len(items),
"annotations": [
{
"file_id": item.file_id,
"file_name": item.file_name,
"data": item.data,
"annotations": [self._normalize_annotation_for_export(ann) for ann in item.annotations],
"created_at": item.created_at.isoformat() if item.created_at else None,
"updated_at": item.updated_at.isoformat() if item.updated_at else None,
}
for item in items
],
}
content = json.dumps(export_data, ensure_ascii=False, indent=2).encode("utf-8")
filename = f"{project_name}_annotations.json"
return content, filename, "application/json"
def _export_jsonl(
self, items: List[AnnotationExportItem], project_name: str
) -> Tuple[bytes, str, str]:
"""导出为 JSON Lines 格式"""
lines = []
for item in items:
record = {
"file_id": item.file_id,
"file_name": item.file_name,
"data": item.data,
"annotations": [self._normalize_annotation_for_export(ann) for ann in item.annotations],
"created_at": item.created_at.isoformat() if item.created_at else None,
"updated_at": item.updated_at.isoformat() if item.updated_at else None,
}
lines.append(json.dumps(record, ensure_ascii=False))
content = "\n".join(lines).encode("utf-8")
filename = f"{project_name}_annotations.jsonl"
return content, filename, "application/x-ndjson"
def _export_csv(
self, items: List[AnnotationExportItem], project_name: str
) -> Tuple[bytes, str, str]:
"""导出为 CSV 格式"""
output = io.StringIO()
# CSV 表头
fieldnames = [
"file_id",
"file_name",
"annotation_result",
"labels",
"created_at",
"updated_at",
]
writer = csv.DictWriter(output, fieldnames=fieldnames)
writer.writeheader()
for item in items:
# 提取标签信息(支持多种标注类型)
labels = []
for ann in item.annotations:
results = self._flatten_annotation_results(ann)
for r in results:
value = r.get("value", {})
label_type = r.get("type", "")
# 提取不同类型的标签值
if "choices" in value:
labels.extend(value["choices"])
elif "text" in value:
labels.append(value["text"])
elif "labels" in value:
labels.extend(value["labels"])
elif "rectanglelabels" in value:
labels.extend(value["rectanglelabels"])
elif "polygonlabels" in value:
labels.extend(value["polygonlabels"])
elif "brushlabels" in value:
labels.extend(value["brushlabels"])
elif "hypertextlabels" in value:
labels.extend(value["hypertextlabels"])
elif "timeserieslabels" in value:
labels.extend(value["timeserieslabels"])
elif "transcription" in value:
labels.append(value["transcription"])
writer.writerow({
"file_id": item.file_id,
"file_name": item.file_name,
"annotation_result": json.dumps(item.annotations, ensure_ascii=False),
"labels": "|".join(labels),
"created_at": item.created_at.isoformat() if item.created_at else "",
"updated_at": item.updated_at.isoformat() if item.updated_at else "",
})
content = output.getvalue().encode("utf-8-sig") # BOM for Excel compatibility
filename = f"{project_name}_annotations.csv"
return content, filename, "text/csv"
def _export_coco(
self, items: List[AnnotationExportItem], project_name: str
) -> Tuple[bytes, str, str]:
"""导出为 COCO 格式(适用于目标检测标注)
注意:当前实现中图片宽高被设置为0,因为需要读取实际图片文件获取尺寸。
bbox 坐标使用 Label Studio 的百分比值(0-100),使用时需要转换为像素坐标。
"""
coco_format = COCOExportFormat(
info={
"description": f"Exported from DataMate project: {project_name}",
"version": "1.0",
"year": datetime.utcnow().year,
"date_created": datetime.utcnow().isoformat(),
},
licenses=[{"id": 1, "name": "Unknown", "url": ""}],
images=[],
annotations=[],
categories=[],
)
category_map: Dict[str, int] = {}
annotation_id = 1
for idx, item in enumerate(items):
image_id = idx + 1
# 添加图片信息
coco_format.images.append({
"id": image_id,
"file_name": item.file_name,
"width": 0, # 需要实际图片尺寸
"height": 0,
})
# 处理标注
for ann in item.annotations:
results = self._flatten_annotation_results(ann)
for r in results:
# 处理矩形框标注 (rectanglelabels)
if r.get("type") == "rectanglelabels":
value = r.get("value", {})
labels = value.get("rectanglelabels", [])
for label in labels:
if label not in category_map:
category_map[label] = len(category_map) + 1
coco_format.categories.append({
"id": category_map[label],
"name": label,
"supercategory": "",
})
# 转换坐标(Label Studio 使用百分比)
x = value.get("x", 0)
y = value.get("y", 0)
width = value.get("width", 0)
height = value.get("height", 0)
coco_format.annotations.append({
"id": annotation_id,
"image_id": image_id,
"category_id": category_map[label],
"bbox": [x, y, width, height],
"area": width * height,
"iscrowd": 0,
})
annotation_id += 1
content = json.dumps(coco_format.model_dump(), ensure_ascii=False, indent=2).encode("utf-8")
filename = f"{project_name}_coco.json"
return content, filename, "application/json"
def _export_yolo(
self, items: List[AnnotationExportItem], project_name: str
) -> Tuple[bytes, str, str]:
"""导出为 YOLO 格式(ZIP 包含 txt 标注文件和 classes.txt)"""
# 创建临时目录
tmp_fd, tmp_path = tempfile.mkstemp(suffix=".zip")
os.close(tmp_fd)
category_set: set = set()
txt_files: Dict[str, str] = {}
for item in items:
lines = []
for ann in item.annotations:
results = self._flatten_annotation_results(ann)
for r in results:
# 处理矩形框标注
if r.get("type") == "rectanglelabels":
value = r.get("value", {})
labels = value.get("rectanglelabels", [])
for label in labels:
category_set.add(label)
# YOLO 格式:class_id x_center y_center width height (归一化 0-1)
x = value.get("x", 0) / 100
y = value.get("y", 0) / 100
w = value.get("width", 0) / 100
h = value.get("height", 0) / 100
x_center = x + w / 2
y_center = y + h / 2
lines.append(f"{label} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}")
if lines:
# 生成对应的 txt 文件名
base_name = os.path.splitext(item.file_name)[0]
txt_files[f"{base_name}.txt"] = "\n".join(lines)
# 生成类别列表
categories = sorted(category_set)
category_map = {cat: idx for idx, cat in enumerate(categories)}
# 更新 txt 文件中的类别索引
for filename, content in txt_files.items():
updated_lines = []
for line in content.split("\n"):
parts = line.split(" ", 1)
if len(parts) == 2:
label, coords = parts
class_id = category_map.get(label, 0)
updated_lines.append(f"{class_id} {coords}")
txt_files[filename] = "\n".join(updated_lines)
# 创建 ZIP 文件
with zipfile.ZipFile(tmp_path, "w", zipfile.ZIP_DEFLATED) as zf:
# 写入 classes.txt
zf.writestr("classes.txt", "\n".join(categories))
# 写入标注文件
for filename, content in txt_files.items():
zf.writestr(f"labels/{filename}", content)
with open(tmp_path, "rb") as f:
content = f.read()
os.unlink(tmp_path)
filename = f"{project_name}_yolo.zip"
return content, filename, "application/zip"