feat(annotation): 添加标注项目文件快照功能

- 新增 LabelingProjectFile 模型用于存储标注项目的文件快照
- 在创建标注项目时记录关联的文件快照数据
- 更新查询逻辑以基于项目快照过滤文件列表
- 优化导出统计功能使用快照数据进行计算
- 添加数据库表结构支持项目文件快照关系
This commit is contained in:
2026-01-30 18:10:13 +08:00
parent 3c3ca130b3
commit 8b2a19f09a
7 changed files with 145 additions and 33 deletions

View File

@@ -14,7 +14,8 @@ from .user_management import (
from .annotation_management import (
AnnotationTemplate,
LabelingProject,
AnnotationResult
AnnotationResult,
LabelingProjectFile
)
from .data_evaluation import (
@@ -32,6 +33,7 @@ __all__ = [
"AnnotationTemplate",
"LabelingProject",
"AnnotationResult",
"LabelingProjectFile",
"EvaluationTask",
"EvaluationItem",
]

View File

@@ -1,7 +1,7 @@
"""Tables of Annotation Management Module"""
import uuid
from sqlalchemy import Column, String, Boolean, TIMESTAMP, Text, Integer, JSON, ForeignKey
import uuid
from sqlalchemy import Column, String, Boolean, TIMESTAMP, Text, Integer, JSON, ForeignKey, UniqueConstraint, Index
from sqlalchemy.sql import func
from app.db.session import Base
@@ -34,7 +34,7 @@ class AnnotationTemplate(Base):
"""检查是否已被软删除"""
return self.deleted_at is not None
class LabelingProject(Base):
class LabelingProject(Base):
"""标注项目模型"""
__tablename__ = "t_dm_labeling_projects"
@@ -50,13 +50,33 @@ class LabelingProject(Base):
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
deleted_at = Column(TIMESTAMP, nullable=True, comment="删除时间(软删除)")
def __repr__(self):
return f"<LabelingProject(id={self.id}, name={self.name}, dataset_id={self.dataset_id})>"
def __repr__(self):
return f"<LabelingProject(id={self.id}, name={self.name}, dataset_id={self.dataset_id})>"
@property
def is_deleted(self) -> bool:
"""检查是否已被软删除"""
return self.deleted_at is not None
def is_deleted(self) -> bool:
"""检查是否已被软删除"""
return self.deleted_at is not None
class LabelingProjectFile(Base):
"""标注项目文件快照模型"""
__tablename__ = "t_dm_labeling_project_files"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID")
project_id = Column(String(36), nullable=False, comment="标注项目ID")
file_id = Column(String(36), nullable=False, comment="文件ID")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
__table_args__ = (
UniqueConstraint("project_id", "file_id", name="uk_project_file"),
Index("idx_project_id", "project_id"),
Index("idx_file_id", "file_id"),
)
def __repr__(self):
return f"<LabelingProjectFile(id={self.id}, project_id={self.project_id}, file_id={self.file_id})>"
class AnnotationResult(Base):

View File

@@ -3,14 +3,16 @@ import math
import uuid
from fastapi import APIRouter, Depends, HTTPException, Query, Path
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.session import get_db
from app.db.models import LabelingProject
from app.db.models import LabelingProject, DatasetFiles
from app.module.shared.schema import StandardResponse, PaginatedData
from app.module.dataset import DatasetManagementService
from app.core.logging import get_logger
from app.module.annotation.service.editor import AnnotationEditorService
from ..service.mapping import DatasetMappingService
from ..service.template import AnnotationTemplateService
from ..schema import (
@@ -116,8 +118,20 @@ async def create_mapping(
configuration=project_configuration or None,
)
# 创建映射关系,包含项目名称(先持久化映射以获得 mapping.id)
mapping = await mapping_service.create_mapping(labeling_project)
file_query = select(DatasetFiles.id).where(
DatasetFiles.dataset_id == request.dataset_id
)
if dataset_type == TEXT_DATASET_TYPE:
file_query = file_query.where(
~AnnotationEditorService._build_source_document_filter()
)
file_result = await db.execute(file_query)
snapshot_file_ids = [str(fid) for fid in file_result.scalars().all()]
# 创建映射关系并写入快照
mapping = await mapping_service.create_mapping_with_snapshot(
labeling_project, snapshot_file_ids
)
response_data = DatasetMappingCreateResponse(
id=mapping.id,

View File

@@ -23,7 +23,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger
from app.db.models import AnnotationResult, Dataset, DatasetFiles, LabelingProject
from app.db.models import AnnotationResult, Dataset, DatasetFiles, LabelingProject, LabelingProjectFile
from app.module.annotation.config import LabelStudioTagConfig
from app.module.annotation.schema.editor import (
EditorProjectInfo,
@@ -436,14 +436,18 @@ class AnnotationEditorService:
exclude_source_documents if exclude_source_documents is not None else True
)
base_conditions = [DatasetFiles.dataset_id == project.dataset_id]
base_conditions = [
LabelingProjectFile.project_id == project_id,
DatasetFiles.dataset_id == project.dataset_id,
]
if should_exclude_source_documents:
base_conditions.append(~self._build_source_document_filter())
count_result = await self.db.execute(
select(func.count()).select_from(DatasetFiles).where(
*base_conditions
)
select(func.count())
.select_from(LabelingProjectFile)
.join(DatasetFiles, LabelingProjectFile.file_id == DatasetFiles.id)
.where(*base_conditions)
)
total = int(count_result.scalar() or 0)
@@ -453,6 +457,7 @@ class AnnotationEditorService:
)
files_result = await self.db.execute(
select(DatasetFiles, AnnotationResult.id, AnnotationResult.updated_at)
.join(LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id)
.outerjoin(
AnnotationResult,
(AnnotationResult.file_id == DatasetFiles.id)
@@ -827,7 +832,10 @@ class AnnotationEditorService:
# 校验文件归属
file_result = await self.db.execute(
select(DatasetFiles).where(
select(DatasetFiles)
.join(LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id)
.where(
LabelingProjectFile.project_id == project.id,
DatasetFiles.id == file_id,
DatasetFiles.dataset_id == project.dataset_id,
)

View File

@@ -25,7 +25,7 @@ from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.models import AnnotationResult, Dataset, DatasetFiles, LabelingProject
from app.db.models import AnnotationResult, Dataset, DatasetFiles, LabelingProject, LabelingProjectFile
async def _read_file_content(file_path: str, max_size: int = 10 * 1024 * 1024) -> Optional[str]:
@@ -75,15 +75,18 @@ class AnnotationExportService:
project = await self._get_project_or_404(project_id)
logger.info(f"Export stats for project: id={project_id}, dataset_id={project.dataset_id}, name={project.name}")
# 获取总文件数(只统计 ACTIVE 状态的文件)
# 获取总文件数(标注项目快照内的文件)
total_result = await self.db.execute(
select(func.count()).select_from(DatasetFiles).where(
select(func.count())
.select_from(LabelingProjectFile)
.join(DatasetFiles, LabelingProjectFile.file_id == DatasetFiles.id)
.where(
LabelingProjectFile.project_id == project_id,
DatasetFiles.dataset_id == project.dataset_id,
DatasetFiles.status == "ACTIVE",
)
)
total_files = int(total_result.scalar() or 0)
logger.info(f"Total files (ACTIVE): {total_files} for dataset_id={project.dataset_id}")
logger.info(f"Total files (snapshot): {total_files} for project_id={project_id}")
# 获取已标注文件数(统计不同的 file_id 数量)
annotated_result = await self.db.execute(
@@ -165,8 +168,13 @@ class AnnotationExportService:
# 只获取已标注的数据
result = await self.db.execute(
select(AnnotationResult, DatasetFiles)
.join(LabelingProjectFile, LabelingProjectFile.file_id == AnnotationResult.file_id)
.join(DatasetFiles, AnnotationResult.file_id == DatasetFiles.id)
.where(AnnotationResult.project_id == project_id)
.where(
AnnotationResult.project_id == project_id,
LabelingProjectFile.project_id == project_id,
DatasetFiles.dataset_id == dataset_id,
)
.order_by(AnnotationResult.updated_at.desc())
)
rows = result.all()
@@ -190,11 +198,13 @@ class AnnotationExportService:
)
)
else:
# 获取所有文件,包括未标注的(只获取 ACTIVE 状态的文件
# 获取所有文件(基于标注项目快照
files_result = await self.db.execute(
select(DatasetFiles).where(
select(DatasetFiles)
.join(LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id)
.where(
LabelingProjectFile.project_id == project_id,
DatasetFiles.dataset_id == dataset_id,
DatasetFiles.status == "ACTIVE",
)
)
files = files_result.scalars().all()

View File

@@ -1,13 +1,13 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import update, func
from sqlalchemy import update, func, insert
from sqlalchemy.orm import aliased
from typing import Optional, List, Tuple
from datetime import datetime
import uuid
from app.core.logging import get_logger
from app.db.models import LabelingProject, AnnotationTemplate, AnnotationResult
from app.db.models import LabelingProject, AnnotationTemplate, AnnotationResult, LabelingProjectFile
from app.db.models.dataset_management import Dataset, DatasetFiles
from app.module.annotation.schema import (
DatasetMappingCreateRequest,
@@ -20,9 +20,11 @@ logger = get_logger(__name__)
class DatasetMappingService:
"""数据集映射服务"""
def __init__(self, db: AsyncSession):
self.db = db
SNAPSHOT_INSERT_BATCH_SIZE = 500
def _build_query_with_dataset_name(self):
"""Build base query with dataset name joined"""
@@ -49,11 +51,14 @@ class DatasetMappingService:
Returns:
(total_count, annotated_count) 元组
"""
# 获取数据集总数据量(统计 ACTIVE 和 COMPLETED 状态的文件)
# 获取标注项目快照数据量(统计快照内的文件)
total_result = await self.db.execute(
select(func.count()).select_from(DatasetFiles).where(
select(func.count())
.select_from(LabelingProjectFile)
.join(DatasetFiles, LabelingProjectFile.file_id == DatasetFiles.id)
.where(
LabelingProjectFile.project_id == project_id,
DatasetFiles.dataset_id == dataset_id,
DatasetFiles.status.in_(["ACTIVE", "COMPLETED"]),
)
)
total_count = int(total_result.scalar() or 0)
@@ -213,6 +218,48 @@ class DatasetMappingService:
logger.debug(f"Mapping created: {labeling_project.id}")
return await self._to_response(labeling_project)
async def create_mapping_with_snapshot(
self,
labeling_project: LabelingProject,
file_ids: List[str],
) -> DatasetMappingResponse:
"""创建数据集映射并写入快照文件"""
logger.debug(
"Create dataset mapping with snapshot: %s -> %s, files=%d",
labeling_project.dataset_id,
labeling_project.labeling_project_id,
len(file_ids),
)
self.db.add(labeling_project)
await self.db.flush()
assert labeling_project.id, "labeling_project.id must be set before snapshot insert"
if file_ids:
await self._insert_snapshot_records(labeling_project.id, file_ids)
await self.db.commit()
await self.db.refresh(labeling_project)
logger.debug("Mapping created with snapshot: %s", labeling_project.id)
return await self._to_response(labeling_project)
async def _insert_snapshot_records(self, project_id: str, file_ids: List[str]) -> None:
batch: List[dict] = []
for file_id in file_ids:
batch.append(
{
"id": str(uuid.uuid4()),
"project_id": project_id,
"file_id": file_id,
}
)
if len(batch) >= self.SNAPSHOT_INSERT_BATCH_SIZE:
await self.db.execute(insert(LabelingProjectFile).values(batch))
batch.clear()
if batch:
await self.db.execute(insert(LabelingProjectFile).values(batch))
async def get_mapping_by_source_uuid(
self,