You've already forked DataMate
551 lines
22 KiB
Python
551 lines
22 KiB
Python
"""
|
|
标注数据导出服务
|
|
|
|
支持的导出格式:
|
|
- JSON: Label Studio 原生 JSON 格式
|
|
- JSONL: JSON Lines 格式(每行一条记录)
|
|
- CSV: CSV 表格格式
|
|
- COCO: COCO 目标检测格式(适用于图像标注)
|
|
- YOLO: YOLO 格式(适用于图像标注)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import csv
|
|
import io
|
|
import json
|
|
import os
|
|
import tempfile
|
|
import zipfile
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
from fastapi import HTTPException
|
|
from sqlalchemy import func, select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.logging import get_logger
|
|
from app.db.models import AnnotationResult, Dataset, DatasetFiles, LabelingProject, LabelingProjectFile
|
|
|
|
|
|
async def _read_file_content(file_path: str, max_size: int = 10 * 1024 * 1024) -> Optional[str]:
|
|
"""读取文件内容,仅适用于文本文件
|
|
|
|
Args:
|
|
file_path: 文件路径
|
|
max_size: 最大读取字节数(默认10MB)
|
|
|
|
Returns:
|
|
文件内容字符串,如果读取失败返回 None
|
|
"""
|
|
try:
|
|
# 检查文件是否存在且大小在限制内
|
|
if not os.path.exists(file_path):
|
|
return None
|
|
|
|
file_size = os.path.getsize(file_path)
|
|
if file_size > max_size:
|
|
return f"[File too large: {file_size} bytes]"
|
|
|
|
# 尝试以文本方式读取
|
|
with open(file_path, 'r', encoding='utf-8', errors='replace') as f:
|
|
return f.read()
|
|
except Exception:
|
|
return None
|
|
|
|
from ..schema.export import (
|
|
AnnotationExportItem,
|
|
COCOExportFormat,
|
|
ExportAnnotationsRequest,
|
|
ExportAnnotationsResponse,
|
|
ExportFormat,
|
|
)
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
SEGMENTED_KEY = "segmented"
|
|
SEGMENTS_KEY = "segments"
|
|
SEGMENT_RESULT_KEY = "result"
|
|
SEGMENT_INDEX_KEY = "segmentIndex"
|
|
SEGMENT_INDEX_FALLBACK_KEY = "segment_index"
|
|
|
|
|
|
class AnnotationExportService:
|
|
"""标注数据导出服务"""
|
|
|
|
def __init__(self, db: AsyncSession):
|
|
self.db = db
|
|
|
|
async def get_export_stats(self, project_id: str) -> ExportAnnotationsResponse:
|
|
"""获取导出统计信息"""
|
|
project = await self._get_project_or_404(project_id)
|
|
logger.info(f"Export stats for project: id={project_id}, dataset_id={project.dataset_id}, name={project.name}")
|
|
|
|
# 获取总文件数(标注项目快照内的文件)
|
|
total_result = await self.db.execute(
|
|
select(func.count())
|
|
.select_from(LabelingProjectFile)
|
|
.join(DatasetFiles, LabelingProjectFile.file_id == DatasetFiles.id)
|
|
.where(
|
|
LabelingProjectFile.project_id == project_id,
|
|
DatasetFiles.dataset_id == project.dataset_id,
|
|
)
|
|
)
|
|
total_files = int(total_result.scalar() or 0)
|
|
logger.info(f"Total files (snapshot): {total_files} for project_id={project_id}")
|
|
|
|
# 获取已标注文件数(统计不同的 file_id 数量)
|
|
annotated_result = await self.db.execute(
|
|
select(func.count(func.distinct(AnnotationResult.file_id))).where(
|
|
AnnotationResult.project_id == project_id
|
|
)
|
|
)
|
|
annotated_files = int(annotated_result.scalar() or 0)
|
|
logger.info(f"Annotated files: {annotated_files} for project_id={project_id}")
|
|
|
|
return ExportAnnotationsResponse(
|
|
project_id=project_id,
|
|
project_name=project.name,
|
|
total_files=total_files,
|
|
annotated_files=annotated_files,
|
|
export_format="json",
|
|
)
|
|
|
|
async def export_annotations(
|
|
self,
|
|
project_id: str,
|
|
request: ExportAnnotationsRequest,
|
|
) -> Tuple[bytes, str, str]:
|
|
"""
|
|
导出标注数据
|
|
|
|
返回: (文件内容bytes, 文件名, content_type)
|
|
"""
|
|
project = await self._get_project_or_404(project_id)
|
|
|
|
# 获取标注数据
|
|
items = await self._fetch_annotation_data(
|
|
project_id=project_id,
|
|
dataset_id=project.dataset_id,
|
|
only_annotated=request.only_annotated,
|
|
include_data=request.include_data,
|
|
)
|
|
|
|
# 根据格式导出
|
|
format_type = ExportFormat(request.format) if isinstance(request.format, str) else request.format
|
|
|
|
if format_type == ExportFormat.JSON:
|
|
return self._export_json(items, project.name)
|
|
elif format_type == ExportFormat.JSONL:
|
|
return self._export_jsonl(items, project.name)
|
|
elif format_type == ExportFormat.CSV:
|
|
return self._export_csv(items, project.name)
|
|
elif format_type == ExportFormat.COCO:
|
|
return self._export_coco(items, project.name)
|
|
elif format_type == ExportFormat.YOLO:
|
|
return self._export_yolo(items, project.name)
|
|
else:
|
|
raise HTTPException(status_code=400, detail=f"不支持的导出格式: {request.format}")
|
|
|
|
async def _get_project_or_404(self, project_id: str) -> LabelingProject:
|
|
"""获取标注项目,不存在则抛出 404"""
|
|
result = await self.db.execute(
|
|
select(LabelingProject).where(
|
|
LabelingProject.id == project_id,
|
|
LabelingProject.deleted_at.is_(None),
|
|
)
|
|
)
|
|
project = result.scalar_one_or_none()
|
|
if not project:
|
|
raise HTTPException(status_code=404, detail=f"标注项目不存在: {project_id}")
|
|
return project
|
|
|
|
async def _fetch_annotation_data(
|
|
self,
|
|
project_id: str,
|
|
dataset_id: str,
|
|
only_annotated: bool = True,
|
|
include_data: bool = False,
|
|
) -> List[AnnotationExportItem]:
|
|
"""获取标注数据列表"""
|
|
items: List[AnnotationExportItem] = []
|
|
|
|
if only_annotated:
|
|
# 只获取已标注的数据
|
|
result = await self.db.execute(
|
|
select(AnnotationResult, DatasetFiles)
|
|
.join(LabelingProjectFile, LabelingProjectFile.file_id == AnnotationResult.file_id)
|
|
.join(DatasetFiles, AnnotationResult.file_id == DatasetFiles.id)
|
|
.where(
|
|
AnnotationResult.project_id == project_id,
|
|
LabelingProjectFile.project_id == project_id,
|
|
DatasetFiles.dataset_id == dataset_id,
|
|
)
|
|
.order_by(AnnotationResult.updated_at.desc())
|
|
)
|
|
rows = result.all()
|
|
|
|
for ann, file in rows:
|
|
annotation_data = ann.annotation or {}
|
|
# 获取文件内容(如果是文本文件且用户要求包含数据)
|
|
file_content = None
|
|
if include_data:
|
|
file_path = getattr(file, "file_path", "")
|
|
file_content = await _read_file_content(file_path)
|
|
|
|
items.append(
|
|
AnnotationExportItem(
|
|
file_id=str(file.id),
|
|
file_name=str(getattr(file, "file_name", "")),
|
|
data={"text": file_content} if include_data else None,
|
|
annotations=[annotation_data] if annotation_data else [],
|
|
created_at=ann.created_at,
|
|
updated_at=ann.updated_at,
|
|
)
|
|
)
|
|
else:
|
|
# 获取所有文件(基于标注项目快照)
|
|
files_result = await self.db.execute(
|
|
select(DatasetFiles)
|
|
.join(LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id)
|
|
.where(
|
|
LabelingProjectFile.project_id == project_id,
|
|
DatasetFiles.dataset_id == dataset_id,
|
|
)
|
|
)
|
|
files = files_result.scalars().all()
|
|
|
|
# 获取已有的标注
|
|
ann_result = await self.db.execute(
|
|
select(AnnotationResult).where(AnnotationResult.project_id == project_id)
|
|
)
|
|
annotations = {str(a.file_id): a for a in ann_result.scalars().all()}
|
|
|
|
for file in files:
|
|
file_id = str(file.id)
|
|
ann = annotations.get(file_id)
|
|
annotation_data = ann.annotation if ann else {}
|
|
|
|
# 获取文件内容(如果是文本文件且用户要求包含数据)
|
|
file_content = None
|
|
if include_data:
|
|
file_path = getattr(file, "file_path", "")
|
|
file_content = await _read_file_content(file_path)
|
|
|
|
items.append(
|
|
AnnotationExportItem(
|
|
file_id=file_id,
|
|
file_name=str(getattr(file, "file_name", "")),
|
|
data={"text": file_content} if include_data else None,
|
|
annotations=[annotation_data] if annotation_data else [],
|
|
created_at=ann.created_at if ann else None,
|
|
updated_at=ann.updated_at if ann else None,
|
|
)
|
|
)
|
|
|
|
return items
|
|
|
|
@staticmethod
|
|
def _flatten_annotation_results(annotation: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
if not annotation or not isinstance(annotation, dict):
|
|
return []
|
|
segments = annotation.get(SEGMENTS_KEY)
|
|
if annotation.get(SEGMENTED_KEY) or isinstance(segments, (dict, list)):
|
|
results: List[Dict[str, Any]] = []
|
|
if isinstance(segments, dict):
|
|
for key, segment in segments.items():
|
|
if not isinstance(segment, dict):
|
|
continue
|
|
segment_results = segment.get(SEGMENT_RESULT_KEY)
|
|
if not isinstance(segment_results, list):
|
|
continue
|
|
for item in segment_results:
|
|
if isinstance(item, dict):
|
|
normalized = dict(item)
|
|
if SEGMENT_INDEX_KEY not in normalized and SEGMENT_INDEX_FALLBACK_KEY not in normalized:
|
|
normalized[SEGMENT_INDEX_KEY] = int(key) if str(key).isdigit() else key
|
|
results.append(normalized)
|
|
elif isinstance(segments, list):
|
|
for idx, segment in enumerate(segments):
|
|
if not isinstance(segment, dict):
|
|
continue
|
|
segment_results = segment.get(SEGMENT_RESULT_KEY)
|
|
if not isinstance(segment_results, list):
|
|
continue
|
|
segment_index = segment.get(SEGMENT_INDEX_KEY, segment.get(SEGMENT_INDEX_FALLBACK_KEY, idx))
|
|
for item in segment_results:
|
|
if isinstance(item, dict):
|
|
normalized = dict(item)
|
|
if SEGMENT_INDEX_KEY not in normalized and SEGMENT_INDEX_FALLBACK_KEY not in normalized:
|
|
normalized[SEGMENT_INDEX_KEY] = segment_index
|
|
results.append(normalized)
|
|
return results
|
|
result = annotation.get(SEGMENT_RESULT_KEY)
|
|
return result if isinstance(result, list) else []
|
|
|
|
@classmethod
|
|
def _normalize_annotation_for_export(cls, annotation: Dict[str, Any]) -> Dict[str, Any]:
|
|
if not annotation or not isinstance(annotation, dict):
|
|
return {}
|
|
segments = annotation.get(SEGMENTS_KEY)
|
|
if annotation.get(SEGMENTED_KEY) or isinstance(segments, (dict, list)):
|
|
normalized = dict(annotation)
|
|
normalized_result = cls._flatten_annotation_results(annotation)
|
|
if SEGMENT_RESULT_KEY not in normalized or not isinstance(normalized.get(SEGMENT_RESULT_KEY), list):
|
|
normalized[SEGMENT_RESULT_KEY] = normalized_result
|
|
return normalized
|
|
return annotation
|
|
|
|
def _export_json(
|
|
self, items: List[AnnotationExportItem], project_name: str
|
|
) -> Tuple[bytes, str, str]:
|
|
"""导出为 JSON 格式"""
|
|
export_data = {
|
|
"project_name": project_name,
|
|
"export_time": datetime.utcnow().isoformat() + "Z",
|
|
"total_items": len(items),
|
|
"annotations": [
|
|
{
|
|
"file_id": item.file_id,
|
|
"file_name": item.file_name,
|
|
"data": item.data,
|
|
"annotations": [self._normalize_annotation_for_export(ann) for ann in item.annotations],
|
|
"created_at": item.created_at.isoformat() if item.created_at else None,
|
|
"updated_at": item.updated_at.isoformat() if item.updated_at else None,
|
|
}
|
|
for item in items
|
|
],
|
|
}
|
|
|
|
content = json.dumps(export_data, ensure_ascii=False, indent=2).encode("utf-8")
|
|
filename = f"{project_name}_annotations.json"
|
|
return content, filename, "application/json"
|
|
|
|
def _export_jsonl(
|
|
self, items: List[AnnotationExportItem], project_name: str
|
|
) -> Tuple[bytes, str, str]:
|
|
"""导出为 JSON Lines 格式"""
|
|
lines = []
|
|
for item in items:
|
|
record = {
|
|
"file_id": item.file_id,
|
|
"file_name": item.file_name,
|
|
"data": item.data,
|
|
"annotations": [self._normalize_annotation_for_export(ann) for ann in item.annotations],
|
|
"created_at": item.created_at.isoformat() if item.created_at else None,
|
|
"updated_at": item.updated_at.isoformat() if item.updated_at else None,
|
|
}
|
|
lines.append(json.dumps(record, ensure_ascii=False))
|
|
|
|
content = "\n".join(lines).encode("utf-8")
|
|
filename = f"{project_name}_annotations.jsonl"
|
|
return content, filename, "application/x-ndjson"
|
|
|
|
def _export_csv(
|
|
self, items: List[AnnotationExportItem], project_name: str
|
|
) -> Tuple[bytes, str, str]:
|
|
"""导出为 CSV 格式"""
|
|
output = io.StringIO()
|
|
|
|
# CSV 表头
|
|
fieldnames = [
|
|
"file_id",
|
|
"file_name",
|
|
"annotation_result",
|
|
"labels",
|
|
"created_at",
|
|
"updated_at",
|
|
]
|
|
|
|
writer = csv.DictWriter(output, fieldnames=fieldnames)
|
|
writer.writeheader()
|
|
|
|
for item in items:
|
|
# 提取标签信息(支持多种标注类型)
|
|
labels = []
|
|
for ann in item.annotations:
|
|
results = self._flatten_annotation_results(ann)
|
|
for r in results:
|
|
value = r.get("value", {})
|
|
label_type = r.get("type", "")
|
|
|
|
# 提取不同类型的标签值
|
|
if "choices" in value:
|
|
labels.extend(value["choices"])
|
|
elif "text" in value:
|
|
labels.append(value["text"])
|
|
elif "labels" in value:
|
|
labels.extend(value["labels"])
|
|
elif "rectanglelabels" in value:
|
|
labels.extend(value["rectanglelabels"])
|
|
elif "polygonlabels" in value:
|
|
labels.extend(value["polygonlabels"])
|
|
elif "brushlabels" in value:
|
|
labels.extend(value["brushlabels"])
|
|
elif "hypertextlabels" in value:
|
|
labels.extend(value["hypertextlabels"])
|
|
elif "timeserieslabels" in value:
|
|
labels.extend(value["timeserieslabels"])
|
|
elif "transcription" in value:
|
|
labels.append(value["transcription"])
|
|
|
|
writer.writerow({
|
|
"file_id": item.file_id,
|
|
"file_name": item.file_name,
|
|
"annotation_result": json.dumps(item.annotations, ensure_ascii=False),
|
|
"labels": "|".join(labels),
|
|
"created_at": item.created_at.isoformat() if item.created_at else "",
|
|
"updated_at": item.updated_at.isoformat() if item.updated_at else "",
|
|
})
|
|
|
|
content = output.getvalue().encode("utf-8-sig") # BOM for Excel compatibility
|
|
filename = f"{project_name}_annotations.csv"
|
|
return content, filename, "text/csv"
|
|
|
|
def _export_coco(
|
|
self, items: List[AnnotationExportItem], project_name: str
|
|
) -> Tuple[bytes, str, str]:
|
|
"""导出为 COCO 格式(适用于目标检测标注)
|
|
|
|
注意:当前实现中图片宽高被设置为0,因为需要读取实际图片文件获取尺寸。
|
|
bbox 坐标使用 Label Studio 的百分比值(0-100),使用时需要转换为像素坐标。
|
|
"""
|
|
coco_format = COCOExportFormat(
|
|
info={
|
|
"description": f"Exported from DataMate project: {project_name}",
|
|
"version": "1.0",
|
|
"year": datetime.utcnow().year,
|
|
"date_created": datetime.utcnow().isoformat(),
|
|
},
|
|
licenses=[{"id": 1, "name": "Unknown", "url": ""}],
|
|
images=[],
|
|
annotations=[],
|
|
categories=[],
|
|
)
|
|
|
|
category_map: Dict[str, int] = {}
|
|
annotation_id = 1
|
|
|
|
for idx, item in enumerate(items):
|
|
image_id = idx + 1
|
|
|
|
# 添加图片信息
|
|
coco_format.images.append({
|
|
"id": image_id,
|
|
"file_name": item.file_name,
|
|
"width": 0, # 需要实际图片尺寸
|
|
"height": 0,
|
|
})
|
|
|
|
# 处理标注
|
|
for ann in item.annotations:
|
|
results = self._flatten_annotation_results(ann)
|
|
for r in results:
|
|
# 处理矩形框标注 (rectanglelabels)
|
|
if r.get("type") == "rectanglelabels":
|
|
value = r.get("value", {})
|
|
labels = value.get("rectanglelabels", [])
|
|
|
|
for label in labels:
|
|
if label not in category_map:
|
|
category_map[label] = len(category_map) + 1
|
|
coco_format.categories.append({
|
|
"id": category_map[label],
|
|
"name": label,
|
|
"supercategory": "",
|
|
})
|
|
|
|
# 转换坐标(Label Studio 使用百分比)
|
|
x = value.get("x", 0)
|
|
y = value.get("y", 0)
|
|
width = value.get("width", 0)
|
|
height = value.get("height", 0)
|
|
|
|
coco_format.annotations.append({
|
|
"id": annotation_id,
|
|
"image_id": image_id,
|
|
"category_id": category_map[label],
|
|
"bbox": [x, y, width, height],
|
|
"area": width * height,
|
|
"iscrowd": 0,
|
|
})
|
|
annotation_id += 1
|
|
|
|
content = json.dumps(coco_format.model_dump(), ensure_ascii=False, indent=2).encode("utf-8")
|
|
filename = f"{project_name}_coco.json"
|
|
return content, filename, "application/json"
|
|
|
|
def _export_yolo(
|
|
self, items: List[AnnotationExportItem], project_name: str
|
|
) -> Tuple[bytes, str, str]:
|
|
"""导出为 YOLO 格式(ZIP 包含 txt 标注文件和 classes.txt)"""
|
|
|
|
# 创建临时目录
|
|
tmp_fd, tmp_path = tempfile.mkstemp(suffix=".zip")
|
|
os.close(tmp_fd)
|
|
|
|
category_set: set = set()
|
|
txt_files: Dict[str, str] = {}
|
|
|
|
for item in items:
|
|
lines = []
|
|
|
|
for ann in item.annotations:
|
|
results = self._flatten_annotation_results(ann)
|
|
for r in results:
|
|
# 处理矩形框标注
|
|
if r.get("type") == "rectanglelabels":
|
|
value = r.get("value", {})
|
|
labels = value.get("rectanglelabels", [])
|
|
|
|
for label in labels:
|
|
category_set.add(label)
|
|
|
|
# YOLO 格式:class_id x_center y_center width height (归一化 0-1)
|
|
x = value.get("x", 0) / 100
|
|
y = value.get("y", 0) / 100
|
|
w = value.get("width", 0) / 100
|
|
h = value.get("height", 0) / 100
|
|
|
|
x_center = x + w / 2
|
|
y_center = y + h / 2
|
|
|
|
lines.append(f"{label} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}")
|
|
|
|
if lines:
|
|
# 生成对应的 txt 文件名
|
|
base_name = os.path.splitext(item.file_name)[0]
|
|
txt_files[f"{base_name}.txt"] = "\n".join(lines)
|
|
|
|
# 生成类别列表
|
|
categories = sorted(category_set)
|
|
category_map = {cat: idx for idx, cat in enumerate(categories)}
|
|
|
|
# 更新 txt 文件中的类别索引
|
|
for filename, content in txt_files.items():
|
|
updated_lines = []
|
|
for line in content.split("\n"):
|
|
parts = line.split(" ", 1)
|
|
if len(parts) == 2:
|
|
label, coords = parts
|
|
class_id = category_map.get(label, 0)
|
|
updated_lines.append(f"{class_id} {coords}")
|
|
txt_files[filename] = "\n".join(updated_lines)
|
|
|
|
# 创建 ZIP 文件
|
|
with zipfile.ZipFile(tmp_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
|
# 写入 classes.txt
|
|
zf.writestr("classes.txt", "\n".join(categories))
|
|
|
|
# 写入标注文件
|
|
for filename, content in txt_files.items():
|
|
zf.writestr(f"labels/{filename}", content)
|
|
|
|
with open(tmp_path, "rb") as f:
|
|
content = f.read()
|
|
|
|
os.unlink(tmp_path)
|
|
|
|
filename = f"{project_name}_yolo.zip"
|
|
return content, filename, "application/zip"
|