diff --git a/runtime/python-executor/datamate/annotation_result_converter.py b/runtime/python-executor/datamate/annotation_result_converter.py new file mode 100644 index 0000000..14b5183 --- /dev/null +++ b/runtime/python-executor/datamate/annotation_result_converter.py @@ -0,0 +1,491 @@ +# -*- coding: utf-8 -*- +"""将自动标注算子输出转换为 Label Studio 兼容格式。 + +支持的算子类型: +- LLMTextClassification → choices +- LLMNamedEntityRecognition → labels (span) +- LLMRelationExtraction → labels + relation +- ImageObjectDetectionBoundingBox → rectanglelabels +""" + +from __future__ import annotations + +import hashlib +import uuid +import xml.etree.ElementTree as ET +from typing import Any, Dict, List, Optional, Tuple + +from loguru import logger + +# --------------------------------------------------------------------------- +# 颜色调色板(Label Studio 背景色) +# --------------------------------------------------------------------------- + +_LABEL_COLORS = [ + "#e53935", + "#fb8c00", + "#43a047", + "#1e88e5", + "#8e24aa", + "#00897b", + "#d81b60", + "#3949ab", + "#fdd835", + "#6d4c41", + "#546e7a", + "#f4511e", +] + + +def _pick_color(index: int) -> str: + return _LABEL_COLORS[index % len(_LABEL_COLORS)] + + +# --------------------------------------------------------------------------- +# 稳定 LS ID 生成(与 editor.py:118-136 一致) +# --------------------------------------------------------------------------- + +def _stable_ls_id(seed: str) -> int: + """生成稳定的 Label Studio 风格整数 ID(JS 安全整数范围内)。""" + digest = hashlib.sha1(seed.encode("utf-8")).hexdigest() + value = int(digest[:13], 16) + return value if value > 0 else 1 + + +def make_ls_task_id(project_id: str, file_id: str) -> int: + return _stable_ls_id(f"task:{project_id}:{file_id}") + + +def make_ls_annotation_id(project_id: str, file_id: str) -> int: + return _stable_ls_id(f"annotation:{project_id}:{file_id}") + + +# --------------------------------------------------------------------------- +# 确定性 region ID(关系抽取需要稳定引用) +# --------------------------------------------------------------------------- + +_NAMESPACE_REGION = uuid.UUID("a1b2c3d4-e5f6-7890-abcd-ef1234567890") + + +def _make_region_id( + project_id: str, file_id: str, text: str, entity_type: str, start: Any, end: Any +) -> str: + seed = f"{project_id}:{file_id}:{text}:{entity_type}:{start}:{end}" + return str(uuid.uuid5(_NAMESPACE_REGION, seed)) + + +# --------------------------------------------------------------------------- +# 1. 文本分类 +# --------------------------------------------------------------------------- + +def convert_text_classification( + annotation: Dict[str, Any], file_id: str, project_id: str +) -> Optional[Dict[str, Any]]: + """将 LLMTextClassification 算子输出转换为 LS annotation。""" + + result_data = annotation.get("result") + if not isinstance(result_data, dict): + return None + + label = result_data.get("label") + if not label: + return None + + region_id = str(uuid.uuid4()) + ls_result = [ + { + "id": region_id, + "from_name": "sentiment", + "to_name": "text", + "type": "choices", + "value": {"choices": [str(label)]}, + } + ] + + return { + "id": make_ls_annotation_id(project_id, file_id), + "task": make_ls_task_id(project_id, file_id), + "result": ls_result, + } + + +# --------------------------------------------------------------------------- +# 2. 命名实体识别 +# --------------------------------------------------------------------------- + +def convert_ner( + annotation: Dict[str, Any], file_id: str, project_id: str +) -> Optional[Dict[str, Any]]: + """将 LLMNamedEntityRecognition 算子输出转换为 LS annotation。""" + + entities = annotation.get("entities") + if not isinstance(entities, list) or not entities: + return None + + ls_result: List[Dict[str, Any]] = [] + for ent in entities: + if not isinstance(ent, dict): + continue + ent_text = ent.get("text", "") + ent_type = ent.get("type", "") + start = ent.get("start") + end = ent.get("end") + + if not ent_text or start is None or end is None: + continue + + region_id = _make_region_id(project_id, file_id, ent_text, ent_type, start, end) + ls_result.append( + { + "id": region_id, + "from_name": "label", + "to_name": "text", + "type": "labels", + "value": { + "start": int(start), + "end": int(end), + "text": str(ent_text), + "labels": [str(ent_type)], + }, + } + ) + + if not ls_result: + return None + + return { + "id": make_ls_annotation_id(project_id, file_id), + "task": make_ls_task_id(project_id, file_id), + "result": ls_result, + } + + +# --------------------------------------------------------------------------- +# 3. 关系抽取 +# --------------------------------------------------------------------------- + +def _find_entity_region_id( + entity_regions: List[Dict[str, Any]], text: str, entity_type: str +) -> Optional[str]: + """在已生成的 entity regions 中查找匹配的 region ID。""" + for region in entity_regions: + value = region.get("value", {}) + if value.get("text") == text and entity_type in value.get("labels", []): + return region["id"] + return None + + +def convert_relation_extraction( + annotation: Dict[str, Any], file_id: str, project_id: str +) -> Optional[Dict[str, Any]]: + """将 LLMRelationExtraction 算子输出转换为 LS annotation。""" + + entities = annotation.get("entities", []) + relations = annotation.get("relations", []) + if not isinstance(entities, list): + entities = [] + if not isinstance(relations, list): + relations = [] + + if not entities and not relations: + return None + + # 构建实体 label regions(去重) + entity_regions: List[Dict[str, Any]] = [] + seen_keys: set = set() + + for ent in entities: + if not isinstance(ent, dict): + continue + ent_text = str(ent.get("text", "")).strip() + ent_type = str(ent.get("type", "")).strip() + start = ent.get("start") + end = ent.get("end") + + if not ent_text or start is None or end is None: + continue + + dedup_key = (ent_text, ent_type, int(start), int(end)) + if dedup_key in seen_keys: + continue + seen_keys.add(dedup_key) + + region_id = _make_region_id(project_id, file_id, ent_text, ent_type, start, end) + entity_regions.append( + { + "id": region_id, + "from_name": "label", + "to_name": "text", + "type": "labels", + "value": { + "start": int(start), + "end": int(end), + "text": ent_text, + "labels": [ent_type], + }, + } + ) + + # 构建关系 + relation_results: List[Dict[str, Any]] = [] + for rel in relations: + if not isinstance(rel, dict): + continue + + subject = rel.get("subject") + obj = rel.get("object") + relation_type = str(rel.get("relation", "")).strip() + + if not isinstance(subject, dict) or not isinstance(obj, dict) or not relation_type: + continue + + subj_text = str(subject.get("text", "")).strip() + subj_type = str(subject.get("type", "")).strip() + obj_text = str(obj.get("text", "")).strip() + obj_type = str(obj.get("type", "")).strip() + + from_id = _find_entity_region_id(entity_regions, subj_text, subj_type) + to_id = _find_entity_region_id(entity_regions, obj_text, obj_type) + + if not from_id or not to_id: + logger.debug( + "Skipping relation '{}': could not find region for subject='{}' or object='{}'", + relation_type, + subj_text, + obj_text, + ) + continue + + relation_results.append( + { + "from_id": from_id, + "to_id": to_id, + "type": "relation", + "direction": "right", + "labels": [relation_type], + } + ) + + ls_result = entity_regions + relation_results + if not ls_result: + return None + + return { + "id": make_ls_annotation_id(project_id, file_id), + "task": make_ls_task_id(project_id, file_id), + "result": ls_result, + } + + +# --------------------------------------------------------------------------- +# 4. 目标检测 +# --------------------------------------------------------------------------- + +def convert_object_detection( + annotation: Dict[str, Any], file_id: str, project_id: str +) -> Optional[Dict[str, Any]]: + """将 ImageObjectDetectionBoundingBox 算子输出转换为 LS annotation。""" + + detections = annotation.get("detections") + if not isinstance(detections, list) or not detections: + return None + + img_width = annotation.get("width", 0) + img_height = annotation.get("height", 0) + if not img_width or not img_height: + return None + + ls_result: List[Dict[str, Any]] = [] + for det in detections: + if not isinstance(det, dict): + continue + + label = det.get("label", "unknown") + bbox = det.get("bbox_xyxy") + if not isinstance(bbox, (list, tuple)) or len(bbox) < 4: + continue + + x1, y1, x2, y2 = float(bbox[0]), float(bbox[1]), float(bbox[2]), float(bbox[3]) + x_pct = x1 * 100.0 / img_width + y_pct = y1 * 100.0 / img_height + w_pct = (x2 - x1) * 100.0 / img_width + h_pct = (y2 - y1) * 100.0 / img_height + + region_id = str(uuid.uuid4()) + ls_result.append( + { + "id": region_id, + "from_name": "label", + "to_name": "image", + "type": "rectanglelabels", + "original_width": int(img_width), + "original_height": int(img_height), + "image_rotation": 0, + "value": { + "x": round(x_pct, 4), + "y": round(y_pct, 4), + "width": round(w_pct, 4), + "height": round(h_pct, 4), + "rotation": 0, + "rectanglelabels": [str(label)], + }, + } + ) + + if not ls_result: + return None + + return { + "id": make_ls_annotation_id(project_id, file_id), + "task": make_ls_task_id(project_id, file_id), + "result": ls_result, + } + + +# --------------------------------------------------------------------------- +# 分发器 +# --------------------------------------------------------------------------- + +TASK_TYPE_CONVERTERS = { + "text_classification": convert_text_classification, + "ner": convert_ner, + "relation_extraction": convert_relation_extraction, + "object_detection": convert_object_detection, +} + + +def convert_annotation( + annotation: Dict[str, Any], file_id: str, project_id: str +) -> Optional[Dict[str, Any]]: + """根据 task_type 分发到对应的转换函数。""" + task_type = annotation.get("task_type") + if task_type is None and "detections" in annotation: + task_type = "object_detection" + converter = TASK_TYPE_CONVERTERS.get(task_type) # type: ignore[arg-type] + if converter is None: + logger.warning("No LS converter for task_type: {}", task_type) + return None + try: + return converter(annotation, file_id, project_id) + except Exception as exc: + logger.error("Failed to convert annotation (task_type={}): {}", task_type, exc) + return None + + +# --------------------------------------------------------------------------- +# label_config XML 生成 +# --------------------------------------------------------------------------- + +def _split_labels(raw: str) -> List[str]: + """逗号分隔字符串 → 去空白列表。""" + return [s.strip() for s in raw.split(",") if s.strip()] + + +def generate_label_config_xml(task_type: str, operator_params: Dict[str, Any]) -> str: + """根据标注类型和算子参数生成 Label Studio label_config XML。""" + + if task_type == "text_classification": + return _gen_text_classification_xml(operator_params) + if task_type == "ner": + return _gen_ner_xml(operator_params) + if task_type == "relation_extraction": + return _gen_relation_extraction_xml(operator_params) + if task_type == "object_detection": + return _gen_object_detection_xml(operator_params) + + # 未知类型:返回最小可用 XML + return "
" + + +def _gen_text_classification_xml(params: Dict[str, Any]) -> str: + categories = _split_labels(str(params.get("categories", "正面,负面,中性"))) + + view = ET.Element("View") + ET.SubElement(view, "Text", name="text", value="$text") + choices = ET.SubElement( + view, + "Choices", + name="sentiment", + toName="text", + choice="single", + showInline="true", + ) + for cat in categories: + ET.SubElement(choices, "Choice", value=cat) + + return ET.tostring(view, encoding="unicode") + + +def _gen_ner_xml(params: Dict[str, Any]) -> str: + entity_types = _split_labels(str(params.get("entityTypes", "PER,ORG,LOC,DATE"))) + + view = ET.Element("View") + labels = ET.SubElement(view, "Labels", name="label", toName="text") + for i, et in enumerate(entity_types): + ET.SubElement(labels, "Label", value=et, background=_pick_color(i)) + ET.SubElement(view, "Text", name="text", value="$text") + + return ET.tostring(view, encoding="unicode") + + +def _gen_relation_extraction_xml(params: Dict[str, Any]) -> str: + entity_types = _split_labels(str(params.get("entityTypes", "PER,ORG,LOC"))) + relation_types = _split_labels( + str(params.get("relationTypes", "属于,位于,创立,工作于")) + ) + + view = ET.Element("View") + relations = ET.SubElement(view, "Relations") + for rt in relation_types: + ET.SubElement(relations, "Relation", value=rt) + labels = ET.SubElement(view, "Labels", name="label", toName="text") + for i, et in enumerate(entity_types): + ET.SubElement(labels, "Label", value=et, background=_pick_color(i)) + ET.SubElement(view, "Text", name="text", value="$text") + + return ET.tostring(view, encoding="unicode") + + +def _gen_object_detection_xml(params: Dict[str, Any]) -> str: + detected_labels: List[str] = params.get("_detected_labels", []) + if not detected_labels: + detected_labels = ["object"] + + view = ET.Element("View") + ET.SubElement(view, "Image", name="image", value="$image") + rect_labels = ET.SubElement(view, "RectangleLabels", name="label", toName="image") + for i, lbl in enumerate(detected_labels): + ET.SubElement(rect_labels, "Label", value=lbl, background=_pick_color(i)) + + return ET.tostring(view, encoding="unicode") + + +# --------------------------------------------------------------------------- +# Pipeline 信息提取 +# --------------------------------------------------------------------------- + +OPERATOR_TASK_TYPE_MAP: Dict[str, str] = { + "LLMTextClassification": "text_classification", + "LLMNamedEntityRecognition": "ner", + "LLMRelationExtraction": "relation_extraction", + "ImageObjectDetectionBoundingBox": "object_detection", +} + + +def infer_task_type_from_pipeline(pipeline: List[Dict[str, Any]]) -> Optional[str]: + """从 normalized pipeline 中推断标注类型(取第一个匹配的算子)。""" + for step in pipeline: + operator_id = str(step.get("operator_id", "")) + task_type = OPERATOR_TASK_TYPE_MAP.get(operator_id) + if task_type is not None: + return task_type + return None + + +def extract_operator_params(pipeline: List[Dict[str, Any]]) -> Dict[str, Any]: + """从 normalized pipeline 中提取第一个标注算子的 overrides 参数。""" + for step in pipeline: + operator_id = str(step.get("operator_id", "")) + if operator_id in OPERATOR_TASK_TYPE_MAP: + return dict(step.get("overrides", {})) + return {} diff --git a/runtime/python-executor/datamate/auto_annotation_worker.py b/runtime/python-executor/datamate/auto_annotation_worker.py index 6487cd0..b6bfa01 100644 --- a/runtime/python-executor/datamate/auto_annotation_worker.py +++ b/runtime/python-executor/datamate/auto_annotation_worker.py @@ -857,6 +857,139 @@ def _register_output_dataset( ) +def _create_labeling_project_with_annotations( + task_id: str, + dataset_id: str, + dataset_name: str, + task_name: str, + dataset_type: str, + normalized_pipeline: List[Dict[str, Any]], + file_results: List[Tuple[str, Dict[str, Any]]], + all_file_ids: List[str], +) -> None: + """将自动标注结果转换为 Label Studio 格式,创建标注项目并写入标注结果。""" + + from datamate.annotation_result_converter import ( + convert_annotation, + extract_operator_params, + generate_label_config_xml, + infer_task_type_from_pipeline, + ) + + task_type = infer_task_type_from_pipeline(normalized_pipeline) + if not task_type: + logger.warning( + "Cannot infer task_type from pipeline for task {}, skipping labeling project creation", + task_id, + ) + return + + operator_params = extract_operator_params(normalized_pipeline) + + # 目标检测:从实际检测结果中收集唯一标签列表 + if task_type == "object_detection": + all_labels: set = set() + for _, ann in file_results: + for det in ann.get("detections", []): + if isinstance(det, dict): + all_labels.add(str(det.get("label", "unknown"))) + operator_params["_detected_labels"] = sorted(all_labels) + + label_config = generate_label_config_xml(task_type, operator_params) + + project_id = str(uuid.uuid4()) + labeling_project_id = str(uuid.uuid4().int % 10**8).zfill(8) + project_name = f"自动标注 - {task_name or dataset_name or task_id[:8]}"[:100] + + now = datetime.now() + configuration = json.dumps( + { + "label_config": label_config, + "description": f"由自动标注任务 {task_id[:8]} 自动创建", + "auto_annotation_task_id": task_id, + }, + ensure_ascii=False, + ) + + insert_project_sql = text( + """ + INSERT INTO t_dm_labeling_projects + (id, dataset_id, name, labeling_project_id, template_id, configuration, created_at, updated_at) + VALUES + (:id, :dataset_id, :name, :labeling_project_id, NULL, :configuration, :now, :now) + """ + ) + insert_snapshot_sql = text( + """ + INSERT INTO t_dm_labeling_project_files (id, project_id, file_id, created_at) + VALUES (:id, :project_id, :file_id, :now) + """ + ) + insert_annotation_sql = text( + """ + INSERT INTO t_dm_annotation_results + (id, project_id, file_id, annotation, annotation_status, file_version, created_at, updated_at) + VALUES + (:id, :project_id, :file_id, :annotation, :annotation_status, :file_version, :now, :now) + """ + ) + + with SQLManager.create_connect() as conn: + # 1. 创建标注项目 + conn.execute( + insert_project_sql, + { + "id": project_id, + "dataset_id": dataset_id, + "name": project_name, + "labeling_project_id": labeling_project_id, + "configuration": configuration, + "now": now, + }, + ) + + # 2. 创建项目文件快照 + for file_id in all_file_ids: + conn.execute( + insert_snapshot_sql, + { + "id": str(uuid.uuid4()), + "project_id": project_id, + "file_id": file_id, + "now": now, + }, + ) + + # 3. 转换并写入标注结果 + converted_count = 0 + for file_id, annotation in file_results: + ls_annotation = convert_annotation(annotation, file_id, project_id) + if ls_annotation is None: + continue + + conn.execute( + insert_annotation_sql, + { + "id": str(uuid.uuid4()), + "project_id": project_id, + "file_id": file_id, + "annotation": json.dumps(ls_annotation, ensure_ascii=False), + "annotation_status": "ANNOTATED", + "file_version": 1, + "now": now, + }, + ) + converted_count += 1 + + logger.info( + "Created labeling project {} ({}) with {} annotations for auto-annotation task {}", + project_id, + project_name, + converted_count, + task_id, + ) + + def _process_single_task(task: Dict[str, Any]) -> None: """执行单个自动标注任务。""" @@ -915,7 +1048,7 @@ def _process_single_task(task: Dict[str, Any]) -> None: else: all_files = _load_dataset_files(dataset_id) - files = [(path, name) for _, path, name in all_files] + files = all_files # [(file_id, file_path, file_name)] total_images = len(files) if total_images == 0: @@ -983,10 +1116,11 @@ def _process_single_task(task: Dict[str, Any]) -> None: processed = 0 detected_total = 0 + file_results: List[Tuple[str, Dict[str, Any]]] = [] # (file_id, annotations) try: - for file_path, file_name in files: + for file_id, file_path, file_name in files: if _is_stop_requested(task_id, run_token): logger.info("Task stop requested during processing: {}", task_id) _update_task_status( @@ -1003,7 +1137,7 @@ def _process_single_task(task: Dict[str, Any]) -> None: clear_run_token=True, error_message="Task stopped by request", ) - return + break try: sample_key = _get_sample_key(dataset_type) @@ -1016,6 +1150,10 @@ def _process_single_task(task: Dict[str, Any]) -> None: detected_total += _count_detections(result) processed += 1 + ann = result.get("annotations") + if isinstance(ann, dict): + file_results.append((file_id, ann)) + progress = int(processed * 100 / total_images) if total_images > 0 else 100 _update_task_status( @@ -1038,19 +1176,21 @@ def _process_single_task(task: Dict[str, Any]) -> None: ) continue - _update_task_status( - task_id, - run_token=run_token, - status="completed", - progress=100, - processed_images=processed, - detected_objects=detected_total, - total_images=total_images, - output_path=output_dir, - output_dataset_id=output_dataset_id, - completed=True, - clear_run_token=True, - ) + else: + # Loop completed without break (not stopped) + _update_task_status( + task_id, + run_token=run_token, + status="completed", + progress=100, + processed_images=processed, + detected_objects=detected_total, + total_images=total_images, + output_path=output_dir, + output_dataset_id=output_dataset_id, + completed=True, + clear_run_token=True, + ) logger.info( "Completed auto-annotation task: id={}, total_images={}, processed={}, detected_objects={}, output_path={}", @@ -1077,6 +1217,26 @@ def _process_single_task(task: Dict[str, Any]) -> None: task_id, e, ) + + # 将自动标注结果转换为 Label Studio 格式并写入标注项目 + if file_results: + try: + _create_labeling_project_with_annotations( + task_id=task_id, + dataset_id=dataset_id, + dataset_name=source_dataset_name, + task_name=task_name, + dataset_type=dataset_type, + normalized_pipeline=normalized_pipeline, + file_results=file_results, + all_file_ids=[fid for fid, _, _ in all_files], + ) + except Exception as e: # pragma: no cover - 防御性日志 + logger.error( + "Failed to create labeling project for auto-annotation task {}: {}", + task_id, + e, + ) except Exception as e: logger.error("Task execution failed for task {}: {}", task_id, e) _update_task_status(