Files
DataMate/runtime/datamate-python/app/module/synthesis/service/ratio_task.py
hefanli 08bd4eca5c feature:增加数据配比功能 (#52)
* refactor: 修改调整数据归集实现,删除无用代码,优化代码结构

* feature: 每天凌晨00:00扫描所有数据集,检查数据集是否超过了预设的保留天数,超出保留天数的数据集调用删除接口进行删除

* fix: 修改删除数据集文件的逻辑,上传到数据集中的文件会同时删除数据库中的记录和文件系统中的文件,归集过来的文件仅删除数据库中的记录

* fix: 增加参数校验和接口定义,删除不使用的接口

* fix: 数据集统计数据默认为0

* feature: 数据集状态增加流转,创建时为草稿状态,上传文件或者归集文件后修改为活动状态

* refactor: 修改分页查询归集任务的代码

* fix: 更新后重新执行;归集任务执行增加事务控制

* feature: 创建归集任务时能够同步创建数据集,更新归集任务时能更新到指定数据集

* fix: 创建归集任务不需要创建数据集时不应该报错

* fix: 修复删除文件时数据集的统计数据不变动

* feature: 查询数据集详情时能够获取到文件标签分布

* fix: tags为空时不进行分析

* fix: 状态修改为ACTIVE

* fix: 修改解析tag的方法

* feature: 实现创建、分页查询、删除配比任务

* feature: 实现创建、分页查询、删除配比任务的前端交互

* fix: 修复进度计算异常导致的页面报错
2025-11-03 10:17:39 +08:00

283 lines
12 KiB
Python

from typing import List, Optional, Dict, Any
import random
import os
import shutil
import asyncio
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.models.ratio_task import RatioInstance, RatioRelation
from app.db.models import Dataset, DatasetFiles
from app.db.session import AsyncSessionLocal
logger = get_logger(__name__)
class RatioTaskService:
"""Service for Ratio Task DB operations."""
def __init__(self, db: AsyncSession):
self.db = db
async def create_task(
self,
*,
name: str,
description: Optional[str],
totals: int,
ratio_method: str,
config: List[Dict[str, Any]],
target_dataset_id: Optional[str] = None,
) -> RatioInstance:
"""Create a ratio task instance and its relations.
config item format: {"dataset_id": str, "counts": int, "filter_conditions": str}
"""
logger.info(f"Creating ratio task: name={name}, method={ratio_method}, totals={totals}, items={len(config or [])}")
instance = RatioInstance(
name=name,
description=description,
ratio_method=ratio_method,
totals=totals,
target_dataset_id=target_dataset_id,
status="PENDING",
)
self.db.add(instance)
await self.db.flush() # populate instance.id
for item in config or []:
relation = RatioRelation(
ratio_instance_id=instance.id,
source_dataset_id=item.get("dataset_id"),
counts=int(item.get("counts", 0)),
filter_conditions=item.get("filter_conditions"),
)
self.db.add(relation)
await self.db.commit()
await self.db.refresh(instance)
logger.info(f"Ratio task created: {instance.id}")
return instance
# ========================= Execution (Background) ========================= #
@staticmethod
async def execute_dataset_ratio_task(instance_id: str) -> None:
"""Execute a ratio task in background.
Supported ratio_method:
- DATASET: randomly select counts files from each source dataset
- TAG: randomly select counts files matching relation.filter_conditions tags
Steps:
- Mark instance RUNNING
- For each relation: fetch ACTIVE files, optionally filter by tags
- Copy selected files into target dataset
- Update dataset statistics and mark instance SUCCESS/FAILED
"""
async with AsyncSessionLocal() as session: # type: AsyncSession
try:
# Load instance and relations
inst_res = await session.execute(select(RatioInstance).where(RatioInstance.id == instance_id))
instance: Optional[RatioInstance] = inst_res.scalar_one_or_none()
if not instance:
logger.error(f"Ratio instance not found: {instance_id}")
return
logger.info(f"start execute ratio task: {instance_id}")
rel_res = await session.execute(
select(RatioRelation).where(RatioRelation.ratio_instance_id == instance_id)
)
relations: List[RatioRelation] = list(rel_res.scalars().all())
# Mark running
instance.status = "RUNNING"
if instance.ratio_method not in {"DATASET", "TAG"}:
logger.info(f"Instance {instance_id} ratio_method={instance.ratio_method} not supported yet")
instance.status = "SUCCESS"
return
# Load target dataset
ds_res = await session.execute(select(Dataset).where(Dataset.id == instance.target_dataset_id))
target_ds: Optional[Dataset] = ds_res.scalar_one_or_none()
if not target_ds:
logger.error(f"Target dataset not found for instance {instance_id}")
instance.status = "FAILED"
return
# Preload existing target file paths for deduplication
existing_path_rows = await session.execute(
select(DatasetFiles.file_path).where(DatasetFiles.dataset_id == target_ds.id)
)
existing_paths = set(p for p in existing_path_rows.scalars().all() if p)
added_count = 0
added_size = 0
for rel in relations:
if not rel.source_dataset_id or not rel.counts or rel.counts <= 0:
continue
# Fetch all files for the source dataset (ACTIVE only)
files_res = await session.execute(
select(DatasetFiles).where(
DatasetFiles.dataset_id == rel.source_dataset_id,
DatasetFiles.status == "ACTIVE",
)
)
files = list(files_res.scalars().all())
# TAG mode: filter by tags according to relation.filter_conditions
if instance.ratio_method == "TAG":
required_tags = RatioTaskService._parse_required_tags(rel.filter_conditions)
if required_tags:
files = [f for f in files if RatioTaskService._file_contains_tags(f, required_tags)]
if not files:
continue
pick_n = min(rel.counts or 0, len(files))
chosen = random.sample(files, pick_n) if pick_n < len(files) else files
# Copy into target dataset with de-dup by target path
for f in chosen:
src_path = f.file_path
new_path = src_path
needs_copy = False
src_prefix = f"/dataset/{rel.source_dataset_id}"
if isinstance(src_path, str) and src_path.startswith(src_prefix):
dst_prefix = f"/dataset/{target_ds.id}"
new_path = src_path.replace(src_prefix, dst_prefix, 1)
needs_copy = True
# De-dup by target path
if new_path in existing_paths:
continue
# Perform copy only when needed
if needs_copy:
dst_dir = os.path.dirname(new_path)
await asyncio.to_thread(os.makedirs, dst_dir, exist_ok=True)
await asyncio.to_thread(shutil.copy2, src_path, new_path)
new_file = DatasetFiles(
dataset_id=target_ds.id, # type: ignore
file_name=f.file_name,
file_path=new_path,
file_type=f.file_type,
file_size=f.file_size,
check_sum=f.check_sum,
tags=f.tags,
dataset_filemetadata=f.dataset_filemetadata,
status="ACTIVE",
)
session.add(new_file)
existing_paths.add(new_path)
added_count += 1
added_size += int(f.file_size or 0)
# Periodically flush to avoid huge transactions
await session.flush()
# Update target dataset statistics
target_ds.file_count = (target_ds.file_count or 0) + added_count # type: ignore
target_ds.size_bytes = (target_ds.size_bytes or 0) + added_size # type: ignore
# If target dataset has files, mark it ACTIVE
if (target_ds.file_count or 0) > 0: # type: ignore
target_ds.status = "ACTIVE"
# Done
instance.status = "SUCCESS"
logger.info(f"Dataset ratio execution completed: instance={instance_id}, files={added_count}, size={added_size}")
except Exception as e:
logger.exception(f"Dataset ratio execution failed for {instance_id}: {e}")
try:
# Try mark failed
inst_res = await session.execute(select(RatioInstance).where(RatioInstance.id == instance_id))
instance = inst_res.scalar_one_or_none()
if instance:
instance.status = "FAILED"
finally:
pass
finally:
await session.commit()
# ------------------------- helpers for TAG filtering ------------------------- #
@staticmethod
def _parse_required_tags(conditions: Optional[str]) -> set[str]:
"""Parse filter_conditions into a set of required tag strings.
Supports simple separators: comma, semicolon, space. Empty/None -> empty set.
"""
if not conditions:
return set()
raw = conditions.replace("\n", " ")
seps = [",", ";", " "]
tokens = [raw]
for sep in seps:
nxt = []
for t in tokens:
nxt.extend(t.split(sep))
tokens = nxt
return {t.strip() for t in tokens if t and t.strip()}
@staticmethod
def _file_contains_tags(f: DatasetFiles, required: set[str]) -> bool:
if not required:
return True
tags = f.tags
if not tags:
return False
try:
# tags could be a list of strings or list of objects with 'name'
tag_names = set()
if isinstance(tags, list):
for item in tags:
if isinstance(item, str):
tag_names.add(item)
elif isinstance(item, dict):
name = item.get("name") or item.get("label") or item.get("tag")
if isinstance(name, str):
tag_names.add(name)
elif isinstance(tags, dict):
# flat dict of name->... treat keys as tags
tag_names = set(map(str, tags.keys()))
else:
return False
logger.info(f">>>>>{tags}>>>>>{required}, {tag_names}")
return required.issubset(tag_names)
except Exception:
return False
@staticmethod
async def get_new_file(f, rel: RatioRelation, target_ds: Dataset) -> DatasetFiles:
new_path = f.file_path
src_prefix = f"/dataset/{rel.source_dataset_id}"
if isinstance(f.file_path, str) and f.file_path.startswith(src_prefix):
dst_prefix = f"/dataset/{target_ds.id}"
new_path = f.file_path.replace(src_prefix, dst_prefix, 1)
dst_dir = os.path.dirname(new_path)
# Ensure directory and copy the file in a thread to avoid blocking the event loop
await asyncio.to_thread(os.makedirs, dst_dir, exist_ok=True)
await asyncio.to_thread(shutil.copy2, f.file_path, new_path)
new_file = DatasetFiles(
dataset_id=target_ds.id, # type: ignore
file_name=f.file_name,
file_path=new_path,
file_type=f.file_type,
file_size=f.file_size,
check_sum=f.check_sum,
tags=f.tags,
dataset_filemetadata=f.dataset_filemetadata,
status="ACTIVE",
)
return new_file