You've already forked DataMate
feat(auto-annotation): unify annotation results with Label Studio format
Automatically convert auto-annotation outputs to Label Studio format and write to t_dm_annotation_results table, enabling seamless editing in the annotation editor.
New file:
- runtime/python-executor/datamate/annotation_result_converter.py
* 4 converters for different annotation types:
- convert_text_classification → choices type
- convert_ner → labels (span) type
- convert_relation_extraction → labels + relation type
- convert_object_detection → rectanglelabels type
* convert_annotation() dispatcher (auto-detects task_type)
* generate_label_config_xml() for dynamic XML generation
* Pipeline introspection utilities
* Label Studio ID generation logic
Modified file:
- runtime/python-executor/datamate/auto_annotation_worker.py
* Preserve file_id through processing loop (line 918)
* Collect file_results as (file_id, annotations) pairs
* New _create_labeling_project_with_annotations() function:
- Creates labeling project linked to source dataset
- Snapshots all files
- Converts results to Label Studio format
- Writes to t_dm_annotation_results in single transaction
* label_config XML stored in t_dm_labeling_projects.configuration
Key features:
- Supports 4 annotation types: text classification, NER, relation extraction, object detection
- Deterministic region IDs for entity references in relation extraction
- Pixel to percentage conversion for object detection
- XML escaping handled by xml.etree.ElementTree
- Partial results preserved on task stop
Users can now view and edit auto-annotation results seamlessly in the annotation editor.
This commit is contained in:
491
runtime/python-executor/datamate/annotation_result_converter.py
Normal file
491
runtime/python-executor/datamate/annotation_result_converter.py
Normal file
@@ -0,0 +1,491 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""将自动标注算子输出转换为 Label Studio 兼容格式。
|
||||
|
||||
支持的算子类型:
|
||||
- LLMTextClassification → choices
|
||||
- LLMNamedEntityRecognition → labels (span)
|
||||
- LLMRelationExtraction → labels + relation
|
||||
- ImageObjectDetectionBoundingBox → rectanglelabels
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import uuid
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from loguru import logger
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 颜色调色板(Label Studio 背景色)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_LABEL_COLORS = [
|
||||
"#e53935",
|
||||
"#fb8c00",
|
||||
"#43a047",
|
||||
"#1e88e5",
|
||||
"#8e24aa",
|
||||
"#00897b",
|
||||
"#d81b60",
|
||||
"#3949ab",
|
||||
"#fdd835",
|
||||
"#6d4c41",
|
||||
"#546e7a",
|
||||
"#f4511e",
|
||||
]
|
||||
|
||||
|
||||
def _pick_color(index: int) -> str:
|
||||
return _LABEL_COLORS[index % len(_LABEL_COLORS)]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 稳定 LS ID 生成(与 editor.py:118-136 一致)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _stable_ls_id(seed: str) -> int:
|
||||
"""生成稳定的 Label Studio 风格整数 ID(JS 安全整数范围内)。"""
|
||||
digest = hashlib.sha1(seed.encode("utf-8")).hexdigest()
|
||||
value = int(digest[:13], 16)
|
||||
return value if value > 0 else 1
|
||||
|
||||
|
||||
def make_ls_task_id(project_id: str, file_id: str) -> int:
|
||||
return _stable_ls_id(f"task:{project_id}:{file_id}")
|
||||
|
||||
|
||||
def make_ls_annotation_id(project_id: str, file_id: str) -> int:
|
||||
return _stable_ls_id(f"annotation:{project_id}:{file_id}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 确定性 region ID(关系抽取需要稳定引用)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_NAMESPACE_REGION = uuid.UUID("a1b2c3d4-e5f6-7890-abcd-ef1234567890")
|
||||
|
||||
|
||||
def _make_region_id(
|
||||
project_id: str, file_id: str, text: str, entity_type: str, start: Any, end: Any
|
||||
) -> str:
|
||||
seed = f"{project_id}:{file_id}:{text}:{entity_type}:{start}:{end}"
|
||||
return str(uuid.uuid5(_NAMESPACE_REGION, seed))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. 文本分类
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def convert_text_classification(
|
||||
annotation: Dict[str, Any], file_id: str, project_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""将 LLMTextClassification 算子输出转换为 LS annotation。"""
|
||||
|
||||
result_data = annotation.get("result")
|
||||
if not isinstance(result_data, dict):
|
||||
return None
|
||||
|
||||
label = result_data.get("label")
|
||||
if not label:
|
||||
return None
|
||||
|
||||
region_id = str(uuid.uuid4())
|
||||
ls_result = [
|
||||
{
|
||||
"id": region_id,
|
||||
"from_name": "sentiment",
|
||||
"to_name": "text",
|
||||
"type": "choices",
|
||||
"value": {"choices": [str(label)]},
|
||||
}
|
||||
]
|
||||
|
||||
return {
|
||||
"id": make_ls_annotation_id(project_id, file_id),
|
||||
"task": make_ls_task_id(project_id, file_id),
|
||||
"result": ls_result,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. 命名实体识别
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def convert_ner(
|
||||
annotation: Dict[str, Any], file_id: str, project_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""将 LLMNamedEntityRecognition 算子输出转换为 LS annotation。"""
|
||||
|
||||
entities = annotation.get("entities")
|
||||
if not isinstance(entities, list) or not entities:
|
||||
return None
|
||||
|
||||
ls_result: List[Dict[str, Any]] = []
|
||||
for ent in entities:
|
||||
if not isinstance(ent, dict):
|
||||
continue
|
||||
ent_text = ent.get("text", "")
|
||||
ent_type = ent.get("type", "")
|
||||
start = ent.get("start")
|
||||
end = ent.get("end")
|
||||
|
||||
if not ent_text or start is None or end is None:
|
||||
continue
|
||||
|
||||
region_id = _make_region_id(project_id, file_id, ent_text, ent_type, start, end)
|
||||
ls_result.append(
|
||||
{
|
||||
"id": region_id,
|
||||
"from_name": "label",
|
||||
"to_name": "text",
|
||||
"type": "labels",
|
||||
"value": {
|
||||
"start": int(start),
|
||||
"end": int(end),
|
||||
"text": str(ent_text),
|
||||
"labels": [str(ent_type)],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if not ls_result:
|
||||
return None
|
||||
|
||||
return {
|
||||
"id": make_ls_annotation_id(project_id, file_id),
|
||||
"task": make_ls_task_id(project_id, file_id),
|
||||
"result": ls_result,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. 关系抽取
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _find_entity_region_id(
|
||||
entity_regions: List[Dict[str, Any]], text: str, entity_type: str
|
||||
) -> Optional[str]:
|
||||
"""在已生成的 entity regions 中查找匹配的 region ID。"""
|
||||
for region in entity_regions:
|
||||
value = region.get("value", {})
|
||||
if value.get("text") == text and entity_type in value.get("labels", []):
|
||||
return region["id"]
|
||||
return None
|
||||
|
||||
|
||||
def convert_relation_extraction(
|
||||
annotation: Dict[str, Any], file_id: str, project_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""将 LLMRelationExtraction 算子输出转换为 LS annotation。"""
|
||||
|
||||
entities = annotation.get("entities", [])
|
||||
relations = annotation.get("relations", [])
|
||||
if not isinstance(entities, list):
|
||||
entities = []
|
||||
if not isinstance(relations, list):
|
||||
relations = []
|
||||
|
||||
if not entities and not relations:
|
||||
return None
|
||||
|
||||
# 构建实体 label regions(去重)
|
||||
entity_regions: List[Dict[str, Any]] = []
|
||||
seen_keys: set = set()
|
||||
|
||||
for ent in entities:
|
||||
if not isinstance(ent, dict):
|
||||
continue
|
||||
ent_text = str(ent.get("text", "")).strip()
|
||||
ent_type = str(ent.get("type", "")).strip()
|
||||
start = ent.get("start")
|
||||
end = ent.get("end")
|
||||
|
||||
if not ent_text or start is None or end is None:
|
||||
continue
|
||||
|
||||
dedup_key = (ent_text, ent_type, int(start), int(end))
|
||||
if dedup_key in seen_keys:
|
||||
continue
|
||||
seen_keys.add(dedup_key)
|
||||
|
||||
region_id = _make_region_id(project_id, file_id, ent_text, ent_type, start, end)
|
||||
entity_regions.append(
|
||||
{
|
||||
"id": region_id,
|
||||
"from_name": "label",
|
||||
"to_name": "text",
|
||||
"type": "labels",
|
||||
"value": {
|
||||
"start": int(start),
|
||||
"end": int(end),
|
||||
"text": ent_text,
|
||||
"labels": [ent_type],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# 构建关系
|
||||
relation_results: List[Dict[str, Any]] = []
|
||||
for rel in relations:
|
||||
if not isinstance(rel, dict):
|
||||
continue
|
||||
|
||||
subject = rel.get("subject")
|
||||
obj = rel.get("object")
|
||||
relation_type = str(rel.get("relation", "")).strip()
|
||||
|
||||
if not isinstance(subject, dict) or not isinstance(obj, dict) or not relation_type:
|
||||
continue
|
||||
|
||||
subj_text = str(subject.get("text", "")).strip()
|
||||
subj_type = str(subject.get("type", "")).strip()
|
||||
obj_text = str(obj.get("text", "")).strip()
|
||||
obj_type = str(obj.get("type", "")).strip()
|
||||
|
||||
from_id = _find_entity_region_id(entity_regions, subj_text, subj_type)
|
||||
to_id = _find_entity_region_id(entity_regions, obj_text, obj_type)
|
||||
|
||||
if not from_id or not to_id:
|
||||
logger.debug(
|
||||
"Skipping relation '{}': could not find region for subject='{}' or object='{}'",
|
||||
relation_type,
|
||||
subj_text,
|
||||
obj_text,
|
||||
)
|
||||
continue
|
||||
|
||||
relation_results.append(
|
||||
{
|
||||
"from_id": from_id,
|
||||
"to_id": to_id,
|
||||
"type": "relation",
|
||||
"direction": "right",
|
||||
"labels": [relation_type],
|
||||
}
|
||||
)
|
||||
|
||||
ls_result = entity_regions + relation_results
|
||||
if not ls_result:
|
||||
return None
|
||||
|
||||
return {
|
||||
"id": make_ls_annotation_id(project_id, file_id),
|
||||
"task": make_ls_task_id(project_id, file_id),
|
||||
"result": ls_result,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. 目标检测
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def convert_object_detection(
|
||||
annotation: Dict[str, Any], file_id: str, project_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""将 ImageObjectDetectionBoundingBox 算子输出转换为 LS annotation。"""
|
||||
|
||||
detections = annotation.get("detections")
|
||||
if not isinstance(detections, list) or not detections:
|
||||
return None
|
||||
|
||||
img_width = annotation.get("width", 0)
|
||||
img_height = annotation.get("height", 0)
|
||||
if not img_width or not img_height:
|
||||
return None
|
||||
|
||||
ls_result: List[Dict[str, Any]] = []
|
||||
for det in detections:
|
||||
if not isinstance(det, dict):
|
||||
continue
|
||||
|
||||
label = det.get("label", "unknown")
|
||||
bbox = det.get("bbox_xyxy")
|
||||
if not isinstance(bbox, (list, tuple)) or len(bbox) < 4:
|
||||
continue
|
||||
|
||||
x1, y1, x2, y2 = float(bbox[0]), float(bbox[1]), float(bbox[2]), float(bbox[3])
|
||||
x_pct = x1 * 100.0 / img_width
|
||||
y_pct = y1 * 100.0 / img_height
|
||||
w_pct = (x2 - x1) * 100.0 / img_width
|
||||
h_pct = (y2 - y1) * 100.0 / img_height
|
||||
|
||||
region_id = str(uuid.uuid4())
|
||||
ls_result.append(
|
||||
{
|
||||
"id": region_id,
|
||||
"from_name": "label",
|
||||
"to_name": "image",
|
||||
"type": "rectanglelabels",
|
||||
"original_width": int(img_width),
|
||||
"original_height": int(img_height),
|
||||
"image_rotation": 0,
|
||||
"value": {
|
||||
"x": round(x_pct, 4),
|
||||
"y": round(y_pct, 4),
|
||||
"width": round(w_pct, 4),
|
||||
"height": round(h_pct, 4),
|
||||
"rotation": 0,
|
||||
"rectanglelabels": [str(label)],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if not ls_result:
|
||||
return None
|
||||
|
||||
return {
|
||||
"id": make_ls_annotation_id(project_id, file_id),
|
||||
"task": make_ls_task_id(project_id, file_id),
|
||||
"result": ls_result,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 分发器
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
TASK_TYPE_CONVERTERS = {
|
||||
"text_classification": convert_text_classification,
|
||||
"ner": convert_ner,
|
||||
"relation_extraction": convert_relation_extraction,
|
||||
"object_detection": convert_object_detection,
|
||||
}
|
||||
|
||||
|
||||
def convert_annotation(
|
||||
annotation: Dict[str, Any], file_id: str, project_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""根据 task_type 分发到对应的转换函数。"""
|
||||
task_type = annotation.get("task_type")
|
||||
if task_type is None and "detections" in annotation:
|
||||
task_type = "object_detection"
|
||||
converter = TASK_TYPE_CONVERTERS.get(task_type) # type: ignore[arg-type]
|
||||
if converter is None:
|
||||
logger.warning("No LS converter for task_type: {}", task_type)
|
||||
return None
|
||||
try:
|
||||
return converter(annotation, file_id, project_id)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to convert annotation (task_type={}): {}", task_type, exc)
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# label_config XML 生成
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _split_labels(raw: str) -> List[str]:
|
||||
"""逗号分隔字符串 → 去空白列表。"""
|
||||
return [s.strip() for s in raw.split(",") if s.strip()]
|
||||
|
||||
|
||||
def generate_label_config_xml(task_type: str, operator_params: Dict[str, Any]) -> str:
|
||||
"""根据标注类型和算子参数生成 Label Studio label_config XML。"""
|
||||
|
||||
if task_type == "text_classification":
|
||||
return _gen_text_classification_xml(operator_params)
|
||||
if task_type == "ner":
|
||||
return _gen_ner_xml(operator_params)
|
||||
if task_type == "relation_extraction":
|
||||
return _gen_relation_extraction_xml(operator_params)
|
||||
if task_type == "object_detection":
|
||||
return _gen_object_detection_xml(operator_params)
|
||||
|
||||
# 未知类型:返回最小可用 XML
|
||||
return "<View><Header value=\"Unknown annotation type\"/></View>"
|
||||
|
||||
|
||||
def _gen_text_classification_xml(params: Dict[str, Any]) -> str:
|
||||
categories = _split_labels(str(params.get("categories", "正面,负面,中性")))
|
||||
|
||||
view = ET.Element("View")
|
||||
ET.SubElement(view, "Text", name="text", value="$text")
|
||||
choices = ET.SubElement(
|
||||
view,
|
||||
"Choices",
|
||||
name="sentiment",
|
||||
toName="text",
|
||||
choice="single",
|
||||
showInline="true",
|
||||
)
|
||||
for cat in categories:
|
||||
ET.SubElement(choices, "Choice", value=cat)
|
||||
|
||||
return ET.tostring(view, encoding="unicode")
|
||||
|
||||
|
||||
def _gen_ner_xml(params: Dict[str, Any]) -> str:
|
||||
entity_types = _split_labels(str(params.get("entityTypes", "PER,ORG,LOC,DATE")))
|
||||
|
||||
view = ET.Element("View")
|
||||
labels = ET.SubElement(view, "Labels", name="label", toName="text")
|
||||
for i, et in enumerate(entity_types):
|
||||
ET.SubElement(labels, "Label", value=et, background=_pick_color(i))
|
||||
ET.SubElement(view, "Text", name="text", value="$text")
|
||||
|
||||
return ET.tostring(view, encoding="unicode")
|
||||
|
||||
|
||||
def _gen_relation_extraction_xml(params: Dict[str, Any]) -> str:
|
||||
entity_types = _split_labels(str(params.get("entityTypes", "PER,ORG,LOC")))
|
||||
relation_types = _split_labels(
|
||||
str(params.get("relationTypes", "属于,位于,创立,工作于"))
|
||||
)
|
||||
|
||||
view = ET.Element("View")
|
||||
relations = ET.SubElement(view, "Relations")
|
||||
for rt in relation_types:
|
||||
ET.SubElement(relations, "Relation", value=rt)
|
||||
labels = ET.SubElement(view, "Labels", name="label", toName="text")
|
||||
for i, et in enumerate(entity_types):
|
||||
ET.SubElement(labels, "Label", value=et, background=_pick_color(i))
|
||||
ET.SubElement(view, "Text", name="text", value="$text")
|
||||
|
||||
return ET.tostring(view, encoding="unicode")
|
||||
|
||||
|
||||
def _gen_object_detection_xml(params: Dict[str, Any]) -> str:
|
||||
detected_labels: List[str] = params.get("_detected_labels", [])
|
||||
if not detected_labels:
|
||||
detected_labels = ["object"]
|
||||
|
||||
view = ET.Element("View")
|
||||
ET.SubElement(view, "Image", name="image", value="$image")
|
||||
rect_labels = ET.SubElement(view, "RectangleLabels", name="label", toName="image")
|
||||
for i, lbl in enumerate(detected_labels):
|
||||
ET.SubElement(rect_labels, "Label", value=lbl, background=_pick_color(i))
|
||||
|
||||
return ET.tostring(view, encoding="unicode")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pipeline 信息提取
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
OPERATOR_TASK_TYPE_MAP: Dict[str, str] = {
|
||||
"LLMTextClassification": "text_classification",
|
||||
"LLMNamedEntityRecognition": "ner",
|
||||
"LLMRelationExtraction": "relation_extraction",
|
||||
"ImageObjectDetectionBoundingBox": "object_detection",
|
||||
}
|
||||
|
||||
|
||||
def infer_task_type_from_pipeline(pipeline: List[Dict[str, Any]]) -> Optional[str]:
|
||||
"""从 normalized pipeline 中推断标注类型(取第一个匹配的算子)。"""
|
||||
for step in pipeline:
|
||||
operator_id = str(step.get("operator_id", ""))
|
||||
task_type = OPERATOR_TASK_TYPE_MAP.get(operator_id)
|
||||
if task_type is not None:
|
||||
return task_type
|
||||
return None
|
||||
|
||||
|
||||
def extract_operator_params(pipeline: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""从 normalized pipeline 中提取第一个标注算子的 overrides 参数。"""
|
||||
for step in pipeline:
|
||||
operator_id = str(step.get("operator_id", ""))
|
||||
if operator_id in OPERATOR_TASK_TYPE_MAP:
|
||||
return dict(step.get("overrides", {}))
|
||||
return {}
|
||||
@@ -857,6 +857,139 @@ def _register_output_dataset(
|
||||
)
|
||||
|
||||
|
||||
def _create_labeling_project_with_annotations(
|
||||
task_id: str,
|
||||
dataset_id: str,
|
||||
dataset_name: str,
|
||||
task_name: str,
|
||||
dataset_type: str,
|
||||
normalized_pipeline: List[Dict[str, Any]],
|
||||
file_results: List[Tuple[str, Dict[str, Any]]],
|
||||
all_file_ids: List[str],
|
||||
) -> None:
|
||||
"""将自动标注结果转换为 Label Studio 格式,创建标注项目并写入标注结果。"""
|
||||
|
||||
from datamate.annotation_result_converter import (
|
||||
convert_annotation,
|
||||
extract_operator_params,
|
||||
generate_label_config_xml,
|
||||
infer_task_type_from_pipeline,
|
||||
)
|
||||
|
||||
task_type = infer_task_type_from_pipeline(normalized_pipeline)
|
||||
if not task_type:
|
||||
logger.warning(
|
||||
"Cannot infer task_type from pipeline for task {}, skipping labeling project creation",
|
||||
task_id,
|
||||
)
|
||||
return
|
||||
|
||||
operator_params = extract_operator_params(normalized_pipeline)
|
||||
|
||||
# 目标检测:从实际检测结果中收集唯一标签列表
|
||||
if task_type == "object_detection":
|
||||
all_labels: set = set()
|
||||
for _, ann in file_results:
|
||||
for det in ann.get("detections", []):
|
||||
if isinstance(det, dict):
|
||||
all_labels.add(str(det.get("label", "unknown")))
|
||||
operator_params["_detected_labels"] = sorted(all_labels)
|
||||
|
||||
label_config = generate_label_config_xml(task_type, operator_params)
|
||||
|
||||
project_id = str(uuid.uuid4())
|
||||
labeling_project_id = str(uuid.uuid4().int % 10**8).zfill(8)
|
||||
project_name = f"自动标注 - {task_name or dataset_name or task_id[:8]}"[:100]
|
||||
|
||||
now = datetime.now()
|
||||
configuration = json.dumps(
|
||||
{
|
||||
"label_config": label_config,
|
||||
"description": f"由自动标注任务 {task_id[:8]} 自动创建",
|
||||
"auto_annotation_task_id": task_id,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
insert_project_sql = text(
|
||||
"""
|
||||
INSERT INTO t_dm_labeling_projects
|
||||
(id, dataset_id, name, labeling_project_id, template_id, configuration, created_at, updated_at)
|
||||
VALUES
|
||||
(:id, :dataset_id, :name, :labeling_project_id, NULL, :configuration, :now, :now)
|
||||
"""
|
||||
)
|
||||
insert_snapshot_sql = text(
|
||||
"""
|
||||
INSERT INTO t_dm_labeling_project_files (id, project_id, file_id, created_at)
|
||||
VALUES (:id, :project_id, :file_id, :now)
|
||||
"""
|
||||
)
|
||||
insert_annotation_sql = text(
|
||||
"""
|
||||
INSERT INTO t_dm_annotation_results
|
||||
(id, project_id, file_id, annotation, annotation_status, file_version, created_at, updated_at)
|
||||
VALUES
|
||||
(:id, :project_id, :file_id, :annotation, :annotation_status, :file_version, :now, :now)
|
||||
"""
|
||||
)
|
||||
|
||||
with SQLManager.create_connect() as conn:
|
||||
# 1. 创建标注项目
|
||||
conn.execute(
|
||||
insert_project_sql,
|
||||
{
|
||||
"id": project_id,
|
||||
"dataset_id": dataset_id,
|
||||
"name": project_name,
|
||||
"labeling_project_id": labeling_project_id,
|
||||
"configuration": configuration,
|
||||
"now": now,
|
||||
},
|
||||
)
|
||||
|
||||
# 2. 创建项目文件快照
|
||||
for file_id in all_file_ids:
|
||||
conn.execute(
|
||||
insert_snapshot_sql,
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"project_id": project_id,
|
||||
"file_id": file_id,
|
||||
"now": now,
|
||||
},
|
||||
)
|
||||
|
||||
# 3. 转换并写入标注结果
|
||||
converted_count = 0
|
||||
for file_id, annotation in file_results:
|
||||
ls_annotation = convert_annotation(annotation, file_id, project_id)
|
||||
if ls_annotation is None:
|
||||
continue
|
||||
|
||||
conn.execute(
|
||||
insert_annotation_sql,
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"project_id": project_id,
|
||||
"file_id": file_id,
|
||||
"annotation": json.dumps(ls_annotation, ensure_ascii=False),
|
||||
"annotation_status": "ANNOTATED",
|
||||
"file_version": 1,
|
||||
"now": now,
|
||||
},
|
||||
)
|
||||
converted_count += 1
|
||||
|
||||
logger.info(
|
||||
"Created labeling project {} ({}) with {} annotations for auto-annotation task {}",
|
||||
project_id,
|
||||
project_name,
|
||||
converted_count,
|
||||
task_id,
|
||||
)
|
||||
|
||||
|
||||
def _process_single_task(task: Dict[str, Any]) -> None:
|
||||
"""执行单个自动标注任务。"""
|
||||
|
||||
@@ -915,7 +1048,7 @@ def _process_single_task(task: Dict[str, Any]) -> None:
|
||||
else:
|
||||
all_files = _load_dataset_files(dataset_id)
|
||||
|
||||
files = [(path, name) for _, path, name in all_files]
|
||||
files = all_files # [(file_id, file_path, file_name)]
|
||||
|
||||
total_images = len(files)
|
||||
if total_images == 0:
|
||||
@@ -983,10 +1116,11 @@ def _process_single_task(task: Dict[str, Any]) -> None:
|
||||
|
||||
processed = 0
|
||||
detected_total = 0
|
||||
file_results: List[Tuple[str, Dict[str, Any]]] = [] # (file_id, annotations)
|
||||
|
||||
try:
|
||||
|
||||
for file_path, file_name in files:
|
||||
for file_id, file_path, file_name in files:
|
||||
if _is_stop_requested(task_id, run_token):
|
||||
logger.info("Task stop requested during processing: {}", task_id)
|
||||
_update_task_status(
|
||||
@@ -1003,7 +1137,7 @@ def _process_single_task(task: Dict[str, Any]) -> None:
|
||||
clear_run_token=True,
|
||||
error_message="Task stopped by request",
|
||||
)
|
||||
return
|
||||
break
|
||||
|
||||
try:
|
||||
sample_key = _get_sample_key(dataset_type)
|
||||
@@ -1016,6 +1150,10 @@ def _process_single_task(task: Dict[str, Any]) -> None:
|
||||
detected_total += _count_detections(result)
|
||||
processed += 1
|
||||
|
||||
ann = result.get("annotations")
|
||||
if isinstance(ann, dict):
|
||||
file_results.append((file_id, ann))
|
||||
|
||||
progress = int(processed * 100 / total_images) if total_images > 0 else 100
|
||||
|
||||
_update_task_status(
|
||||
@@ -1038,19 +1176,21 @@ def _process_single_task(task: Dict[str, Any]) -> None:
|
||||
)
|
||||
continue
|
||||
|
||||
_update_task_status(
|
||||
task_id,
|
||||
run_token=run_token,
|
||||
status="completed",
|
||||
progress=100,
|
||||
processed_images=processed,
|
||||
detected_objects=detected_total,
|
||||
total_images=total_images,
|
||||
output_path=output_dir,
|
||||
output_dataset_id=output_dataset_id,
|
||||
completed=True,
|
||||
clear_run_token=True,
|
||||
)
|
||||
else:
|
||||
# Loop completed without break (not stopped)
|
||||
_update_task_status(
|
||||
task_id,
|
||||
run_token=run_token,
|
||||
status="completed",
|
||||
progress=100,
|
||||
processed_images=processed,
|
||||
detected_objects=detected_total,
|
||||
total_images=total_images,
|
||||
output_path=output_dir,
|
||||
output_dataset_id=output_dataset_id,
|
||||
completed=True,
|
||||
clear_run_token=True,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Completed auto-annotation task: id={}, total_images={}, processed={}, detected_objects={}, output_path={}",
|
||||
@@ -1077,6 +1217,26 @@ def _process_single_task(task: Dict[str, Any]) -> None:
|
||||
task_id,
|
||||
e,
|
||||
)
|
||||
|
||||
# 将自动标注结果转换为 Label Studio 格式并写入标注项目
|
||||
if file_results:
|
||||
try:
|
||||
_create_labeling_project_with_annotations(
|
||||
task_id=task_id,
|
||||
dataset_id=dataset_id,
|
||||
dataset_name=source_dataset_name,
|
||||
task_name=task_name,
|
||||
dataset_type=dataset_type,
|
||||
normalized_pipeline=normalized_pipeline,
|
||||
file_results=file_results,
|
||||
all_file_ids=[fid for fid, _, _ in all_files],
|
||||
)
|
||||
except Exception as e: # pragma: no cover - 防御性日志
|
||||
logger.error(
|
||||
"Failed to create labeling project for auto-annotation task {}: {}",
|
||||
task_id,
|
||||
e,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Task execution failed for task {}: {}", task_id, e)
|
||||
_update_task_status(
|
||||
|
||||
Reference in New Issue
Block a user