You've already forked DataMate
- 在DatasetApplicationService中注入ResourceAccessService并添加所有权验证 - 在KnowledgeSetApplicationService中注入ResourceAccessService并添加所有权验证 - 修改DatasetRepository接口和实现类,增加按创建者过滤的方法 - 修改KnowledgeSetRepository接口和实现类,增加按创建者过滤的方法 - 在RAG索引器服务中添加知识库访问权限检查和作用域过滤 - 更新实体元对象处理器以使用请求用户上下文获取当前用户 - 在前端设置页面添加用户权限管理功能和角色权限控制 - 为Python标注服务增加用户上下文和数据集访问权限验证
178 lines
6.3 KiB
Python
178 lines
6.3 KiB
Python
"""Service layer for Auto Annotation tasks"""
|
|
from __future__ import annotations
|
|
|
|
from typing import List, Optional
|
|
from datetime import datetime
|
|
from uuid import uuid4
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.db.models.annotation_management import AutoAnnotationTask
|
|
from app.db.models.dataset_management import Dataset, DatasetFiles
|
|
from app.module.annotation.security import RequestUserContext
|
|
|
|
from ..schema.auto import (
|
|
CreateAutoAnnotationTaskRequest,
|
|
AutoAnnotationTaskResponse,
|
|
)
|
|
|
|
|
|
class AutoAnnotationTaskService:
|
|
"""自动标注任务服务(仅管理任务元数据,真正执行由 runtime 负责)"""
|
|
|
|
async def create_task(
|
|
self,
|
|
db: AsyncSession,
|
|
request: CreateAutoAnnotationTaskRequest,
|
|
dataset_name: Optional[str] = None,
|
|
total_images: int = 0,
|
|
) -> AutoAnnotationTaskResponse:
|
|
"""创建自动标注任务,初始状态为 pending。
|
|
|
|
这里仅插入任务记录,不负责真正执行 YOLO 推理,
|
|
后续可以由调度器/worker 读取该表并更新进度。
|
|
"""
|
|
|
|
now = datetime.now()
|
|
|
|
task = AutoAnnotationTask(
|
|
id=str(uuid4()),
|
|
name=request.name,
|
|
dataset_id=request.dataset_id,
|
|
dataset_name=dataset_name,
|
|
config=request.config.model_dump(by_alias=True),
|
|
file_ids=request.file_ids, # 存储用户选择的文件ID列表
|
|
status="pending",
|
|
progress=0,
|
|
total_images=total_images,
|
|
processed_images=0,
|
|
detected_objects=0,
|
|
created_at=now,
|
|
updated_at=now,
|
|
)
|
|
|
|
db.add(task)
|
|
await db.commit()
|
|
await db.refresh(task)
|
|
|
|
# 创建后附带 sourceDatasets 信息(通常只有一个原始数据集)
|
|
resp = AutoAnnotationTaskResponse.model_validate(task)
|
|
try:
|
|
resp.source_datasets = await self._compute_source_datasets(db, task)
|
|
except Exception:
|
|
resp.source_datasets = [dataset_name] if dataset_name else [request.dataset_id]
|
|
return resp
|
|
|
|
def _apply_dataset_scope(self, query, user_context: RequestUserContext):
|
|
if user_context.is_admin:
|
|
return query
|
|
return query.join(
|
|
Dataset,
|
|
AutoAnnotationTask.dataset_id == Dataset.id,
|
|
).where(Dataset.created_by == user_context.user_id)
|
|
|
|
async def list_tasks(
|
|
self,
|
|
db: AsyncSession,
|
|
user_context: RequestUserContext,
|
|
) -> List[AutoAnnotationTaskResponse]:
|
|
"""获取未软删除的自动标注任务列表,按创建时间倒序。"""
|
|
|
|
query = select(AutoAnnotationTask).where(AutoAnnotationTask.deleted_at.is_(None))
|
|
query = self._apply_dataset_scope(query, user_context)
|
|
result = await db.execute(
|
|
query.order_by(AutoAnnotationTask.created_at.desc())
|
|
)
|
|
tasks: List[AutoAnnotationTask] = list(result.scalars().all())
|
|
|
|
responses: List[AutoAnnotationTaskResponse] = []
|
|
for task in tasks:
|
|
resp = AutoAnnotationTaskResponse.model_validate(task)
|
|
try:
|
|
resp.source_datasets = await self._compute_source_datasets(db, task)
|
|
except Exception:
|
|
# 出错时降级为单个 datasetName/datasetId
|
|
fallback_name = getattr(task, "dataset_name", None)
|
|
fallback_id = getattr(task, "dataset_id", "")
|
|
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
|
|
responses.append(resp)
|
|
|
|
return responses
|
|
|
|
async def get_task(
|
|
self,
|
|
db: AsyncSession,
|
|
task_id: str,
|
|
user_context: RequestUserContext,
|
|
) -> Optional[AutoAnnotationTaskResponse]:
|
|
query = select(AutoAnnotationTask).where(
|
|
AutoAnnotationTask.id == task_id,
|
|
AutoAnnotationTask.deleted_at.is_(None),
|
|
)
|
|
query = self._apply_dataset_scope(query, user_context)
|
|
result = await db.execute(query)
|
|
task = result.scalar_one_or_none()
|
|
if not task:
|
|
return None
|
|
|
|
resp = AutoAnnotationTaskResponse.model_validate(task)
|
|
try:
|
|
resp.source_datasets = await self._compute_source_datasets(db, task)
|
|
except Exception:
|
|
fallback_name = getattr(task, "dataset_name", None)
|
|
fallback_id = getattr(task, "dataset_id", "")
|
|
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
|
|
return resp
|
|
|
|
async def _compute_source_datasets(
|
|
self,
|
|
db: AsyncSession,
|
|
task: AutoAnnotationTask,
|
|
) -> List[str]:
|
|
"""根据任务的 file_ids 推断实际涉及到的所有数据集名称。
|
|
|
|
- 如果存在 file_ids,则通过 t_dm_dataset_files 反查 dataset_id,再关联 t_dm_datasets 获取名称;
|
|
- 如果没有 file_ids,则退回到任务上冗余的 dataset_name/dataset_id。
|
|
"""
|
|
|
|
file_ids = task.file_ids or []
|
|
if file_ids:
|
|
stmt = (
|
|
select(Dataset.name)
|
|
.join(DatasetFiles, Dataset.id == DatasetFiles.dataset_id)
|
|
.where(DatasetFiles.id.in_(file_ids))
|
|
.distinct()
|
|
)
|
|
result = await db.execute(stmt)
|
|
names = [row[0] for row in result.fetchall() if row[0]]
|
|
if names:
|
|
return names
|
|
|
|
# 回退:只显示一个数据集
|
|
if task.dataset_name:
|
|
return [task.dataset_name]
|
|
if task.dataset_id:
|
|
return [task.dataset_id]
|
|
return []
|
|
|
|
async def soft_delete_task(
|
|
self,
|
|
db: AsyncSession,
|
|
task_id: str,
|
|
user_context: RequestUserContext,
|
|
) -> bool:
|
|
query = select(AutoAnnotationTask).where(
|
|
AutoAnnotationTask.id == task_id,
|
|
AutoAnnotationTask.deleted_at.is_(None),
|
|
)
|
|
query = self._apply_dataset_scope(query, user_context)
|
|
result = await db.execute(query)
|
|
task = result.scalar_one_or_none()
|
|
if not task:
|
|
return False
|
|
|
|
task.deleted_at = datetime.now()
|
|
await db.commit()
|
|
return True
|