Files
DataMate/runtime/datamate-python/app/module/annotation/service/builtin_templates.py
Jerry Yan 250a13ff70 feat(annotation): 支持图像标注项目并添加内置标注模板
- 扩展标注编辑器支持 TEXT/IMAGE 数据类型
- 添加三个内置图像标注模板:目标检测、语义分割(掩码)、语义分割(多边形)
- 实现内置标注模板的数据库初始化功能
- 集成标注配置验证和模板管理服务
- 更新项目不支持提示信息以反映新的数据类型支持
2026-01-25 18:35:07 +08:00

155 lines
5.2 KiB
Python

from dataclasses import dataclass
from datetime import datetime
from typing import List
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.models.annotation_management import AnnotationTemplate
from app.module.annotation.utils.config_validator import LabelStudioConfigValidator
logger = get_logger(__name__)
DATA_TYPE_IMAGE = "image"
CATEGORY_COMPUTER_VISION = "computer-vision"
STYLE_HORIZONTAL = "horizontal"
VERSION_DEFAULT = "1.0.0"
OBJECT_DETECTION_LABEL_CONFIG = """<View>
<Image name=\"image\" value=\"$image\"/>
<RectangleLabels name=\"label\" toName=\"image\">
<Label value=\"Airplane\" background=\"green\"/>
<Label value=\"Car\" background=\"blue\"/>
</RectangleLabels>
</View>"""
SEMANTIC_SEGMENTATION_MASK_LABEL_CONFIG = """<View>
<Image name=\"image\" value=\"$image\" zoom=\"true\"/>
<BrushLabels name=\"tag\" toName=\"image\">
<Label value=\"Airplane\" background=\"rgba(255, 0, 0, 0.7)\"/>
<Label value=\"Car\" background=\"rgba(0, 0, 255, 0.7)\"/>
</BrushLabels>
</View>"""
SEMANTIC_SEGMENTATION_POLYGON_LABEL_CONFIG = """<View>
<Header value=\"选择标签并点击图像开始\"/>
<Image name=\"image\" value=\"$image\" zoom=\"true\"/>
<PolygonLabels name=\"label\" toName=\"image\" strokeWidth=\"3\" pointSize=\"small\" opacity=\"0.9\">
<Label value=\"Airplane\" background=\"red\"/>
<Label value=\"Car\" background=\"blue\"/>
</PolygonLabels>
</View>"""
@dataclass(frozen=True)
class BuiltInTemplateDefinition:
id: str
name: str
description: str
data_type: str
labeling_type: str
label_config: str
style: str
category: str
version: str
BUILT_IN_TEMPLATES: List[BuiltInTemplateDefinition] = [
BuiltInTemplateDefinition(
id="tpl-object-detection-001",
name="目标检测(边界框)",
description=(
"在目标周围绘制边界框,适用于自动驾驶、交通监控、安防监控、零售分析等场景。"
"关联模型:YOLO、R-CNN、SSD"
),
data_type=DATA_TYPE_IMAGE,
labeling_type="object-detection",
label_config=OBJECT_DETECTION_LABEL_CONFIG,
style=STYLE_HORIZONTAL,
category=CATEGORY_COMPUTER_VISION,
version=VERSION_DEFAULT,
),
BuiltInTemplateDefinition(
id="tpl-semantic-segmentation-mask-001",
name="语义分割(掩码)",
description=(
"使用画笔工具在目标周围绘制掩码,适用于自动驾驶、医学图像分析、卫星图像分析等场景。"
"关联模型:U-Net、DeepLab、Mask R-CNN"
),
data_type=DATA_TYPE_IMAGE,
labeling_type="semantic-segmentation-mask",
label_config=SEMANTIC_SEGMENTATION_MASK_LABEL_CONFIG,
style=STYLE_HORIZONTAL,
category=CATEGORY_COMPUTER_VISION,
version=VERSION_DEFAULT,
),
BuiltInTemplateDefinition(
id="tpl-semantic-segmentation-polygon-001",
name="语义分割(多边形)",
description=(
"在目标周围绘制多边形,适用于自动驾驶、医学图像、卫星图像、精准农业等场景。"
"关联模型:DeepLab、PSPNet、U-Net"
),
data_type=DATA_TYPE_IMAGE,
labeling_type="semantic-segmentation-polygon",
label_config=SEMANTIC_SEGMENTATION_POLYGON_LABEL_CONFIG,
style=STYLE_HORIZONTAL,
category=CATEGORY_COMPUTER_VISION,
version=VERSION_DEFAULT,
),
]
assert len({template.id for template in BUILT_IN_TEMPLATES}) == len(BUILT_IN_TEMPLATES), (
"内置模板ID不能重复"
)
async def ensure_builtin_annotation_templates(db: AsyncSession) -> int:
inserted = 0
for template in BUILT_IN_TEMPLATES:
label_config = template.label_config.strip()
assert label_config, f"内置模板 {template.id} 的 label_config 不能为空"
valid, error = LabelStudioConfigValidator.validate_xml(label_config)
if not valid:
raise ValueError(f"内置模板 {template.id} 的 label_config 无效: {error}")
result = await db.execute(
select(AnnotationTemplate).where(AnnotationTemplate.id == template.id)
)
existing = result.scalar_one_or_none()
if existing:
continue
record = AnnotationTemplate(
id=template.id,
name=template.name,
description=template.description,
data_type=template.data_type,
labeling_type=template.labeling_type,
configuration=None,
label_config=label_config,
style=template.style,
category=template.category,
built_in=True,
version=template.version,
created_at=datetime.now(),
)
db.add(record)
try:
await db.commit()
inserted += 1
except IntegrityError:
await db.rollback()
logger.warning(f"内置模板已存在,跳过插入: {template.id}")
except Exception:
await db.rollback()
logger.exception(f"写入内置模板失败: {template.id}")
raise
return inserted