Fix ratio (#162)

* fix: fixed the issue where an error would be reported when only setting the proportioning quantity when creating a proportioning task

* fix: prevent adding the same file multiple times

* fix: implement a more flexible matching strategy, allowing only the tag name to be configured for matching
This commit is contained in:
hefanli
2025-12-11 17:45:16 +08:00
committed by GitHub
parent bb8641bea2
commit 8f529952f6
2 changed files with 21 additions and 10 deletions

View File

@@ -46,7 +46,7 @@ class DatasetFileTag(BaseModel):
tags.append(tag_values)
# 如果 from_name 不为空,添加前缀
if self.from_name:
tags = [f"{self.from_name}@{tag}" for tag in tags]
tags = [{"label": self.from_name, "value": tag} for tag in tags]
return tags

View File

@@ -61,7 +61,7 @@ class RatioTaskService:
'label': {
"label":item.get("filter_conditions").label.label,
"value":item.get("filter_conditions").label.value,
},
} if item.get("filter_conditions").label else None,
})
)
logger.info(f"Relation created: {relation.id}, {relation}, {item}, {config}")
@@ -147,6 +147,7 @@ class RatioTaskService:
select(DatasetFiles.file_path).where(DatasetFiles.dataset_id == target_ds.id)
)
existing_paths = set(p for p in existing_path_rows.scalars().all() if p)
source_paths = set()
added_count = 0
added_size = 0
@@ -164,10 +165,13 @@ class RatioTaskService:
chosen = random.sample(files, pick_n) if pick_n < len(files) else files
# Copy into target dataset with de-dup by target path
for f in chosen:
await RatioTaskService.handle_selected_file(existing_paths, f, session, target_ds)
for file in chosen:
if file.file_path in source_paths:
continue
await RatioTaskService.handle_selected_file(existing_paths, file, session, target_ds)
source_paths.add(file.file_path)
added_count += 1
added_size += int(f.file_size or 0)
added_size += int(file.file_size or 0)
# Periodically flush to avoid huge transactions
await session.flush()
@@ -286,8 +290,15 @@ class RatioTaskService:
return False
try:
# tags could be a list of strings or list of objects with 'name'
tag_names = RatioTaskService.get_all_tags(tags)
return f"{conditions.label.label}@{conditions.label.value}" in tag_names
all_tags = RatioTaskService.get_all_tags(tags)
for tag in all_tags:
if conditions.label.label and tag.get("label") != conditions.label.label:
continue
if conditions.label.value is not None:
return True
if tag.get("value") == conditions.label.value:
return True
return False
except Exception as e:
logger.exception(f"Failed to get tags for {file}", e)
return False
@@ -295,9 +306,9 @@ class RatioTaskService:
return True
@staticmethod
def get_all_tags(tags) -> set[str]:
def get_all_tags(tags) -> list[dict]:
"""获取所有处理后的标签字符串列表"""
all_tags = set()
all_tags = list()
if not tags:
return all_tags
@@ -314,5 +325,5 @@ class RatioTaskService:
for file_tag in file_tags:
for tag_data in file_tag.get_tags():
all_tags.add(tag_data)
all_tags.append(tag_data)
return all_tags