You've already forked DataMate
feat(auth): 为数据管理和RAG服务增加资源访问控制
- 在DatasetApplicationService中注入ResourceAccessService并添加所有权验证 - 在KnowledgeSetApplicationService中注入ResourceAccessService并添加所有权验证 - 修改DatasetRepository接口和实现类,增加按创建者过滤的方法 - 修改KnowledgeSetRepository接口和实现类,增加按创建者过滤的方法 - 在RAG索引器服务中添加知识库访问权限检查和作用域过滤 - 更新实体元对象处理器以使用请求用户上下文获取当前用户 - 在前端设置页面添加用户权限管理功能和角色权限控制 - 为Python标注服务增加用户上下文和数据集访问权限验证
This commit is contained in:
@@ -17,12 +17,17 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.session import get_db
|
||||
from app.module.shared.schema import StandardResponse
|
||||
from app.module.dataset import DatasetManagementService
|
||||
from app.core.logging import get_logger
|
||||
|
||||
from ..schema.auto import (
|
||||
CreateAutoAnnotationTaskRequest,
|
||||
AutoAnnotationTaskResponse,
|
||||
from app.module.dataset import DatasetManagementService
|
||||
from app.core.logging import get_logger
|
||||
|
||||
from ..security import (
|
||||
RequestUserContext,
|
||||
assert_dataset_access,
|
||||
get_request_user_context,
|
||||
)
|
||||
from ..schema.auto import (
|
||||
CreateAutoAnnotationTaskRequest,
|
||||
AutoAnnotationTaskResponse,
|
||||
)
|
||||
from ..service.auto import AutoAnnotationTaskService
|
||||
|
||||
@@ -37,15 +42,16 @@ service = AutoAnnotationTaskService()
|
||||
|
||||
|
||||
@router.get("", response_model=StandardResponse[List[AutoAnnotationTaskResponse]])
|
||||
async def list_auto_annotation_tasks(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
async def list_auto_annotation_tasks(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||
):
|
||||
"""获取自动标注任务列表。
|
||||
|
||||
前端当前不传分页参数,这里直接返回所有未删除任务。
|
||||
"""
|
||||
|
||||
tasks = await service.list_tasks(db)
|
||||
tasks = await service.list_tasks(db, user_context)
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
@@ -54,28 +60,30 @@ async def list_auto_annotation_tasks(
|
||||
|
||||
|
||||
@router.post("", response_model=StandardResponse[AutoAnnotationTaskResponse])
|
||||
async def create_auto_annotation_task(
|
||||
request: CreateAutoAnnotationTaskRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
async def create_auto_annotation_task(
|
||||
request: CreateAutoAnnotationTaskRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||
):
|
||||
"""创建自动标注任务。
|
||||
|
||||
当前仅创建任务记录并置为 pending,实际执行由后续调度/worker 完成。
|
||||
"""
|
||||
|
||||
logger.info(
|
||||
"Creating auto annotation task: name=%s, dataset_id=%s, config=%s, file_ids=%s",
|
||||
request.name,
|
||||
logger.info(
|
||||
"Creating auto annotation task: name=%s, dataset_id=%s, config=%s, file_ids=%s",
|
||||
request.name,
|
||||
request.dataset_id,
|
||||
request.config.model_dump(by_alias=True),
|
||||
request.file_ids,
|
||||
)
|
||||
|
||||
# 尝试获取数据集名称和文件数量用于冗余字段,失败时不阻塞任务创建
|
||||
dataset_name = None
|
||||
total_images = 0
|
||||
try:
|
||||
dm_client = DatasetManagementService(db)
|
||||
# 尝试获取数据集名称和文件数量用于冗余字段,失败时不阻塞任务创建
|
||||
dataset_name = None
|
||||
total_images = 0
|
||||
await assert_dataset_access(db, request.dataset_id, user_context)
|
||||
try:
|
||||
dm_client = DatasetManagementService(db)
|
||||
# Service.get_dataset 返回 DatasetResponse,包含 name 和 fileCount
|
||||
dataset = await dm_client.get_dataset(request.dataset_id)
|
||||
if dataset is not None:
|
||||
@@ -103,16 +111,17 @@ async def create_auto_annotation_task(
|
||||
|
||||
|
||||
@router.get("/{task_id}/status", response_model=StandardResponse[AutoAnnotationTaskResponse])
|
||||
async def get_auto_annotation_task_status(
|
||||
task_id: str = Path(..., description="任务ID"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
async def get_auto_annotation_task_status(
|
||||
task_id: str = Path(..., description="任务ID"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||
):
|
||||
"""获取单个自动标注任务状态。
|
||||
|
||||
前端当前主要通过列表轮询,这里提供按 ID 查询的补充接口。
|
||||
"""
|
||||
|
||||
task = await service.get_task(db, task_id)
|
||||
task = await service.get_task(db, task_id, user_context)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
@@ -124,13 +133,14 @@ async def get_auto_annotation_task_status(
|
||||
|
||||
|
||||
@router.delete("/{task_id}", response_model=StandardResponse[bool])
|
||||
async def delete_auto_annotation_task(
|
||||
task_id: str = Path(..., description="任务ID"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
async def delete_auto_annotation_task(
|
||||
task_id: str = Path(..., description="任务ID"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||
):
|
||||
"""删除(软删除)自动标注任务,仅标记 deleted_at。"""
|
||||
|
||||
ok = await service.soft_delete_task(db, task_id)
|
||||
ok = await service.soft_delete_task(db, task_id, user_context)
|
||||
if not ok:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
@@ -142,10 +152,11 @@ async def delete_auto_annotation_task(
|
||||
|
||||
|
||||
@router.get("/{task_id}/download")
|
||||
async def download_auto_annotation_result(
|
||||
task_id: str = Path(..., description="任务ID"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
async def download_auto_annotation_result(
|
||||
task_id: str = Path(..., description="任务ID"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||
):
|
||||
"""下载指定自动标注任务的结果 ZIP。"""
|
||||
|
||||
import io
|
||||
@@ -154,7 +165,7 @@ async def download_auto_annotation_result(
|
||||
import tempfile
|
||||
|
||||
# 复用服务层获取任务信息
|
||||
task = await service.get_task(db, task_id)
|
||||
task = await service.get_task(db, task_id, user_context)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
|
||||
@@ -27,6 +27,10 @@ from app.module.annotation.schema.editor import (
|
||||
UpsertAnnotationResponse,
|
||||
)
|
||||
from app.module.annotation.service.editor import AnnotationEditorService
|
||||
from app.module.annotation.security import (
|
||||
RequestUserContext,
|
||||
get_request_user_context,
|
||||
)
|
||||
from app.module.shared.schema import StandardResponse
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -44,8 +48,9 @@ router = APIRouter(
|
||||
async def get_editor_project_info(
|
||||
project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||
):
|
||||
service = AnnotationEditorService(db)
|
||||
service = AnnotationEditorService(db, user_context)
|
||||
info = await service.get_project_info(project_id)
|
||||
return StandardResponse(code=200, message="success", data=info)
|
||||
|
||||
@@ -64,8 +69,9 @@ async def list_editor_tasks(
|
||||
description="是否排除已被转换为TXT的源文档文件(PDF/DOC/DOCX,仅文本数据集生效)",
|
||||
),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||
):
|
||||
service = AnnotationEditorService(db)
|
||||
service = AnnotationEditorService(db, user_context)
|
||||
result = await service.list_tasks(
|
||||
project_id,
|
||||
page=page,
|
||||
@@ -86,8 +92,9 @@ async def get_editor_task(
|
||||
None, alias="segmentIndex", description="段落索引(分段模式下使用)"
|
||||
),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||
):
|
||||
service = AnnotationEditorService(db)
|
||||
service = AnnotationEditorService(db, user_context)
|
||||
task = await service.get_task(project_id, file_id, segment_index=segment_index)
|
||||
return StandardResponse(code=200, message="success", data=task)
|
||||
|
||||
@@ -103,8 +110,9 @@ async def get_editor_task_segment(
|
||||
..., ge=0, alias="segmentIndex", description="段落索引(从0开始)"
|
||||
),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||
):
|
||||
service = AnnotationEditorService(db)
|
||||
service = AnnotationEditorService(db, user_context)
|
||||
result = await service.get_task_segment(project_id, file_id, segment_index)
|
||||
return StandardResponse(code=200, message="success", data=result)
|
||||
|
||||
@@ -118,8 +126,9 @@ async def upsert_editor_annotation(
|
||||
project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"),
|
||||
file_id: str = Path(..., description="文件ID(t_dm_dataset_files.id)"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||
):
|
||||
service = AnnotationEditorService(db)
|
||||
service = AnnotationEditorService(db, user_context)
|
||||
result = await service.upsert_annotation(project_id, file_id, request)
|
||||
return StandardResponse(code=200, message="success", data=result)
|
||||
|
||||
@@ -132,11 +141,12 @@ async def check_file_version(
|
||||
project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"),
|
||||
file_id: str = Path(..., description="文件ID(t_dm_dataset_files.id)"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||
):
|
||||
"""
|
||||
检查文件是否有新版本
|
||||
"""
|
||||
service = AnnotationEditorService(db)
|
||||
service = AnnotationEditorService(db, user_context)
|
||||
result = await service.check_file_version(project_id, file_id)
|
||||
return StandardResponse(code=200, message="success", data=result)
|
||||
|
||||
@@ -149,10 +159,11 @@ async def use_new_version(
|
||||
project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"),
|
||||
file_id: str = Path(..., description="文件ID(t_dm_dataset_files.id)"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||
):
|
||||
"""
|
||||
使用文件新版本并清空标注
|
||||
"""
|
||||
service = AnnotationEditorService(db)
|
||||
service = AnnotationEditorService(db, user_context)
|
||||
result = await service.use_new_version(project_id, file_id)
|
||||
return StandardResponse(code=200, message="success", data=result)
|
||||
|
||||
@@ -12,6 +12,11 @@ from app.module.shared.schema import StandardResponse, PaginatedData
|
||||
from app.module.dataset import DatasetManagementService
|
||||
from app.core.logging import get_logger
|
||||
|
||||
from ..security import (
|
||||
RequestUserContext,
|
||||
assert_dataset_access,
|
||||
get_request_user_context,
|
||||
)
|
||||
from ..service.mapping import DatasetMappingService
|
||||
from ..service.template import AnnotationTemplateService
|
||||
from ..service.knowledge_sync import KnowledgeSyncService
|
||||
@@ -42,7 +47,9 @@ async def login_label_studio(mapping_id: str, db: AsyncSession = Depends(get_db)
|
||||
"", response_model=StandardResponse[DatasetMappingCreateResponse], status_code=201
|
||||
)
|
||||
async def create_mapping(
|
||||
request: DatasetMappingCreateRequest, db: AsyncSession = Depends(get_db)
|
||||
request: DatasetMappingCreateRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||
):
|
||||
"""
|
||||
创建数据集映射
|
||||
@@ -58,6 +65,8 @@ async def create_mapping(
|
||||
mapping_service = DatasetMappingService(db)
|
||||
template_service = AnnotationTemplateService()
|
||||
|
||||
await assert_dataset_access(db, request.dataset_id, user_context)
|
||||
|
||||
logger.info(f"Create dataset mapping request: {request.dataset_id}")
|
||||
|
||||
# 从DM服务获取数据集信息
|
||||
@@ -163,7 +172,7 @@ async def create_mapping(
|
||||
try:
|
||||
from ..service.editor import AnnotationEditorService
|
||||
|
||||
editor_service = AnnotationEditorService(db)
|
||||
editor_service = AnnotationEditorService(db, user_context)
|
||||
# 异步预计算切片(不阻塞创建响应)
|
||||
segmentation_result = (
|
||||
await editor_service.precompute_segmentation_for_project(
|
||||
@@ -202,6 +211,7 @@ async def list_mappings(
|
||||
False, description="是否包含模板详情", alias="includeTemplate"
|
||||
),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||
):
|
||||
"""
|
||||
查询所有映射关系(分页)
|
||||
@@ -230,6 +240,8 @@ async def list_mappings(
|
||||
limit=size,
|
||||
include_deleted=False,
|
||||
include_template=include_template,
|
||||
current_user_id=user_context.user_id,
|
||||
is_admin=user_context.is_admin,
|
||||
)
|
||||
|
||||
# 计算总页数
|
||||
@@ -256,7 +268,11 @@ async def list_mappings(
|
||||
|
||||
|
||||
@router.get("/{mapping_id}", response_model=StandardResponse[DatasetMappingResponse])
|
||||
async def get_mapping(mapping_id: str, db: AsyncSession = Depends(get_db)):
|
||||
async def get_mapping(
|
||||
mapping_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||
):
|
||||
"""
|
||||
根据 UUID 查询单个映射关系(包含关联的标注模板详情)
|
||||
|
||||
@@ -278,6 +294,7 @@ async def get_mapping(mapping_id: str, db: AsyncSession = Depends(get_db)):
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Mapping not found: {mapping_id}"
|
||||
)
|
||||
await assert_dataset_access(db, mapping.dataset_id, user_context)
|
||||
|
||||
logger.info(
|
||||
f"Found mapping: {mapping.id}, template_included: {mapping.template is not None}"
|
||||
@@ -304,6 +321,7 @@ async def get_mappings_by_source(
|
||||
True, description="是否包含模板详情", alias="includeTemplate"
|
||||
),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||
):
|
||||
"""
|
||||
根据源数据集 ID 查询所有映射关系(分页,包含模板详情)
|
||||
@@ -319,6 +337,7 @@ async def get_mappings_by_source(
|
||||
"""
|
||||
try:
|
||||
service = DatasetMappingService(db)
|
||||
await assert_dataset_access(db, dataset_id, user_context)
|
||||
|
||||
# 计算 skip
|
||||
skip = (page - 1) * size
|
||||
@@ -333,6 +352,8 @@ async def get_mappings_by_source(
|
||||
skip=skip,
|
||||
limit=size,
|
||||
include_template=include_template,
|
||||
current_user_id=user_context.user_id,
|
||||
is_admin=user_context.is_admin,
|
||||
)
|
||||
|
||||
# 计算总页数
|
||||
@@ -364,6 +385,7 @@ async def get_mappings_by_source(
|
||||
async def delete_mapping(
|
||||
project_id: str = Path(..., description="映射UUID(path param)"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||
):
|
||||
"""
|
||||
删除映射关系(软删除)
|
||||
@@ -387,6 +409,7 @@ async def delete_mapping(
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Mapping either not found or not specified."
|
||||
)
|
||||
await assert_dataset_access(db, mapping.dataset_id, user_context)
|
||||
|
||||
id = mapping.id
|
||||
dataset_id = mapping.dataset_id
|
||||
@@ -428,6 +451,7 @@ async def update_mapping(
|
||||
project_id: str = Path(..., description="映射UUID(path param)"),
|
||||
request: DatasetMappingUpdateRequest = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||
):
|
||||
"""
|
||||
更新标注项目信息
|
||||
@@ -456,6 +480,7 @@ async def update_mapping(
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Mapping not found: {project_id}"
|
||||
)
|
||||
await assert_dataset_access(db, mapping_orm.dataset_id, user_context)
|
||||
|
||||
# 构建更新数据
|
||||
update_values = {}
|
||||
|
||||
@@ -10,6 +10,11 @@ from app.module.dataset import DatasetManagementService
|
||||
from app.core.logging import get_logger
|
||||
from app.core.config import settings
|
||||
|
||||
from ..security import (
|
||||
RequestUserContext,
|
||||
assert_dataset_access,
|
||||
get_request_user_context,
|
||||
)
|
||||
from ..service.mapping import DatasetMappingService
|
||||
from ..schema import (
|
||||
SyncDatasetRequest,
|
||||
@@ -32,7 +37,8 @@ logger = get_logger(__name__)
|
||||
@router.post("/sync", response_model=StandardResponse[SyncDatasetResponse])
|
||||
async def sync_dataset_content(
|
||||
request: SyncDatasetRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||
):
|
||||
"""
|
||||
Sync Dataset Content (Files and Annotations)
|
||||
@@ -51,6 +57,7 @@ async def sync_dataset_content(
|
||||
status_code=404,
|
||||
detail=f"Mapping not found: {request.id}"
|
||||
)
|
||||
await assert_dataset_access(db, mapping.dataset_id, user_context)
|
||||
|
||||
dm_client = DatasetManagementService(db)
|
||||
dataset_info = await dm_client.get_dataset(mapping.dataset_id)
|
||||
@@ -82,7 +89,8 @@ async def sync_dataset_content(
|
||||
@router.post("/annotation/sync", response_model=StandardResponse[SyncAnnotationsResponse])
|
||||
async def sync_annotations(
|
||||
request: SyncAnnotationsRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||
):
|
||||
"""
|
||||
Sync Annotations Only (Bidirectional Support)
|
||||
@@ -102,6 +110,7 @@ async def sync_annotations(
|
||||
status_code=404,
|
||||
detail=f"Mapping not found: {request.id}"
|
||||
)
|
||||
await assert_dataset_access(db, mapping.dataset_id, user_context)
|
||||
|
||||
result = SyncAnnotationsResponse(
|
||||
id=mapping.id,
|
||||
@@ -156,7 +165,8 @@ async def check_label_studio_connection():
|
||||
async def update_file_tags(
|
||||
request: UpdateFileTagsRequest,
|
||||
file_id: str = Path(..., description="文件ID"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||
):
|
||||
"""
|
||||
Update File Tags (Partial Update with Auto Format Conversion)
|
||||
@@ -189,6 +199,7 @@ async def update_file_tags(
|
||||
raise HTTPException(status_code=404, detail=f"File not found: {file_id}")
|
||||
|
||||
dataset_id = str(file_record.dataset_id) # type: ignore - Convert Column to str
|
||||
await assert_dataset_access(db, dataset_id, user_context)
|
||||
|
||||
# 查找数据集关联的模板ID
|
||||
from ..service.mapping import DatasetMappingService
|
||||
|
||||
69
runtime/datamate-python/app/module/annotation/security.py
Normal file
69
runtime/datamate-python/app/module/annotation/security.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models.dataset_management import Dataset
|
||||
|
||||
HEADER_USER_ID = "X-User-Id"
|
||||
HEADER_USER_NAME = "X-User-Name"
|
||||
HEADER_USER_ROLES = "X-User-Roles"
|
||||
ADMIN_ROLE_CODE = "ROLE_ADMIN"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RequestUserContext:
|
||||
user_id: str
|
||||
username: str | None
|
||||
roles: Tuple[str, ...]
|
||||
|
||||
@property
|
||||
def is_admin(self) -> bool:
|
||||
return any(role.upper() == ADMIN_ROLE_CODE for role in self.roles)
|
||||
|
||||
|
||||
def get_request_user_context(request: Request) -> RequestUserContext:
|
||||
user_id = (request.headers.get(HEADER_USER_ID) or "").strip()
|
||||
username = (request.headers.get(HEADER_USER_NAME) or "").strip() or None
|
||||
role_header = request.headers.get(HEADER_USER_ROLES) or ""
|
||||
roles = tuple(
|
||||
role.strip()
|
||||
for role in role_header.split(",")
|
||||
if role and role.strip()
|
||||
)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=403, detail="权限不足:缺少用户身份")
|
||||
return RequestUserContext(user_id=user_id, username=username, roles=roles)
|
||||
|
||||
|
||||
def ensure_dataset_owner_access(
|
||||
user_context: RequestUserContext,
|
||||
dataset_owner_user_id: str | None,
|
||||
dataset_id: str,
|
||||
) -> None:
|
||||
if user_context.is_admin:
|
||||
return
|
||||
if not dataset_owner_user_id or dataset_owner_user_id != user_context.user_id:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"无权访问数据集: {dataset_id}",
|
||||
)
|
||||
|
||||
|
||||
async def assert_dataset_access(
|
||||
db: AsyncSession,
|
||||
dataset_id: str,
|
||||
user_context: RequestUserContext,
|
||||
) -> None:
|
||||
owner_result = await db.execute(
|
||||
select(Dataset.created_by).where(Dataset.id == dataset_id)
|
||||
)
|
||||
dataset_owner = owner_result.scalar_one_or_none()
|
||||
if dataset_owner is None:
|
||||
raise HTTPException(status_code=404, detail=f"数据集不存在: {dataset_id}")
|
||||
ensure_dataset_owner_access(user_context, str(dataset_owner), dataset_id)
|
||||
|
||||
@@ -5,11 +5,12 @@ from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models.annotation_management import AutoAnnotationTask
|
||||
from app.db.models.dataset_management import Dataset, DatasetFiles
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models.annotation_management import AutoAnnotationTask
|
||||
from app.db.models.dataset_management import Dataset, DatasetFiles
|
||||
from app.module.annotation.security import RequestUserContext
|
||||
|
||||
from ..schema.auto import (
|
||||
CreateAutoAnnotationTaskRequest,
|
||||
@@ -17,7 +18,7 @@ from ..schema.auto import (
|
||||
)
|
||||
|
||||
|
||||
class AutoAnnotationTaskService:
|
||||
class AutoAnnotationTaskService:
|
||||
"""自动标注任务服务(仅管理任务元数据,真正执行由 runtime 负责)"""
|
||||
|
||||
async def create_task(
|
||||
@@ -63,15 +64,27 @@ class AutoAnnotationTaskService:
|
||||
resp.source_datasets = [dataset_name] if dataset_name else [request.dataset_id]
|
||||
return resp
|
||||
|
||||
async def list_tasks(self, db: AsyncSession) -> List[AutoAnnotationTaskResponse]:
|
||||
"""获取未软删除的自动标注任务列表,按创建时间倒序。"""
|
||||
|
||||
result = await db.execute(
|
||||
select(AutoAnnotationTask)
|
||||
.where(AutoAnnotationTask.deleted_at.is_(None))
|
||||
.order_by(AutoAnnotationTask.created_at.desc())
|
||||
)
|
||||
tasks: List[AutoAnnotationTask] = list(result.scalars().all())
|
||||
def _apply_dataset_scope(self, query, user_context: RequestUserContext):
|
||||
if user_context.is_admin:
|
||||
return query
|
||||
return query.join(
|
||||
Dataset,
|
||||
AutoAnnotationTask.dataset_id == Dataset.id,
|
||||
).where(Dataset.created_by == user_context.user_id)
|
||||
|
||||
async def list_tasks(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
user_context: RequestUserContext,
|
||||
) -> List[AutoAnnotationTaskResponse]:
|
||||
"""获取未软删除的自动标注任务列表,按创建时间倒序。"""
|
||||
|
||||
query = select(AutoAnnotationTask).where(AutoAnnotationTask.deleted_at.is_(None))
|
||||
query = self._apply_dataset_scope(query, user_context)
|
||||
result = await db.execute(
|
||||
query.order_by(AutoAnnotationTask.created_at.desc())
|
||||
)
|
||||
tasks: List[AutoAnnotationTask] = list(result.scalars().all())
|
||||
|
||||
responses: List[AutoAnnotationTaskResponse] = []
|
||||
for task in tasks:
|
||||
@@ -87,16 +100,21 @@ class AutoAnnotationTaskService:
|
||||
|
||||
return responses
|
||||
|
||||
async def get_task(self, db: AsyncSession, task_id: str) -> Optional[AutoAnnotationTaskResponse]:
|
||||
result = await db.execute(
|
||||
select(AutoAnnotationTask).where(
|
||||
AutoAnnotationTask.id == task_id,
|
||||
AutoAnnotationTask.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
return None
|
||||
async def get_task(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
task_id: str,
|
||||
user_context: RequestUserContext,
|
||||
) -> Optional[AutoAnnotationTaskResponse]:
|
||||
query = select(AutoAnnotationTask).where(
|
||||
AutoAnnotationTask.id == task_id,
|
||||
AutoAnnotationTask.deleted_at.is_(None),
|
||||
)
|
||||
query = self._apply_dataset_scope(query, user_context)
|
||||
result = await db.execute(query)
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
return None
|
||||
|
||||
resp = AutoAnnotationTaskResponse.model_validate(task)
|
||||
try:
|
||||
@@ -138,16 +156,21 @@ class AutoAnnotationTaskService:
|
||||
return [task.dataset_id]
|
||||
return []
|
||||
|
||||
async def soft_delete_task(self, db: AsyncSession, task_id: str) -> bool:
|
||||
result = await db.execute(
|
||||
select(AutoAnnotationTask).where(
|
||||
AutoAnnotationTask.id == task_id,
|
||||
AutoAnnotationTask.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
return False
|
||||
async def soft_delete_task(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
task_id: str,
|
||||
user_context: RequestUserContext,
|
||||
) -> bool:
|
||||
query = select(AutoAnnotationTask).where(
|
||||
AutoAnnotationTask.id == task_id,
|
||||
AutoAnnotationTask.deleted_at.is_(None),
|
||||
)
|
||||
query = self._apply_dataset_scope(query, user_context)
|
||||
result = await db.execute(query)
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
return False
|
||||
|
||||
task.deleted_at = datetime.now()
|
||||
await db.commit()
|
||||
|
||||
@@ -54,6 +54,10 @@ from app.module.annotation.service.knowledge_sync import KnowledgeSyncService
|
||||
from app.module.annotation.service.annotation_text_splitter import (
|
||||
AnnotationTextSplitter,
|
||||
)
|
||||
from app.module.annotation.security import (
|
||||
RequestUserContext,
|
||||
ensure_dataset_owner_access,
|
||||
)
|
||||
from app.module.annotation.service.text_fetcher import (
|
||||
fetch_text_content_via_download_api,
|
||||
)
|
||||
@@ -104,8 +108,9 @@ class AnnotationEditorService:
|
||||
# 分段阈值:超过此字符数自动分段
|
||||
SEGMENT_THRESHOLD = 200
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
def __init__(self, db: AsyncSession, user_context: RequestUserContext):
|
||||
self.db = db
|
||||
self.user_context = user_context
|
||||
self.template_service = AnnotationTemplateService()
|
||||
|
||||
@staticmethod
|
||||
@@ -157,14 +162,24 @@ class AnnotationEditorService:
|
||||
|
||||
async def _get_project_or_404(self, project_id: str) -> LabelingProject:
|
||||
result = await self.db.execute(
|
||||
select(LabelingProject).where(
|
||||
select(LabelingProject, Dataset.created_by).join(
|
||||
Dataset,
|
||||
LabelingProject.dataset_id == Dataset.id,
|
||||
).where(
|
||||
LabelingProject.id == project_id,
|
||||
LabelingProject.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
row = result.first()
|
||||
if not row:
|
||||
raise HTTPException(status_code=404, detail=f"标注项目不存在: {project_id}")
|
||||
project = row[0]
|
||||
dataset_owner = row[1]
|
||||
ensure_dataset_owner_access(
|
||||
self.user_context,
|
||||
str(dataset_owner) if dataset_owner is not None else None,
|
||||
project.dataset_id,
|
||||
)
|
||||
return project
|
||||
|
||||
async def _get_dataset_type(self, dataset_id: str) -> Optional[str]:
|
||||
|
||||
@@ -478,7 +478,9 @@ class DatasetMappingService:
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
include_deleted: bool = False,
|
||||
include_template: bool = False
|
||||
include_template: bool = False,
|
||||
current_user_id: Optional[str] = None,
|
||||
is_admin: bool = False,
|
||||
) -> Tuple[List[DatasetMappingResponse], int]:
|
||||
"""
|
||||
获取所有映射及总数(用于分页)
|
||||
@@ -495,9 +497,16 @@ class DatasetMappingService:
|
||||
query = self._build_query_with_dataset_name()
|
||||
if not include_deleted:
|
||||
query = query.where(LabelingProject.deleted_at.is_(None))
|
||||
if not is_admin:
|
||||
query = query.where(Dataset.created_by == current_user_id)
|
||||
|
||||
# 获取总数
|
||||
count_query = select(func.count()).select_from(LabelingProject)
|
||||
if not is_admin:
|
||||
count_query = count_query.join(
|
||||
Dataset,
|
||||
LabelingProject.dataset_id == Dataset.id,
|
||||
).where(Dataset.created_by == current_user_id)
|
||||
if not include_deleted:
|
||||
count_query = count_query.where(LabelingProject.deleted_at.is_(None))
|
||||
|
||||
@@ -557,7 +566,9 @@ class DatasetMappingService:
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
include_deleted: bool = False,
|
||||
include_template: bool = False
|
||||
include_template: bool = False,
|
||||
current_user_id: Optional[str] = None,
|
||||
is_admin: bool = False,
|
||||
) -> Tuple[List[DatasetMappingResponse], int]:
|
||||
"""
|
||||
根据源数据集ID获取映射关系及总数(用于分页)
|
||||
@@ -578,11 +589,18 @@ class DatasetMappingService:
|
||||
|
||||
if not include_deleted:
|
||||
query = query.where(LabelingProject.deleted_at.is_(None))
|
||||
if not is_admin:
|
||||
query = query.where(Dataset.created_by == current_user_id)
|
||||
|
||||
# 获取总数
|
||||
count_query = select(func.count()).select_from(LabelingProject).where(
|
||||
LabelingProject.dataset_id == dataset_id
|
||||
)
|
||||
if not is_admin:
|
||||
count_query = count_query.join(
|
||||
Dataset,
|
||||
LabelingProject.dataset_id == Dataset.id,
|
||||
).where(Dataset.created_by == current_user_id)
|
||||
if not include_deleted:
|
||||
count_query = count_query.where(LabelingProject.deleted_at.is_(None))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user