Merge branch 'rbac' into lsf

This commit is contained in:
2026-02-06 15:44:43 +08:00
61 changed files with 2525 additions and 247 deletions

View File

@@ -5,11 +5,12 @@ from typing import List, Optional
from datetime import datetime
from uuid import uuid4
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models.annotation_management import AutoAnnotationTask
from app.db.models.dataset_management import Dataset, DatasetFiles
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models.annotation_management import AutoAnnotationTask
from app.db.models.dataset_management import Dataset, DatasetFiles
from app.module.annotation.security import RequestUserContext
from ..schema.auto import (
CreateAutoAnnotationTaskRequest,
@@ -17,7 +18,7 @@ from ..schema.auto import (
)
class AutoAnnotationTaskService:
class AutoAnnotationTaskService:
"""自动标注任务服务(仅管理任务元数据,真正执行由 runtime 负责)"""
async def create_task(
@@ -63,15 +64,27 @@ class AutoAnnotationTaskService:
resp.source_datasets = [dataset_name] if dataset_name else [request.dataset_id]
return resp
async def list_tasks(self, db: AsyncSession) -> List[AutoAnnotationTaskResponse]:
"""获取未软删除的自动标注任务列表,按创建时间倒序。"""
result = await db.execute(
select(AutoAnnotationTask)
.where(AutoAnnotationTask.deleted_at.is_(None))
.order_by(AutoAnnotationTask.created_at.desc())
)
tasks: List[AutoAnnotationTask] = list(result.scalars().all())
def _apply_dataset_scope(self, query, user_context: RequestUserContext):
if user_context.is_admin:
return query
return query.join(
Dataset,
AutoAnnotationTask.dataset_id == Dataset.id,
).where(Dataset.created_by == user_context.user_id)
async def list_tasks(
self,
db: AsyncSession,
user_context: RequestUserContext,
) -> List[AutoAnnotationTaskResponse]:
"""获取未软删除的自动标注任务列表,按创建时间倒序。"""
query = select(AutoAnnotationTask).where(AutoAnnotationTask.deleted_at.is_(None))
query = self._apply_dataset_scope(query, user_context)
result = await db.execute(
query.order_by(AutoAnnotationTask.created_at.desc())
)
tasks: List[AutoAnnotationTask] = list(result.scalars().all())
responses: List[AutoAnnotationTaskResponse] = []
for task in tasks:
@@ -87,16 +100,21 @@ class AutoAnnotationTaskService:
return responses
async def get_task(self, db: AsyncSession, task_id: str) -> Optional[AutoAnnotationTaskResponse]:
result = await db.execute(
select(AutoAnnotationTask).where(
AutoAnnotationTask.id == task_id,
AutoAnnotationTask.deleted_at.is_(None),
)
)
task = result.scalar_one_or_none()
if not task:
return None
async def get_task(
self,
db: AsyncSession,
task_id: str,
user_context: RequestUserContext,
) -> Optional[AutoAnnotationTaskResponse]:
query = select(AutoAnnotationTask).where(
AutoAnnotationTask.id == task_id,
AutoAnnotationTask.deleted_at.is_(None),
)
query = self._apply_dataset_scope(query, user_context)
result = await db.execute(query)
task = result.scalar_one_or_none()
if not task:
return None
resp = AutoAnnotationTaskResponse.model_validate(task)
try:
@@ -138,16 +156,21 @@ class AutoAnnotationTaskService:
return [task.dataset_id]
return []
async def soft_delete_task(self, db: AsyncSession, task_id: str) -> bool:
result = await db.execute(
select(AutoAnnotationTask).where(
AutoAnnotationTask.id == task_id,
AutoAnnotationTask.deleted_at.is_(None),
)
)
task = result.scalar_one_or_none()
if not task:
return False
async def soft_delete_task(
self,
db: AsyncSession,
task_id: str,
user_context: RequestUserContext,
) -> bool:
query = select(AutoAnnotationTask).where(
AutoAnnotationTask.id == task_id,
AutoAnnotationTask.deleted_at.is_(None),
)
query = self._apply_dataset_scope(query, user_context)
result = await db.execute(query)
task = result.scalar_one_or_none()
if not task:
return False
task.deleted_at = datetime.now()
await db.commit()

View File

@@ -54,6 +54,10 @@ from app.module.annotation.service.knowledge_sync import KnowledgeSyncService
from app.module.annotation.service.annotation_text_splitter import (
AnnotationTextSplitter,
)
from app.module.annotation.security import (
RequestUserContext,
ensure_dataset_owner_access,
)
from app.module.annotation.service.text_fetcher import (
fetch_text_content_via_download_api,
)
@@ -104,8 +108,9 @@ class AnnotationEditorService:
# 分段阈值:超过此字符数自动分段
SEGMENT_THRESHOLD = 200
def __init__(self, db: AsyncSession):
def __init__(self, db: AsyncSession, user_context: RequestUserContext):
self.db = db
self.user_context = user_context
self.template_service = AnnotationTemplateService()
@staticmethod
@@ -157,14 +162,24 @@ class AnnotationEditorService:
async def _get_project_or_404(self, project_id: str) -> LabelingProject:
result = await self.db.execute(
select(LabelingProject).where(
select(LabelingProject, Dataset.created_by).join(
Dataset,
LabelingProject.dataset_id == Dataset.id,
).where(
LabelingProject.id == project_id,
LabelingProject.deleted_at.is_(None),
)
)
project = result.scalar_one_or_none()
if not project:
row = result.first()
if not row:
raise HTTPException(status_code=404, detail=f"标注项目不存在: {project_id}")
project = row[0]
dataset_owner = row[1]
ensure_dataset_owner_access(
self.user_context,
str(dataset_owner) if dataset_owner is not None else None,
project.dataset_id,
)
return project
async def _get_dataset_type(self, dataset_id: str) -> Optional[str]:

View File

@@ -478,7 +478,9 @@ class DatasetMappingService:
skip: int = 0,
limit: int = 100,
include_deleted: bool = False,
include_template: bool = False
include_template: bool = False,
current_user_id: Optional[str] = None,
is_admin: bool = False,
) -> Tuple[List[DatasetMappingResponse], int]:
"""
获取所有映射及总数(用于分页)
@@ -495,9 +497,16 @@ class DatasetMappingService:
query = self._build_query_with_dataset_name()
if not include_deleted:
query = query.where(LabelingProject.deleted_at.is_(None))
if not is_admin:
query = query.where(Dataset.created_by == current_user_id)
# 获取总数
count_query = select(func.count()).select_from(LabelingProject)
if not is_admin:
count_query = count_query.join(
Dataset,
LabelingProject.dataset_id == Dataset.id,
).where(Dataset.created_by == current_user_id)
if not include_deleted:
count_query = count_query.where(LabelingProject.deleted_at.is_(None))
@@ -557,7 +566,9 @@ class DatasetMappingService:
skip: int = 0,
limit: int = 100,
include_deleted: bool = False,
include_template: bool = False
include_template: bool = False,
current_user_id: Optional[str] = None,
is_admin: bool = False,
) -> Tuple[List[DatasetMappingResponse], int]:
"""
根据源数据集ID获取映射关系及总数(用于分页)
@@ -578,11 +589,18 @@ class DatasetMappingService:
if not include_deleted:
query = query.where(LabelingProject.deleted_at.is_(None))
if not is_admin:
query = query.where(Dataset.created_by == current_user_id)
# 获取总数
count_query = select(func.count()).select_from(LabelingProject).where(
LabelingProject.dataset_id == dataset_id
)
if not is_admin:
count_query = count_query.join(
Dataset,
LabelingProject.dataset_id == Dataset.id,
).where(Dataset.created_by == current_user_id)
if not include_deleted:
count_query = count_query.where(LabelingProject.deleted_at.is_(None))