from dataclasses import dataclass from datetime import datetime from typing import List from sqlalchemy import select from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from app.core.logging import get_logger from app.db.models.annotation_management import AnnotationTemplate from app.module.annotation.utils.config_validator import LabelStudioConfigValidator logger = get_logger(__name__) DATA_TYPE_IMAGE = "image" CATEGORY_COMPUTER_VISION = "computer-vision" STYLE_HORIZONTAL = "horizontal" VERSION_DEFAULT = "1.0.0" IMAGE_CLASSIFICATION_LABEL_CONFIG = """ """ OBJECT_DETECTION_LABEL_CONFIG = """ """ SEMANTIC_SEGMENTATION_MASK_LABEL_CONFIG = """ """ SEMANTIC_SEGMENTATION_POLYGON_LABEL_CONFIG = """
""" @dataclass(frozen=True) class BuiltInTemplateDefinition: id: str name: str description: str data_type: str labeling_type: str label_config: str style: str category: str version: str BUILT_IN_TEMPLATES: List[BuiltInTemplateDefinition] = [ BuiltInTemplateDefinition( id="tpl-image-classification-001", name="图像分类", description=( "对图像进行分类,适用于内容审核、安全检测、社交媒体审核等场景。" "关联模型:ResNet、EfficientNet、Vision Transformer" ), data_type=DATA_TYPE_IMAGE, labeling_type="image-classification", label_config=IMAGE_CLASSIFICATION_LABEL_CONFIG, style=STYLE_HORIZONTAL, category=CATEGORY_COMPUTER_VISION, version=VERSION_DEFAULT, ), BuiltInTemplateDefinition( id="tpl-object-detection-001", name="目标检测(边界框)", description=( "在目标周围绘制边界框,适用于自动驾驶、交通监控、安防监控、零售分析等场景。" "关联模型:YOLO、R-CNN、SSD" ), data_type=DATA_TYPE_IMAGE, labeling_type="object-detection", label_config=OBJECT_DETECTION_LABEL_CONFIG, style=STYLE_HORIZONTAL, category=CATEGORY_COMPUTER_VISION, version=VERSION_DEFAULT, ), BuiltInTemplateDefinition( id="tpl-semantic-segmentation-mask-001", name="语义分割(掩码)", description=( "使用画笔工具在目标周围绘制掩码,适用于自动驾驶、医学图像分析、卫星图像分析等场景。" "关联模型:U-Net、DeepLab、Mask R-CNN" ), data_type=DATA_TYPE_IMAGE, labeling_type="semantic-segmentation-mask", label_config=SEMANTIC_SEGMENTATION_MASK_LABEL_CONFIG, style=STYLE_HORIZONTAL, category=CATEGORY_COMPUTER_VISION, version=VERSION_DEFAULT, ), BuiltInTemplateDefinition( id="tpl-semantic-segmentation-polygon-001", name="语义分割(多边形)", description=( "在目标周围绘制多边形,适用于自动驾驶、医学图像、卫星图像、精准农业等场景。" "关联模型:DeepLab、PSPNet、U-Net" ), data_type=DATA_TYPE_IMAGE, labeling_type="semantic-segmentation-polygon", label_config=SEMANTIC_SEGMENTATION_POLYGON_LABEL_CONFIG, style=STYLE_HORIZONTAL, category=CATEGORY_COMPUTER_VISION, version=VERSION_DEFAULT, ), ] assert len({template.id for template in BUILT_IN_TEMPLATES}) == len(BUILT_IN_TEMPLATES), ( "内置模板ID不能重复" ) async def ensure_builtin_annotation_templates(db: AsyncSession) -> int: inserted = 0 for template in BUILT_IN_TEMPLATES: label_config = template.label_config.strip() assert label_config, f"内置模板 {template.id} 的 label_config 不能为空" valid, error = LabelStudioConfigValidator.validate_xml(label_config) if not valid: raise ValueError(f"内置模板 {template.id} 的 label_config 无效: {error}") result = await db.execute( select(AnnotationTemplate).where(AnnotationTemplate.id == template.id) ) existing = result.scalar_one_or_none() if existing: continue record = AnnotationTemplate( id=template.id, name=template.name, description=template.description, data_type=template.data_type, labeling_type=template.labeling_type, configuration=None, label_config=label_config, style=template.style, category=template.category, built_in=True, version=template.version, created_at=datetime.now(), ) db.add(record) try: await db.commit() inserted += 1 except IntegrityError: await db.rollback() logger.warning(f"内置模板已存在,跳过插入: {template.id}") except Exception: await db.rollback() logger.exception(f"写入内置模板失败: {template.id}") raise return inserted