You've already forked DataMate
* refactor: 修改调整数据归集实现,删除无用代码,优化代码结构 * feature: 每天凌晨00:00扫描所有数据集,检查数据集是否超过了预设的保留天数,超出保留天数的数据集调用删除接口进行删除 * fix: 修改删除数据集文件的逻辑,上传到数据集中的文件会同时删除数据库中的记录和文件系统中的文件,归集过来的文件仅删除数据库中的记录 * fix: 增加参数校验和接口定义,删除不使用的接口 * fix: 数据集统计数据默认为0 * feature: 数据集状态增加流转,创建时为草稿状态,上传文件或者归集文件后修改为活动状态 * refactor: 修改分页查询归集任务的代码 * fix: 更新后重新执行;归集任务执行增加事务控制 * feature: 创建归集任务时能够同步创建数据集,更新归集任务时能更新到指定数据集 * fix: 创建归集任务不需要创建数据集时不应该报错 * fix: 修复删除文件时数据集的统计数据不变动 * feature: 查询数据集详情时能够获取到文件标签分布 * fix: tags为空时不进行分析 * fix: 状态修改为ACTIVE * fix: 修改解析tag的方法 * feature: 实现创建、分页查询、删除配比任务 * feature: 实现创建、分页查询、删除配比任务的前端交互 * fix: 修复进度计算异常导致的页面报错
283 lines
12 KiB
Python
283 lines
12 KiB
Python
from typing import List, Optional, Dict, Any
|
|
import random
|
|
import os
|
|
import shutil
|
|
import asyncio
|
|
|
|
from sqlalchemy import select
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.logging import get_logger
|
|
from app.db.models.ratio_task import RatioInstance, RatioRelation
|
|
from app.db.models import Dataset, DatasetFiles
|
|
from app.db.session import AsyncSessionLocal
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class RatioTaskService:
|
|
"""Service for Ratio Task DB operations."""
|
|
|
|
def __init__(self, db: AsyncSession):
|
|
self.db = db
|
|
|
|
async def create_task(
|
|
self,
|
|
*,
|
|
name: str,
|
|
description: Optional[str],
|
|
totals: int,
|
|
ratio_method: str,
|
|
config: List[Dict[str, Any]],
|
|
target_dataset_id: Optional[str] = None,
|
|
) -> RatioInstance:
|
|
"""Create a ratio task instance and its relations.
|
|
|
|
config item format: {"dataset_id": str, "counts": int, "filter_conditions": str}
|
|
"""
|
|
logger.info(f"Creating ratio task: name={name}, method={ratio_method}, totals={totals}, items={len(config or [])}")
|
|
|
|
instance = RatioInstance(
|
|
name=name,
|
|
description=description,
|
|
ratio_method=ratio_method,
|
|
totals=totals,
|
|
target_dataset_id=target_dataset_id,
|
|
status="PENDING",
|
|
)
|
|
self.db.add(instance)
|
|
await self.db.flush() # populate instance.id
|
|
|
|
for item in config or []:
|
|
relation = RatioRelation(
|
|
ratio_instance_id=instance.id,
|
|
source_dataset_id=item.get("dataset_id"),
|
|
counts=int(item.get("counts", 0)),
|
|
filter_conditions=item.get("filter_conditions"),
|
|
)
|
|
self.db.add(relation)
|
|
|
|
await self.db.commit()
|
|
await self.db.refresh(instance)
|
|
logger.info(f"Ratio task created: {instance.id}")
|
|
return instance
|
|
|
|
# ========================= Execution (Background) ========================= #
|
|
|
|
@staticmethod
|
|
async def execute_dataset_ratio_task(instance_id: str) -> None:
|
|
"""Execute a ratio task in background.
|
|
|
|
Supported ratio_method:
|
|
- DATASET: randomly select counts files from each source dataset
|
|
- TAG: randomly select counts files matching relation.filter_conditions tags
|
|
|
|
Steps:
|
|
- Mark instance RUNNING
|
|
- For each relation: fetch ACTIVE files, optionally filter by tags
|
|
- Copy selected files into target dataset
|
|
- Update dataset statistics and mark instance SUCCESS/FAILED
|
|
"""
|
|
async with AsyncSessionLocal() as session: # type: AsyncSession
|
|
try:
|
|
# Load instance and relations
|
|
inst_res = await session.execute(select(RatioInstance).where(RatioInstance.id == instance_id))
|
|
instance: Optional[RatioInstance] = inst_res.scalar_one_or_none()
|
|
if not instance:
|
|
logger.error(f"Ratio instance not found: {instance_id}")
|
|
return
|
|
logger.info(f"start execute ratio task: {instance_id}")
|
|
|
|
rel_res = await session.execute(
|
|
select(RatioRelation).where(RatioRelation.ratio_instance_id == instance_id)
|
|
)
|
|
relations: List[RatioRelation] = list(rel_res.scalars().all())
|
|
|
|
# Mark running
|
|
instance.status = "RUNNING"
|
|
|
|
if instance.ratio_method not in {"DATASET", "TAG"}:
|
|
logger.info(f"Instance {instance_id} ratio_method={instance.ratio_method} not supported yet")
|
|
instance.status = "SUCCESS"
|
|
return
|
|
|
|
# Load target dataset
|
|
ds_res = await session.execute(select(Dataset).where(Dataset.id == instance.target_dataset_id))
|
|
target_ds: Optional[Dataset] = ds_res.scalar_one_or_none()
|
|
if not target_ds:
|
|
logger.error(f"Target dataset not found for instance {instance_id}")
|
|
instance.status = "FAILED"
|
|
return
|
|
|
|
# Preload existing target file paths for deduplication
|
|
existing_path_rows = await session.execute(
|
|
select(DatasetFiles.file_path).where(DatasetFiles.dataset_id == target_ds.id)
|
|
)
|
|
existing_paths = set(p for p in existing_path_rows.scalars().all() if p)
|
|
|
|
added_count = 0
|
|
added_size = 0
|
|
|
|
for rel in relations:
|
|
if not rel.source_dataset_id or not rel.counts or rel.counts <= 0:
|
|
continue
|
|
|
|
# Fetch all files for the source dataset (ACTIVE only)
|
|
files_res = await session.execute(
|
|
select(DatasetFiles).where(
|
|
DatasetFiles.dataset_id == rel.source_dataset_id,
|
|
DatasetFiles.status == "ACTIVE",
|
|
)
|
|
)
|
|
files = list(files_res.scalars().all())
|
|
|
|
# TAG mode: filter by tags according to relation.filter_conditions
|
|
if instance.ratio_method == "TAG":
|
|
required_tags = RatioTaskService._parse_required_tags(rel.filter_conditions)
|
|
if required_tags:
|
|
files = [f for f in files if RatioTaskService._file_contains_tags(f, required_tags)]
|
|
|
|
if not files:
|
|
continue
|
|
|
|
pick_n = min(rel.counts or 0, len(files))
|
|
chosen = random.sample(files, pick_n) if pick_n < len(files) else files
|
|
|
|
# Copy into target dataset with de-dup by target path
|
|
for f in chosen:
|
|
src_path = f.file_path
|
|
new_path = src_path
|
|
needs_copy = False
|
|
src_prefix = f"/dataset/{rel.source_dataset_id}"
|
|
if isinstance(src_path, str) and src_path.startswith(src_prefix):
|
|
dst_prefix = f"/dataset/{target_ds.id}"
|
|
new_path = src_path.replace(src_prefix, dst_prefix, 1)
|
|
needs_copy = True
|
|
|
|
# De-dup by target path
|
|
if new_path in existing_paths:
|
|
continue
|
|
|
|
# Perform copy only when needed
|
|
if needs_copy:
|
|
dst_dir = os.path.dirname(new_path)
|
|
await asyncio.to_thread(os.makedirs, dst_dir, exist_ok=True)
|
|
await asyncio.to_thread(shutil.copy2, src_path, new_path)
|
|
|
|
new_file = DatasetFiles(
|
|
dataset_id=target_ds.id, # type: ignore
|
|
file_name=f.file_name,
|
|
file_path=new_path,
|
|
file_type=f.file_type,
|
|
file_size=f.file_size,
|
|
check_sum=f.check_sum,
|
|
tags=f.tags,
|
|
dataset_filemetadata=f.dataset_filemetadata,
|
|
status="ACTIVE",
|
|
)
|
|
session.add(new_file)
|
|
existing_paths.add(new_path)
|
|
added_count += 1
|
|
added_size += int(f.file_size or 0)
|
|
|
|
# Periodically flush to avoid huge transactions
|
|
await session.flush()
|
|
|
|
# Update target dataset statistics
|
|
target_ds.file_count = (target_ds.file_count or 0) + added_count # type: ignore
|
|
target_ds.size_bytes = (target_ds.size_bytes or 0) + added_size # type: ignore
|
|
# If target dataset has files, mark it ACTIVE
|
|
if (target_ds.file_count or 0) > 0: # type: ignore
|
|
target_ds.status = "ACTIVE"
|
|
|
|
# Done
|
|
instance.status = "SUCCESS"
|
|
logger.info(f"Dataset ratio execution completed: instance={instance_id}, files={added_count}, size={added_size}")
|
|
|
|
except Exception as e:
|
|
logger.exception(f"Dataset ratio execution failed for {instance_id}: {e}")
|
|
try:
|
|
# Try mark failed
|
|
inst_res = await session.execute(select(RatioInstance).where(RatioInstance.id == instance_id))
|
|
instance = inst_res.scalar_one_or_none()
|
|
if instance:
|
|
instance.status = "FAILED"
|
|
finally:
|
|
pass
|
|
finally:
|
|
await session.commit()
|
|
|
|
# ------------------------- helpers for TAG filtering ------------------------- #
|
|
|
|
@staticmethod
|
|
def _parse_required_tags(conditions: Optional[str]) -> set[str]:
|
|
"""Parse filter_conditions into a set of required tag strings.
|
|
|
|
Supports simple separators: comma, semicolon, space. Empty/None -> empty set.
|
|
"""
|
|
if not conditions:
|
|
return set()
|
|
raw = conditions.replace("\n", " ")
|
|
seps = [",", ";", " "]
|
|
tokens = [raw]
|
|
for sep in seps:
|
|
nxt = []
|
|
for t in tokens:
|
|
nxt.extend(t.split(sep))
|
|
tokens = nxt
|
|
return {t.strip() for t in tokens if t and t.strip()}
|
|
|
|
@staticmethod
|
|
def _file_contains_tags(f: DatasetFiles, required: set[str]) -> bool:
|
|
if not required:
|
|
return True
|
|
tags = f.tags
|
|
if not tags:
|
|
return False
|
|
try:
|
|
# tags could be a list of strings or list of objects with 'name'
|
|
tag_names = set()
|
|
if isinstance(tags, list):
|
|
for item in tags:
|
|
if isinstance(item, str):
|
|
tag_names.add(item)
|
|
elif isinstance(item, dict):
|
|
name = item.get("name") or item.get("label") or item.get("tag")
|
|
if isinstance(name, str):
|
|
tag_names.add(name)
|
|
elif isinstance(tags, dict):
|
|
# flat dict of name->... treat keys as tags
|
|
tag_names = set(map(str, tags.keys()))
|
|
else:
|
|
return False
|
|
logger.info(f">>>>>{tags}>>>>>{required}, {tag_names}")
|
|
return required.issubset(tag_names)
|
|
except Exception:
|
|
return False
|
|
|
|
@staticmethod
|
|
async def get_new_file(f, rel: RatioRelation, target_ds: Dataset) -> DatasetFiles:
|
|
new_path = f.file_path
|
|
src_prefix = f"/dataset/{rel.source_dataset_id}"
|
|
if isinstance(f.file_path, str) and f.file_path.startswith(src_prefix):
|
|
dst_prefix = f"/dataset/{target_ds.id}"
|
|
new_path = f.file_path.replace(src_prefix, dst_prefix, 1)
|
|
dst_dir = os.path.dirname(new_path)
|
|
# Ensure directory and copy the file in a thread to avoid blocking the event loop
|
|
await asyncio.to_thread(os.makedirs, dst_dir, exist_ok=True)
|
|
await asyncio.to_thread(shutil.copy2, f.file_path, new_path)
|
|
|
|
new_file = DatasetFiles(
|
|
dataset_id=target_ds.id, # type: ignore
|
|
file_name=f.file_name,
|
|
file_path=new_path,
|
|
file_type=f.file_type,
|
|
file_size=f.file_size,
|
|
check_sum=f.check_sum,
|
|
tags=f.tags,
|
|
dataset_filemetadata=f.dataset_filemetadata,
|
|
status="ACTIVE",
|
|
)
|
|
return new_file
|