Fix ratio (#162)

* fix: fixed the issue where an error would be reported when only setting the proportioning quantity when creating a proportioning task

* fix: prevent adding the same file multiple times

* fix: implement a more flexible matching strategy, allowing only the tag name to be configured for matching
This commit is contained in:
hefanli
2025-12-11 17:45:16 +08:00
committed by GitHub
parent bb8641bea2
commit 8f529952f6
2 changed files with 21 additions and 10 deletions

View File

@@ -46,7 +46,7 @@ class DatasetFileTag(BaseModel):
tags.append(tag_values) tags.append(tag_values)
# 如果 from_name 不为空,添加前缀 # 如果 from_name 不为空,添加前缀
if self.from_name: if self.from_name:
tags = [f"{self.from_name}@{tag}" for tag in tags] tags = [{"label": self.from_name, "value": tag} for tag in tags]
return tags return tags

View File

@@ -61,7 +61,7 @@ class RatioTaskService:
'label': { 'label': {
"label":item.get("filter_conditions").label.label, "label":item.get("filter_conditions").label.label,
"value":item.get("filter_conditions").label.value, "value":item.get("filter_conditions").label.value,
}, } if item.get("filter_conditions").label else None,
}) })
) )
logger.info(f"Relation created: {relation.id}, {relation}, {item}, {config}") logger.info(f"Relation created: {relation.id}, {relation}, {item}, {config}")
@@ -147,6 +147,7 @@ class RatioTaskService:
select(DatasetFiles.file_path).where(DatasetFiles.dataset_id == target_ds.id) select(DatasetFiles.file_path).where(DatasetFiles.dataset_id == target_ds.id)
) )
existing_paths = set(p for p in existing_path_rows.scalars().all() if p) existing_paths = set(p for p in existing_path_rows.scalars().all() if p)
source_paths = set()
added_count = 0 added_count = 0
added_size = 0 added_size = 0
@@ -164,10 +165,13 @@ class RatioTaskService:
chosen = random.sample(files, pick_n) if pick_n < len(files) else files chosen = random.sample(files, pick_n) if pick_n < len(files) else files
# Copy into target dataset with de-dup by target path # Copy into target dataset with de-dup by target path
for f in chosen: for file in chosen:
await RatioTaskService.handle_selected_file(existing_paths, f, session, target_ds) if file.file_path in source_paths:
continue
await RatioTaskService.handle_selected_file(existing_paths, file, session, target_ds)
source_paths.add(file.file_path)
added_count += 1 added_count += 1
added_size += int(f.file_size or 0) added_size += int(file.file_size or 0)
# Periodically flush to avoid huge transactions # Periodically flush to avoid huge transactions
await session.flush() await session.flush()
@@ -286,8 +290,15 @@ class RatioTaskService:
return False return False
try: try:
# tags could be a list of strings or list of objects with 'name' # tags could be a list of strings or list of objects with 'name'
tag_names = RatioTaskService.get_all_tags(tags) all_tags = RatioTaskService.get_all_tags(tags)
return f"{conditions.label.label}@{conditions.label.value}" in tag_names for tag in all_tags:
if conditions.label.label and tag.get("label") != conditions.label.label:
continue
if conditions.label.value is not None:
return True
if tag.get("value") == conditions.label.value:
return True
return False
except Exception as e: except Exception as e:
logger.exception(f"Failed to get tags for {file}", e) logger.exception(f"Failed to get tags for {file}", e)
return False return False
@@ -295,9 +306,9 @@ class RatioTaskService:
return True return True
@staticmethod @staticmethod
def get_all_tags(tags) -> set[str]: def get_all_tags(tags) -> list[dict]:
"""获取所有处理后的标签字符串列表""" """获取所有处理后的标签字符串列表"""
all_tags = set() all_tags = list()
if not tags: if not tags:
return all_tags return all_tags
@@ -314,5 +325,5 @@ class RatioTaskService:
for file_tag in file_tags: for file_tag in file_tags:
for tag_data in file_tag.get_tags(): for tag_data in file_tag.get_tags():
all_tags.add(tag_data) all_tags.append(tag_data)
return all_tags return all_tags