from dataclasses import dataclass
from datetime import datetime
from typing import List
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.models.annotation_management import AnnotationTemplate
from app.module.annotation.utils.config_validator import LabelStudioConfigValidator
logger = get_logger(__name__)
DATA_TYPE_IMAGE = "image"
DATA_TYPE_AUDIO = "audio"
DATA_TYPE_VIDEO = "video"
CATEGORY_COMPUTER_VISION = "computer-vision"
CATEGORY_AUDIO_SPEECH = "audio-speech"
CATEGORY_VIDEO = "video"
STYLE_HORIZONTAL = "horizontal"
VERSION_DEFAULT = "1.0.0"
IMAGE_CLASSIFICATION_LABEL_CONFIG = """
"""
OBJECT_DETECTION_LABEL_CONFIG = """
"""
SEMANTIC_SEGMENTATION_MASK_LABEL_CONFIG = """
"""
SEMANTIC_SEGMENTATION_POLYGON_LABEL_CONFIG = """
"""
ASR_SEGMENTS_LABEL_CONFIG = """
"""
ASR_LABEL_CONFIG = """
"""
CONVERSATION_ANALYSIS_LABEL_CONFIG = """
"""
INTENT_CLASSIFICATION_LABEL_CONFIG = """
"""
SIGNAL_QUALITY_LABEL_CONFIG = """
"""
SOUND_EVENT_DETECTION_LABEL_CONFIG = """
"""
SPEAKER_SEGMENTATION_LABEL_CONFIG = """
"""
VIDEO_CLASSIFICATION_LABEL_CONFIG = """
"""
VIDEO_OBJECT_TRACKING_LABEL_CONFIG = """
"""
VIDEO_TIMELINE_SEGMENTATION_LABEL_CONFIG = """
"""
@dataclass(frozen=True)
class BuiltInTemplateDefinition:
id: str
name: str
description: str
data_type: str
labeling_type: str
label_config: str
style: str
category: str
version: str
BUILT_IN_TEMPLATES: List[BuiltInTemplateDefinition] = [
BuiltInTemplateDefinition(
id="tpl-image-classification-001",
name="图像分类",
description=(
"对图像进行分类,适用于内容审核、安全检测、社交媒体审核等场景。"
"关联模型:ResNet、EfficientNet、Vision Transformer"
),
data_type=DATA_TYPE_IMAGE,
labeling_type="image-classification",
label_config=IMAGE_CLASSIFICATION_LABEL_CONFIG,
style=STYLE_HORIZONTAL,
category=CATEGORY_COMPUTER_VISION,
version=VERSION_DEFAULT,
),
BuiltInTemplateDefinition(
id="tpl-object-detection-001",
name="目标检测(边界框)",
description=(
"在目标周围绘制边界框,适用于自动驾驶、交通监控、安防监控、零售分析等场景。"
"关联模型:YOLO、R-CNN、SSD"
),
data_type=DATA_TYPE_IMAGE,
labeling_type="object-detection",
label_config=OBJECT_DETECTION_LABEL_CONFIG,
style=STYLE_HORIZONTAL,
category=CATEGORY_COMPUTER_VISION,
version=VERSION_DEFAULT,
),
BuiltInTemplateDefinition(
id="tpl-semantic-segmentation-mask-001",
name="语义分割(掩码)",
description=(
"使用画笔工具在目标周围绘制掩码,适用于自动驾驶、医学图像分析、卫星图像分析等场景。"
"关联模型:U-Net、DeepLab、Mask R-CNN"
),
data_type=DATA_TYPE_IMAGE,
labeling_type="semantic-segmentation-mask",
label_config=SEMANTIC_SEGMENTATION_MASK_LABEL_CONFIG,
style=STYLE_HORIZONTAL,
category=CATEGORY_COMPUTER_VISION,
version=VERSION_DEFAULT,
),
BuiltInTemplateDefinition(
id="tpl-semantic-segmentation-polygon-001",
name="语义分割(多边形)",
description=(
"在目标周围绘制多边形,适用于自动驾驶、医学图像、卫星图像、精准农业等场景。"
"关联模型:DeepLab、PSPNet、U-Net"
),
data_type=DATA_TYPE_IMAGE,
labeling_type="semantic-segmentation-polygon",
label_config=SEMANTIC_SEGMENTATION_POLYGON_LABEL_CONFIG,
style=STYLE_HORIZONTAL,
category=CATEGORY_COMPUTER_VISION,
version=VERSION_DEFAULT,
),
BuiltInTemplateDefinition(
id="tpl-asr-segments-001",
name="语音识别(分段)",
description=(
"对音频进行语音活动分段并转录文本,适用于呼叫中心转录、会议记录、播客转录等场景。"
"关联模型:Whisper、Wav2Vec2、DeepSpeech"
),
data_type=DATA_TYPE_AUDIO,
labeling_type="asr-segments",
label_config=ASR_SEGMENTS_LABEL_CONFIG,
style=STYLE_HORIZONTAL,
category=CATEGORY_AUDIO_SPEECH,
version=VERSION_DEFAULT,
),
BuiltInTemplateDefinition(
id="tpl-asr-001",
name="语音识别",
description=(
"转录音频内容,适用于播客转录、会议记录、客服通话、字幕生成等场景。"
"关联模型:Whisper、Wav2Vec、DeepSpeech"
),
data_type=DATA_TYPE_AUDIO,
labeling_type="asr",
label_config=ASR_LABEL_CONFIG,
style=STYLE_HORIZONTAL,
category=CATEGORY_AUDIO_SPEECH,
version=VERSION_DEFAULT,
),
BuiltInTemplateDefinition(
id="tpl-conversation-analysis-001",
name="对话分析",
description="分析对话语句并标注事实和情感方面,适用于呼叫中心质检、客服分析、会议分析等场景",
data_type=DATA_TYPE_AUDIO,
labeling_type="conversation-analysis",
label_config=CONVERSATION_ANALYSIS_LABEL_CONFIG,
style=STYLE_HORIZONTAL,
category=CATEGORY_AUDIO_SPEECH,
version=VERSION_DEFAULT,
),
BuiltInTemplateDefinition(
id="tpl-intent-classification-001",
name="意图分类",
description="进行语音活动分段并选择语音意图,适用于语音助手、智能音箱、IVR系统等场景",
data_type=DATA_TYPE_AUDIO,
labeling_type="intent-classification",
label_config=INTENT_CLASSIFICATION_LABEL_CONFIG,
style=STYLE_HORIZONTAL,
category=CATEGORY_AUDIO_SPEECH,
version=VERSION_DEFAULT,
),
BuiltInTemplateDefinition(
id="tpl-signal-quality-001",
name="信号质量检测",
description="评估音频信号质量,适用于电信、呼叫中心质检、音频制作、VoIP质量评估等场景",
data_type=DATA_TYPE_AUDIO,
labeling_type="signal-quality",
label_config=SIGNAL_QUALITY_LABEL_CONFIG,
style=STYLE_HORIZONTAL,
category=CATEGORY_AUDIO_SPEECH,
version=VERSION_DEFAULT,
),
BuiltInTemplateDefinition(
id="tpl-sound-event-001",
name="声音事件检测",
description="选择音频片段并分类声音事件,适用于安防监控、智慧城市、环境监测、工业监测等场景",
data_type=DATA_TYPE_AUDIO,
labeling_type="sound-event-detection",
label_config=SOUND_EVENT_DETECTION_LABEL_CONFIG,
style=STYLE_HORIZONTAL,
category=CATEGORY_AUDIO_SPEECH,
version=VERSION_DEFAULT,
),
BuiltInTemplateDefinition(
id="tpl-speaker-segmentation-001",
name="说话人分割",
description="执行说话人分割/话者分离任务,适用于会议转录、播客制作、呼叫中心分析等场景",
data_type=DATA_TYPE_AUDIO,
labeling_type="speaker-segmentation",
label_config=SPEAKER_SEGMENTATION_LABEL_CONFIG,
style=STYLE_HORIZONTAL,
category=CATEGORY_AUDIO_SPEECH,
version=VERSION_DEFAULT,
),
BuiltInTemplateDefinition(
id="tpl-video-classification-001",
name="视频分类",
description="对视频进行整体分类,适用于内容审核、媒体分析、质检等场景",
data_type=DATA_TYPE_VIDEO,
labeling_type="video-classification",
label_config=VIDEO_CLASSIFICATION_LABEL_CONFIG,
style=STYLE_HORIZONTAL,
category=CATEGORY_VIDEO,
version=VERSION_DEFAULT,
),
BuiltInTemplateDefinition(
id="tpl-video-object-tracking-001",
name="视频目标追踪",
description="在视频中追踪目标对象,适用于安防监控、交通分析、行为分析等场景",
data_type=DATA_TYPE_VIDEO,
labeling_type="video-object-tracking",
label_config=VIDEO_OBJECT_TRACKING_LABEL_CONFIG,
style=STYLE_HORIZONTAL,
category=CATEGORY_VIDEO,
version=VERSION_DEFAULT,
),
BuiltInTemplateDefinition(
id="tpl-video-timeline-segmentation-001",
name="视频时间线分割",
description="对视频时间线进行分段标注,适用于视频剪辑、内容索引等场景",
data_type=DATA_TYPE_VIDEO,
labeling_type="video-timeline-segmentation",
label_config=VIDEO_TIMELINE_SEGMENTATION_LABEL_CONFIG,
style=STYLE_HORIZONTAL,
category=CATEGORY_VIDEO,
version=VERSION_DEFAULT,
),
]
assert len({template.id for template in BUILT_IN_TEMPLATES}) == len(BUILT_IN_TEMPLATES), (
"内置模板ID不能重复"
)
async def ensure_builtin_annotation_templates(db: AsyncSession) -> int:
inserted = 0
for template in BUILT_IN_TEMPLATES:
label_config = template.label_config.strip()
assert label_config, f"内置模板 {template.id} 的 label_config 不能为空"
valid, error = LabelStudioConfigValidator.validate_xml(label_config)
if not valid:
raise ValueError(f"内置模板 {template.id} 的 label_config 无效: {error}")
result = await db.execute(
select(AnnotationTemplate).where(AnnotationTemplate.id == template.id)
)
existing = result.scalar_one_or_none()
if existing:
continue
record = AnnotationTemplate(
id=template.id,
name=template.name,
description=template.description,
data_type=template.data_type,
labeling_type=template.labeling_type,
configuration=None,
label_config=label_config,
style=template.style,
category=template.category,
built_in=True,
version=template.version,
created_at=datetime.now(),
)
db.add(record)
try:
await db.commit()
inserted += 1
except IntegrityError:
await db.rollback()
logger.warning(f"内置模板已存在,跳过插入: {template.id}")
except Exception:
await db.rollback()
logger.exception(f"写入内置模板失败: {template.id}")
raise
return inserted