You've already forked DataMate
- 移除 TemplateConfigurationForm 组件并引入 TemplateConfigurationTreeEditor - 使用 useTagConfig Hook 获取标签配置 - 将自定义XML状态 customXml 替换为 labelConfig - 删除模板编辑标签页和选择模板状态管理 - 更新XML解析逻辑支持更多对象和标注控件类型 - 添加配置验证功能确保至少包含数据对象和标注控件 - 在模板详情页面使用树形编辑器显示配置详情 - 更新任务创建页面集成新的树形配置编辑器 - 调整预览数据生成功能适配新的XML解析方式
346 lines
11 KiB
Python
346 lines
11 KiB
Python
"""
|
|
Annotation Template Service
|
|
"""
|
|
from typing import Optional, List
|
|
from datetime import datetime
|
|
from sqlalchemy import select, func, or_
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from uuid import uuid4
|
|
from fastapi import HTTPException
|
|
|
|
from app.db.models.annotation_management import AnnotationTemplate
|
|
from app.module.annotation.schema.template import (
|
|
CreateAnnotationTemplateRequest,
|
|
UpdateAnnotationTemplateRequest,
|
|
AnnotationTemplateResponse,
|
|
AnnotationTemplateListResponse,
|
|
TemplateConfiguration
|
|
)
|
|
from app.module.annotation.utils.config_validator import LabelStudioConfigValidator
|
|
from app.module.annotation.config import LabelStudioTagConfig
|
|
|
|
|
|
class AnnotationTemplateService:
|
|
"""标注模板服务"""
|
|
|
|
@staticmethod
|
|
def generate_label_studio_config(config: TemplateConfiguration) -> str:
|
|
"""
|
|
从配置JSON生成Label Studio XML配置
|
|
|
|
Args:
|
|
config: 模板配置对象
|
|
|
|
Returns:
|
|
Label Studio XML字符串
|
|
"""
|
|
tag_config = LabelStudioTagConfig()
|
|
control_types = tag_config.get_control_types()
|
|
|
|
def normalize_control_type(raw: Optional[str]) -> str:
|
|
if not raw:
|
|
return "Choices"
|
|
if raw in control_types:
|
|
return raw
|
|
raw_lower = raw.lower()
|
|
for ct in control_types:
|
|
if ct.lower() == raw_lower:
|
|
return ct
|
|
return raw
|
|
|
|
# 生成对象定义
|
|
object_parts: List[str] = []
|
|
for obj in config.objects:
|
|
obj_attrs = [
|
|
f'name="{obj.name}"',
|
|
f'value="{obj.value}"'
|
|
]
|
|
object_parts.append(f' <{obj.type} {" ".join(obj_attrs)}/>')
|
|
|
|
# 生成标签定义
|
|
control_parts: List[str] = []
|
|
for label in config.labels:
|
|
label_attrs = [f'name="{label.from_name}"', f'toName="{label.to_name}"']
|
|
|
|
# 添加可选属性
|
|
if label.required:
|
|
label_attrs.append('required="true"')
|
|
|
|
tag_type = normalize_control_type(label.type)
|
|
if tag_type == "Choices" and label.choice:
|
|
label_attrs.append(f'choice="{label.choice}"')
|
|
if tag_type == "Choices" and label.show_inline is not None:
|
|
label_attrs.append(f'showInline="{"true" if label.show_inline else "false"}"')
|
|
|
|
# 检查是否需要子元素
|
|
if label.options or label.labels:
|
|
choices = label.options or label.labels or []
|
|
control_parts.append(f' <{tag_type} {" ".join(label_attrs)}>')
|
|
|
|
# 从配置获取子元素标签名
|
|
child_tag = tag_config.get_child_tag(tag_type)
|
|
if not child_tag:
|
|
# 默认使用 Label
|
|
child_tag = "Label"
|
|
|
|
for choice in choices:
|
|
control_parts.append(f' <{child_tag} value="{choice}"/>')
|
|
control_parts.append(f' </{tag_type}>')
|
|
else:
|
|
# 处理简单标签类型(不需要子元素)
|
|
control_parts.append(f' <{tag_type} {" ".join(label_attrs)}/>')
|
|
|
|
# 说明:
|
|
# - Label Studio Frontend 默认会将控件(control tags)渲染到右侧侧栏(side-column/controls)。
|
|
# - 如果在 XML 中手工做“双栏布局”,会导致控件出现在主区域,从而与侧栏的结果面板重复,影响体验。
|
|
xml_parts = ['<View>']
|
|
xml_parts.extend(object_parts)
|
|
xml_parts.extend(control_parts)
|
|
xml_parts.append('</View>')
|
|
return '\n'.join(xml_parts)
|
|
|
|
async def create_template(
|
|
self,
|
|
db: AsyncSession,
|
|
request: CreateAnnotationTemplateRequest
|
|
) -> AnnotationTemplateResponse:
|
|
"""
|
|
创建标注模板
|
|
|
|
Args:
|
|
db: 数据库会话
|
|
request: 创建请求
|
|
|
|
Returns:
|
|
创建的模板响应
|
|
"""
|
|
label_config = request.label_config
|
|
valid, error = LabelStudioConfigValidator.validate_xml(label_config)
|
|
if not valid:
|
|
raise HTTPException(status_code=400, detail=f"Invalid labelConfig: {error}")
|
|
|
|
# 创建模板对象
|
|
template = AnnotationTemplate(
|
|
id=str(uuid4()),
|
|
name=request.name,
|
|
description=request.description,
|
|
data_type=request.data_type,
|
|
labeling_type=request.labeling_type,
|
|
configuration=None,
|
|
label_config=label_config,
|
|
style=request.style,
|
|
category=request.category,
|
|
built_in=False,
|
|
version="1.0.0",
|
|
created_at=datetime.now()
|
|
)
|
|
|
|
db.add(template)
|
|
await db.commit()
|
|
await db.refresh(template)
|
|
|
|
return self._to_response(template)
|
|
|
|
async def get_template(
|
|
self,
|
|
db: AsyncSession,
|
|
template_id: str
|
|
) -> Optional[AnnotationTemplateResponse]:
|
|
"""
|
|
获取单个模板
|
|
|
|
Args:
|
|
db: 数据库会话
|
|
template_id: 模板ID
|
|
|
|
Returns:
|
|
模板响应或None
|
|
"""
|
|
result = await db.execute(
|
|
select(AnnotationTemplate)
|
|
.where(
|
|
AnnotationTemplate.id == template_id,
|
|
AnnotationTemplate.deleted_at.is_(None)
|
|
)
|
|
)
|
|
template = result.scalar_one_or_none()
|
|
|
|
if template:
|
|
return self._to_response(template)
|
|
return None
|
|
|
|
async def list_templates(
|
|
self,
|
|
db: AsyncSession,
|
|
page: int = 1,
|
|
size: int = 10,
|
|
category: Optional[str] = None,
|
|
data_type: Optional[str] = None,
|
|
labeling_type: Optional[str] = None,
|
|
built_in: Optional[bool] = None,
|
|
keyword: Optional[str] = None
|
|
) -> AnnotationTemplateListResponse:
|
|
"""
|
|
获取模板列表
|
|
|
|
Args:
|
|
db: 数据库会话
|
|
page: 页码
|
|
size: 每页大小
|
|
category: 分类筛选
|
|
data_type: 数据类型筛选
|
|
labeling_type: 标注类型筛选
|
|
built_in: 是否内置模板筛选
|
|
|
|
Returns:
|
|
模板列表响应
|
|
"""
|
|
# 构建查询条件
|
|
conditions: List = [AnnotationTemplate.deleted_at.is_(None)]
|
|
|
|
if category:
|
|
conditions.append(AnnotationTemplate.category == category) # type: ignore
|
|
if data_type:
|
|
conditions.append(AnnotationTemplate.data_type == data_type) # type: ignore
|
|
if labeling_type:
|
|
conditions.append(AnnotationTemplate.labeling_type == labeling_type) # type: ignore
|
|
if built_in is not None:
|
|
conditions.append(AnnotationTemplate.built_in == built_in) # type: ignore
|
|
if keyword:
|
|
like_keyword = f"%{keyword}%"
|
|
conditions.append(
|
|
or_(
|
|
AnnotationTemplate.name.ilike(like_keyword), # type: ignore
|
|
AnnotationTemplate.description.ilike(like_keyword) # type: ignore
|
|
)
|
|
)
|
|
|
|
# 查询总数
|
|
count_result = await db.execute(
|
|
select(func.count()).select_from(AnnotationTemplate).where(*conditions)
|
|
)
|
|
total = count_result.scalar() or 0
|
|
|
|
# 分页查询
|
|
result = await db.execute(
|
|
select(AnnotationTemplate)
|
|
.where(*conditions)
|
|
.order_by(AnnotationTemplate.created_at.desc())
|
|
.limit(size)
|
|
.offset((page - 1) * size)
|
|
)
|
|
templates = result.scalars().all()
|
|
|
|
return AnnotationTemplateListResponse(
|
|
content=[self._to_response(t) for t in templates],
|
|
total=total,
|
|
page=page,
|
|
size=size,
|
|
totalPages=(total + size - 1) // size
|
|
)
|
|
|
|
async def update_template(
|
|
self,
|
|
db: AsyncSession,
|
|
template_id: str,
|
|
request: UpdateAnnotationTemplateRequest
|
|
) -> Optional[AnnotationTemplateResponse]:
|
|
"""
|
|
更新模板
|
|
|
|
Args:
|
|
db: 数据库会话
|
|
template_id: 模板ID
|
|
request: 更新请求
|
|
|
|
Returns:
|
|
更新后的模板响应或None
|
|
"""
|
|
result = await db.execute(
|
|
select(AnnotationTemplate)
|
|
.where(
|
|
AnnotationTemplate.id == template_id,
|
|
AnnotationTemplate.deleted_at.is_(None)
|
|
)
|
|
)
|
|
template = result.scalar_one_or_none()
|
|
|
|
if not template:
|
|
return None
|
|
|
|
# 更新字段
|
|
update_data = request.model_dump(exclude_unset=True, by_alias=False)
|
|
|
|
for field, value in update_data.items():
|
|
if field == "label_config" and value is not None:
|
|
valid, error = LabelStudioConfigValidator.validate_xml(value)
|
|
if not valid:
|
|
raise HTTPException(status_code=400, detail=f"Invalid labelConfig: {error}")
|
|
setattr(template, field, value)
|
|
else:
|
|
setattr(template, field, value)
|
|
|
|
template.updated_at = datetime.now() # type: ignore
|
|
|
|
await db.commit()
|
|
await db.refresh(template)
|
|
|
|
return self._to_response(template)
|
|
|
|
async def delete_template(
|
|
self,
|
|
db: AsyncSession,
|
|
template_id: str
|
|
) -> bool:
|
|
"""
|
|
删除模板(软删除)
|
|
|
|
Args:
|
|
db: 数据库会话
|
|
template_id: 模板ID
|
|
|
|
Returns:
|
|
是否删除成功
|
|
"""
|
|
result = await db.execute(
|
|
select(AnnotationTemplate)
|
|
.where(
|
|
AnnotationTemplate.id == template_id,
|
|
AnnotationTemplate.deleted_at.is_(None)
|
|
)
|
|
)
|
|
template = result.scalar_one_or_none()
|
|
|
|
if not template:
|
|
return False
|
|
|
|
template.deleted_at = datetime.now() # type: ignore
|
|
await db.commit()
|
|
|
|
return True
|
|
|
|
def _to_response(self, template: AnnotationTemplate) -> AnnotationTemplateResponse:
|
|
"""
|
|
转换为响应对象
|
|
|
|
Args:
|
|
template: 数据库模型对象
|
|
|
|
Returns:
|
|
模板响应对象
|
|
"""
|
|
config = None
|
|
if template.configuration:
|
|
try:
|
|
from typing import cast, Dict, Any
|
|
config_dict = cast(Dict[str, Any], template.configuration)
|
|
config = TemplateConfiguration(**config_dict)
|
|
except Exception:
|
|
config = None
|
|
|
|
response = AnnotationTemplateResponse.model_validate(template)
|
|
response.configuration = config
|
|
response.label_config = template.label_config # type: ignore
|
|
|
|
return response
|