Files
DataMate/runtime/datamate-python/app/module/ratio/service/ratio_task.py
hefanli 1d19cd3a62 feature: add data-evaluation
* feature: add evaluation task management function

* feature: add evaluation task detail page

* fix: delete duplicate definition for table t_model_config

* refactor: rename package synthesis to ratio

* refactor: add eval file table and  refactor related code

* fix: calling large models in parallel during evaluation
2025-12-04 09:23:54 +08:00

319 lines
12 KiB
Python

from datetime import datetime
from typing import List, Optional, Dict, Any
import random
import json
import os
import shutil
import asyncio
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.models.ratio_task import RatioInstance, RatioRelation
from app.db.models import Dataset, DatasetFiles
from app.db.session import AsyncSessionLocal
from app.module.dataset.schema.dataset_file import DatasetFileTag
from app.module.shared.schema import TaskStatus
from app.module.ratio.schema.ratio_task import FilterCondition
logger = get_logger(__name__)
class RatioTaskService:
"""Service for Ratio Task DB operations."""
def __init__(self, db: AsyncSession):
self.db = db
async def create_task(
self,
*,
name: str,
description: Optional[str],
totals: int,
config: List[Dict[str, Any]],
target_dataset_id: Optional[str] = None,
) -> RatioInstance:
"""Create a ratio task instance and its relations.
config item format: {"dataset_id": str, "counts": int, "filter_conditions": str}
"""
logger.info(f"Creating ratio task: name={name}, totals={totals}, items={len(config or [])}")
instance = RatioInstance(
name=name,
description=description,
totals=totals,
target_dataset_id=target_dataset_id,
status="PENDING",
)
self.db.add(instance)
await self.db.flush() # populate instance.id
for item in config or []:
relation = RatioRelation(
ratio_instance_id=instance.id,
source_dataset_id=item.get("dataset_id"),
counts=int(item.get("counts", 0)),
filter_conditions=json.dumps({
'date_range': item.get("filter_conditions").date_range,
'label': {
"label":item.get("filter_conditions").label.label,
"value":item.get("filter_conditions").label.value,
},
})
)
logger.info(f"Relation created: {relation.id}, {relation}, {item}, {config}")
self.db.add(relation)
await self.db.commit()
await self.db.refresh(instance)
logger.info(f"Ratio task created: {instance.id}")
return instance
# ========================= Execution (Background) ========================= #
@staticmethod
async def execute_dataset_ratio_task(instance_id: str) -> None:
"""Execute a ratio task in background.
Supported ratio_method:
- DATASET: randomly select counts files from each source dataset
- TAG: randomly select counts files matching relation.filter_conditions tags
Steps:
- Mark instance RUNNING
- For each relation: fetch ACTIVE files, optionally filter by tags
- Copy selected files into target dataset
- Update dataset statistics and mark instance SUCCESS/FAILED
"""
async with AsyncSessionLocal() as session: # type: AsyncSession
try:
# Load instance and relations
inst_res = await session.execute(select(RatioInstance).where(RatioInstance.id == instance_id))
instance: Optional[RatioInstance] = inst_res.scalar_one_or_none()
if not instance:
logger.error(f"Ratio instance not found: {instance_id}")
return
logger.info(f"start execute ratio task: {instance_id}")
rel_res = await session.execute(
select(RatioRelation).where(RatioRelation.ratio_instance_id == instance_id)
)
relations: List[RatioRelation] = list(rel_res.scalars().all())
# Mark running
instance.status = TaskStatus.RUNNING.name
# Load target dataset
ds_res = await session.execute(select(Dataset).where(Dataset.id == instance.target_dataset_id))
target_ds: Optional[Dataset] = ds_res.scalar_one_or_none()
if not target_ds:
logger.error(f"Target dataset not found for instance {instance_id}")
instance.status = TaskStatus.FAILED.name
return
added_count, added_size = await RatioTaskService.handle_ratio_relations(relations,session, target_ds)
# Update target dataset statistics
target_ds.file_count = (target_ds.file_count or 0) + added_count # type: ignore
target_ds.size_bytes = (target_ds.size_bytes or 0) + added_size # type: ignore
# If target dataset has files, mark it ACTIVE
if (target_ds.file_count or 0) > 0: # type: ignore
target_ds.status = "ACTIVE"
# Done
instance.status = TaskStatus.COMPLETED.name
logger.info(f"Dataset ratio execution completed: instance={instance_id}, files={added_count}, size={added_size}, {instance.status}")
except Exception as e:
logger.exception(f"Dataset ratio execution failed for {instance_id}: {e}")
try:
# Try mark failed
inst_res = await session.execute(select(RatioInstance).where(RatioInstance.id == instance_id))
instance = inst_res.scalar_one_or_none()
if instance:
instance.status = TaskStatus.FAILED.name
finally:
pass
finally:
await session.commit()
@staticmethod
async def handle_ratio_relations(relations: list[RatioRelation], session, target_ds: Dataset) -> tuple[int, int]:
# Preload existing target file paths for deduplication
existing_path_rows = await session.execute(
select(DatasetFiles.file_path).where(DatasetFiles.dataset_id == target_ds.id)
)
existing_paths = set(p for p in existing_path_rows.scalars().all() if p)
added_count = 0
added_size = 0
for rel in relations:
if not rel.source_dataset_id or not rel.counts or rel.counts <= 0:
continue
files = await RatioTaskService.get_files(rel, session)
if not files:
continue
pick_n = min(rel.counts or 0, len(files))
chosen = random.sample(files, pick_n) if pick_n < len(files) else files
# Copy into target dataset with de-dup by target path
for f in chosen:
await RatioTaskService.handle_selected_file(existing_paths, f, session, target_ds)
added_count += 1
added_size += int(f.file_size or 0)
# Periodically flush to avoid huge transactions
await session.flush()
return added_count, added_size
@staticmethod
async def handle_selected_file(existing_paths: set[Any], f, session, target_ds: Dataset):
src_path = f.file_path
dst_prefix = f"/dataset/{target_ds.id}/"
file_name = RatioTaskService.get_new_file_name(dst_prefix, existing_paths, f)
new_path = dst_prefix + file_name
dst_dir = os.path.dirname(new_path)
await asyncio.to_thread(os.makedirs, dst_dir, exist_ok=True)
await asyncio.to_thread(shutil.copy2, src_path, new_path)
file_data = {
"dataset_id": target_ds.id, # type: ignore
"file_name": file_name,
"file_path": new_path,
"file_type": f.file_type,
"file_size": f.file_size,
"check_sum": f.check_sum,
"tags": f.tags,
"tags_updated_at": datetime.now(),
"dataset_filemetadata": f.dataset_filemetadata,
"status": "ACTIVE",
}
file_record = {k: v for k, v in file_data.items() if v is not None}
session.add(DatasetFiles(**file_record))
existing_paths.add(new_path)
@staticmethod
def get_new_file_name(dst_prefix: str, existing_paths: set[Any], f) -> str:
file_name = f.file_name
new_path = dst_prefix + file_name
# Handle file path conflicts by appending a number to the filename
if new_path in existing_paths:
file_name_base, file_ext = os.path.splitext(file_name)
counter = 1
original_file_name = file_name
while new_path in existing_paths:
file_name = f"{file_name_base}_{counter}{file_ext}"
new_path = f"{dst_prefix}{file_name}"
counter += 1
if counter > 1000: # Safety check to prevent infinite loops
logger.error(f"Could not find unique filename for {original_file_name} after 1000 attempts")
break
return file_name
@staticmethod
async def get_files(rel: RatioRelation, session) -> list[Any]:
# Fetch all files for the source dataset (ACTIVE only)
files_res = await session.execute(
select(DatasetFiles).where(
DatasetFiles.dataset_id == rel.source_dataset_id,
DatasetFiles.status == "ACTIVE",
)
)
files = list(files_res.scalars().all())
# TAG mode: filter by tags according to relation.filter_conditions
conditions = RatioTaskService._parse_conditions(rel.filter_conditions)
if conditions:
files = [f for f in files if RatioTaskService._filter_file(f, conditions)]
return files
# ------------------------- helpers for TAG filtering ------------------------- #
@staticmethod
def _parse_conditions(conditions: Optional[str]) -> Optional[FilterCondition]:
"""Parse filter_conditions JSON string into a FilterCondition object.
Args:
conditions: JSON string containing filter conditions
Returns:
FilterCondition object if conditions is not None/empty, otherwise None
"""
if not conditions:
return None
try:
data = json.loads(conditions)
return FilterCondition(**data)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse filter conditions: {e}")
return None
except Exception as e:
logger.error(f"Error creating FilterCondition: {e}")
return None
@staticmethod
def _filter_file(file: DatasetFiles, conditions: FilterCondition) -> bool:
if not conditions:
return True
logger.info(f"start filter file: {file}, conditions: {conditions}")
# Check data range condition if provided
if conditions.date_range:
try:
from datetime import datetime, timedelta
data_range_days = int(conditions.date_range)
if data_range_days > 0:
cutoff_date = datetime.now() - timedelta(days=data_range_days)
if file.tags_updated_at and file.tags_updated_at < cutoff_date:
return False
except (ValueError, TypeError) as e:
logger.warning(f"Invalid data_range value: {conditions.date_range}", e)
return False
# Check label condition if provided
if conditions.label:
tags = file.tags
if not tags:
return False
try:
# tags could be a list of strings or list of objects with 'name'
tag_names = RatioTaskService.get_all_tags(tags)
return f"{conditions.label.label}@{conditions.label.value}" in tag_names
except Exception as e:
logger.exception(f"Failed to get tags for {file}", e)
return False
return True
@staticmethod
def get_all_tags(tags) -> set[str]:
"""获取所有处理后的标签字符串列表"""
all_tags = set()
if not tags:
return all_tags
file_tags = []
for tag_data in tags:
# 处理可能的命名风格转换(下划线转驼峰)
processed_data = {}
for key, value in tag_data.items():
# 将驼峰转为下划线以匹配 Pydantic 模型字段
processed_data[key] = value
# 创建 DatasetFileTag 对象
file_tag = DatasetFileTag(**processed_data)
file_tags.append(file_tag)
for file_tag in file_tags:
for tag_data in file_tag.get_tags():
all_tags.add(tag_data)
return all_tags