"""Service layer for Auto Annotation tasks""" from __future__ import annotations from typing import List, Optional, Dict, Any from datetime import datetime from uuid import uuid4 from fastapi import HTTPException from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.db.models.annotation_management import ( AutoAnnotationTask, AnnotationTaskOperatorInstance, ) from app.db.models.dataset_management import Dataset, DatasetFiles from app.module.annotation.security import RequestUserContext from ..schema.auto import ( CreateAutoAnnotationTaskRequest, AutoAnnotationTaskResponse, ) class AutoAnnotationTaskService: """自动标注任务服务(仅管理任务元数据,真正执行由 runtime 负责)""" @staticmethod def _normalize_file_ids(file_ids: Optional[List[str]]) -> List[str]: if not file_ids: return [] return [fid for fid in dict.fromkeys(file_ids) if fid] @staticmethod def _extract_operator_id(step: Dict[str, Any]) -> Optional[str]: operator_id = step.get("operatorId") or step.get("operator_id") or step.get("id") if operator_id is None: return None operator_id = str(operator_id).strip() return operator_id or None @classmethod def _to_operator_instances( cls, task_id: str, pipeline: List[Dict[str, Any]], ) -> List[AnnotationTaskOperatorInstance]: instances: List[AnnotationTaskOperatorInstance] = [] for step in pipeline: if not isinstance(step, dict): continue operator_id = cls._extract_operator_id(step) if not operator_id: continue settings_override = ( step.get("overrides") or step.get("settingsOverride") or step.get("settings_override") or {} ) if not isinstance(settings_override, dict): settings_override = {} instances.append( AnnotationTaskOperatorInstance( task_id=task_id, op_index=len(instances) + 1, operator_id=operator_id, settings_override=settings_override, inputs=step.get("inputs"), outputs=step.get("outputs"), ) ) return instances @staticmethod def _to_pipeline(request: CreateAutoAnnotationTaskRequest) -> Optional[List[Dict[str, Any]]]: """将请求标准化为 pipeline 结构。""" if request.pipeline: return [step.model_dump(by_alias=True) for step in request.pipeline] if request.config is None: return None # 兼容旧版 YOLO 请求 -> 单步 pipeline config = request.config.model_dump(by_alias=True) step_overrides: Dict[str, Any] = { "modelSize": config.get("modelSize"), "confThreshold": config.get("confThreshold"), "targetClasses": config.get("targetClasses") or [], } output_dataset_name = request.output_dataset_name or config.get("outputDatasetName") if output_dataset_name: step_overrides["outputDatasetName"] = output_dataset_name return [ { "operatorId": "ImageObjectDetectionBoundingBox", "overrides": step_overrides, } ] async def validate_file_ids( self, db: AsyncSession, dataset_id: str, file_ids: Optional[List[str]], ) -> List[str]: """校验 fileIds 是否全部属于 dataset 且有效。""" normalized_ids = self._normalize_file_ids(file_ids) if not normalized_ids: if file_ids: raise HTTPException(status_code=400, detail="fileIds 不能为空列表") return [] stmt = select(DatasetFiles.id).where( DatasetFiles.id.in_(normalized_ids), DatasetFiles.dataset_id == dataset_id, DatasetFiles.status == "ACTIVE", ) result = await db.execute(stmt) found_ids = {row[0] for row in result.fetchall()} missing = [fid for fid in normalized_ids if fid not in found_ids] if missing: raise HTTPException( status_code=400, detail=f"部分 fileIds 不存在、不可用或不属于数据集: {missing[:10]}", ) return normalized_ids async def create_task( self, db: AsyncSession, request: CreateAutoAnnotationTaskRequest, user_context: RequestUserContext, dataset_name: Optional[str] = None, total_images: int = 0, dataset_type: str = "IMAGE", ) -> AutoAnnotationTaskResponse: """创建自动标注任务,初始状态为 pending。 这里仅插入任务记录,不负责真正执行 YOLO 推理, 后续可以由调度器/worker 读取该表并更新进度。 """ now = datetime.now() validated_file_ids = await self.validate_file_ids( db, request.dataset_id, request.file_ids, ) if validated_file_ids: total_images = len(validated_file_ids) normalized_pipeline = self._to_pipeline(request) if not normalized_pipeline: raise HTTPException(status_code=400, detail="pipeline 不能为空") normalized_config = request.config.model_dump(by_alias=True) if request.config else {} task_id = str(uuid4()) task = AutoAnnotationTask( id=task_id, name=request.name, dataset_id=request.dataset_id, dataset_name=dataset_name, dataset_type=dataset_type, created_by=user_context.user_id, config=normalized_config, task_mode=request.task_mode, executor_type=request.executor_type, pipeline=normalized_pipeline, file_ids=validated_file_ids or None, status="pending", progress=0, total_images=total_images, processed_images=0, detected_objects=0, stop_requested=False, created_at=now, updated_at=now, ) operator_instances = self._to_operator_instances(task_id, normalized_pipeline) db.add(task) if operator_instances: db.add_all(operator_instances) await db.commit() await db.refresh(task) # 创建后附带 sourceDatasets 信息(通常只有一个原始数据集) resp = AutoAnnotationTaskResponse.model_validate(task) try: resp.source_datasets = await self._compute_source_datasets(db, task) except Exception: resp.source_datasets = [dataset_name] if dataset_name else [request.dataset_id] return resp def _apply_dataset_scope(self, query, user_context: RequestUserContext): if user_context.is_admin: return query return query.join( Dataset, AutoAnnotationTask.dataset_id == Dataset.id, ).where(Dataset.created_by == user_context.user_id) async def list_tasks( self, db: AsyncSession, user_context: RequestUserContext, ) -> List[AutoAnnotationTaskResponse]: """获取未软删除的自动标注任务列表,按创建时间倒序。""" query = select(AutoAnnotationTask).where(AutoAnnotationTask.deleted_at.is_(None)) query = self._apply_dataset_scope(query, user_context) result = await db.execute( query.order_by(AutoAnnotationTask.created_at.desc()) ) tasks: List[AutoAnnotationTask] = list(result.scalars().all()) responses: List[AutoAnnotationTaskResponse] = [] for task in tasks: resp = AutoAnnotationTaskResponse.model_validate(task) try: resp.source_datasets = await self._compute_source_datasets(db, task) except Exception: # 出错时降级为单个 datasetName/datasetId fallback_name = getattr(task, "dataset_name", None) fallback_id = getattr(task, "dataset_id", "") resp.source_datasets = [fallback_name] if fallback_name else [fallback_id] responses.append(resp) return responses async def get_task( self, db: AsyncSession, task_id: str, user_context: RequestUserContext, ) -> Optional[AutoAnnotationTaskResponse]: query = select(AutoAnnotationTask).where( AutoAnnotationTask.id == task_id, AutoAnnotationTask.deleted_at.is_(None), ) query = self._apply_dataset_scope(query, user_context) result = await db.execute(query) task = result.scalar_one_or_none() if not task: return None resp = AutoAnnotationTaskResponse.model_validate(task) try: resp.source_datasets = await self._compute_source_datasets(db, task) except Exception: fallback_name = getattr(task, "dataset_name", None) fallback_id = getattr(task, "dataset_id", "") resp.source_datasets = [fallback_name] if fallback_name else [fallback_id] return resp async def _compute_source_datasets( self, db: AsyncSession, task: AutoAnnotationTask, ) -> List[str]: """根据任务的 file_ids 推断实际涉及到的所有数据集名称。 - 如果存在 file_ids,则通过 t_dm_dataset_files 反查 dataset_id,再关联 t_dm_datasets 获取名称; - 如果没有 file_ids,则退回到任务上冗余的 dataset_name/dataset_id。 """ file_ids = task.file_ids or [] if file_ids: stmt = ( select(Dataset.name) .join(DatasetFiles, Dataset.id == DatasetFiles.dataset_id) .where(DatasetFiles.id.in_(file_ids)) .distinct() ) result = await db.execute(stmt) names = [row[0] for row in result.fetchall() if row[0]] if names: return names # 回退:只显示一个数据集 if task.dataset_name: return [task.dataset_name] if task.dataset_id: return [task.dataset_id] return [] async def request_stop_task( self, db: AsyncSession, task_id: str, user_context: RequestUserContext, ) -> Optional[AutoAnnotationTaskResponse]: query = select(AutoAnnotationTask).where( AutoAnnotationTask.id == task_id, AutoAnnotationTask.deleted_at.is_(None), ) query = self._apply_dataset_scope(query, user_context) result = await db.execute(query) task = result.scalar_one_or_none() if not task: return None now = datetime.now() terminal_states = {"completed", "failed", "stopped"} if task.status not in terminal_states: task.stop_requested = True task.error_message = "Task stop requested" if task.status == "pending": task.status = "stopped" task.progress = task.progress or 0 task.completed_at = now task.run_token = None task.updated_at = now await db.commit() await db.refresh(task) resp = AutoAnnotationTaskResponse.model_validate(task) try: resp.source_datasets = await self._compute_source_datasets(db, task) except Exception: fallback_name = getattr(task, "dataset_name", None) fallback_id = getattr(task, "dataset_id", "") resp.source_datasets = [fallback_name] if fallback_name else [fallback_id] return resp async def soft_delete_task( self, db: AsyncSession, task_id: str, user_context: RequestUserContext, ) -> bool: query = select(AutoAnnotationTask).where( AutoAnnotationTask.id == task_id, AutoAnnotationTask.deleted_at.is_(None), ) query = self._apply_dataset_scope(query, user_context) result = await db.execute(query) task = result.scalar_one_or_none() if not task: return False task.deleted_at = datetime.now() await db.commit() return True