feat: add labeling template. refactor: switch to Poetry, build and deploy of backend Python (#79)

* feat: Enhance annotation module with template management and validation

- Added DatasetMappingCreateRequest and DatasetMappingUpdateRequest schemas to handle dataset mapping requests with camelCase and snake_case support.
- Introduced Annotation Template schemas including CreateAnnotationTemplateRequest, UpdateAnnotationTemplateRequest, and AnnotationTemplateResponse for managing annotation templates.
- Implemented AnnotationTemplateService for creating, updating, retrieving, and deleting annotation templates, including validation of configurations and XML generation.
- Added utility class LabelStudioConfigValidator for validating Label Studio configurations and XML formats.
- Updated database schema for annotation templates and labeling projects to include new fields and constraints.
- Seeded initial annotation templates for various use cases including image classification, object detection, and text classification.

* feat: Enhance TemplateForm with improved validation and dynamic field rendering; update LabelStudio config validation for camelCase support

* feat: Update docker-compose.yml to mark datamate dataset volume and network as external

* feat: Add tag configuration management and related components

- Introduced new components for tag selection and browsing in the frontend.
- Added API endpoint to fetch tag configuration from the backend.
- Implemented tag configuration management in the backend, including loading from YAML.
- Enhanced template service to support dynamic tag rendering based on configuration.
- Updated validation utilities to incorporate tag configuration checks.
- Refactored existing code to utilize the new tag configuration structure.

* feat: Refactor LabelStudioTagConfig for improved configuration loading and validation

* feat: Update Makefile to include backend-python-docker-build in the build process

* feat: Migrate to poetry for better deps management

* Add pyyaml dependency and update Dockerfile to use Poetry for dependency management

- Added pyyaml (>=6.0.3,<7.0.0) to pyproject.toml dependencies.
- Updated Dockerfile to install Poetry and manage dependencies using it.
- Improved layer caching by copying only dependency files before the application code.
- Removed unnecessary installation of build dependencies to keep the final image size small.

* feat: Remove duplicated backend-python-docker-build target from Makefile

* fix: airflow is not ready for adding yet

* feat: update Python version to 3.12 and remove project installation step in Dockerfile
This commit is contained in:
Jason Wang
2025-11-13 15:32:30 +08:00
committed by GitHub
parent 2660845b74
commit 45743f39f5
40 changed files with 3223 additions and 262 deletions

View File

@@ -111,6 +111,10 @@ class Client:
"label_config": label_config or "<View></View>"
}
# Log the request body for debugging
logger.debug(f"Request body: {project_data}")
logger.debug(f"Label config being sent:\n{project_data['label_config']}")
response = await self.client.post("/api/projects", json=project_data)
response.raise_for_status()
@@ -127,7 +131,7 @@ class Client:
logger.error(
f"Create project failed - HTTP {e.response.status_code}\n"
f"URL: {e.request.url}\n"
f"Response Headers: {dict(e.response.headers)}\n"
f"Request Body: {e.request.content.decode() if e.request.content else 'None'}\n"
f"Response Body: {e.response.text[:1000]}" # First 1000 chars
)
return None

View File

@@ -0,0 +1,4 @@
"""Tag configuration package"""
from .tag_config import LabelStudioTagConfig
__all__ = ['LabelStudioTagConfig']

View File

@@ -0,0 +1,467 @@
# Label Studio Tag Configuration
# Defines supported tags, their properties, and child element requirements
# Object tags - represent data to be annotated
objects:
Audio:
description: "Display audio files"
required_attrs: [name, value]
optional_attrs: []
category: media
Bitmask:
description: "Display bitmask images for segmentation"
required_attrs: [name, value]
optional_attrs: []
category: image
PDF:
description: "Display PDF documents"
required_attrs: [name, value]
optional_attrs: []
category: document
Markdown:
description: "Display Markdown content"
required_attrs: [name, value]
optional_attrs: []
category: document
ParagraphLabels:
description: "Display paragraphs with label support"
required_attrs: [name, value]
optional_attrs: []
category: text
Timeseries:
description: "Display timeseries data"
required_attrs: [name, value]
optional_attrs: []
category: data
Vector:
description: "Display vector data for annotation"
required_attrs: [name, value]
optional_attrs: []
category: data
Chat:
description: "Display chat data for annotation"
required_attrs: [name, value]
optional_attrs: []
category: text
HyperText:
description: "Display HTML content"
required_attrs: [name, value]
optional_attrs: []
category: document
Image:
description: "Display images for annotation"
required_attrs: [name, value]
optional_attrs: []
category: image
Text:
description: "Display text for annotation"
required_attrs: [name, value]
optional_attrs: []
category: text
Video:
description: "Display video files"
required_attrs: [name, value]
optional_attrs: []
category: media
AudioPlus:
description: "Advanced audio player"
required_attrs: [name, value]
optional_attrs: []
category: media
Paragraphs:
description: "Display paragraphs of text"
required_attrs: [name, value]
optional_attrs: []
category: text
Table:
description: "Display tabular data"
required_attrs: [name, value]
optional_attrs: []
category: data
# Control tags - tools for annotation
# Categories:
# - labeling: Controls used for annotating/labeling objects (shown in template form)
# - layout: UI/layout elements not used for labeling (hidden from template form by default)
controls:
# Choice-based controls (use <Choice> children)
Choices:
description: "Multiple choice classification"
required_attrs: [name, toName]
optional_attrs:
required:
type: boolean
description: "Whether the choice is required"
choice:
type: string
values: [single, multiple]
default: single
description: "Selection mode: single or multiple"
showInline:
type: boolean
default: true
description: "Show choices inline or as dropdown"
requires_children: true
child_tag: Choice
child_required_attrs: [value]
category: labeling
Taxonomy:
description: "Hierarchical multi-label classification"
required_attrs: [name, toName]
optional_attrs:
required:
type: boolean
maxDepth:
type: number
default: 3
description: "Maximum depth of taxonomy tree"
requires_children: true
child_tag: Path
child_required_attrs: [value]
category: labeling
Ranker:
description: "Rank items in order"
required_attrs: [name, toName]
optional_attrs:
required:
type: boolean
maxChoices:
type: number
default: 5
description: "Maximum number of choices to rank"
requires_children: true
child_tag: Choice
child_required_attrs: [value]
category: layout
List:
description: "List selection control"
required_attrs: [name, toName]
optional_attrs:
required:
type: boolean
mode:
type: string
values: [single, multiple]
default: single
requires_children: true
child_tag: Item
child_required_attrs: [value]
category: layout
Filter:
description: "Filter control for annotation"
required_attrs: [name, toName]
optional_attrs:
required:
type: boolean
requires_children: false
category: layout
Collapse:
description: "Collapsible UI section"
required_attrs: [name]
optional_attrs:
collapsed:
type: boolean
default: false
requires_children: false
category: layout
Header:
description: "Section header for UI grouping"
required_attrs: [name]
optional_attrs:
level:
type: number
default: 1
description: "Header level (1-6)"
requires_children: false
category: layout
Shortcut:
description: "Keyboard shortcut definition"
required_attrs: [name, toName]
optional_attrs:
key:
type: string
description: "Shortcut key"
requires_children: false
category: layout
Style:
description: "Custom style for annotation UI"
required_attrs: [name]
optional_attrs:
value:
type: string
description: "CSS style value"
requires_children: false
category: layout
MagicWand:
description: "Magic wand segmentation tool"
required_attrs: [name, toName]
optional_attrs:
required:
type: boolean
requires_children: false
category: labeling
BitmaskLabels:
description: "Bitmask segmentation with labels"
required_attrs: [name, toName]
optional_attrs: [required]
requires_children: true
child_tag: Label
child_required_attrs: [value]
category: labeling
TimeseriesLabels:
description: "Labels for timeseries data"
required_attrs: [name, toName]
optional_attrs: [required]
requires_children: true
child_tag: Label
child_required_attrs: [value]
category: labeling
VectorLabels:
description: "Labels for vector data"
required_attrs: [name, toName]
optional_attrs: [required]
requires_children: true
child_tag: Label
child_required_attrs: [value]
category: labeling
ParagraphLabels:
description: "Labels for paragraphs"
required_attrs: [name, toName]
optional_attrs: [required]
requires_children: true
child_tag: Label
child_required_attrs: [value]
category: labeling
Relation:
description: "Draw relation between objects"
required_attrs: [name, toName]
optional_attrs:
required:
type: boolean
requires_children: false
category: layout
Relations:
description: "Draw multiple relations between objects"
required_attrs: [name, toName]
optional_attrs:
required:
type: boolean
requires_children: false
category: layout
Pairwise:
description: "Pairwise comparison control"
required_attrs: [name, toName]
optional_attrs:
required:
type: boolean
requires_children: false
category: layout
DateTime:
description: "Date and time input"
required_attrs: [name, toName]
optional_attrs:
required:
type: boolean
format:
type: string
default: "YYYY-MM-DD HH:mm:ss"
requires_children: false
category: labeling
Number:
description: "Numeric input field"
required_attrs: [name, toName]
optional_attrs:
required:
type: boolean
min:
type: number
max:
type: number
step:
type: number
default: 1
requires_children: false
category: labeling
# Label-based controls (use <Label> children)
RectangleLabels:
description: "Rectangle bounding boxes with labels"
required_attrs: [name, toName]
optional_attrs:
required:
type: boolean
description: "Whether annotation is required"
strokeWidth:
type: number
default: 3
description: "Width of the bounding box border"
canRotate:
type: boolean
default: true
description: "Allow rotation of rectangles"
requires_children: true
child_tag: Label
child_required_attrs: [value]
category: labeling
PolygonLabels:
description: "Polygon annotations with labels"
required_attrs: [name, toName]
optional_attrs:
required:
type: boolean
strokeWidth:
type: number
default: 3
pointSize:
type: string
values: [small, medium, large]
default: medium
requires_children: true
child_tag: Label
child_required_attrs: [value]
category: labeling
Labels:
description: "Generic labels for classification"
required_attrs: [name, toName]
optional_attrs: [required]
requires_children: true
child_tag: Label
child_required_attrs: [value]
category: labeling
KeyPointLabels:
description: "Keypoint annotations with labels"
required_attrs: [name, toName]
optional_attrs: [required]
requires_children: true
child_tag: Label
child_required_attrs: [value]
category: labeling
BrushLabels:
description: "Brush/semantic segmentation with labels"
required_attrs: [name, toName]
optional_attrs: [required]
requires_children: true
child_tag: Label
child_required_attrs: [value]
category: labeling
EllipseLabels:
description: "Ellipse annotations with labels"
required_attrs: [name, toName]
optional_attrs: [required]
requires_children: true
child_tag: Label
child_required_attrs: [value]
category: labeling
# Simple controls (no children required)
Rectangle:
description: "Rectangle bounding box without labels"
required_attrs: [name, toName]
optional_attrs: [required]
requires_children: false
category: labeling
Polygon:
description: "Polygon annotation without labels"
required_attrs: [name, toName]
optional_attrs: [required]
requires_children: false
category: labeling
Ellipse:
description: "Ellipse annotation without labels"
required_attrs: [name, toName]
optional_attrs: [required]
requires_children: false
category: labeling
KeyPoint:
description: "Keypoint annotation without labels"
required_attrs: [name, toName]
optional_attrs: [required]
requires_children: false
category: labeling
Brush:
description: "Brush annotation without labels"
required_attrs: [name, toName]
optional_attrs: [required]
requires_children: false
category: labeling
TextArea:
description: "Text input field"
required_attrs: [name, toName]
optional_attrs:
required:
type: boolean
placeholder:
type: string
description: "Placeholder text"
maxSubmissions:
type: number
description: "Maximum number of submissions"
rows:
type: number
default: 3
description: "Number of rows in textarea"
editable:
type: boolean
default: true
requires_children: false
category: labeling
Rating:
description: "Star rating or numeric rating"
required_attrs: [name, toName]
optional_attrs:
required:
type: boolean
maxRating:
type: number
default: 5
description: "Maximum rating value"
defaultValue:
type: number
description: "Default rating value"
size:
type: string
values: [small, medium, large]
default: medium
icon:
type: string
values: [star, heart, fire, thumbs]
default: star
requires_children: false
category: labeling
VideoRectangle:
description: "Rectangle annotations for video"
required_attrs: [name, toName]
optional_attrs: [required]
requires_children: false
category: labeling

View File

@@ -0,0 +1,150 @@
"""
Label Studio Tag Configuration Loader
"""
import yaml
from typing import Dict, Any, Optional, Set, Tuple
from pathlib import Path
class LabelStudioTagConfig:
"""Label Studio标签配置管理器"""
_instance: Optional['LabelStudioTagConfig'] = None
_config: Dict[str, Any] = {}
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
"""初始化时加载配置"""
if not self._config:
self._load_config()
@classmethod
def _load_config(cls):
"""加载YAML配置文件"""
config_path = Path(__file__).parent / "label_studio_tags.yaml"
with open(config_path, 'r', encoding='utf-8') as f:
cls._config = yaml.safe_load(f) or {}
@classmethod
def get_object_types(cls) -> Set[str]:
"""获取所有支持的对象类型"""
return set(cls._config.get('objects', {}).keys())
@classmethod
def get_control_types(cls) -> Set[str]:
"""获取所有支持的控件类型"""
return set(cls._config.get('controls', {}).keys())
@classmethod
def get_control_config(cls, control_type: str) -> Optional[Dict[str, Any]]:
"""获取控件的配置信息"""
return cls._config.get('controls', {}).get(control_type)
@classmethod
def get_object_config(cls, object_type: str) -> Optional[Dict[str, Any]]:
"""获取对象的配置信息"""
return cls._config.get('objects', {}).get(object_type)
@classmethod
def requires_children(cls, control_type: str) -> bool:
"""检查控件是否需要子元素"""
config = cls.get_control_config(control_type)
return config.get('requires_children', False) if config else False
@classmethod
def get_child_tag(cls, control_type: str) -> Optional[str]:
"""获取控件的子元素标签名"""
config = cls.get_control_config(control_type)
return config.get('child_tag') if config else None
@classmethod
def get_controls_with_child_tag(cls, child_tag: str) -> Set[str]:
"""获取使用指定子元素标签的所有控件类型"""
controls = set()
for control_type, config in cls._config.get('controls', {}).items():
if config.get('child_tag') == child_tag:
controls.add(control_type)
return controls
@classmethod
def get_optional_attrs(cls, tag_type: str, is_control: bool = True) -> Dict[str, Any]:
"""
获取标签的可选属性配置
Args:
tag_type: 标签类型
is_control: 是否为控件类型(否则为对象类型)
Returns:
可选属性配置字典
"""
config = cls.get_control_config(tag_type) if is_control else cls.get_object_config(tag_type)
if not config:
return {}
optional_attrs = config.get('optional_attrs', {})
# 如果是简单列表格式(旧格式),转换为字典
if isinstance(optional_attrs, list):
return {attr: {} for attr in optional_attrs}
# 确保返回的是字典
return optional_attrs if isinstance(optional_attrs, dict) else {}
@classmethod
def validate_attr_value(cls, tag_type: str, attr_name: str, attr_value: Any, is_control: bool = True) -> Tuple[bool, Optional[str]]:
"""
验证属性值是否符合配置要求
Args:
tag_type: 标签类型
attr_name: 属性名
attr_value: 属性值
is_control: 是否为控件类型
Returns:
(是否有效, 错误信息)
"""
optional_attrs = cls.get_optional_attrs(tag_type, is_control)
if attr_name not in optional_attrs:
return True, None # 不在配置中的属性,不验证
attr_config = optional_attrs.get(attr_name, {})
# 如果配置不是字典,跳过验证
if not isinstance(attr_config, dict):
return True, None
# 检查类型
expected_type = attr_config.get('type')
if expected_type == 'boolean':
if not isinstance(attr_value, (bool, str)) or (isinstance(attr_value, str) and attr_value.lower() not in ['true', 'false']):
return False, f"Attribute '{attr_name}' must be boolean"
elif expected_type == 'number':
try:
float(attr_value)
except (ValueError, TypeError):
return False, f"Attribute '{attr_name}' must be a number"
# 检查枚举值
allowed_values = attr_config.get('values')
if allowed_values and attr_value not in allowed_values:
return False, f"Attribute '{attr_name}' must be one of {allowed_values}, got '{attr_value}'"
return True, None
@classmethod
def get_attr_default(cls, tag_type: str, attr_name: str, is_control: bool = True) -> Optional[Any]:
"""获取属性的默认值"""
optional_attrs = cls.get_optional_attrs(tag_type, is_control)
attr_config = optional_attrs.get(attr_name, {})
# 确保attr_config是字典后再访问
if isinstance(attr_config, dict):
return attr_config.get('default')
return None

View File

@@ -1,6 +1,6 @@
from fastapi import APIRouter
from .about import router as about_router
from .config import router as about_router
from .project import router as project_router
from .task import router as task_router
from .template import router as template_router

View File

@@ -1,25 +0,0 @@
from fastapi import APIRouter
from app.module.shared.schema import StandardResponse
from app.core.logging import get_logger
from app.core.config import settings
from ..schema import ConfigResponse
router = APIRouter(
prefix="/about",
tags=["annotation/about"]
)
logger = get_logger(__name__)
@router.get("", response_model=StandardResponse[ConfigResponse])
async def get_config():
"""获取配置信息"""
return StandardResponse(
code=200,
message="success",
data=ConfigResponse(
label_studio_url=settings.label_studio_base_url,
)
)

View File

@@ -0,0 +1,47 @@
from fastapi import APIRouter
from app.module.shared.schema import StandardResponse
from app.core.logging import get_logger
from app.core.config import settings
from ..schema import (
ConfigResponse,
TagConfigResponse
)
from ..config.tag_config import LabelStudioTagConfig
router = APIRouter(
prefix="/tags",
tags=["annotation/config"]
)
logger = get_logger(__name__)
@router.get("", response_model=StandardResponse[ConfigResponse])
async def get_config():
"""获取配置信息(已废弃,请使用 /api/annotation/about)"""
return StandardResponse(
code=200,
message="success",
data=ConfigResponse(
label_studio_url=settings.label_studio_base_url,
)
)
@router.get("/config", response_model=StandardResponse[TagConfigResponse], summary="获取标签配置")
async def get_tag_config():
"""
获取所有Label Studio标签类型的配置(对象+控件),用于前端动态渲染。
"""
# Ensure config is loaded by instantiating the class
tag_config = LabelStudioTagConfig()
config = LabelStudioTagConfig._config
if not config:
logger.error("Failed to load tag configuration")
return StandardResponse(
code=500,
message="Failed to load tag configuration",
data={"objects": {}, "controls": {}}
)
return StandardResponse(code=200, message="success", data=config)

View File

@@ -2,7 +2,7 @@ from typing import Optional
import math
import uuid
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi import APIRouter, Depends, HTTPException, Query, Path
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.session import get_db
@@ -149,7 +149,7 @@ async def create_mapping(
@router.get("", response_model=StandardResponse[PaginatedData[DatasetMappingResponse]])
async def list_mappings(
page: int = Query(1, ge=1, description="页码(从1开始)"),
page_size: int = Query(20, ge=1, le=100, description="每页记录数"),
page_size: int = Query(20, ge=1, le=100, description="每页记录数", alias="pageSize"),
db: AsyncSession = Depends(get_db)
):
"""
@@ -163,8 +163,6 @@ async def list_mappings(
# 计算 skip
skip = (page - 1) * page_size
logger.info(f"Listing mappings, page={page}, page_size={page_size}")
# 获取数据和总数
mappings, total = await service.get_all_mappings_with_count(
skip=skip,
@@ -183,7 +181,7 @@ async def list_mappings(
content=mappings
)
logger.info(f"Found {len(mappings)} mappings on page {page}, total: {total}")
logger.info(f"List mappings: page={page}, returned {len(mappings)}/{total}")
return StandardResponse(
code=200,
@@ -234,7 +232,7 @@ async def get_mapping(
async def get_mappings_by_source(
dataset_id: str,
page: int = Query(1, ge=1, description="页码(从1开始)"),
page_size: int = Query(20, ge=1, le=100, description="每页记录数"),
page_size: int = Query(20, ge=1, le=100, description="每页记录数", alias="pageSize"),
db: AsyncSession = Depends(get_db)
):
"""
@@ -283,49 +281,30 @@ async def get_mappings_by_source(
logger.error(f"Error getting mappings: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.delete("", response_model=StandardResponse[DeleteDatasetResponse])
@router.delete("/{project_id}", response_model=StandardResponse[DeleteDatasetResponse])
async def delete_mapping(
m: Optional[str] = Query(None, description="映射UUID"),
proj: Optional[str] = Query(None, description="Label Studio项目ID"),
project_id: str = Path(..., description="映射UUID(path param)"),
db: AsyncSession = Depends(get_db)
):
"""
删除映射关系和对应的 Label Studio 项目
可以通过以下任一方式指定要删除的映射:
- m: 映射UUID
- proj: Label Studio项目ID
- 两者都提供(优先使用 m)
通过 path 参数 `project_id` 指定要删除的映射(映射的 UUID)。
此操作会:
1. 删除 Label Studio 中的项目
2. 软删除数据库中的映射记录
"""
try:
# Log incoming request parameters for debugging
logger.debug(f"Delete mapping request received: m={m!r}, proj={proj!r}")
# 至少需要提供一个参数
if not m and not proj:
logger.debug("Missing both 'm' and 'proj' in delete request")
raise HTTPException(
status_code=400,
detail="Either 'm' (mapping UUID) or 'proj' (project ID) must be provided"
)
logger.debug(f"Delete mapping request received: project_id={project_id!r}")
ls_client = LabelStudioClient(base_url=settings.label_studio_base_url,
token=settings.label_studio_user_token)
service = DatasetMappingService(db)
# 优先使用 mapping_id 查询
if m:
logger.debug(f"Deleting by mapping UUID: {m}")
mapping = await service.get_mapping_by_uuid(m)
# 如果没有提供 m,使用 proj 查询
elif proj:
logger.debug(f"Deleting by project ID: {proj}")
mapping = await service.get_mapping_by_labeling_project_id(proj)
else:
mapping = None
# 使用 mapping UUID 查询映射记录
logger.debug(f"Deleting by mapping UUID: {project_id}")
mapping = await service.get_mapping_by_uuid(project_id)
logger.debug(f"Mapping lookup result: {mapping}")

View File

@@ -1,6 +1,8 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi import APIRouter, Depends, HTTPException, Query, Path
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Optional
from typing import List, Optional, Dict, Any
from datetime import datetime
from pydantic import BaseModel, Field, ConfigDict
from app.db.session import get_db
from app.module.shared.schema import StandardResponse
@@ -17,6 +19,10 @@ from ..schema import (
SyncDatasetResponse,
SyncAnnotationsRequest,
SyncAnnotationsResponse,
UpdateFileTagsRequest,
UpdateFileTagsResponse,
UpdateFileTagsRequest,
UpdateFileTagsResponse
)
@@ -32,24 +38,10 @@ async def sync_dataset_content(
db: AsyncSession = Depends(get_db)
):
"""
同步数据集内容(包括文件和标注)
Sync Dataset Content (Files and Annotations)
根据指定的mapping ID,同步DM程序数据集中的内容到Label Studio数据集中。
默认同时同步文件和标注数据。
Args:
request: 同步请求,包含:
- id: 映射ID(mapping UUID)
- batchSize: 批处理大小
- filePriority: 文件同步优先级
- labelPriority: 标签同步优先级
- syncAnnotations: 是否同步标注(默认True)
- annotationDirection: 标注同步方向(默认bidirectional)
- overwrite: 是否允许覆盖DataMate中的标注(默认True)
- overwriteLabelingProject: 是否允许覆盖Label Studio中的标注(默认True)
Returns:
同步结果
"""
try:
ls_client = LabelStudioClient(base_url=settings.label_studio_base_url,
@@ -123,28 +115,10 @@ async def sync_annotations(
db: AsyncSession = Depends(get_db)
):
"""
仅同步标注结果(支持双向同步)
根据指定mapping ID和同步方向,在DM数据集和Label Studio之间同步标注结果
标注结果存储在数据集文件表的tags字段中,使用简化格式
同步策略:
- 默认为双向同步,基于时间戳自动解决冲突
- overwrite: 控制是否允许用Label Studio的标注覆盖DataMate(基于时间戳比较)
- overwriteLabelingProject: 控制是否允许用DataMate的标注覆盖Label Studio(基于时间戳比较)
- 如果Label Studio标注的updated_at更新,且overwrite=True,则覆盖DataMate
- 如果DataMate标注的updated_at更新,且overwriteLabelingProject=True,则覆盖Label Studio
Args:
request: 同步请求,包含:
- id: 映射ID(mapping UUID)
- batchSize: 批处理大小
- direction: 同步方向 (ls_to_dm/dm_to_ls/bidirectional)
- overwrite: 是否允许覆盖DataMate中的标注(默认True)
- overwriteLabelingProject: 是否允许覆盖Label Studio中的标注(默认True)
Returns:
同步结果,包含同步统计信息和冲突解决情况
Sync Annotations Only (Bidirectional Support)
同步指定 mapping 下的标注数据,支持单向或双向同步,基于时间戳自动解决冲突
请求与响应由 Pydantic 模型 `SyncAnnotationsRequest` / `SyncAnnotationsResponse` 定义
"""
try:
ls_client = LabelStudioClient(base_url=settings.label_studio_base_url,
@@ -207,9 +181,9 @@ async def sync_annotations(
@router.get("/check-ls-connection")
async def check_label_studio_connection():
"""
检查Label Studio连接状态
用于诊断Label Studio连接问题,返回连接状态和配置信息
Check Label Studio Connection Status
诊断 Label Studio 连接并返回简要连接信息(状态、base URL、token 摘要、项目统计)。
"""
try:
ls_client = LabelStudioClient(
@@ -258,4 +232,55 @@ async def check_label_studio_connection():
)
except Exception as e:
logger.error(f"Error checking Label Studio connection: {e}")
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(status_code=500, detail=str(e))
@router.put(
"/{file_id}",
response_model=StandardResponse[UpdateFileTagsResponse],
)
async def update_file_tags(
request: UpdateFileTagsRequest,
file_id: str = Path(..., description="文件ID"),
db: AsyncSession = Depends(get_db)
):
"""
Update File Tags (Partial Update)
接收部分标签更新并合并到指定文件(只修改提交的标签,其余保持不变),并更新 `tags_updated_at`。
请求与响应使用 Pydantic 模型 `UpdateFileTagsRequest` / `UpdateFileTagsResponse`。
"""
service = DatasetManagementService(db)
success, error_msg, updated_at = await service.update_file_tags_partial(
file_id=file_id,
new_tags=request.tags
)
if not success:
if "not found" in (error_msg or "").lower():
raise HTTPException(status_code=404, detail=error_msg)
raise HTTPException(status_code=500, detail=error_msg or "更新标签失败")
# 获取更新后的完整标签列表
from sqlalchemy.future import select
from app.db.models import DatasetFiles
result = await db.execute(
select(DatasetFiles).where(DatasetFiles.id == file_id)
)
file_record = result.scalar_one_or_none()
if not file_record:
raise HTTPException(status_code=404, detail=f"File not found: {file_id}")
response_data = UpdateFileTagsResponse(
fileId=file_id,
tags=file_record.tags or [], # type: ignore
tagsUpdatedAt=updated_at or datetime.now()
)
return StandardResponse(
code=200,
message="标签更新成功",
data=response_data
)

View File

@@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.db.session import get_db
from app.module.shared.schema import StandardResponse
from app.module.annotation.schema.template import (
from app.module.annotation.schema import (
CreateAnnotationTemplateRequest,
UpdateAnnotationTemplateRequest,
AnnotationTemplateResponse,
@@ -15,7 +15,7 @@ from app.module.annotation.schema.template import (
)
from app.module.annotation.service.template import AnnotationTemplateService
router = APIRouter(prefix="/templates", tags=["Annotation Template"])
router = APIRouter(prefix="/template", tags=["annotation/template"])
template_service = AnnotationTemplateService()
@@ -23,7 +23,6 @@ template_service = AnnotationTemplateService()
@router.post(
"",
response_model=StandardResponse[AnnotationTemplateResponse],
summary="创建标注模板"
)
async def create_template(
request: CreateAnnotationTemplateRequest,
@@ -47,7 +46,6 @@ async def create_template(
@router.get(
"/{template_id}",
response_model=StandardResponse[AnnotationTemplateResponse],
summary="获取模板详情"
)
async def get_template(
template_id: str,
@@ -65,9 +63,8 @@ async def get_template(
@router.get(
"",
response_model=StandardResponse[AnnotationTemplateListResponse],
summary="获取模板列表"
)
async def list_templates(
async def list_template(
page: int = Query(1, ge=1, description="页码"),
size: int = Query(10, ge=1, le=100, description="每页大小"),
category: Optional[str] = Query(None, description="分类筛选"),
@@ -101,7 +98,6 @@ async def list_templates(
@router.put(
"/{template_id}",
response_model=StandardResponse[AnnotationTemplateResponse],
summary="更新模板"
)
async def update_template(
template_id: str,
@@ -122,7 +118,6 @@ async def update_template(
@router.delete(
"/{template_id}",
response_model=StandardResponse[bool],
summary="删除模板"
)
async def delete_template(
template_id: str,

View File

@@ -1,4 +1,7 @@
from .config import ConfigResponse
from .config import (
ConfigResponse,
TagConfigResponse
)
from .mapping import (
DatasetMappingCreateRequest,
@@ -15,8 +18,21 @@ from .sync import (
SyncAnnotationsResponse,
)
from .tag import (
UpdateFileTagsRequest,
UpdateFileTagsResponse,
)
from .template import (
CreateAnnotationTemplateRequest,
UpdateAnnotationTemplateRequest,
AnnotationTemplateResponse,
AnnotationTemplateListResponse
)
__all__ = [
"ConfigResponse",
"TagConfigResponse",
"DatasetMappingCreateRequest",
"DatasetMappingCreateResponse",
"DatasetMappingUpdateRequest",
@@ -26,4 +42,10 @@ __all__ = [
"SyncAnnotationsRequest",
"SyncAnnotationsResponse",
"DeleteDatasetResponse",
"UpdateFileTagsRequest",
"UpdateFileTagsResponse",
"CreateAnnotationTemplateRequest",
"UpdateAnnotationTemplateRequest",
"AnnotationTemplateResponse",
"AnnotationTemplateListResponse",
]

View File

@@ -1,8 +1,36 @@
from pydantic import Field
from typing import Dict, Any, List, Optional
from pydantic import BaseModel, Field, ConfigDict
from app.module.shared.schema import BaseResponseModel
from app.module.shared.schema import StandardResponse
class ConfigResponse(BaseResponseModel):
"""配置信息响应模型"""
label_studio_url: str = Field(..., description="Label Studio基础URL")
label_studio_url: str = Field(..., description="Label Studio基础URL")
class _TagAttributeConfig(BaseModel):
"""标签属性配置"""
type: Optional[str] = Field(None, description="属性类型: boolean/string/number")
values: Optional[List[str]] = Field(None, description="允许的枚举值列表")
default: Optional[Any] = Field(None, description="默认值")
description: Optional[str] = Field(None, description="属性描述")
model_config = ConfigDict(populate_by_name=True)
class _TagDefinition(BaseModel):
"""标签定义"""
description: str = Field(..., description="标签描述")
required_attrs: List[str] = Field(default_factory=list, alias="requiredAttrs", description="必需属性列表")
optional_attrs: Dict[str, _TagAttributeConfig] = Field(default_factory=dict, alias="optionalAttrs", description="可选属性配置")
requires_children: bool = Field(default=False, alias="requiresChildren", description="是否需要子元素")
child_tag: Optional[str] = Field(None, alias="childTag", description="子元素标签名")
child_required_attrs: Optional[List[str]] = Field(None, alias="childRequiredAttrs", description="子元素必需属性")
category: Optional[str] = Field(None, description="标签分类")
class TagConfigResponse(BaseResponseModel):
"""标签配置响应"""
objects: Dict[str, _TagDefinition] = Field(default_factory=dict, description="对象标签配置")
controls: Dict[str, _TagDefinition] = Field(default_factory=dict, description="控件标签配置")

View File

@@ -0,0 +1,17 @@
from datetime import datetime
from typing import List, Dict, Any
from pydantic import BaseModel, Field
from app.module.shared.schema import BaseResponseModel
class UpdateFileTagsRequest(BaseModel):
"""更新文件标签请求"""
tags: List[Dict[str, Any]] = Field(..., description="要更新的标签列表(部分更新)")
class UpdateFileTagsResponse(BaseResponseModel):
"""更新文件标签响应"""
file_id: str = Field(..., alias="fileId", description="文件ID")
tags: List[Dict[str, Any]] = Field(..., description="更新后的完整标签列表")
tags_updated_at: datetime = Field(..., alias="tagsUpdatedAt", description="标签更新时间")

View File

@@ -1,7 +1,7 @@
"""
Annotation Template Schemas
"""
from typing import List, Dict, Any, Optional, Literal
from typing import List, Dict, Any, Optional
from datetime import datetime
from pydantic import BaseModel, Field, ConfigDict

View File

@@ -83,7 +83,7 @@ class DatasetMappingService:
labeling_project: LabelingProject
) -> DatasetMappingResponse:
"""创建数据集映射"""
logger.info(f"Create dataset mapping: {labeling_project.dataset_id} -> {labeling_project.labeling_project_id}")
logger.debug(f"Create dataset mapping: {labeling_project.dataset_id} -> {labeling_project.labeling_project_id}")
# Use the passed object directly
self.db.add(labeling_project)
@@ -201,7 +201,7 @@ class DatasetMappingService:
)
await self.db.commit()
if result.rowcount > 0:
if result.rowcount and result.rowcount > 0: # type: ignore
return await self.get_mapping_by_uuid(mapping_id)
return None
@@ -219,7 +219,7 @@ class DatasetMappingService:
)
await self.db.commit()
success = result.rowcount > 0
success = result.rowcount and result.rowcount > 0 # type: ignore
if success:
logger.info(f"Mapping soft-deleted: {mapping_id}")
else:

View File

@@ -122,7 +122,7 @@ class SyncService:
return {}
all_tasks = result.get("tasks", [])
logger.info(f"Successfully fetched {len(all_tasks)} tasks")
logger.debug(f"Successfully fetched {len(all_tasks)} tasks")
# 使用字典推导式构建映射
dm_file_to_task_mapping = {
@@ -131,7 +131,7 @@ class SyncService:
if task.get('data', {}).get('file_id') is not None
}
logger.info(f"Found {len(dm_file_to_task_mapping)} existing task mappings")
logger.debug(f"Found {len(dm_file_to_task_mapping)} existing task mappings")
return dm_file_to_task_mapping
except Exception as e:
@@ -163,10 +163,10 @@ class SyncService:
)
if not files_response or not files_response.content:
logger.info(f"No more files on page {page + 1}")
logger.debug(f"No more files on page {page + 1}")
break
logger.info(f"Processing page {page + 1}, {len(files_response.content)} files")
logger.debug(f"Processing page {page + 1}, {len(files_response.content)} files")
# 筛选新文件并构建任务数据
new_tasks = []
@@ -178,7 +178,7 @@ class SyncService:
task_data = self._build_task_data(file_info, dataset_id)
new_tasks.append(task_data)
logger.info(f"Page {page + 1}: {len(new_tasks)} new files, {len(files_response.content) - len(new_tasks)} existing")
logger.debug(f"Page {page + 1}: {len(new_tasks)} new files, {len(files_response.content) - len(new_tasks)} existing")
# 批量创建任务
if new_tasks:
@@ -202,16 +202,16 @@ class SyncService:
deleted_file_ids = set(existing_dm_file_mapping.keys()) - current_file_ids
if not deleted_file_ids:
logger.info("No tasks to delete")
logger.debug("No tasks to delete")
return 0
tasks_to_delete = [existing_dm_file_mapping[fid] for fid in deleted_file_ids]
logger.info(f"Deleting {len(tasks_to_delete)} orphaned tasks")
logger.debug(f"Deleting {len(tasks_to_delete)} orphaned tasks")
delete_result = await self.ls_client.delete_tasks_batch(tasks_to_delete)
deleted_count = delete_result.get("successful", 0)
logger.info(f"Successfully deleted {deleted_count} tasks")
logger.debug(f"Successfully deleted {deleted_count} tasks")
return deleted_count
async def sync_dataset_files(
@@ -229,7 +229,7 @@ class SyncService:
Returns:
同步结果响应
"""
logger.info(f"Start syncing dataset files by mapping: {mapping_id}")
logger.debug(f"Start syncing dataset files by mapping: {mapping_id}")
# 获取映射关系
mapping = await self.mapping_service.get_mapping_by_uuid(mapping_id)
@@ -247,7 +247,7 @@ class SyncService:
# 委托给sync_files执行实际同步
result = await self.sync_files(mapping, batch_size)
logger.info(f"Sync completed: created={result['created']}, deleted={result['deleted']}, total={result['total']}")
logger.info(f"Sync files completed: created={result['created']}, deleted={result['deleted']}, total={result['total']}")
return SyncDatasetResponse(
id=mapping.id,
@@ -342,7 +342,7 @@ class SyncService:
Returns:
同步统计信息: {"created": int, "deleted": int, "total": int}
"""
logger.info(f"Syncing files for dataset {mapping.dataset_id} to project {mapping.labeling_project_id}")
logger.debug(f"Syncing files for dataset {mapping.dataset_id} to project {mapping.labeling_project_id}")
# 获取DM数据集信息
dataset_info = await self.dm_client.get_dataset(mapping.dataset_id)
@@ -350,12 +350,12 @@ class SyncService:
raise NoDatasetInfoFoundError(mapping.dataset_id)
total_files = dataset_info.fileCount
logger.info(f"Total files in DM dataset: {total_files}")
logger.debug(f"Total files in DM dataset: {total_files}")
# 获取Label Studio中已存在的文件映射
existing_dm_file_mapping = await self.get_existing_dm_file_mapping(mapping.labeling_project_id)
existing_file_ids = set(existing_dm_file_mapping.keys())
logger.info(f"{len(existing_file_ids)} tasks already exist in Label Studio")
logger.debug(f"{len(existing_file_ids)} tasks already exist in Label Studio")
# 分页获取DM文件并创建新任务
current_file_ids, created_count = await self._fetch_dm_files_paginated(
@@ -371,7 +371,7 @@ class SyncService:
current_file_ids
)
logger.info(f"File sync completed: total={total_files}, created={created_count}, deleted={deleted_count}")
logger.debug(f"File sync completed: total={total_files}, created={created_count}, deleted={deleted_count}")
return {
"created": created_count,

View File

@@ -17,6 +17,7 @@ from app.module.annotation.schema.template import (
TemplateConfiguration
)
from app.module.annotation.utils.config_validator import LabelStudioConfigValidator
from app.module.annotation.config import LabelStudioTagConfig
class AnnotationTemplateService:
@@ -33,6 +34,7 @@ class AnnotationTemplateService:
Returns:
Label Studio XML字符串
"""
tag_config = LabelStudioTagConfig()
xml_parts = ['<View>']
# 生成对象定义
@@ -56,15 +58,22 @@ class AnnotationTemplateService:
tag_type = label.type.capitalize() if label.type else "Choices"
# 处理带选项的标签类型
# 检查是否需要子元素
if label.options or label.labels:
choices = label.options or label.labels or []
xml_parts.append(f' <{tag_type} {" ".join(label_attrs)}>')
# 从配置获取子元素标签名
child_tag = tag_config.get_child_tag(tag_type)
if not child_tag:
# 默认使用 Label
child_tag = "Label"
for choice in choices:
xml_parts.append(f' <Label value="{choice}"/>')
xml_parts.append(f' <{child_tag} value="{choice}"/>')
xml_parts.append(f' </{tag_type}>')
else:
# 处理简单标签类型
# 处理简单标签类型(不需要子元素)
xml_parts.append(f' <{tag_type} {" ".join(label_attrs)}/>')
xml_parts.append('</View>')

View File

@@ -3,29 +3,16 @@ Label Studio Configuration Validation Utilities
"""
from typing import Dict, List, Tuple, Optional
import xml.etree.ElementTree as ET
from app.module.annotation.config import LabelStudioTagConfig
class LabelStudioConfigValidator:
"""验证Label Studio配置的工具类"""
# 支持的控件类型
CONTROL_TYPES = {
'Choices', 'RectangleLabels', 'PolygonLabels', 'Labels',
'TextArea', 'Rating', 'KeyPointLabels', 'BrushLabels',
'EllipseLabels', 'VideoRectangle', 'AudioPlus'
}
# 支持的对象类型
OBJECT_TYPES = {
'Image', 'Text', 'Audio', 'Video', 'HyperText',
'AudioPlus', 'Paragraphs', 'Table'
}
# 需要子标签的控件类型
LABEL_BASED_CONTROLS = {
'Choices', 'RectangleLabels', 'PolygonLabels', 'Labels',
'KeyPointLabels', 'BrushLabels', 'EllipseLabels'
}
@staticmethod
def _get_config() -> LabelStudioTagConfig:
"""获取标签配置实例"""
return LabelStudioTagConfig()
@staticmethod
def validate_xml(xml_string: str) -> Tuple[bool, Optional[str]]:
@@ -39,6 +26,7 @@ class LabelStudioConfigValidator:
(是否有效, 错误信息)
"""
try:
config = LabelStudioConfigValidator._get_config()
root = ET.fromstring(xml_string)
# 检查根元素
@@ -46,12 +34,14 @@ class LabelStudioConfigValidator:
return False, "Root element must be <View>"
# 检查是否有对象定义
objects = [child for child in root if child.tag in LabelStudioConfigValidator.OBJECT_TYPES]
object_types = config.get_object_types()
objects = [child for child in root if child.tag in object_types]
if not objects:
return False, "No data objects (Image, Text, etc.) found"
# 检查是否有控件定义
controls = [child for child in root if child.tag in LabelStudioConfigValidator.CONTROL_TYPES]
control_types = config.get_control_types()
controls = [child for child in root if child.tag in control_types]
if not controls:
return False, "No annotation controls found"
@@ -79,6 +69,8 @@ class LabelStudioConfigValidator:
Returns:
(是否有效, 错误信息)
"""
config = LabelStudioConfigValidator._get_config()
# 检查必需属性
if 'name' not in control.attrib:
return False, "Missing 'name' attribute"
@@ -86,16 +78,20 @@ class LabelStudioConfigValidator:
if 'toName' not in control.attrib:
return False, "Missing 'toName' attribute"
# 检查标签型控件是否有子标签
if control.tag in LabelStudioConfigValidator.LABEL_BASED_CONTROLS:
labels = control.findall('Label')
if not labels:
return False, f"{control.tag} must have at least one <Label> child"
# 检查控件是否需要子元素
if config.requires_children(control.tag):
child_tag = config.get_child_tag(control.tag)
if not child_tag:
return False, f"Configuration error: no child_tag defined for {control.tag}"
# 检查每个标签是否有value
for label in labels:
if 'value' not in label.attrib:
return False, "Label missing 'value' attribute"
children = control.findall(child_tag)
if not children:
return False, f"{control.tag} must have at least one <{child_tag}> child"
# 检查每个子元素是否有value
for child in children:
if 'value' not in child.attrib:
return False, f"{child_tag} missing 'value' attribute"
return True, None
@@ -111,16 +107,24 @@ class LabelStudioConfigValidator:
字典,键为控件名称,值为标签值列表
"""
result = {}
config = LabelStudioConfigValidator._get_config()
try:
root = ET.fromstring(xml_string)
controls = [child for child in root if child.tag in LabelStudioConfigValidator.LABEL_BASED_CONTROLS]
control_types = config.get_control_types()
controls = [child for child in root if child.tag in control_types]
for control in controls:
if not config.requires_children(control.tag):
continue
control_name = control.get('name', 'unknown')
labels = control.findall('Label')
label_values = [label.get('value', '') for label in labels]
result[control_name] = label_values
child_tag = config.get_child_tag(control.tag)
if child_tag:
children = control.findall(child_tag)
label_values = [child.get('value', '') for child in children]
result[control_name] = label_values
except Exception:
pass
@@ -182,6 +186,9 @@ class LabelStudioConfigValidator:
@staticmethod
def _validate_label_definition(label: Dict) -> Tuple[bool, Optional[str]]:
"""验证标签定义"""
config = LabelStudioConfigValidator._get_config()
control_types = config.get_control_types()
# Support both camelCase and snake_case
from_name = label.get('fromName') or label.get('from_name')
to_name = label.get('toName') or label.get('to_name')
@@ -195,11 +202,11 @@ class LabelStudioConfigValidator:
return False, "Missing required field 'type'"
# 检查类型是否支持
if label_type not in LabelStudioConfigValidator.CONTROL_TYPES:
if label_type not in control_types:
return False, f"Unsupported control type '{label_type}'"
# 检查标签型控件是否有选项或标签
if label_type in LabelStudioConfigValidator.LABEL_BASED_CONTROLS:
# 检查是否需要子元素(options 或 labels)
if config.requires_children(label_type):
if 'options' not in label and 'labels' not in label:
return False, f"{label_type} must have 'options' or 'labels' field"
@@ -208,6 +215,9 @@ class LabelStudioConfigValidator:
@staticmethod
def _validate_object_definition(obj: Dict) -> Tuple[bool, Optional[str]]:
"""验证对象定义"""
config = LabelStudioConfigValidator._get_config()
object_types = config.get_object_types()
required_fields = ['name', 'type', 'value']
for field in required_fields:
@@ -215,7 +225,7 @@ class LabelStudioConfigValidator:
return False, f"Missing required field '{field}'"
# 检查类型是否支持
if obj['type'] not in LabelStudioConfigValidator.OBJECT_TYPES:
if obj['type'] not in object_types:
return False, f"Unsupported object type '{obj['type']}'"
# 检查value格式

View File

@@ -1,6 +1,10 @@
from .dataset_file import (
DatasetFileResponse,
PagedDatasetFileResponse,
BatchUpdateFileTagsRequest,
BatchUpdateFileTagsResponse,
FileTagUpdateResult,
FileTagUpdate,
)
from .dataset import (
@@ -13,4 +17,8 @@ __all__ = [
"DatasetFileResponse",
"PagedDatasetFileResponse",
"DatasetTypeResponse",
"BatchUpdateFileTagsRequest",
"BatchUpdateFileTagsResponse",
"FileTagUpdateResult",
"FileTagUpdate",
]

View File

@@ -49,3 +49,42 @@ class DatasetFileTag(BaseModel):
tags = [f"{self.from_name} {tag}" for tag in tags]
return tags
class FileTagUpdate(BaseModel):
"""单个文件的标签更新请求"""
file_id: str = Field(..., alias="fileId", description="文件ID")
tags: List[Dict[str, Any]] = Field(..., description="要更新的标签列表(部分更新)")
class Config:
populate_by_name = True
class BatchUpdateFileTagsRequest(BaseModel):
"""批量更新文件标签请求"""
updates: List[FileTagUpdate] = Field(..., description="文件标签更新列表", min_length=1)
class Config:
populate_by_name = True
class FileTagUpdateResult(BaseModel):
"""单个文件标签更新结果"""
file_id: str = Field(..., alias="fileId", description="文件ID")
success: bool = Field(..., description="是否更新成功")
message: Optional[str] = Field(None, description="结果信息")
tags_updated_at: Optional[datetime] = Field(None, alias="tagsUpdatedAt", description="标签更新时间")
class Config:
populate_by_name = True
class BatchUpdateFileTagsResponse(BaseModel):
"""批量更新文件标签响应"""
results: List[FileTagUpdateResult] = Field(..., description="更新结果列表")
total: int = Field(..., description="总更新数量")
success_count: int = Field(..., alias="successCount", description="成功数量")
failure_count: int = Field(..., alias="failureCount", description="失败数量")
class Config:
populate_by_name = True

View File

@@ -1,7 +1,8 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import func
from typing import Optional
from typing import Optional, List, Dict, Any
from datetime import datetime
from app.core.config import settings
from app.core.logging import get_logger
@@ -22,12 +23,12 @@ class Service:
db: 数据库会话
"""
self.db = db
logger.info("Initialize DM service client (Database mode)")
logger.debug("Initialize DM service client (Database mode)")
async def get_dataset(self, dataset_id: str) -> Optional[DatasetResponse]:
"""获取数据集详情"""
try:
logger.info(f"Getting dataset detail: {dataset_id} ...")
logger.debug(f"Getting dataset detail: {dataset_id} ...")
result = await self.db.execute(
select(Dataset).where(Dataset.id == dataset_id)
@@ -66,7 +67,7 @@ class Service:
) -> Optional[PagedDatasetFileResponse]:
"""获取数据集文件列表"""
try:
logger.info(f"Get dataset files: dataset={dataset_id}, page={page}, size={size}")
logger.debug(f"Get dataset files: dataset={dataset_id}, page={page}, size={size}")
# 构建查询
query = select(DatasetFiles).where(DatasetFiles.dataset_id == dataset_id)
@@ -159,4 +160,67 @@ class Service:
async def close(self):
"""关闭客户端连接(数据库模式下无需操作)"""
logger.info("DM service client closed (Database mode)")
logger.info("DM service client closed (Database mode)")
async def update_file_tags_partial(
self,
file_id: str,
new_tags: List[Dict[str, Any]]
) -> tuple[bool, Optional[str], Optional[datetime]]:
"""
部分更新文件标签
Args:
file_id: 文件ID
new_tags: 新的标签列表(部分更新)
Returns:
(成功标志, 错误信息, 更新时间)
"""
try:
logger.info(f"Partial updating tags for file: {file_id}")
# 获取文件记录
result = await self.db.execute(
select(DatasetFiles).where(DatasetFiles.id == file_id)
)
file_record = result.scalar_one_or_none()
if not file_record:
logger.error(f"File not found: {file_id}")
return False, f"File not found: {file_id}", None
# 获取现有标签
existing_tags: List[Dict[str, Any]] = file_record.tags or [] # type: ignore
# 创建标签ID到索引的映射
tag_id_map = {tag.get('id'): idx for idx, tag in enumerate(existing_tags) if tag.get('id')}
# 更新或追加标签
for new_tag in new_tags:
tag_id = new_tag.get('id')
if tag_id and tag_id in tag_id_map:
# 更新现有标签
idx = tag_id_map[tag_id]
existing_tags[idx] = new_tag
logger.debug(f"Updated existing tag with id: {tag_id}")
else:
# 追加新标签
existing_tags.append(new_tag)
logger.debug(f"Added new tag with id: {tag_id}")
# 更新数据库
update_time = datetime.utcnow()
file_record.tags = existing_tags # type: ignore
file_record.tags_updated_at = update_time # type: ignore
await self.db.commit()
await self.db.refresh(file_record)
logger.info(f"Successfully updated tags for file: {file_id}")
return True, None, update_time
except Exception as e:
logger.error(f"Failed to update tags for file {file_id}: {e}")
await self.db.rollback()
return False, str(e), None