feat(auth): 为数据管理和RAG服务增加资源访问控制

- 在DatasetApplicationService中注入ResourceAccessService并添加所有权验证
- 在KnowledgeSetApplicationService中注入ResourceAccessService并添加所有权验证
- 修改DatasetRepository接口和实现类,增加按创建者过滤的方法
- 修改KnowledgeSetRepository接口和实现类,增加按创建者过滤的方法
- 在RAG索引器服务中添加知识库访问权限检查和作用域过滤
- 更新实体元对象处理器以使用请求用户上下文获取当前用户
- 在前端设置页面添加用户权限管理功能和角色权限控制
- 为Python标注服务增加用户上下文和数据集访问权限验证
This commit is contained in:
2026-02-06 14:58:46 +08:00
parent 056cee11cc
commit 6a4c4ae3d7
28 changed files with 1063 additions and 158 deletions

View File

@@ -17,12 +17,17 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.db.session import get_db
from app.module.shared.schema import StandardResponse
from app.module.dataset import DatasetManagementService
from app.core.logging import get_logger
from ..schema.auto import (
CreateAutoAnnotationTaskRequest,
AutoAnnotationTaskResponse,
from app.module.dataset import DatasetManagementService
from app.core.logging import get_logger
from ..security import (
RequestUserContext,
assert_dataset_access,
get_request_user_context,
)
from ..schema.auto import (
CreateAutoAnnotationTaskRequest,
AutoAnnotationTaskResponse,
)
from ..service.auto import AutoAnnotationTaskService
@@ -37,15 +42,16 @@ service = AutoAnnotationTaskService()
@router.get("", response_model=StandardResponse[List[AutoAnnotationTaskResponse]])
async def list_auto_annotation_tasks(
db: AsyncSession = Depends(get_db),
):
async def list_auto_annotation_tasks(
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""获取自动标注任务列表。
前端当前不传分页参数,这里直接返回所有未删除任务。
"""
tasks = await service.list_tasks(db)
tasks = await service.list_tasks(db, user_context)
return StandardResponse(
code=200,
message="success",
@@ -54,28 +60,30 @@ async def list_auto_annotation_tasks(
@router.post("", response_model=StandardResponse[AutoAnnotationTaskResponse])
async def create_auto_annotation_task(
request: CreateAutoAnnotationTaskRequest,
db: AsyncSession = Depends(get_db),
):
async def create_auto_annotation_task(
request: CreateAutoAnnotationTaskRequest,
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""创建自动标注任务。
当前仅创建任务记录并置为 pending,实际执行由后续调度/worker 完成。
"""
logger.info(
"Creating auto annotation task: name=%s, dataset_id=%s, config=%s, file_ids=%s",
request.name,
logger.info(
"Creating auto annotation task: name=%s, dataset_id=%s, config=%s, file_ids=%s",
request.name,
request.dataset_id,
request.config.model_dump(by_alias=True),
request.file_ids,
)
# 尝试获取数据集名称和文件数量用于冗余字段,失败时不阻塞任务创建
dataset_name = None
total_images = 0
try:
dm_client = DatasetManagementService(db)
# 尝试获取数据集名称和文件数量用于冗余字段,失败时不阻塞任务创建
dataset_name = None
total_images = 0
await assert_dataset_access(db, request.dataset_id, user_context)
try:
dm_client = DatasetManagementService(db)
# Service.get_dataset 返回 DatasetResponse,包含 name 和 fileCount
dataset = await dm_client.get_dataset(request.dataset_id)
if dataset is not None:
@@ -103,16 +111,17 @@ async def create_auto_annotation_task(
@router.get("/{task_id}/status", response_model=StandardResponse[AutoAnnotationTaskResponse])
async def get_auto_annotation_task_status(
task_id: str = Path(..., description="任务ID"),
db: AsyncSession = Depends(get_db),
):
async def get_auto_annotation_task_status(
task_id: str = Path(..., description="任务ID"),
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""获取单个自动标注任务状态。
前端当前主要通过列表轮询,这里提供按 ID 查询的补充接口。
"""
task = await service.get_task(db, task_id)
task = await service.get_task(db, task_id, user_context)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
@@ -124,13 +133,14 @@ async def get_auto_annotation_task_status(
@router.delete("/{task_id}", response_model=StandardResponse[bool])
async def delete_auto_annotation_task(
task_id: str = Path(..., description="任务ID"),
db: AsyncSession = Depends(get_db),
):
async def delete_auto_annotation_task(
task_id: str = Path(..., description="任务ID"),
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""删除(软删除)自动标注任务,仅标记 deleted_at。"""
ok = await service.soft_delete_task(db, task_id)
ok = await service.soft_delete_task(db, task_id, user_context)
if not ok:
raise HTTPException(status_code=404, detail="Task not found")
@@ -142,10 +152,11 @@ async def delete_auto_annotation_task(
@router.get("/{task_id}/download")
async def download_auto_annotation_result(
task_id: str = Path(..., description="任务ID"),
db: AsyncSession = Depends(get_db),
):
async def download_auto_annotation_result(
task_id: str = Path(..., description="任务ID"),
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""下载指定自动标注任务的结果 ZIP。"""
import io
@@ -154,7 +165,7 @@ async def download_auto_annotation_result(
import tempfile
# 复用服务层获取任务信息
task = await service.get_task(db, task_id)
task = await service.get_task(db, task_id, user_context)
if not task:
raise HTTPException(status_code=404, detail="Task not found")

View File

@@ -27,6 +27,10 @@ from app.module.annotation.schema.editor import (
UpsertAnnotationResponse,
)
from app.module.annotation.service.editor import AnnotationEditorService
from app.module.annotation.security import (
RequestUserContext,
get_request_user_context,
)
from app.module.shared.schema import StandardResponse
logger = get_logger(__name__)
@@ -44,8 +48,9 @@ router = APIRouter(
async def get_editor_project_info(
project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"),
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
service = AnnotationEditorService(db)
service = AnnotationEditorService(db, user_context)
info = await service.get_project_info(project_id)
return StandardResponse(code=200, message="success", data=info)
@@ -64,8 +69,9 @@ async def list_editor_tasks(
description="是否排除已被转换为TXT的源文档文件(PDF/DOC/DOCX,仅文本数据集生效)",
),
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
service = AnnotationEditorService(db)
service = AnnotationEditorService(db, user_context)
result = await service.list_tasks(
project_id,
page=page,
@@ -86,8 +92,9 @@ async def get_editor_task(
None, alias="segmentIndex", description="段落索引(分段模式下使用)"
),
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
service = AnnotationEditorService(db)
service = AnnotationEditorService(db, user_context)
task = await service.get_task(project_id, file_id, segment_index=segment_index)
return StandardResponse(code=200, message="success", data=task)
@@ -103,8 +110,9 @@ async def get_editor_task_segment(
..., ge=0, alias="segmentIndex", description="段落索引(从0开始)"
),
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
service = AnnotationEditorService(db)
service = AnnotationEditorService(db, user_context)
result = await service.get_task_segment(project_id, file_id, segment_index)
return StandardResponse(code=200, message="success", data=result)
@@ -118,8 +126,9 @@ async def upsert_editor_annotation(
project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"),
file_id: str = Path(..., description="文件ID(t_dm_dataset_files.id)"),
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
service = AnnotationEditorService(db)
service = AnnotationEditorService(db, user_context)
result = await service.upsert_annotation(project_id, file_id, request)
return StandardResponse(code=200, message="success", data=result)
@@ -132,11 +141,12 @@ async def check_file_version(
project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"),
file_id: str = Path(..., description="文件ID(t_dm_dataset_files.id)"),
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""
检查文件是否有新版本
"""
service = AnnotationEditorService(db)
service = AnnotationEditorService(db, user_context)
result = await service.check_file_version(project_id, file_id)
return StandardResponse(code=200, message="success", data=result)
@@ -149,10 +159,11 @@ async def use_new_version(
project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"),
file_id: str = Path(..., description="文件ID(t_dm_dataset_files.id)"),
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""
使用文件新版本并清空标注
"""
service = AnnotationEditorService(db)
service = AnnotationEditorService(db, user_context)
result = await service.use_new_version(project_id, file_id)
return StandardResponse(code=200, message="success", data=result)

View File

@@ -12,6 +12,11 @@ from app.module.shared.schema import StandardResponse, PaginatedData
from app.module.dataset import DatasetManagementService
from app.core.logging import get_logger
from ..security import (
RequestUserContext,
assert_dataset_access,
get_request_user_context,
)
from ..service.mapping import DatasetMappingService
from ..service.template import AnnotationTemplateService
from ..service.knowledge_sync import KnowledgeSyncService
@@ -42,7 +47,9 @@ async def login_label_studio(mapping_id: str, db: AsyncSession = Depends(get_db)
"", response_model=StandardResponse[DatasetMappingCreateResponse], status_code=201
)
async def create_mapping(
request: DatasetMappingCreateRequest, db: AsyncSession = Depends(get_db)
request: DatasetMappingCreateRequest,
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""
创建数据集映射
@@ -58,6 +65,8 @@ async def create_mapping(
mapping_service = DatasetMappingService(db)
template_service = AnnotationTemplateService()
await assert_dataset_access(db, request.dataset_id, user_context)
logger.info(f"Create dataset mapping request: {request.dataset_id}")
# 从DM服务获取数据集信息
@@ -163,7 +172,7 @@ async def create_mapping(
try:
from ..service.editor import AnnotationEditorService
editor_service = AnnotationEditorService(db)
editor_service = AnnotationEditorService(db, user_context)
# 异步预计算切片(不阻塞创建响应)
segmentation_result = (
await editor_service.precompute_segmentation_for_project(
@@ -202,6 +211,7 @@ async def list_mappings(
False, description="是否包含模板详情", alias="includeTemplate"
),
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""
查询所有映射关系(分页)
@@ -230,6 +240,8 @@ async def list_mappings(
limit=size,
include_deleted=False,
include_template=include_template,
current_user_id=user_context.user_id,
is_admin=user_context.is_admin,
)
# 计算总页数
@@ -256,7 +268,11 @@ async def list_mappings(
@router.get("/{mapping_id}", response_model=StandardResponse[DatasetMappingResponse])
async def get_mapping(mapping_id: str, db: AsyncSession = Depends(get_db)):
async def get_mapping(
mapping_id: str,
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""
根据 UUID 查询单个映射关系(包含关联的标注模板详情)
@@ -278,6 +294,7 @@ async def get_mapping(mapping_id: str, db: AsyncSession = Depends(get_db)):
raise HTTPException(
status_code=404, detail=f"Mapping not found: {mapping_id}"
)
await assert_dataset_access(db, mapping.dataset_id, user_context)
logger.info(
f"Found mapping: {mapping.id}, template_included: {mapping.template is not None}"
@@ -304,6 +321,7 @@ async def get_mappings_by_source(
True, description="是否包含模板详情", alias="includeTemplate"
),
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""
根据源数据集 ID 查询所有映射关系(分页,包含模板详情)
@@ -319,6 +337,7 @@ async def get_mappings_by_source(
"""
try:
service = DatasetMappingService(db)
await assert_dataset_access(db, dataset_id, user_context)
# 计算 skip
skip = (page - 1) * size
@@ -333,6 +352,8 @@ async def get_mappings_by_source(
skip=skip,
limit=size,
include_template=include_template,
current_user_id=user_context.user_id,
is_admin=user_context.is_admin,
)
# 计算总页数
@@ -364,6 +385,7 @@ async def get_mappings_by_source(
async def delete_mapping(
project_id: str = Path(..., description="映射UUID(path param)"),
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""
删除映射关系(软删除)
@@ -387,6 +409,7 @@ async def delete_mapping(
raise HTTPException(
status_code=404, detail=f"Mapping either not found or not specified."
)
await assert_dataset_access(db, mapping.dataset_id, user_context)
id = mapping.id
dataset_id = mapping.dataset_id
@@ -428,6 +451,7 @@ async def update_mapping(
project_id: str = Path(..., description="映射UUID(path param)"),
request: DatasetMappingUpdateRequest = None,
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""
更新标注项目信息
@@ -456,6 +480,7 @@ async def update_mapping(
raise HTTPException(
status_code=404, detail=f"Mapping not found: {project_id}"
)
await assert_dataset_access(db, mapping_orm.dataset_id, user_context)
# 构建更新数据
update_values = {}

View File

@@ -10,6 +10,11 @@ from app.module.dataset import DatasetManagementService
from app.core.logging import get_logger
from app.core.config import settings
from ..security import (
RequestUserContext,
assert_dataset_access,
get_request_user_context,
)
from ..service.mapping import DatasetMappingService
from ..schema import (
SyncDatasetRequest,
@@ -32,7 +37,8 @@ logger = get_logger(__name__)
@router.post("/sync", response_model=StandardResponse[SyncDatasetResponse])
async def sync_dataset_content(
request: SyncDatasetRequest,
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""
Sync Dataset Content (Files and Annotations)
@@ -51,6 +57,7 @@ async def sync_dataset_content(
status_code=404,
detail=f"Mapping not found: {request.id}"
)
await assert_dataset_access(db, mapping.dataset_id, user_context)
dm_client = DatasetManagementService(db)
dataset_info = await dm_client.get_dataset(mapping.dataset_id)
@@ -82,7 +89,8 @@ async def sync_dataset_content(
@router.post("/annotation/sync", response_model=StandardResponse[SyncAnnotationsResponse])
async def sync_annotations(
request: SyncAnnotationsRequest,
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""
Sync Annotations Only (Bidirectional Support)
@@ -102,6 +110,7 @@ async def sync_annotations(
status_code=404,
detail=f"Mapping not found: {request.id}"
)
await assert_dataset_access(db, mapping.dataset_id, user_context)
result = SyncAnnotationsResponse(
id=mapping.id,
@@ -156,7 +165,8 @@ async def check_label_studio_connection():
async def update_file_tags(
request: UpdateFileTagsRequest,
file_id: str = Path(..., description="文件ID"),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""
Update File Tags (Partial Update with Auto Format Conversion)
@@ -189,6 +199,7 @@ async def update_file_tags(
raise HTTPException(status_code=404, detail=f"File not found: {file_id}")
dataset_id = str(file_record.dataset_id) # type: ignore - Convert Column to str
await assert_dataset_access(db, dataset_id, user_context)
# 查找数据集关联的模板ID
from ..service.mapping import DatasetMappingService