from dataclasses import dataclass
from datetime import datetime
from typing import List
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.models.annotation_management import AnnotationTemplate
from app.module.annotation.utils.config_validator import LabelStudioConfigValidator
logger = get_logger(__name__)
DATA_TYPE_IMAGE = "image"
CATEGORY_COMPUTER_VISION = "computer-vision"
STYLE_HORIZONTAL = "horizontal"
VERSION_DEFAULT = "1.0.0"
IMAGE_CLASSIFICATION_LABEL_CONFIG = """
"""
OBJECT_DETECTION_LABEL_CONFIG = """
"""
SEMANTIC_SEGMENTATION_MASK_LABEL_CONFIG = """
"""
SEMANTIC_SEGMENTATION_POLYGON_LABEL_CONFIG = """
"""
@dataclass(frozen=True)
class BuiltInTemplateDefinition:
id: str
name: str
description: str
data_type: str
labeling_type: str
label_config: str
style: str
category: str
version: str
BUILT_IN_TEMPLATES: List[BuiltInTemplateDefinition] = [
BuiltInTemplateDefinition(
id="tpl-image-classification-001",
name="图像分类",
description=(
"对图像进行分类,适用于内容审核、安全检测、社交媒体审核等场景。"
"关联模型:ResNet、EfficientNet、Vision Transformer"
),
data_type=DATA_TYPE_IMAGE,
labeling_type="image-classification",
label_config=IMAGE_CLASSIFICATION_LABEL_CONFIG,
style=STYLE_HORIZONTAL,
category=CATEGORY_COMPUTER_VISION,
version=VERSION_DEFAULT,
),
BuiltInTemplateDefinition(
id="tpl-object-detection-001",
name="目标检测(边界框)",
description=(
"在目标周围绘制边界框,适用于自动驾驶、交通监控、安防监控、零售分析等场景。"
"关联模型:YOLO、R-CNN、SSD"
),
data_type=DATA_TYPE_IMAGE,
labeling_type="object-detection",
label_config=OBJECT_DETECTION_LABEL_CONFIG,
style=STYLE_HORIZONTAL,
category=CATEGORY_COMPUTER_VISION,
version=VERSION_DEFAULT,
),
BuiltInTemplateDefinition(
id="tpl-semantic-segmentation-mask-001",
name="语义分割(掩码)",
description=(
"使用画笔工具在目标周围绘制掩码,适用于自动驾驶、医学图像分析、卫星图像分析等场景。"
"关联模型:U-Net、DeepLab、Mask R-CNN"
),
data_type=DATA_TYPE_IMAGE,
labeling_type="semantic-segmentation-mask",
label_config=SEMANTIC_SEGMENTATION_MASK_LABEL_CONFIG,
style=STYLE_HORIZONTAL,
category=CATEGORY_COMPUTER_VISION,
version=VERSION_DEFAULT,
),
BuiltInTemplateDefinition(
id="tpl-semantic-segmentation-polygon-001",
name="语义分割(多边形)",
description=(
"在目标周围绘制多边形,适用于自动驾驶、医学图像、卫星图像、精准农业等场景。"
"关联模型:DeepLab、PSPNet、U-Net"
),
data_type=DATA_TYPE_IMAGE,
labeling_type="semantic-segmentation-polygon",
label_config=SEMANTIC_SEGMENTATION_POLYGON_LABEL_CONFIG,
style=STYLE_HORIZONTAL,
category=CATEGORY_COMPUTER_VISION,
version=VERSION_DEFAULT,
),
]
assert len({template.id for template in BUILT_IN_TEMPLATES}) == len(BUILT_IN_TEMPLATES), (
"内置模板ID不能重复"
)
async def ensure_builtin_annotation_templates(db: AsyncSession) -> int:
inserted = 0
for template in BUILT_IN_TEMPLATES:
label_config = template.label_config.strip()
assert label_config, f"内置模板 {template.id} 的 label_config 不能为空"
valid, error = LabelStudioConfigValidator.validate_xml(label_config)
if not valid:
raise ValueError(f"内置模板 {template.id} 的 label_config 无效: {error}")
result = await db.execute(
select(AnnotationTemplate).where(AnnotationTemplate.id == template.id)
)
existing = result.scalar_one_or_none()
if existing:
continue
record = AnnotationTemplate(
id=template.id,
name=template.name,
description=template.description,
data_type=template.data_type,
labeling_type=template.labeling_type,
configuration=None,
label_config=label_config,
style=template.style,
category=template.category,
built_in=True,
version=template.version,
created_at=datetime.now(),
)
db.add(record)
try:
await db.commit()
inserted += 1
except IntegrityError:
await db.rollback()
logger.warning(f"内置模板已存在,跳过插入: {template.id}")
except Exception:
await db.rollback()
logger.exception(f"写入内置模板失败: {template.id}")
raise
return inserted