feature: 数据配比增加通过更新时间来配置 (#95)

* feature: 数据配比增加通过更新时间来配置

* fix: 修复配比时间参数传递的问题
This commit is contained in:
hefanli
2025-11-20 18:50:51 +08:00
committed by GitHub
parent 955ffff6cd
commit cddfe9b149
10 changed files with 458 additions and 595 deletions

View File

@@ -1,11 +1,13 @@
from .common import (
BaseResponseModel,
StandardResponse,
PaginatedData
PaginatedData,
TaskStatus
)
__all__ = [
"BaseResponseModel",
"StandardResponse",
"PaginatedData"
]
"PaginatedData",
"TaskStatus"
]

View File

@@ -1,8 +1,9 @@
"""
通用响应模型
"""
from typing import Generic, TypeVar, Optional, List, Type
from typing import Generic, TypeVar, List
from pydantic import BaseModel, Field
from enum import Enum
# 定义泛型类型变量
T = TypeVar('T')
@@ -16,7 +17,7 @@ def to_camel(string: str) -> str:
class BaseResponseModel(BaseModel):
"""基础响应模型,启用别名生成器"""
class Config:
populate_by_name = True
alias_generator = to_camel
@@ -24,7 +25,7 @@ class BaseResponseModel(BaseModel):
class StandardResponse(BaseResponseModel, Generic[T]):
"""
标准API响应格式
所有API端点应返回此格式,确保响应的一致性
"""
code: int = Field(..., description="HTTP状态码")
@@ -42,3 +43,9 @@ class PaginatedData(BaseResponseModel, Generic[T]):
total_elements: int = Field(..., description="总条数")
total_pages: int = Field(..., description="总页数")
content: List[T] = Field(..., description="当前页数据")
class TaskStatus(Enum):
PENDING = "PENDING"
RUNNING = "RUNNING"
COMPLETED = "COMPLETED"
FAILED = "FAILED"

View File

@@ -12,7 +12,7 @@ from app.core.logging import get_logger
from app.db.models import Dataset
from app.db.session import get_db
from app.module.dataset import DatasetManagementService
from app.module.shared.schema import StandardResponse
from app.module.shared.schema import StandardResponse, TaskStatus
from app.module.synthesis.schema.ratio_task import (
CreateRatioTaskResponse,
CreateRatioTaskRequest,
@@ -49,52 +49,18 @@ async def create_ratio_task(
await valid_exists(db, req)
# 创建目标数据集:名称使用“<任务名称>-配比生成-时间戳”
target_dataset_name = f"{req.name}-配比生成-{datetime.now().strftime('%Y%m%d%H%M%S')}"
target_dataset = await create_target_dataset(db, req, source_types)
target_type = get_target_dataset_type(source_types)
instance = await create_ratio_instance(db, req, target_dataset)
target_dataset = Dataset(
id=str(uuid.uuid4()),
name=target_dataset_name,
description=req.description or "",
dataset_type=target_type,
status="DRAFT",
)
target_dataset.path = f"/dataset/{target_dataset.id}"
db.add(target_dataset)
await db.flush() # 获取 target_dataset.id
service = RatioTaskService(db)
instance = await service.create_task(
name=req.name,
description=req.description,
totals=int(req.totals),
ratio_method=req.ratio_method,
config=[
{
"dataset_id": item.dataset_id,
"counts": int(item.counts),
"filter_conditions": item.filter_conditions,
}
for item in req.config
],
target_dataset_id=target_dataset.id,
)
# 异步执行配比任务(支持 DATASET / TAG)
asyncio.create_task(RatioTaskService.execute_dataset_ratio_task(instance.id))
return StandardResponse(
code=200,
message="success",
data=CreateRatioTaskResponse(
response_data = CreateRatioTaskResponse(
id=instance.id,
name=instance.name,
description=instance.description,
totals=instance.totals or 0,
ratio_method=instance.ratio_method or req.ratio_method,
status=instance.status or "PENDING",
status=instance.status or TaskStatus.PENDING.name,
config=req.config,
targetDataset=TargetDatasetInfo(
id=str(target_dataset.id),
@@ -103,6 +69,10 @@ async def create_ratio_task(
status=str(target_dataset.status),
)
)
return StandardResponse(
code=200,
message="success",
data=response_data
)
except HTTPException:
await db.rollback()
@@ -113,6 +83,46 @@ async def create_ratio_task(
raise HTTPException(status_code=500, detail="Internal server error")
async def create_ratio_instance(db, req: CreateRatioTaskRequest, target_dataset: Dataset) -> RatioInstance:
service = RatioTaskService(db)
logger.info(f"create_ratio_instance: {req}")
instance = await service.create_task(
name=req.name,
description=req.description,
totals=int(req.totals),
config=[
{
"dataset_id": item.dataset_id,
"counts": int(item.counts),
"filter_conditions": item.filter_conditions,
}
for item in req.config
],
target_dataset_id=target_dataset.id,
)
return instance
async def create_target_dataset(db, req: CreateRatioTaskRequest, source_types: set[str]) -> Dataset:
# 创建目标数据集:名称使用“<任务名称>-时间戳”
target_dataset_name = f"{req.name}-{datetime.now().strftime('%Y%m%d%H%M%S')}"
target_type = get_target_dataset_type(source_types)
target_dataset_id = uuid.uuid4()
target_dataset = Dataset(
id=str(target_dataset_id),
name=target_dataset_name,
description=req.description or "",
dataset_type=target_type,
status="DRAFT",
path=f"/dataset/{target_dataset_id}",
)
db.add(target_dataset)
await db.flush() # 获取 target_dataset.id
return target_dataset
@router.get("", response_model=StandardResponse[PagedRatioTaskResponse], status_code=200)
async def list_ratio_tasks(
page: int = 1,

View File

@@ -2,10 +2,36 @@ from typing import List, Optional, Dict, Any
from datetime import datetime
from pydantic import BaseModel, Field, field_validator
from app.core.logging import get_logger
from app.module.shared.schema.common import TaskStatus
logger = get_logger(__name__)
class FilterCondition(BaseModel):
date_range: Optional[str] = Field(None, description="数据范围", alias="dateRange")
label: Optional[str] = Field(None, description="标签")
@field_validator("date_range")
@classmethod
def validate_date_range(cls, v: Optional[str]) -> Optional[str]:
# ensure it's a numeric string if provided
if not v:
return v
try:
int(v)
return v
except (ValueError, TypeError) as e:
raise ValueError("date_range must be a numeric string")
class Config:
# allow population by field name when constructing model programmatically
validate_by_name = True
class RatioConfigItem(BaseModel):
dataset_id: str = Field(..., alias="datasetId", description="数据集id")
counts: str = Field(..., description="数量")
filter_conditions: str = Field(..., description="过滤条件")
filter_conditions: FilterCondition = Field(..., alias="filterConditions", description="过滤条件")
@field_validator("counts")
@classmethod
@@ -22,17 +48,8 @@ class CreateRatioTaskRequest(BaseModel):
name: str = Field(..., description="名称")
description: Optional[str] = Field(None, description="描述")
totals: str = Field(..., description="目标数量")
ratio_method: str = Field(..., description="配比方式", alias="ratio_method")
config: List[RatioConfigItem] = Field(..., description="配比设置列表")
@field_validator("ratio_method")
@classmethod
def validate_ratio_method(cls, v: str) -> str:
allowed = {"TAG", "DATASET"}
if v not in allowed:
raise ValueError(f"ratio_method must be one of {allowed}")
return v
@field_validator("totals")
@classmethod
def validate_totals(cls, v: str) -> str:
@@ -58,8 +75,7 @@ class CreateRatioTaskResponse(BaseModel):
name: str
description: Optional[str] = None
totals: int
ratio_method: str
status: str
status: TaskStatus
# echoed config
config: List[RatioConfigItem]
# created dataset

View File

@@ -14,6 +14,8 @@ from app.db.models.ratio_task import RatioInstance, RatioRelation
from app.db.models import Dataset, DatasetFiles
from app.db.session import AsyncSessionLocal
from app.module.dataset.schema.dataset_file import DatasetFileTag
from app.module.shared.schema import TaskStatus
from app.module.synthesis.schema.ratio_task import FilterCondition
logger = get_logger(__name__)
@@ -30,7 +32,6 @@ class RatioTaskService:
name: str,
description: Optional[str],
totals: int,
ratio_method: str,
config: List[Dict[str, Any]],
target_dataset_id: Optional[str] = None,
) -> RatioInstance:
@@ -38,12 +39,11 @@ class RatioTaskService:
config item format: {"dataset_id": str, "counts": int, "filter_conditions": str}
"""
logger.info(f"Creating ratio task: name={name}, method={ratio_method}, totals={totals}, items={len(config or [])}")
logger.info(f"Creating ratio task: name={name}, totals={totals}, items={len(config or [])}")
instance = RatioInstance(
name=name,
description=description,
ratio_method=ratio_method,
totals=totals,
target_dataset_id=target_dataset_id,
status="PENDING",
@@ -56,8 +56,12 @@ class RatioTaskService:
ratio_instance_id=instance.id,
source_dataset_id=item.get("dataset_id"),
counts=int(item.get("counts", 0)),
filter_conditions=item.get("filter_conditions"),
filter_conditions=json.dumps({
'date_range': item.get("filter_conditions").date_range,
'label': item.get("filter_conditions").label,
})
)
logger.info(f"Relation created: {relation.id}, {relation}, {item}, {config}")
self.db.add(relation)
await self.db.commit()
@@ -97,94 +101,17 @@ class RatioTaskService:
relations: List[RatioRelation] = list(rel_res.scalars().all())
# Mark running
instance.status = "RUNNING"
if instance.ratio_method not in {"DATASET", "TAG"}:
logger.info(f"Instance {instance_id} ratio_method={instance.ratio_method} not supported yet")
instance.status = "SUCCESS"
return
instance.status = TaskStatus.RUNNING.name
# Load target dataset
ds_res = await session.execute(select(Dataset).where(Dataset.id == instance.target_dataset_id))
target_ds: Optional[Dataset] = ds_res.scalar_one_or_none()
if not target_ds:
logger.error(f"Target dataset not found for instance {instance_id}")
instance.status = "FAILED"
instance.status = TaskStatus.FAILED.name
return
# Preload existing target file paths for deduplication
existing_path_rows = await session.execute(
select(DatasetFiles.file_path).where(DatasetFiles.dataset_id == target_ds.id)
)
existing_paths = set(p for p in existing_path_rows.scalars().all() if p)
added_count = 0
added_size = 0
for rel in relations:
if not rel.source_dataset_id or not rel.counts or rel.counts <= 0:
continue
# Fetch all files for the source dataset (ACTIVE only)
files_res = await session.execute(
select(DatasetFiles).where(
DatasetFiles.dataset_id == rel.source_dataset_id,
DatasetFiles.status == "ACTIVE",
)
)
files = list(files_res.scalars().all())
# TAG mode: filter by tags according to relation.filter_conditions
if instance.ratio_method == "TAG":
required_tags = RatioTaskService._parse_required_tags(rel.filter_conditions)
if required_tags:
files = [f for f in files if RatioTaskService._file_contains_tags(f, required_tags)]
if not files:
continue
pick_n = min(rel.counts or 0, len(files))
chosen = random.sample(files, pick_n) if pick_n < len(files) else files
# Copy into target dataset with de-dup by target path
for f in chosen:
src_path = f.file_path
new_path = src_path
needs_copy = False
src_prefix = f"/dataset/{rel.source_dataset_id}"
if isinstance(src_path, str) and src_path.startswith(src_prefix):
dst_prefix = f"/dataset/{target_ds.id}"
new_path = src_path.replace(src_prefix, dst_prefix, 1)
needs_copy = True
# De-dup by target path
if new_path in existing_paths:
continue
# Perform copy only when needed
if needs_copy:
dst_dir = os.path.dirname(new_path)
await asyncio.to_thread(os.makedirs, dst_dir, exist_ok=True)
await asyncio.to_thread(shutil.copy2, src_path, new_path)
new_file = DatasetFiles(
dataset_id=target_ds.id, # type: ignore
file_name=f.file_name,
file_path=new_path,
file_type=f.file_type,
file_size=f.file_size,
check_sum=f.check_sum,
tags=f.tags,
dataset_filemetadata=f.dataset_filemetadata,
status="ACTIVE",
)
session.add(new_file)
existing_paths.add(new_path)
added_count += 1
added_size += int(f.file_size or 0)
# Periodically flush to avoid huge transactions
await session.flush()
added_count, added_size = await RatioTaskService.handle_ratio_relations(relations,session, target_ds)
# Update target dataset statistics
target_ds.file_count = (target_ds.file_count or 0) + added_count # type: ignore
@@ -194,8 +121,8 @@ class RatioTaskService:
target_ds.status = "ACTIVE"
# Done
instance.status = "SUCCESS"
logger.info(f"Dataset ratio execution completed: instance={instance_id}, files={added_count}, size={added_size}")
instance.status = TaskStatus.COMPLETED.name
logger.info(f"Dataset ratio execution completed: instance={instance_id}, files={added_count}, size={added_size}, {instance.status}")
except Exception as e:
logger.exception(f"Dataset ratio execution failed for {instance_id}: {e}")
@@ -204,42 +131,163 @@ class RatioTaskService:
inst_res = await session.execute(select(RatioInstance).where(RatioInstance.id == instance_id))
instance = inst_res.scalar_one_or_none()
if instance:
instance.status = "FAILED"
instance.status = TaskStatus.FAILED.name
finally:
pass
finally:
await session.commit()
@staticmethod
async def handle_ratio_relations(relations: list[RatioRelation], session, target_ds: Dataset) -> tuple[int, int]:
# Preload existing target file paths for deduplication
existing_path_rows = await session.execute(
select(DatasetFiles.file_path).where(DatasetFiles.dataset_id == target_ds.id)
)
existing_paths = set(p for p in existing_path_rows.scalars().all() if p)
added_count = 0
added_size = 0
for rel in relations:
if not rel.source_dataset_id or not rel.counts or rel.counts <= 0:
continue
files = await RatioTaskService.get_files(rel, session)
if not files:
continue
pick_n = min(rel.counts or 0, len(files))
chosen = random.sample(files, pick_n) if pick_n < len(files) else files
# Copy into target dataset with de-dup by target path
for f in chosen:
await RatioTaskService.handle_selected_file(existing_paths, f, session, target_ds)
added_count += 1
added_size += int(f.file_size or 0)
# Periodically flush to avoid huge transactions
await session.flush()
return added_count, added_size
@staticmethod
async def handle_selected_file(existing_paths: set[Any], f, session, target_ds: Dataset):
src_path = f.file_path
dst_prefix = f"/dataset/{target_ds.id}"
file_name = RatioTaskService.get_new_file_name(dst_prefix, existing_paths, f)
new_path = dst_prefix + file_name
dst_dir = os.path.dirname(new_path)
await asyncio.to_thread(os.makedirs, dst_dir, exist_ok=True)
await asyncio.to_thread(shutil.copy2, src_path, new_path)
new_file = DatasetFiles(
dataset_id=target_ds.id, # type: ignore
file_name=file_name,
file_path=new_path,
file_type=f.file_type,
file_size=f.file_size,
check_sum=f.check_sum,
tags=f.tags,
dataset_filemetadata=f.dataset_filemetadata,
status="ACTIVE",
)
session.add(new_file)
existing_paths.add(new_path)
@staticmethod
def get_new_file_name(dst_prefix: str, existing_paths: set[Any], f) -> str:
file_name = f.file_name
new_path = dst_prefix + file_name
# Handle file path conflicts by appending a number to the filename
if new_path in existing_paths:
file_name_base, file_ext = os.path.splitext(file_name)
counter = 1
original_file_name = file_name
while new_path in existing_paths:
file_name = f"{file_name_base}_{counter}{file_ext}"
new_path = f"{dst_prefix}{file_name}"
counter += 1
if counter > 1000: # Safety check to prevent infinite loops
logger.error(f"Could not find unique filename for {original_file_name} after 1000 attempts")
break
return file_name
@staticmethod
async def get_files(rel: RatioRelation, session) -> list[Any]:
# Fetch all files for the source dataset (ACTIVE only)
files_res = await session.execute(
select(DatasetFiles).where(
DatasetFiles.dataset_id == rel.source_dataset_id,
DatasetFiles.status == "ACTIVE",
)
)
files = list(files_res.scalars().all())
# TAG mode: filter by tags according to relation.filter_conditions
conditions = RatioTaskService._parse_conditions(rel.filter_conditions)
if conditions:
files = [f for f in files if RatioTaskService._filter_file(f, conditions)]
return files
# ------------------------- helpers for TAG filtering ------------------------- #
@staticmethod
def _parse_required_tags(conditions: Optional[str]) -> set[str]:
"""Parse filter_conditions into a set of required tag strings.
def _parse_conditions(conditions: Optional[str]) -> Optional[FilterCondition]:
"""Parse filter_conditions JSON string into a FilterCondition object.
Supports simple separators: comma, semicolon, space. Empty/None -> empty set.
Args:
conditions: JSON string containing filter conditions
Returns:
FilterCondition object if conditions is not None/empty, otherwise None
"""
if not conditions:
return set()
data = json.loads(conditions)
required_tags = set()
if data.get("label"):
required_tags.add(data["label"])
return required_tags
return None
try:
data = json.loads(conditions)
return FilterCondition(**data)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse filter conditions: {e}")
return None
except Exception as e:
logger.error(f"Error creating FilterCondition: {e}")
return None
@staticmethod
def _file_contains_tags(file: DatasetFiles, required: set[str]) -> bool:
if not required:
def _filter_file(file: DatasetFiles, conditions: FilterCondition) -> bool:
if not conditions:
return True
tags = file.tags
if not tags:
return False
try:
# tags could be a list of strings or list of objects with 'name'
tag_names = RatioTaskService.get_all_tags(tags)
return required.issubset(tag_names)
except Exception as e:
logger.exception(f"Failed to get tags for {file}", e)
return False
logger.info(f"start filter file: {file}, conditions: {conditions}")
# Check data range condition if provided
if conditions.date_range:
try:
from datetime import datetime, timedelta
data_range_days = int(conditions.date_range)
if data_range_days > 0:
cutoff_date = datetime.now() - timedelta(days=data_range_days)
if file.tags_updated_at and file.tags_updated_at < cutoff_date:
return False
except (ValueError, TypeError) as e:
logger.warning(f"Invalid data_range value: {conditions.date_range}", e)
return False
# Check label condition if provided
if conditions.label:
tags = file.tags
if not tags:
return False
try:
# tags could be a list of strings or list of objects with 'name'
tag_names = RatioTaskService.get_all_tags(tags)
return conditions.label in tag_names
except Exception as e:
logger.exception(f"Failed to get tags for {file}", e)
return False
return True
@staticmethod
def get_all_tags(tags) -> set[str]: