feat(annotation): 添加标注项目文件快照功能

- 新增 LabelingProjectFile 模型用于存储标注项目的文件快照
- 在创建标注项目时记录关联的文件快照数据
- 更新查询逻辑以基于项目快照过滤文件列表
- 优化导出统计功能使用快照数据进行计算
- 添加数据库表结构支持项目文件快照关系
This commit is contained in:
2026-01-30 18:10:13 +08:00
parent 3c3ca130b3
commit 8b2a19f09a
7 changed files with 145 additions and 33 deletions

View File

@@ -23,7 +23,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger
from app.db.models import AnnotationResult, Dataset, DatasetFiles, LabelingProject
from app.db.models import AnnotationResult, Dataset, DatasetFiles, LabelingProject, LabelingProjectFile
from app.module.annotation.config import LabelStudioTagConfig
from app.module.annotation.schema.editor import (
EditorProjectInfo,
@@ -436,14 +436,18 @@ class AnnotationEditorService:
exclude_source_documents if exclude_source_documents is not None else True
)
base_conditions = [DatasetFiles.dataset_id == project.dataset_id]
base_conditions = [
LabelingProjectFile.project_id == project_id,
DatasetFiles.dataset_id == project.dataset_id,
]
if should_exclude_source_documents:
base_conditions.append(~self._build_source_document_filter())
count_result = await self.db.execute(
select(func.count()).select_from(DatasetFiles).where(
*base_conditions
)
select(func.count())
.select_from(LabelingProjectFile)
.join(DatasetFiles, LabelingProjectFile.file_id == DatasetFiles.id)
.where(*base_conditions)
)
total = int(count_result.scalar() or 0)
@@ -453,6 +457,7 @@ class AnnotationEditorService:
)
files_result = await self.db.execute(
select(DatasetFiles, AnnotationResult.id, AnnotationResult.updated_at)
.join(LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id)
.outerjoin(
AnnotationResult,
(AnnotationResult.file_id == DatasetFiles.id)
@@ -827,7 +832,10 @@ class AnnotationEditorService:
# 校验文件归属
file_result = await self.db.execute(
select(DatasetFiles).where(
select(DatasetFiles)
.join(LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id)
.where(
LabelingProjectFile.project_id == project.id,
DatasetFiles.id == file_id,
DatasetFiles.dataset_id == project.dataset_id,
)

View File

@@ -25,7 +25,7 @@ from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.models import AnnotationResult, Dataset, DatasetFiles, LabelingProject
from app.db.models import AnnotationResult, Dataset, DatasetFiles, LabelingProject, LabelingProjectFile
async def _read_file_content(file_path: str, max_size: int = 10 * 1024 * 1024) -> Optional[str]:
@@ -75,15 +75,18 @@ class AnnotationExportService:
project = await self._get_project_or_404(project_id)
logger.info(f"Export stats for project: id={project_id}, dataset_id={project.dataset_id}, name={project.name}")
# 获取总文件数(只统计 ACTIVE 状态的文件)
# 获取总文件数(标注项目快照内的文件)
total_result = await self.db.execute(
select(func.count()).select_from(DatasetFiles).where(
select(func.count())
.select_from(LabelingProjectFile)
.join(DatasetFiles, LabelingProjectFile.file_id == DatasetFiles.id)
.where(
LabelingProjectFile.project_id == project_id,
DatasetFiles.dataset_id == project.dataset_id,
DatasetFiles.status == "ACTIVE",
)
)
total_files = int(total_result.scalar() or 0)
logger.info(f"Total files (ACTIVE): {total_files} for dataset_id={project.dataset_id}")
logger.info(f"Total files (snapshot): {total_files} for project_id={project_id}")
# 获取已标注文件数(统计不同的 file_id 数量)
annotated_result = await self.db.execute(
@@ -165,8 +168,13 @@ class AnnotationExportService:
# 只获取已标注的数据
result = await self.db.execute(
select(AnnotationResult, DatasetFiles)
.join(LabelingProjectFile, LabelingProjectFile.file_id == AnnotationResult.file_id)
.join(DatasetFiles, AnnotationResult.file_id == DatasetFiles.id)
.where(AnnotationResult.project_id == project_id)
.where(
AnnotationResult.project_id == project_id,
LabelingProjectFile.project_id == project_id,
DatasetFiles.dataset_id == dataset_id,
)
.order_by(AnnotationResult.updated_at.desc())
)
rows = result.all()
@@ -190,11 +198,13 @@ class AnnotationExportService:
)
)
else:
# 获取所有文件,包括未标注的(只获取 ACTIVE 状态的文件
# 获取所有文件(基于标注项目快照
files_result = await self.db.execute(
select(DatasetFiles).where(
select(DatasetFiles)
.join(LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id)
.where(
LabelingProjectFile.project_id == project_id,
DatasetFiles.dataset_id == dataset_id,
DatasetFiles.status == "ACTIVE",
)
)
files = files_result.scalars().all()

View File

@@ -1,13 +1,13 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import update, func
from sqlalchemy import update, func, insert
from sqlalchemy.orm import aliased
from typing import Optional, List, Tuple
from datetime import datetime
import uuid
from app.core.logging import get_logger
from app.db.models import LabelingProject, AnnotationTemplate, AnnotationResult
from app.db.models import LabelingProject, AnnotationTemplate, AnnotationResult, LabelingProjectFile
from app.db.models.dataset_management import Dataset, DatasetFiles
from app.module.annotation.schema import (
DatasetMappingCreateRequest,
@@ -20,9 +20,11 @@ logger = get_logger(__name__)
class DatasetMappingService:
"""数据集映射服务"""
def __init__(self, db: AsyncSession):
self.db = db
SNAPSHOT_INSERT_BATCH_SIZE = 500
def _build_query_with_dataset_name(self):
"""Build base query with dataset name joined"""
@@ -49,11 +51,14 @@ class DatasetMappingService:
Returns:
(total_count, annotated_count) 元组
"""
# 获取数据集总数据量(统计 ACTIVE 和 COMPLETED 状态的文件)
# 获取标注项目快照数据量(统计快照内的文件)
total_result = await self.db.execute(
select(func.count()).select_from(DatasetFiles).where(
select(func.count())
.select_from(LabelingProjectFile)
.join(DatasetFiles, LabelingProjectFile.file_id == DatasetFiles.id)
.where(
LabelingProjectFile.project_id == project_id,
DatasetFiles.dataset_id == dataset_id,
DatasetFiles.status.in_(["ACTIVE", "COMPLETED"]),
)
)
total_count = int(total_result.scalar() or 0)
@@ -213,6 +218,48 @@ class DatasetMappingService:
logger.debug(f"Mapping created: {labeling_project.id}")
return await self._to_response(labeling_project)
async def create_mapping_with_snapshot(
self,
labeling_project: LabelingProject,
file_ids: List[str],
) -> DatasetMappingResponse:
"""创建数据集映射并写入快照文件"""
logger.debug(
"Create dataset mapping with snapshot: %s -> %s, files=%d",
labeling_project.dataset_id,
labeling_project.labeling_project_id,
len(file_ids),
)
self.db.add(labeling_project)
await self.db.flush()
assert labeling_project.id, "labeling_project.id must be set before snapshot insert"
if file_ids:
await self._insert_snapshot_records(labeling_project.id, file_ids)
await self.db.commit()
await self.db.refresh(labeling_project)
logger.debug("Mapping created with snapshot: %s", labeling_project.id)
return await self._to_response(labeling_project)
async def _insert_snapshot_records(self, project_id: str, file_ids: List[str]) -> None:
batch: List[dict] = []
for file_id in file_ids:
batch.append(
{
"id": str(uuid.uuid4()),
"project_id": project_id,
"file_id": file_id,
}
)
if len(batch) >= self.SNAPSHOT_INSERT_BATCH_SIZE:
await self.db.execute(insert(LabelingProjectFile).values(batch))
batch.clear()
if batch:
await self.db.execute(insert(LabelingProjectFile).values(batch))
async def get_mapping_by_source_uuid(
self,