You've already forked DataMate
feat(annotation): 添加标注项目文件快照功能
- 新增 LabelingProjectFile 模型用于存储标注项目的文件快照 - 在创建标注项目时记录关联的文件快照数据 - 更新查询逻辑以基于项目快照过滤文件列表 - 优化导出统计功能使用快照数据进行计算 - 添加数据库表结构支持项目文件快照关系
This commit is contained in:
@@ -3,14 +3,16 @@ import math
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Path
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.session import get_db
|
||||
from app.db.models import LabelingProject
|
||||
from app.db.models import LabelingProject, DatasetFiles
|
||||
from app.module.shared.schema import StandardResponse, PaginatedData
|
||||
from app.module.dataset import DatasetManagementService
|
||||
from app.core.logging import get_logger
|
||||
|
||||
from app.module.annotation.service.editor import AnnotationEditorService
|
||||
from ..service.mapping import DatasetMappingService
|
||||
from ..service.template import AnnotationTemplateService
|
||||
from ..schema import (
|
||||
@@ -116,8 +118,20 @@ async def create_mapping(
|
||||
configuration=project_configuration or None,
|
||||
)
|
||||
|
||||
# 创建映射关系,包含项目名称(先持久化映射以获得 mapping.id)
|
||||
mapping = await mapping_service.create_mapping(labeling_project)
|
||||
file_query = select(DatasetFiles.id).where(
|
||||
DatasetFiles.dataset_id == request.dataset_id
|
||||
)
|
||||
if dataset_type == TEXT_DATASET_TYPE:
|
||||
file_query = file_query.where(
|
||||
~AnnotationEditorService._build_source_document_filter()
|
||||
)
|
||||
file_result = await db.execute(file_query)
|
||||
snapshot_file_ids = [str(fid) for fid in file_result.scalars().all()]
|
||||
|
||||
# 创建映射关系并写入快照
|
||||
mapping = await mapping_service.create_mapping_with_snapshot(
|
||||
labeling_project, snapshot_file_ids
|
||||
)
|
||||
|
||||
response_data = DatasetMappingCreateResponse(
|
||||
id=mapping.id,
|
||||
|
||||
@@ -23,7 +23,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
from app.db.models import AnnotationResult, Dataset, DatasetFiles, LabelingProject
|
||||
from app.db.models import AnnotationResult, Dataset, DatasetFiles, LabelingProject, LabelingProjectFile
|
||||
from app.module.annotation.config import LabelStudioTagConfig
|
||||
from app.module.annotation.schema.editor import (
|
||||
EditorProjectInfo,
|
||||
@@ -436,14 +436,18 @@ class AnnotationEditorService:
|
||||
exclude_source_documents if exclude_source_documents is not None else True
|
||||
)
|
||||
|
||||
base_conditions = [DatasetFiles.dataset_id == project.dataset_id]
|
||||
base_conditions = [
|
||||
LabelingProjectFile.project_id == project_id,
|
||||
DatasetFiles.dataset_id == project.dataset_id,
|
||||
]
|
||||
if should_exclude_source_documents:
|
||||
base_conditions.append(~self._build_source_document_filter())
|
||||
|
||||
count_result = await self.db.execute(
|
||||
select(func.count()).select_from(DatasetFiles).where(
|
||||
*base_conditions
|
||||
)
|
||||
select(func.count())
|
||||
.select_from(LabelingProjectFile)
|
||||
.join(DatasetFiles, LabelingProjectFile.file_id == DatasetFiles.id)
|
||||
.where(*base_conditions)
|
||||
)
|
||||
total = int(count_result.scalar() or 0)
|
||||
|
||||
@@ -453,6 +457,7 @@ class AnnotationEditorService:
|
||||
)
|
||||
files_result = await self.db.execute(
|
||||
select(DatasetFiles, AnnotationResult.id, AnnotationResult.updated_at)
|
||||
.join(LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id)
|
||||
.outerjoin(
|
||||
AnnotationResult,
|
||||
(AnnotationResult.file_id == DatasetFiles.id)
|
||||
@@ -827,7 +832,10 @@ class AnnotationEditorService:
|
||||
|
||||
# 校验文件归属
|
||||
file_result = await self.db.execute(
|
||||
select(DatasetFiles).where(
|
||||
select(DatasetFiles)
|
||||
.join(LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id)
|
||||
.where(
|
||||
LabelingProjectFile.project_id == project.id,
|
||||
DatasetFiles.id == file_id,
|
||||
DatasetFiles.dataset_id == project.dataset_id,
|
||||
)
|
||||
|
||||
@@ -25,7 +25,7 @@ from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.db.models import AnnotationResult, Dataset, DatasetFiles, LabelingProject
|
||||
from app.db.models import AnnotationResult, Dataset, DatasetFiles, LabelingProject, LabelingProjectFile
|
||||
|
||||
|
||||
async def _read_file_content(file_path: str, max_size: int = 10 * 1024 * 1024) -> Optional[str]:
|
||||
@@ -75,15 +75,18 @@ class AnnotationExportService:
|
||||
project = await self._get_project_or_404(project_id)
|
||||
logger.info(f"Export stats for project: id={project_id}, dataset_id={project.dataset_id}, name={project.name}")
|
||||
|
||||
# 获取总文件数(只统计 ACTIVE 状态的文件)
|
||||
# 获取总文件数(标注项目快照内的文件)
|
||||
total_result = await self.db.execute(
|
||||
select(func.count()).select_from(DatasetFiles).where(
|
||||
select(func.count())
|
||||
.select_from(LabelingProjectFile)
|
||||
.join(DatasetFiles, LabelingProjectFile.file_id == DatasetFiles.id)
|
||||
.where(
|
||||
LabelingProjectFile.project_id == project_id,
|
||||
DatasetFiles.dataset_id == project.dataset_id,
|
||||
DatasetFiles.status == "ACTIVE",
|
||||
)
|
||||
)
|
||||
total_files = int(total_result.scalar() or 0)
|
||||
logger.info(f"Total files (ACTIVE): {total_files} for dataset_id={project.dataset_id}")
|
||||
logger.info(f"Total files (snapshot): {total_files} for project_id={project_id}")
|
||||
|
||||
# 获取已标注文件数(统计不同的 file_id 数量)
|
||||
annotated_result = await self.db.execute(
|
||||
@@ -165,8 +168,13 @@ class AnnotationExportService:
|
||||
# 只获取已标注的数据
|
||||
result = await self.db.execute(
|
||||
select(AnnotationResult, DatasetFiles)
|
||||
.join(LabelingProjectFile, LabelingProjectFile.file_id == AnnotationResult.file_id)
|
||||
.join(DatasetFiles, AnnotationResult.file_id == DatasetFiles.id)
|
||||
.where(AnnotationResult.project_id == project_id)
|
||||
.where(
|
||||
AnnotationResult.project_id == project_id,
|
||||
LabelingProjectFile.project_id == project_id,
|
||||
DatasetFiles.dataset_id == dataset_id,
|
||||
)
|
||||
.order_by(AnnotationResult.updated_at.desc())
|
||||
)
|
||||
rows = result.all()
|
||||
@@ -190,11 +198,13 @@ class AnnotationExportService:
|
||||
)
|
||||
)
|
||||
else:
|
||||
# 获取所有文件,包括未标注的(只获取 ACTIVE 状态的文件)
|
||||
# 获取所有文件(基于标注项目快照)
|
||||
files_result = await self.db.execute(
|
||||
select(DatasetFiles).where(
|
||||
select(DatasetFiles)
|
||||
.join(LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id)
|
||||
.where(
|
||||
LabelingProjectFile.project_id == project_id,
|
||||
DatasetFiles.dataset_id == dataset_id,
|
||||
DatasetFiles.status == "ACTIVE",
|
||||
)
|
||||
)
|
||||
files = files_result.scalars().all()
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy import update, func
|
||||
from sqlalchemy import update, func, insert
|
||||
from sqlalchemy.orm import aliased
|
||||
from typing import Optional, List, Tuple
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.db.models import LabelingProject, AnnotationTemplate, AnnotationResult
|
||||
from app.db.models import LabelingProject, AnnotationTemplate, AnnotationResult, LabelingProjectFile
|
||||
from app.db.models.dataset_management import Dataset, DatasetFiles
|
||||
from app.module.annotation.schema import (
|
||||
DatasetMappingCreateRequest,
|
||||
@@ -20,9 +20,11 @@ logger = get_logger(__name__)
|
||||
|
||||
class DatasetMappingService:
|
||||
"""数据集映射服务"""
|
||||
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
SNAPSHOT_INSERT_BATCH_SIZE = 500
|
||||
|
||||
def _build_query_with_dataset_name(self):
|
||||
"""Build base query with dataset name joined"""
|
||||
@@ -49,11 +51,14 @@ class DatasetMappingService:
|
||||
Returns:
|
||||
(total_count, annotated_count) 元组
|
||||
"""
|
||||
# 获取数据集总数据量(统计 ACTIVE 和 COMPLETED 状态的文件)
|
||||
# 获取标注项目快照数据量(只统计快照内的文件)
|
||||
total_result = await self.db.execute(
|
||||
select(func.count()).select_from(DatasetFiles).where(
|
||||
select(func.count())
|
||||
.select_from(LabelingProjectFile)
|
||||
.join(DatasetFiles, LabelingProjectFile.file_id == DatasetFiles.id)
|
||||
.where(
|
||||
LabelingProjectFile.project_id == project_id,
|
||||
DatasetFiles.dataset_id == dataset_id,
|
||||
DatasetFiles.status.in_(["ACTIVE", "COMPLETED"]),
|
||||
)
|
||||
)
|
||||
total_count = int(total_result.scalar() or 0)
|
||||
@@ -213,6 +218,48 @@ class DatasetMappingService:
|
||||
|
||||
logger.debug(f"Mapping created: {labeling_project.id}")
|
||||
return await self._to_response(labeling_project)
|
||||
|
||||
async def create_mapping_with_snapshot(
|
||||
self,
|
||||
labeling_project: LabelingProject,
|
||||
file_ids: List[str],
|
||||
) -> DatasetMappingResponse:
|
||||
"""创建数据集映射并写入快照文件"""
|
||||
logger.debug(
|
||||
"Create dataset mapping with snapshot: %s -> %s, files=%d",
|
||||
labeling_project.dataset_id,
|
||||
labeling_project.labeling_project_id,
|
||||
len(file_ids),
|
||||
)
|
||||
|
||||
self.db.add(labeling_project)
|
||||
await self.db.flush()
|
||||
assert labeling_project.id, "labeling_project.id must be set before snapshot insert"
|
||||
|
||||
if file_ids:
|
||||
await self._insert_snapshot_records(labeling_project.id, file_ids)
|
||||
|
||||
await self.db.commit()
|
||||
await self.db.refresh(labeling_project)
|
||||
|
||||
logger.debug("Mapping created with snapshot: %s", labeling_project.id)
|
||||
return await self._to_response(labeling_project)
|
||||
|
||||
async def _insert_snapshot_records(self, project_id: str, file_ids: List[str]) -> None:
|
||||
batch: List[dict] = []
|
||||
for file_id in file_ids:
|
||||
batch.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"project_id": project_id,
|
||||
"file_id": file_id,
|
||||
}
|
||||
)
|
||||
if len(batch) >= self.SNAPSHOT_INSERT_BATCH_SIZE:
|
||||
await self.db.execute(insert(LabelingProjectFile).values(batch))
|
||||
batch.clear()
|
||||
if batch:
|
||||
await self.db.execute(insert(LabelingProjectFile).values(batch))
|
||||
|
||||
async def get_mapping_by_source_uuid(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user