You've already forked DataMate
feat(annotation): 添加标注项目文件快照功能
- 新增 LabelingProjectFile 模型用于存储标注项目的文件快照 - 在创建标注项目时记录关联的文件快照数据 - 更新查询逻辑以基于项目快照过滤文件列表 - 优化导出统计功能使用快照数据进行计算 - 添加数据库表结构支持项目文件快照关系
This commit is contained in:
@@ -14,7 +14,8 @@ from .user_management import (
|
|||||||
from .annotation_management import (
|
from .annotation_management import (
|
||||||
AnnotationTemplate,
|
AnnotationTemplate,
|
||||||
LabelingProject,
|
LabelingProject,
|
||||||
AnnotationResult
|
AnnotationResult,
|
||||||
|
LabelingProjectFile
|
||||||
)
|
)
|
||||||
|
|
||||||
from .data_evaluation import (
|
from .data_evaluation import (
|
||||||
@@ -32,6 +33,7 @@ __all__ = [
|
|||||||
"AnnotationTemplate",
|
"AnnotationTemplate",
|
||||||
"LabelingProject",
|
"LabelingProject",
|
||||||
"AnnotationResult",
|
"AnnotationResult",
|
||||||
|
"LabelingProjectFile",
|
||||||
"EvaluationTask",
|
"EvaluationTask",
|
||||||
"EvaluationItem",
|
"EvaluationItem",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Tables of Annotation Management Module"""
|
"""Tables of Annotation Management Module"""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from sqlalchemy import Column, String, Boolean, TIMESTAMP, Text, Integer, JSON, ForeignKey
|
from sqlalchemy import Column, String, Boolean, TIMESTAMP, Text, Integer, JSON, ForeignKey, UniqueConstraint, Index
|
||||||
from sqlalchemy.sql import func
|
from sqlalchemy.sql import func
|
||||||
|
|
||||||
from app.db.session import Base
|
from app.db.session import Base
|
||||||
@@ -34,7 +34,7 @@ class AnnotationTemplate(Base):
|
|||||||
"""检查是否已被软删除"""
|
"""检查是否已被软删除"""
|
||||||
return self.deleted_at is not None
|
return self.deleted_at is not None
|
||||||
|
|
||||||
class LabelingProject(Base):
|
class LabelingProject(Base):
|
||||||
"""标注项目模型"""
|
"""标注项目模型"""
|
||||||
|
|
||||||
__tablename__ = "t_dm_labeling_projects"
|
__tablename__ = "t_dm_labeling_projects"
|
||||||
@@ -50,13 +50,33 @@ class LabelingProject(Base):
|
|||||||
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
|
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
|
||||||
deleted_at = Column(TIMESTAMP, nullable=True, comment="删除时间(软删除)")
|
deleted_at = Column(TIMESTAMP, nullable=True, comment="删除时间(软删除)")
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<LabelingProject(id={self.id}, name={self.name}, dataset_id={self.dataset_id})>"
|
return f"<LabelingProject(id={self.id}, name={self.name}, dataset_id={self.dataset_id})>"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_deleted(self) -> bool:
|
def is_deleted(self) -> bool:
|
||||||
"""检查是否已被软删除"""
|
"""检查是否已被软删除"""
|
||||||
return self.deleted_at is not None
|
return self.deleted_at is not None
|
||||||
|
|
||||||
|
|
||||||
|
class LabelingProjectFile(Base):
|
||||||
|
"""标注项目文件快照模型"""
|
||||||
|
|
||||||
|
__tablename__ = "t_dm_labeling_project_files"
|
||||||
|
|
||||||
|
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID")
|
||||||
|
project_id = Column(String(36), nullable=False, comment="标注项目ID")
|
||||||
|
file_id = Column(String(36), nullable=False, comment="文件ID")
|
||||||
|
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint("project_id", "file_id", name="uk_project_file"),
|
||||||
|
Index("idx_project_id", "project_id"),
|
||||||
|
Index("idx_file_id", "file_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<LabelingProjectFile(id={self.id}, project_id={self.project_id}, file_id={self.file_id})>"
|
||||||
|
|
||||||
|
|
||||||
class AnnotationResult(Base):
|
class AnnotationResult(Base):
|
||||||
|
|||||||
@@ -3,14 +3,16 @@ import math
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, Path
|
from fastapi import APIRouter, Depends, HTTPException, Query, Path
|
||||||
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.db.session import get_db
|
from app.db.session import get_db
|
||||||
from app.db.models import LabelingProject
|
from app.db.models import LabelingProject, DatasetFiles
|
||||||
from app.module.shared.schema import StandardResponse, PaginatedData
|
from app.module.shared.schema import StandardResponse, PaginatedData
|
||||||
from app.module.dataset import DatasetManagementService
|
from app.module.dataset import DatasetManagementService
|
||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
|
|
||||||
|
from app.module.annotation.service.editor import AnnotationEditorService
|
||||||
from ..service.mapping import DatasetMappingService
|
from ..service.mapping import DatasetMappingService
|
||||||
from ..service.template import AnnotationTemplateService
|
from ..service.template import AnnotationTemplateService
|
||||||
from ..schema import (
|
from ..schema import (
|
||||||
@@ -116,8 +118,20 @@ async def create_mapping(
|
|||||||
configuration=project_configuration or None,
|
configuration=project_configuration or None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建映射关系,包含项目名称(先持久化映射以获得 mapping.id)
|
file_query = select(DatasetFiles.id).where(
|
||||||
mapping = await mapping_service.create_mapping(labeling_project)
|
DatasetFiles.dataset_id == request.dataset_id
|
||||||
|
)
|
||||||
|
if dataset_type == TEXT_DATASET_TYPE:
|
||||||
|
file_query = file_query.where(
|
||||||
|
~AnnotationEditorService._build_source_document_filter()
|
||||||
|
)
|
||||||
|
file_result = await db.execute(file_query)
|
||||||
|
snapshot_file_ids = [str(fid) for fid in file_result.scalars().all()]
|
||||||
|
|
||||||
|
# 创建映射关系并写入快照
|
||||||
|
mapping = await mapping_service.create_mapping_with_snapshot(
|
||||||
|
labeling_project, snapshot_file_ids
|
||||||
|
)
|
||||||
|
|
||||||
response_data = DatasetMappingCreateResponse(
|
response_data = DatasetMappingCreateResponse(
|
||||||
id=mapping.id,
|
id=mapping.id,
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
from app.db.models import AnnotationResult, Dataset, DatasetFiles, LabelingProject
|
from app.db.models import AnnotationResult, Dataset, DatasetFiles, LabelingProject, LabelingProjectFile
|
||||||
from app.module.annotation.config import LabelStudioTagConfig
|
from app.module.annotation.config import LabelStudioTagConfig
|
||||||
from app.module.annotation.schema.editor import (
|
from app.module.annotation.schema.editor import (
|
||||||
EditorProjectInfo,
|
EditorProjectInfo,
|
||||||
@@ -436,14 +436,18 @@ class AnnotationEditorService:
|
|||||||
exclude_source_documents if exclude_source_documents is not None else True
|
exclude_source_documents if exclude_source_documents is not None else True
|
||||||
)
|
)
|
||||||
|
|
||||||
base_conditions = [DatasetFiles.dataset_id == project.dataset_id]
|
base_conditions = [
|
||||||
|
LabelingProjectFile.project_id == project_id,
|
||||||
|
DatasetFiles.dataset_id == project.dataset_id,
|
||||||
|
]
|
||||||
if should_exclude_source_documents:
|
if should_exclude_source_documents:
|
||||||
base_conditions.append(~self._build_source_document_filter())
|
base_conditions.append(~self._build_source_document_filter())
|
||||||
|
|
||||||
count_result = await self.db.execute(
|
count_result = await self.db.execute(
|
||||||
select(func.count()).select_from(DatasetFiles).where(
|
select(func.count())
|
||||||
*base_conditions
|
.select_from(LabelingProjectFile)
|
||||||
)
|
.join(DatasetFiles, LabelingProjectFile.file_id == DatasetFiles.id)
|
||||||
|
.where(*base_conditions)
|
||||||
)
|
)
|
||||||
total = int(count_result.scalar() or 0)
|
total = int(count_result.scalar() or 0)
|
||||||
|
|
||||||
@@ -453,6 +457,7 @@ class AnnotationEditorService:
|
|||||||
)
|
)
|
||||||
files_result = await self.db.execute(
|
files_result = await self.db.execute(
|
||||||
select(DatasetFiles, AnnotationResult.id, AnnotationResult.updated_at)
|
select(DatasetFiles, AnnotationResult.id, AnnotationResult.updated_at)
|
||||||
|
.join(LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id)
|
||||||
.outerjoin(
|
.outerjoin(
|
||||||
AnnotationResult,
|
AnnotationResult,
|
||||||
(AnnotationResult.file_id == DatasetFiles.id)
|
(AnnotationResult.file_id == DatasetFiles.id)
|
||||||
@@ -827,7 +832,10 @@ class AnnotationEditorService:
|
|||||||
|
|
||||||
# 校验文件归属
|
# 校验文件归属
|
||||||
file_result = await self.db.execute(
|
file_result = await self.db.execute(
|
||||||
select(DatasetFiles).where(
|
select(DatasetFiles)
|
||||||
|
.join(LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id)
|
||||||
|
.where(
|
||||||
|
LabelingProjectFile.project_id == project.id,
|
||||||
DatasetFiles.id == file_id,
|
DatasetFiles.id == file_id,
|
||||||
DatasetFiles.dataset_id == project.dataset_id,
|
DatasetFiles.dataset_id == project.dataset_id,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from sqlalchemy import func, select
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
from app.db.models import AnnotationResult, Dataset, DatasetFiles, LabelingProject
|
from app.db.models import AnnotationResult, Dataset, DatasetFiles, LabelingProject, LabelingProjectFile
|
||||||
|
|
||||||
|
|
||||||
async def _read_file_content(file_path: str, max_size: int = 10 * 1024 * 1024) -> Optional[str]:
|
async def _read_file_content(file_path: str, max_size: int = 10 * 1024 * 1024) -> Optional[str]:
|
||||||
@@ -75,15 +75,18 @@ class AnnotationExportService:
|
|||||||
project = await self._get_project_or_404(project_id)
|
project = await self._get_project_or_404(project_id)
|
||||||
logger.info(f"Export stats for project: id={project_id}, dataset_id={project.dataset_id}, name={project.name}")
|
logger.info(f"Export stats for project: id={project_id}, dataset_id={project.dataset_id}, name={project.name}")
|
||||||
|
|
||||||
# 获取总文件数(只统计 ACTIVE 状态的文件)
|
# 获取总文件数(标注项目快照内的文件)
|
||||||
total_result = await self.db.execute(
|
total_result = await self.db.execute(
|
||||||
select(func.count()).select_from(DatasetFiles).where(
|
select(func.count())
|
||||||
|
.select_from(LabelingProjectFile)
|
||||||
|
.join(DatasetFiles, LabelingProjectFile.file_id == DatasetFiles.id)
|
||||||
|
.where(
|
||||||
|
LabelingProjectFile.project_id == project_id,
|
||||||
DatasetFiles.dataset_id == project.dataset_id,
|
DatasetFiles.dataset_id == project.dataset_id,
|
||||||
DatasetFiles.status == "ACTIVE",
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
total_files = int(total_result.scalar() or 0)
|
total_files = int(total_result.scalar() or 0)
|
||||||
logger.info(f"Total files (ACTIVE): {total_files} for dataset_id={project.dataset_id}")
|
logger.info(f"Total files (snapshot): {total_files} for project_id={project_id}")
|
||||||
|
|
||||||
# 获取已标注文件数(统计不同的 file_id 数量)
|
# 获取已标注文件数(统计不同的 file_id 数量)
|
||||||
annotated_result = await self.db.execute(
|
annotated_result = await self.db.execute(
|
||||||
@@ -165,8 +168,13 @@ class AnnotationExportService:
|
|||||||
# 只获取已标注的数据
|
# 只获取已标注的数据
|
||||||
result = await self.db.execute(
|
result = await self.db.execute(
|
||||||
select(AnnotationResult, DatasetFiles)
|
select(AnnotationResult, DatasetFiles)
|
||||||
|
.join(LabelingProjectFile, LabelingProjectFile.file_id == AnnotationResult.file_id)
|
||||||
.join(DatasetFiles, AnnotationResult.file_id == DatasetFiles.id)
|
.join(DatasetFiles, AnnotationResult.file_id == DatasetFiles.id)
|
||||||
.where(AnnotationResult.project_id == project_id)
|
.where(
|
||||||
|
AnnotationResult.project_id == project_id,
|
||||||
|
LabelingProjectFile.project_id == project_id,
|
||||||
|
DatasetFiles.dataset_id == dataset_id,
|
||||||
|
)
|
||||||
.order_by(AnnotationResult.updated_at.desc())
|
.order_by(AnnotationResult.updated_at.desc())
|
||||||
)
|
)
|
||||||
rows = result.all()
|
rows = result.all()
|
||||||
@@ -190,11 +198,13 @@ class AnnotationExportService:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# 获取所有文件,包括未标注的(只获取 ACTIVE 状态的文件)
|
# 获取所有文件(基于标注项目快照)
|
||||||
files_result = await self.db.execute(
|
files_result = await self.db.execute(
|
||||||
select(DatasetFiles).where(
|
select(DatasetFiles)
|
||||||
|
.join(LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id)
|
||||||
|
.where(
|
||||||
|
LabelingProjectFile.project_id == project_id,
|
||||||
DatasetFiles.dataset_id == dataset_id,
|
DatasetFiles.dataset_id == dataset_id,
|
||||||
DatasetFiles.status == "ACTIVE",
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
files = files_result.scalars().all()
|
files = files_result.scalars().all()
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
from sqlalchemy import update, func
|
from sqlalchemy import update, func, insert
|
||||||
from sqlalchemy.orm import aliased
|
from sqlalchemy.orm import aliased
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
from app.db.models import LabelingProject, AnnotationTemplate, AnnotationResult
|
from app.db.models import LabelingProject, AnnotationTemplate, AnnotationResult, LabelingProjectFile
|
||||||
from app.db.models.dataset_management import Dataset, DatasetFiles
|
from app.db.models.dataset_management import Dataset, DatasetFiles
|
||||||
from app.module.annotation.schema import (
|
from app.module.annotation.schema import (
|
||||||
DatasetMappingCreateRequest,
|
DatasetMappingCreateRequest,
|
||||||
@@ -20,9 +20,11 @@ logger = get_logger(__name__)
|
|||||||
|
|
||||||
class DatasetMappingService:
|
class DatasetMappingService:
|
||||||
"""数据集映射服务"""
|
"""数据集映射服务"""
|
||||||
|
|
||||||
def __init__(self, db: AsyncSession):
|
def __init__(self, db: AsyncSession):
|
||||||
self.db = db
|
self.db = db
|
||||||
|
|
||||||
|
SNAPSHOT_INSERT_BATCH_SIZE = 500
|
||||||
|
|
||||||
def _build_query_with_dataset_name(self):
|
def _build_query_with_dataset_name(self):
|
||||||
"""Build base query with dataset name joined"""
|
"""Build base query with dataset name joined"""
|
||||||
@@ -49,11 +51,14 @@ class DatasetMappingService:
|
|||||||
Returns:
|
Returns:
|
||||||
(total_count, annotated_count) 元组
|
(total_count, annotated_count) 元组
|
||||||
"""
|
"""
|
||||||
# 获取数据集总数据量(统计 ACTIVE 和 COMPLETED 状态的文件)
|
# 获取标注项目快照数据量(只统计快照内的文件)
|
||||||
total_result = await self.db.execute(
|
total_result = await self.db.execute(
|
||||||
select(func.count()).select_from(DatasetFiles).where(
|
select(func.count())
|
||||||
|
.select_from(LabelingProjectFile)
|
||||||
|
.join(DatasetFiles, LabelingProjectFile.file_id == DatasetFiles.id)
|
||||||
|
.where(
|
||||||
|
LabelingProjectFile.project_id == project_id,
|
||||||
DatasetFiles.dataset_id == dataset_id,
|
DatasetFiles.dataset_id == dataset_id,
|
||||||
DatasetFiles.status.in_(["ACTIVE", "COMPLETED"]),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
total_count = int(total_result.scalar() or 0)
|
total_count = int(total_result.scalar() or 0)
|
||||||
@@ -213,6 +218,48 @@ class DatasetMappingService:
|
|||||||
|
|
||||||
logger.debug(f"Mapping created: {labeling_project.id}")
|
logger.debug(f"Mapping created: {labeling_project.id}")
|
||||||
return await self._to_response(labeling_project)
|
return await self._to_response(labeling_project)
|
||||||
|
|
||||||
|
async def create_mapping_with_snapshot(
|
||||||
|
self,
|
||||||
|
labeling_project: LabelingProject,
|
||||||
|
file_ids: List[str],
|
||||||
|
) -> DatasetMappingResponse:
|
||||||
|
"""创建数据集映射并写入快照文件"""
|
||||||
|
logger.debug(
|
||||||
|
"Create dataset mapping with snapshot: %s -> %s, files=%d",
|
||||||
|
labeling_project.dataset_id,
|
||||||
|
labeling_project.labeling_project_id,
|
||||||
|
len(file_ids),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.db.add(labeling_project)
|
||||||
|
await self.db.flush()
|
||||||
|
assert labeling_project.id, "labeling_project.id must be set before snapshot insert"
|
||||||
|
|
||||||
|
if file_ids:
|
||||||
|
await self._insert_snapshot_records(labeling_project.id, file_ids)
|
||||||
|
|
||||||
|
await self.db.commit()
|
||||||
|
await self.db.refresh(labeling_project)
|
||||||
|
|
||||||
|
logger.debug("Mapping created with snapshot: %s", labeling_project.id)
|
||||||
|
return await self._to_response(labeling_project)
|
||||||
|
|
||||||
|
async def _insert_snapshot_records(self, project_id: str, file_ids: List[str]) -> None:
|
||||||
|
batch: List[dict] = []
|
||||||
|
for file_id in file_ids:
|
||||||
|
batch.append(
|
||||||
|
{
|
||||||
|
"id": str(uuid.uuid4()),
|
||||||
|
"project_id": project_id,
|
||||||
|
"file_id": file_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if len(batch) >= self.SNAPSHOT_INSERT_BATCH_SIZE:
|
||||||
|
await self.db.execute(insert(LabelingProjectFile).values(batch))
|
||||||
|
batch.clear()
|
||||||
|
if batch:
|
||||||
|
await self.db.execute(insert(LabelingProjectFile).values(batch))
|
||||||
|
|
||||||
async def get_mapping_by_source_uuid(
|
async def get_mapping_by_source_uuid(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -48,6 +48,17 @@ CREATE TABLE IF NOT EXISTS t_dm_labeling_projects (
|
|||||||
INDEX idx_labeling_project_id (labeling_project_id)
|
INDEX idx_labeling_project_id (labeling_project_id)
|
||||||
) COMMENT='标注项目表';
|
) COMMENT='标注项目表';
|
||||||
|
|
||||||
|
-- 标注项目文件快照表
|
||||||
|
CREATE TABLE IF NOT EXISTS t_dm_labeling_project_files (
|
||||||
|
id VARCHAR(36) PRIMARY KEY COMMENT 'UUID',
|
||||||
|
project_id VARCHAR(36) NOT NULL COMMENT '标注项目ID',
|
||||||
|
file_id VARCHAR(36) NOT NULL COMMENT '文件ID',
|
||||||
|
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
|
||||||
|
UNIQUE KEY uk_project_file (project_id, file_id),
|
||||||
|
INDEX idx_project_id (project_id),
|
||||||
|
INDEX idx_file_id (file_id)
|
||||||
|
) COMMENT='标注项目文件快照表';
|
||||||
|
|
||||||
-- 标注结果表
|
-- 标注结果表
|
||||||
CREATE TABLE IF NOT EXISTS t_dm_annotation_results (
|
CREATE TABLE IF NOT EXISTS t_dm_annotation_results (
|
||||||
id VARCHAR(36) PRIMARY KEY COMMENT 'UUID',
|
id VARCHAR(36) PRIMARY KEY COMMENT 'UUID',
|
||||||
|
|||||||
Reference in New Issue
Block a user