# -*- coding: utf-8 -*- """将自动标注算子输出转换为 Label Studio 兼容格式。 支持的算子类型: - LLMTextClassification → choices - LLMNamedEntityRecognition → labels (span) - LLMRelationExtraction → labels + relation - ImageObjectDetectionBoundingBox → rectanglelabels """ from __future__ import annotations import hashlib import uuid import xml.etree.ElementTree as ET from typing import Any, Dict, List, Optional, Tuple from loguru import logger # --------------------------------------------------------------------------- # 颜色调色板(Label Studio 背景色) # --------------------------------------------------------------------------- _LABEL_COLORS = [ "#e53935", "#fb8c00", "#43a047", "#1e88e5", "#8e24aa", "#00897b", "#d81b60", "#3949ab", "#fdd835", "#6d4c41", "#546e7a", "#f4511e", ] def _pick_color(index: int) -> str: return _LABEL_COLORS[index % len(_LABEL_COLORS)] # --------------------------------------------------------------------------- # 稳定 LS ID 生成(与 editor.py:118-136 一致) # --------------------------------------------------------------------------- def _stable_ls_id(seed: str) -> int: """生成稳定的 Label Studio 风格整数 ID(JS 安全整数范围内)。""" digest = hashlib.sha1(seed.encode("utf-8")).hexdigest() value = int(digest[:13], 16) return value if value > 0 else 1 def make_ls_task_id(project_id: str, file_id: str) -> int: return _stable_ls_id(f"task:{project_id}:{file_id}") def make_ls_annotation_id(project_id: str, file_id: str) -> int: return _stable_ls_id(f"annotation:{project_id}:{file_id}") # --------------------------------------------------------------------------- # 确定性 region ID(关系抽取需要稳定引用) # --------------------------------------------------------------------------- _NAMESPACE_REGION = uuid.UUID("a1b2c3d4-e5f6-7890-abcd-ef1234567890") def _make_region_id( project_id: str, file_id: str, text: str, entity_type: str, start: Any, end: Any ) -> str: seed = f"{project_id}:{file_id}:{text}:{entity_type}:{start}:{end}" return str(uuid.uuid5(_NAMESPACE_REGION, seed)) # --------------------------------------------------------------------------- # 1. 文本分类 # --------------------------------------------------------------------------- def convert_text_classification( annotation: Dict[str, Any], file_id: str, project_id: str ) -> Optional[Dict[str, Any]]: """将 LLMTextClassification 算子输出转换为 LS annotation。""" result_data = annotation.get("result") if not isinstance(result_data, dict): return None label = result_data.get("label") if not label: return None region_id = str(uuid.uuid4()) ls_result = [ { "id": region_id, "from_name": "sentiment", "to_name": "text", "type": "choices", "value": {"choices": [str(label)]}, } ] return { "id": make_ls_annotation_id(project_id, file_id), "task": make_ls_task_id(project_id, file_id), "result": ls_result, } # --------------------------------------------------------------------------- # 2. 命名实体识别 # --------------------------------------------------------------------------- def convert_ner( annotation: Dict[str, Any], file_id: str, project_id: str ) -> Optional[Dict[str, Any]]: """将 LLMNamedEntityRecognition 算子输出转换为 LS annotation。""" entities = annotation.get("entities") if not isinstance(entities, list) or not entities: return None ls_result: List[Dict[str, Any]] = [] for ent in entities: if not isinstance(ent, dict): continue ent_text = ent.get("text", "") ent_type = ent.get("type", "") start = ent.get("start") end = ent.get("end") if not ent_text or start is None or end is None: continue region_id = _make_region_id(project_id, file_id, ent_text, ent_type, start, end) ls_result.append( { "id": region_id, "from_name": "label", "to_name": "text", "type": "labels", "value": { "start": int(start), "end": int(end), "text": str(ent_text), "labels": [str(ent_type)], }, } ) if not ls_result: return None return { "id": make_ls_annotation_id(project_id, file_id), "task": make_ls_task_id(project_id, file_id), "result": ls_result, } # --------------------------------------------------------------------------- # 3. 关系抽取 # --------------------------------------------------------------------------- def _find_entity_region_id( entity_regions: List[Dict[str, Any]], text: str, entity_type: str ) -> Optional[str]: """在已生成的 entity regions 中查找匹配的 region ID。""" for region in entity_regions: value = region.get("value", {}) if value.get("text") == text and entity_type in value.get("labels", []): return region["id"] return None def convert_relation_extraction( annotation: Dict[str, Any], file_id: str, project_id: str ) -> Optional[Dict[str, Any]]: """将 LLMRelationExtraction 算子输出转换为 LS annotation。""" entities = annotation.get("entities", []) relations = annotation.get("relations", []) if not isinstance(entities, list): entities = [] if not isinstance(relations, list): relations = [] if not entities and not relations: return None # 构建实体 label regions(去重) entity_regions: List[Dict[str, Any]] = [] seen_keys: set = set() for ent in entities: if not isinstance(ent, dict): continue ent_text = str(ent.get("text", "")).strip() ent_type = str(ent.get("type", "")).strip() start = ent.get("start") end = ent.get("end") if not ent_text or start is None or end is None: continue dedup_key = (ent_text, ent_type, int(start), int(end)) if dedup_key in seen_keys: continue seen_keys.add(dedup_key) region_id = _make_region_id(project_id, file_id, ent_text, ent_type, start, end) entity_regions.append( { "id": region_id, "from_name": "label", "to_name": "text", "type": "labels", "value": { "start": int(start), "end": int(end), "text": ent_text, "labels": [ent_type], }, } ) # 构建关系 relation_results: List[Dict[str, Any]] = [] for rel in relations: if not isinstance(rel, dict): continue subject = rel.get("subject") obj = rel.get("object") relation_type = str(rel.get("relation", "")).strip() if not isinstance(subject, dict) or not isinstance(obj, dict) or not relation_type: continue subj_text = str(subject.get("text", "")).strip() subj_type = str(subject.get("type", "")).strip() obj_text = str(obj.get("text", "")).strip() obj_type = str(obj.get("type", "")).strip() from_id = _find_entity_region_id(entity_regions, subj_text, subj_type) to_id = _find_entity_region_id(entity_regions, obj_text, obj_type) if not from_id or not to_id: logger.debug( "Skipping relation '{}': could not find region for subject='{}' or object='{}'", relation_type, subj_text, obj_text, ) continue relation_results.append( { "from_id": from_id, "to_id": to_id, "type": "relation", "direction": "right", "labels": [relation_type], } ) ls_result = entity_regions + relation_results if not ls_result: return None return { "id": make_ls_annotation_id(project_id, file_id), "task": make_ls_task_id(project_id, file_id), "result": ls_result, } # --------------------------------------------------------------------------- # 4. 目标检测 # --------------------------------------------------------------------------- def convert_object_detection( annotation: Dict[str, Any], file_id: str, project_id: str ) -> Optional[Dict[str, Any]]: """将 ImageObjectDetectionBoundingBox 算子输出转换为 LS annotation。""" detections = annotation.get("detections") if not isinstance(detections, list) or not detections: return None img_width = annotation.get("width", 0) img_height = annotation.get("height", 0) if not img_width or not img_height: return None ls_result: List[Dict[str, Any]] = [] for det in detections: if not isinstance(det, dict): continue label = det.get("label", "unknown") bbox = det.get("bbox_xyxy") if not isinstance(bbox, (list, tuple)) or len(bbox) < 4: continue x1, y1, x2, y2 = float(bbox[0]), float(bbox[1]), float(bbox[2]), float(bbox[3]) x_pct = x1 * 100.0 / img_width y_pct = y1 * 100.0 / img_height w_pct = (x2 - x1) * 100.0 / img_width h_pct = (y2 - y1) * 100.0 / img_height region_id = str(uuid.uuid4()) ls_result.append( { "id": region_id, "from_name": "label", "to_name": "image", "type": "rectanglelabels", "original_width": int(img_width), "original_height": int(img_height), "image_rotation": 0, "value": { "x": round(x_pct, 4), "y": round(y_pct, 4), "width": round(w_pct, 4), "height": round(h_pct, 4), "rotation": 0, "rectanglelabels": [str(label)], }, } ) if not ls_result: return None return { "id": make_ls_annotation_id(project_id, file_id), "task": make_ls_task_id(project_id, file_id), "result": ls_result, } # --------------------------------------------------------------------------- # 分发器 # --------------------------------------------------------------------------- TASK_TYPE_CONVERTERS = { "text_classification": convert_text_classification, "ner": convert_ner, "relation_extraction": convert_relation_extraction, "object_detection": convert_object_detection, } def convert_annotation( annotation: Dict[str, Any], file_id: str, project_id: str ) -> Optional[Dict[str, Any]]: """根据 task_type 分发到对应的转换函数。""" task_type = annotation.get("task_type") if task_type is None and "detections" in annotation: task_type = "object_detection" converter = TASK_TYPE_CONVERTERS.get(task_type) # type: ignore[arg-type] if converter is None: logger.warning("No LS converter for task_type: {}", task_type) return None try: return converter(annotation, file_id, project_id) except Exception as exc: logger.error("Failed to convert annotation (task_type={}): {}", task_type, exc) return None # --------------------------------------------------------------------------- # label_config XML 生成 # --------------------------------------------------------------------------- def _split_labels(raw: str) -> List[str]: """逗号分隔字符串 → 去空白列表。""" return [s.strip() for s in raw.split(",") if s.strip()] def generate_label_config_xml(task_type: str, operator_params: Dict[str, Any]) -> str: """根据标注类型和算子参数生成 Label Studio label_config XML。""" if task_type == "text_classification": return _gen_text_classification_xml(operator_params) if task_type == "ner": return _gen_ner_xml(operator_params) if task_type == "relation_extraction": return _gen_relation_extraction_xml(operator_params) if task_type == "object_detection": return _gen_object_detection_xml(operator_params) # 未知类型:返回最小可用 XML return "
" def _gen_text_classification_xml(params: Dict[str, Any]) -> str: categories = _split_labels(str(params.get("categories", "正面,负面,中性"))) view = ET.Element("View") ET.SubElement(view, "Text", name="text", value="$text") choices = ET.SubElement( view, "Choices", name="sentiment", toName="text", choice="single", showInline="true", ) for cat in categories: ET.SubElement(choices, "Choice", value=cat) return ET.tostring(view, encoding="unicode") def _gen_ner_xml(params: Dict[str, Any]) -> str: entity_types = _split_labels(str(params.get("entityTypes", "PER,ORG,LOC,DATE"))) view = ET.Element("View") labels = ET.SubElement(view, "Labels", name="label", toName="text") for i, et in enumerate(entity_types): ET.SubElement(labels, "Label", value=et, background=_pick_color(i)) ET.SubElement(view, "Text", name="text", value="$text") return ET.tostring(view, encoding="unicode") def _gen_relation_extraction_xml(params: Dict[str, Any]) -> str: entity_types = _split_labels(str(params.get("entityTypes", "PER,ORG,LOC"))) relation_types = _split_labels( str(params.get("relationTypes", "属于,位于,创立,工作于")) ) view = ET.Element("View") relations = ET.SubElement(view, "Relations") for rt in relation_types: ET.SubElement(relations, "Relation", value=rt) labels = ET.SubElement(view, "Labels", name="label", toName="text") for i, et in enumerate(entity_types): ET.SubElement(labels, "Label", value=et, background=_pick_color(i)) ET.SubElement(view, "Text", name="text", value="$text") return ET.tostring(view, encoding="unicode") def _gen_object_detection_xml(params: Dict[str, Any]) -> str: detected_labels: List[str] = params.get("_detected_labels", []) if not detected_labels: detected_labels = ["object"] view = ET.Element("View") ET.SubElement(view, "Image", name="image", value="$image") rect_labels = ET.SubElement(view, "RectangleLabels", name="label", toName="image") for i, lbl in enumerate(detected_labels): ET.SubElement(rect_labels, "Label", value=lbl, background=_pick_color(i)) return ET.tostring(view, encoding="unicode") # --------------------------------------------------------------------------- # Pipeline 信息提取 # --------------------------------------------------------------------------- OPERATOR_TASK_TYPE_MAP: Dict[str, str] = { "LLMTextClassification": "text_classification", "LLMNamedEntityRecognition": "ner", "LLMRelationExtraction": "relation_extraction", "ImageObjectDetectionBoundingBox": "object_detection", } def infer_task_type_from_pipeline(pipeline: List[Dict[str, Any]]) -> Optional[str]: """从 normalized pipeline 中推断标注类型(取第一个匹配的算子)。""" for step in pipeline: operator_id = str(step.get("operator_id", "")) task_type = OPERATOR_TASK_TYPE_MAP.get(operator_id) if task_type is not None: return task_type return None def extract_operator_params(pipeline: List[Dict[str, Any]]) -> Dict[str, Any]: """从 normalized pipeline 中提取第一个标注算子的 overrides 参数。""" for step in pipeline: operator_id = str(step.get("operator_id", "")) if operator_id in OPERATOR_TASK_TYPE_MAP: return dict(step.get("overrides", {})) return {}