You've already forked DataMate
feature: 数据配比增加通过更新时间来配置 (#95)
* feature: 数据配比增加通过更新时间来配置 * fix: 修复配比时间参数传递的问题
This commit is contained in:
@@ -14,6 +14,8 @@ from app.db.models.ratio_task import RatioInstance, RatioRelation
|
||||
from app.db.models import Dataset, DatasetFiles
|
||||
from app.db.session import AsyncSessionLocal
|
||||
from app.module.dataset.schema.dataset_file import DatasetFileTag
|
||||
from app.module.shared.schema import TaskStatus
|
||||
from app.module.synthesis.schema.ratio_task import FilterCondition
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -30,7 +32,6 @@ class RatioTaskService:
|
||||
name: str,
|
||||
description: Optional[str],
|
||||
totals: int,
|
||||
ratio_method: str,
|
||||
config: List[Dict[str, Any]],
|
||||
target_dataset_id: Optional[str] = None,
|
||||
) -> RatioInstance:
|
||||
@@ -38,12 +39,11 @@ class RatioTaskService:
|
||||
|
||||
config item format: {"dataset_id": str, "counts": int, "filter_conditions": str}
|
||||
"""
|
||||
logger.info(f"Creating ratio task: name={name}, method={ratio_method}, totals={totals}, items={len(config or [])}")
|
||||
logger.info(f"Creating ratio task: name={name}, totals={totals}, items={len(config or [])}")
|
||||
|
||||
instance = RatioInstance(
|
||||
name=name,
|
||||
description=description,
|
||||
ratio_method=ratio_method,
|
||||
totals=totals,
|
||||
target_dataset_id=target_dataset_id,
|
||||
status="PENDING",
|
||||
@@ -56,8 +56,12 @@ class RatioTaskService:
|
||||
ratio_instance_id=instance.id,
|
||||
source_dataset_id=item.get("dataset_id"),
|
||||
counts=int(item.get("counts", 0)),
|
||||
filter_conditions=item.get("filter_conditions"),
|
||||
filter_conditions=json.dumps({
|
||||
'date_range': item.get("filter_conditions").date_range,
|
||||
'label': item.get("filter_conditions").label,
|
||||
})
|
||||
)
|
||||
logger.info(f"Relation created: {relation.id}, {relation}, {item}, {config}")
|
||||
self.db.add(relation)
|
||||
|
||||
await self.db.commit()
|
||||
@@ -97,94 +101,17 @@ class RatioTaskService:
|
||||
relations: List[RatioRelation] = list(rel_res.scalars().all())
|
||||
|
||||
# Mark running
|
||||
instance.status = "RUNNING"
|
||||
|
||||
if instance.ratio_method not in {"DATASET", "TAG"}:
|
||||
logger.info(f"Instance {instance_id} ratio_method={instance.ratio_method} not supported yet")
|
||||
instance.status = "SUCCESS"
|
||||
return
|
||||
instance.status = TaskStatus.RUNNING.name
|
||||
|
||||
# Load target dataset
|
||||
ds_res = await session.execute(select(Dataset).where(Dataset.id == instance.target_dataset_id))
|
||||
target_ds: Optional[Dataset] = ds_res.scalar_one_or_none()
|
||||
if not target_ds:
|
||||
logger.error(f"Target dataset not found for instance {instance_id}")
|
||||
instance.status = "FAILED"
|
||||
instance.status = TaskStatus.FAILED.name
|
||||
return
|
||||
|
||||
# Preload existing target file paths for deduplication
|
||||
existing_path_rows = await session.execute(
|
||||
select(DatasetFiles.file_path).where(DatasetFiles.dataset_id == target_ds.id)
|
||||
)
|
||||
existing_paths = set(p for p in existing_path_rows.scalars().all() if p)
|
||||
|
||||
added_count = 0
|
||||
added_size = 0
|
||||
|
||||
for rel in relations:
|
||||
if not rel.source_dataset_id or not rel.counts or rel.counts <= 0:
|
||||
continue
|
||||
|
||||
# Fetch all files for the source dataset (ACTIVE only)
|
||||
files_res = await session.execute(
|
||||
select(DatasetFiles).where(
|
||||
DatasetFiles.dataset_id == rel.source_dataset_id,
|
||||
DatasetFiles.status == "ACTIVE",
|
||||
)
|
||||
)
|
||||
files = list(files_res.scalars().all())
|
||||
|
||||
# TAG mode: filter by tags according to relation.filter_conditions
|
||||
if instance.ratio_method == "TAG":
|
||||
required_tags = RatioTaskService._parse_required_tags(rel.filter_conditions)
|
||||
if required_tags:
|
||||
files = [f for f in files if RatioTaskService._file_contains_tags(f, required_tags)]
|
||||
|
||||
if not files:
|
||||
continue
|
||||
|
||||
pick_n = min(rel.counts or 0, len(files))
|
||||
chosen = random.sample(files, pick_n) if pick_n < len(files) else files
|
||||
|
||||
# Copy into target dataset with de-dup by target path
|
||||
for f in chosen:
|
||||
src_path = f.file_path
|
||||
new_path = src_path
|
||||
needs_copy = False
|
||||
src_prefix = f"/dataset/{rel.source_dataset_id}"
|
||||
if isinstance(src_path, str) and src_path.startswith(src_prefix):
|
||||
dst_prefix = f"/dataset/{target_ds.id}"
|
||||
new_path = src_path.replace(src_prefix, dst_prefix, 1)
|
||||
needs_copy = True
|
||||
|
||||
# De-dup by target path
|
||||
if new_path in existing_paths:
|
||||
continue
|
||||
|
||||
# Perform copy only when needed
|
||||
if needs_copy:
|
||||
dst_dir = os.path.dirname(new_path)
|
||||
await asyncio.to_thread(os.makedirs, dst_dir, exist_ok=True)
|
||||
await asyncio.to_thread(shutil.copy2, src_path, new_path)
|
||||
|
||||
new_file = DatasetFiles(
|
||||
dataset_id=target_ds.id, # type: ignore
|
||||
file_name=f.file_name,
|
||||
file_path=new_path,
|
||||
file_type=f.file_type,
|
||||
file_size=f.file_size,
|
||||
check_sum=f.check_sum,
|
||||
tags=f.tags,
|
||||
dataset_filemetadata=f.dataset_filemetadata,
|
||||
status="ACTIVE",
|
||||
)
|
||||
session.add(new_file)
|
||||
existing_paths.add(new_path)
|
||||
added_count += 1
|
||||
added_size += int(f.file_size or 0)
|
||||
|
||||
# Periodically flush to avoid huge transactions
|
||||
await session.flush()
|
||||
added_count, added_size = await RatioTaskService.handle_ratio_relations(relations,session, target_ds)
|
||||
|
||||
# Update target dataset statistics
|
||||
target_ds.file_count = (target_ds.file_count or 0) + added_count # type: ignore
|
||||
@@ -194,8 +121,8 @@ class RatioTaskService:
|
||||
target_ds.status = "ACTIVE"
|
||||
|
||||
# Done
|
||||
instance.status = "SUCCESS"
|
||||
logger.info(f"Dataset ratio execution completed: instance={instance_id}, files={added_count}, size={added_size}")
|
||||
instance.status = TaskStatus.COMPLETED.name
|
||||
logger.info(f"Dataset ratio execution completed: instance={instance_id}, files={added_count}, size={added_size}, {instance.status}")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Dataset ratio execution failed for {instance_id}: {e}")
|
||||
@@ -204,42 +131,163 @@ class RatioTaskService:
|
||||
inst_res = await session.execute(select(RatioInstance).where(RatioInstance.id == instance_id))
|
||||
instance = inst_res.scalar_one_or_none()
|
||||
if instance:
|
||||
instance.status = "FAILED"
|
||||
instance.status = TaskStatus.FAILED.name
|
||||
finally:
|
||||
pass
|
||||
finally:
|
||||
await session.commit()
|
||||
|
||||
@staticmethod
|
||||
async def handle_ratio_relations(relations: list[RatioRelation], session, target_ds: Dataset) -> tuple[int, int]:
|
||||
# Preload existing target file paths for deduplication
|
||||
existing_path_rows = await session.execute(
|
||||
select(DatasetFiles.file_path).where(DatasetFiles.dataset_id == target_ds.id)
|
||||
)
|
||||
existing_paths = set(p for p in existing_path_rows.scalars().all() if p)
|
||||
|
||||
added_count = 0
|
||||
added_size = 0
|
||||
|
||||
for rel in relations:
|
||||
if not rel.source_dataset_id or not rel.counts or rel.counts <= 0:
|
||||
continue
|
||||
|
||||
files = await RatioTaskService.get_files(rel, session)
|
||||
|
||||
if not files:
|
||||
continue
|
||||
|
||||
pick_n = min(rel.counts or 0, len(files))
|
||||
chosen = random.sample(files, pick_n) if pick_n < len(files) else files
|
||||
|
||||
# Copy into target dataset with de-dup by target path
|
||||
for f in chosen:
|
||||
await RatioTaskService.handle_selected_file(existing_paths, f, session, target_ds)
|
||||
added_count += 1
|
||||
added_size += int(f.file_size or 0)
|
||||
|
||||
# Periodically flush to avoid huge transactions
|
||||
await session.flush()
|
||||
return added_count, added_size
|
||||
|
||||
@staticmethod
|
||||
async def handle_selected_file(existing_paths: set[Any], f, session, target_ds: Dataset):
|
||||
src_path = f.file_path
|
||||
dst_prefix = f"/dataset/{target_ds.id}"
|
||||
file_name = RatioTaskService.get_new_file_name(dst_prefix, existing_paths, f)
|
||||
|
||||
new_path = dst_prefix + file_name
|
||||
dst_dir = os.path.dirname(new_path)
|
||||
await asyncio.to_thread(os.makedirs, dst_dir, exist_ok=True)
|
||||
await asyncio.to_thread(shutil.copy2, src_path, new_path)
|
||||
|
||||
new_file = DatasetFiles(
|
||||
dataset_id=target_ds.id, # type: ignore
|
||||
file_name=file_name,
|
||||
file_path=new_path,
|
||||
file_type=f.file_type,
|
||||
file_size=f.file_size,
|
||||
check_sum=f.check_sum,
|
||||
tags=f.tags,
|
||||
dataset_filemetadata=f.dataset_filemetadata,
|
||||
status="ACTIVE",
|
||||
)
|
||||
session.add(new_file)
|
||||
existing_paths.add(new_path)
|
||||
|
||||
@staticmethod
|
||||
def get_new_file_name(dst_prefix: str, existing_paths: set[Any], f) -> str:
|
||||
file_name = f.file_name
|
||||
new_path = dst_prefix + file_name
|
||||
|
||||
# Handle file path conflicts by appending a number to the filename
|
||||
if new_path in existing_paths:
|
||||
file_name_base, file_ext = os.path.splitext(file_name)
|
||||
counter = 1
|
||||
original_file_name = file_name
|
||||
while new_path in existing_paths:
|
||||
file_name = f"{file_name_base}_{counter}{file_ext}"
|
||||
new_path = f"{dst_prefix}{file_name}"
|
||||
counter += 1
|
||||
if counter > 1000: # Safety check to prevent infinite loops
|
||||
logger.error(f"Could not find unique filename for {original_file_name} after 1000 attempts")
|
||||
break
|
||||
return file_name
|
||||
|
||||
@staticmethod
|
||||
async def get_files(rel: RatioRelation, session) -> list[Any]:
|
||||
# Fetch all files for the source dataset (ACTIVE only)
|
||||
files_res = await session.execute(
|
||||
select(DatasetFiles).where(
|
||||
DatasetFiles.dataset_id == rel.source_dataset_id,
|
||||
DatasetFiles.status == "ACTIVE",
|
||||
)
|
||||
)
|
||||
files = list(files_res.scalars().all())
|
||||
|
||||
# TAG mode: filter by tags according to relation.filter_conditions
|
||||
conditions = RatioTaskService._parse_conditions(rel.filter_conditions)
|
||||
if conditions:
|
||||
files = [f for f in files if RatioTaskService._filter_file(f, conditions)]
|
||||
return files
|
||||
|
||||
# ------------------------- helpers for TAG filtering ------------------------- #
|
||||
|
||||
@staticmethod
|
||||
def _parse_required_tags(conditions: Optional[str]) -> set[str]:
|
||||
"""Parse filter_conditions into a set of required tag strings.
|
||||
def _parse_conditions(conditions: Optional[str]) -> Optional[FilterCondition]:
|
||||
"""Parse filter_conditions JSON string into a FilterCondition object.
|
||||
|
||||
Supports simple separators: comma, semicolon, space. Empty/None -> empty set.
|
||||
Args:
|
||||
conditions: JSON string containing filter conditions
|
||||
|
||||
Returns:
|
||||
FilterCondition object if conditions is not None/empty, otherwise None
|
||||
"""
|
||||
if not conditions:
|
||||
return set()
|
||||
data = json.loads(conditions)
|
||||
required_tags = set()
|
||||
if data.get("label"):
|
||||
required_tags.add(data["label"])
|
||||
return required_tags
|
||||
return None
|
||||
try:
|
||||
data = json.loads(conditions)
|
||||
return FilterCondition(**data)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse filter conditions: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating FilterCondition: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _file_contains_tags(file: DatasetFiles, required: set[str]) -> bool:
|
||||
if not required:
|
||||
def _filter_file(file: DatasetFiles, conditions: FilterCondition) -> bool:
|
||||
if not conditions:
|
||||
return True
|
||||
tags = file.tags
|
||||
if not tags:
|
||||
return False
|
||||
try:
|
||||
# tags could be a list of strings or list of objects with 'name'
|
||||
tag_names = RatioTaskService.get_all_tags(tags)
|
||||
return required.issubset(tag_names)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to get tags for {file}", e)
|
||||
return False
|
||||
logger.info(f"start filter file: {file}, conditions: {conditions}")
|
||||
|
||||
# Check data range condition if provided
|
||||
if conditions.date_range:
|
||||
try:
|
||||
from datetime import datetime, timedelta
|
||||
data_range_days = int(conditions.date_range)
|
||||
if data_range_days > 0:
|
||||
cutoff_date = datetime.now() - timedelta(days=data_range_days)
|
||||
if file.tags_updated_at and file.tags_updated_at < cutoff_date:
|
||||
return False
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning(f"Invalid data_range value: {conditions.date_range}", e)
|
||||
return False
|
||||
|
||||
# Check label condition if provided
|
||||
if conditions.label:
|
||||
tags = file.tags
|
||||
if not tags:
|
||||
return False
|
||||
try:
|
||||
# tags could be a list of strings or list of objects with 'name'
|
||||
tag_names = RatioTaskService.get_all_tags(tags)
|
||||
return conditions.label in tag_names
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to get tags for {file}", e)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_all_tags(tags) -> set[str]:
|
||||
|
||||
Reference in New Issue
Block a user