You've already forked DataMate
feat(auth): 为数据管理和RAG服务增加资源访问控制
- 在DatasetApplicationService中注入ResourceAccessService并添加所有权验证 - 在KnowledgeSetApplicationService中注入ResourceAccessService并添加所有权验证 - 修改DatasetRepository接口和实现类,增加按创建者过滤的方法 - 修改KnowledgeSetRepository接口和实现类,增加按创建者过滤的方法 - 在RAG索引器服务中添加知识库访问权限检查和作用域过滤 - 更新实体元对象处理器以使用请求用户上下文获取当前用户 - 在前端设置页面添加用户权限管理功能和角色权限控制 - 为Python标注服务增加用户上下文和数据集访问权限验证
This commit is contained in:
@@ -5,11 +5,12 @@ from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models.annotation_management import AutoAnnotationTask
|
||||
from app.db.models.dataset_management import Dataset, DatasetFiles
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models.annotation_management import AutoAnnotationTask
|
||||
from app.db.models.dataset_management import Dataset, DatasetFiles
|
||||
from app.module.annotation.security import RequestUserContext
|
||||
|
||||
from ..schema.auto import (
|
||||
CreateAutoAnnotationTaskRequest,
|
||||
@@ -17,7 +18,7 @@ from ..schema.auto import (
|
||||
)
|
||||
|
||||
|
||||
class AutoAnnotationTaskService:
|
||||
class AutoAnnotationTaskService:
|
||||
"""自动标注任务服务(仅管理任务元数据,真正执行由 runtime 负责)"""
|
||||
|
||||
async def create_task(
|
||||
@@ -63,15 +64,27 @@ class AutoAnnotationTaskService:
|
||||
resp.source_datasets = [dataset_name] if dataset_name else [request.dataset_id]
|
||||
return resp
|
||||
|
||||
async def list_tasks(self, db: AsyncSession) -> List[AutoAnnotationTaskResponse]:
|
||||
"""获取未软删除的自动标注任务列表,按创建时间倒序。"""
|
||||
|
||||
result = await db.execute(
|
||||
select(AutoAnnotationTask)
|
||||
.where(AutoAnnotationTask.deleted_at.is_(None))
|
||||
.order_by(AutoAnnotationTask.created_at.desc())
|
||||
)
|
||||
tasks: List[AutoAnnotationTask] = list(result.scalars().all())
|
||||
def _apply_dataset_scope(self, query, user_context: RequestUserContext):
|
||||
if user_context.is_admin:
|
||||
return query
|
||||
return query.join(
|
||||
Dataset,
|
||||
AutoAnnotationTask.dataset_id == Dataset.id,
|
||||
).where(Dataset.created_by == user_context.user_id)
|
||||
|
||||
async def list_tasks(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
user_context: RequestUserContext,
|
||||
) -> List[AutoAnnotationTaskResponse]:
|
||||
"""获取未软删除的自动标注任务列表,按创建时间倒序。"""
|
||||
|
||||
query = select(AutoAnnotationTask).where(AutoAnnotationTask.deleted_at.is_(None))
|
||||
query = self._apply_dataset_scope(query, user_context)
|
||||
result = await db.execute(
|
||||
query.order_by(AutoAnnotationTask.created_at.desc())
|
||||
)
|
||||
tasks: List[AutoAnnotationTask] = list(result.scalars().all())
|
||||
|
||||
responses: List[AutoAnnotationTaskResponse] = []
|
||||
for task in tasks:
|
||||
@@ -87,16 +100,21 @@ class AutoAnnotationTaskService:
|
||||
|
||||
return responses
|
||||
|
||||
async def get_task(self, db: AsyncSession, task_id: str) -> Optional[AutoAnnotationTaskResponse]:
|
||||
result = await db.execute(
|
||||
select(AutoAnnotationTask).where(
|
||||
AutoAnnotationTask.id == task_id,
|
||||
AutoAnnotationTask.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
return None
|
||||
async def get_task(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
task_id: str,
|
||||
user_context: RequestUserContext,
|
||||
) -> Optional[AutoAnnotationTaskResponse]:
|
||||
query = select(AutoAnnotationTask).where(
|
||||
AutoAnnotationTask.id == task_id,
|
||||
AutoAnnotationTask.deleted_at.is_(None),
|
||||
)
|
||||
query = self._apply_dataset_scope(query, user_context)
|
||||
result = await db.execute(query)
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
return None
|
||||
|
||||
resp = AutoAnnotationTaskResponse.model_validate(task)
|
||||
try:
|
||||
@@ -138,16 +156,21 @@ class AutoAnnotationTaskService:
|
||||
return [task.dataset_id]
|
||||
return []
|
||||
|
||||
async def soft_delete_task(self, db: AsyncSession, task_id: str) -> bool:
|
||||
result = await db.execute(
|
||||
select(AutoAnnotationTask).where(
|
||||
AutoAnnotationTask.id == task_id,
|
||||
AutoAnnotationTask.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
return False
|
||||
async def soft_delete_task(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
task_id: str,
|
||||
user_context: RequestUserContext,
|
||||
) -> bool:
|
||||
query = select(AutoAnnotationTask).where(
|
||||
AutoAnnotationTask.id == task_id,
|
||||
AutoAnnotationTask.deleted_at.is_(None),
|
||||
)
|
||||
query = self._apply_dataset_scope(query, user_context)
|
||||
result = await db.execute(query)
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
return False
|
||||
|
||||
task.deleted_at = datetime.now()
|
||||
await db.commit()
|
||||
|
||||
@@ -54,6 +54,10 @@ from app.module.annotation.service.knowledge_sync import KnowledgeSyncService
|
||||
from app.module.annotation.service.annotation_text_splitter import (
|
||||
AnnotationTextSplitter,
|
||||
)
|
||||
from app.module.annotation.security import (
|
||||
RequestUserContext,
|
||||
ensure_dataset_owner_access,
|
||||
)
|
||||
from app.module.annotation.service.text_fetcher import (
|
||||
fetch_text_content_via_download_api,
|
||||
)
|
||||
@@ -104,8 +108,9 @@ class AnnotationEditorService:
|
||||
# 分段阈值:超过此字符数自动分段
|
||||
SEGMENT_THRESHOLD = 200
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
def __init__(self, db: AsyncSession, user_context: RequestUserContext):
|
||||
self.db = db
|
||||
self.user_context = user_context
|
||||
self.template_service = AnnotationTemplateService()
|
||||
|
||||
@staticmethod
|
||||
@@ -157,14 +162,24 @@ class AnnotationEditorService:
|
||||
|
||||
async def _get_project_or_404(self, project_id: str) -> LabelingProject:
|
||||
result = await self.db.execute(
|
||||
select(LabelingProject).where(
|
||||
select(LabelingProject, Dataset.created_by).join(
|
||||
Dataset,
|
||||
LabelingProject.dataset_id == Dataset.id,
|
||||
).where(
|
||||
LabelingProject.id == project_id,
|
||||
LabelingProject.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
row = result.first()
|
||||
if not row:
|
||||
raise HTTPException(status_code=404, detail=f"标注项目不存在: {project_id}")
|
||||
project = row[0]
|
||||
dataset_owner = row[1]
|
||||
ensure_dataset_owner_access(
|
||||
self.user_context,
|
||||
str(dataset_owner) if dataset_owner is not None else None,
|
||||
project.dataset_id,
|
||||
)
|
||||
return project
|
||||
|
||||
async def _get_dataset_type(self, dataset_id: str) -> Optional[str]:
|
||||
|
||||
@@ -478,7 +478,9 @@ class DatasetMappingService:
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
include_deleted: bool = False,
|
||||
include_template: bool = False
|
||||
include_template: bool = False,
|
||||
current_user_id: Optional[str] = None,
|
||||
is_admin: bool = False,
|
||||
) -> Tuple[List[DatasetMappingResponse], int]:
|
||||
"""
|
||||
获取所有映射及总数(用于分页)
|
||||
@@ -495,9 +497,16 @@ class DatasetMappingService:
|
||||
query = self._build_query_with_dataset_name()
|
||||
if not include_deleted:
|
||||
query = query.where(LabelingProject.deleted_at.is_(None))
|
||||
if not is_admin:
|
||||
query = query.where(Dataset.created_by == current_user_id)
|
||||
|
||||
# 获取总数
|
||||
count_query = select(func.count()).select_from(LabelingProject)
|
||||
if not is_admin:
|
||||
count_query = count_query.join(
|
||||
Dataset,
|
||||
LabelingProject.dataset_id == Dataset.id,
|
||||
).where(Dataset.created_by == current_user_id)
|
||||
if not include_deleted:
|
||||
count_query = count_query.where(LabelingProject.deleted_at.is_(None))
|
||||
|
||||
@@ -557,7 +566,9 @@ class DatasetMappingService:
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
include_deleted: bool = False,
|
||||
include_template: bool = False
|
||||
include_template: bool = False,
|
||||
current_user_id: Optional[str] = None,
|
||||
is_admin: bool = False,
|
||||
) -> Tuple[List[DatasetMappingResponse], int]:
|
||||
"""
|
||||
根据源数据集ID获取映射关系及总数(用于分页)
|
||||
@@ -578,11 +589,18 @@ class DatasetMappingService:
|
||||
|
||||
if not include_deleted:
|
||||
query = query.where(LabelingProject.deleted_at.is_(None))
|
||||
if not is_admin:
|
||||
query = query.where(Dataset.created_by == current_user_id)
|
||||
|
||||
# 获取总数
|
||||
count_query = select(func.count()).select_from(LabelingProject).where(
|
||||
LabelingProject.dataset_id == dataset_id
|
||||
)
|
||||
if not is_admin:
|
||||
count_query = count_query.join(
|
||||
Dataset,
|
||||
LabelingProject.dataset_id == Dataset.id,
|
||||
).where(Dataset.created_by == current_user_id)
|
||||
if not include_deleted:
|
||||
count_query = count_query.where(LabelingProject.deleted_at.is_(None))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user