You've already forked DataMate
Fix ratio (#162)
* fix: fixed the issue where an error would be reported when only setting the proportioning quantity when creating a proportioning task * fix: prevent adding the same file multiple times * fix: implement a more flexible matching strategy, allowing only the tag name to be configured for matching
This commit is contained in:
@@ -46,7 +46,7 @@ class DatasetFileTag(BaseModel):
|
||||
tags.append(tag_values)
|
||||
# 如果 from_name 不为空,添加前缀
|
||||
if self.from_name:
|
||||
tags = [f"{self.from_name}@{tag}" for tag in tags]
|
||||
tags = [{"label": self.from_name, "value": tag} for tag in tags]
|
||||
return tags
|
||||
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ class RatioTaskService:
|
||||
'label': {
|
||||
"label":item.get("filter_conditions").label.label,
|
||||
"value":item.get("filter_conditions").label.value,
|
||||
},
|
||||
} if item.get("filter_conditions").label else None,
|
||||
})
|
||||
)
|
||||
logger.info(f"Relation created: {relation.id}, {relation}, {item}, {config}")
|
||||
@@ -147,6 +147,7 @@ class RatioTaskService:
|
||||
select(DatasetFiles.file_path).where(DatasetFiles.dataset_id == target_ds.id)
|
||||
)
|
||||
existing_paths = set(p for p in existing_path_rows.scalars().all() if p)
|
||||
source_paths = set()
|
||||
|
||||
added_count = 0
|
||||
added_size = 0
|
||||
@@ -164,10 +165,13 @@ class RatioTaskService:
|
||||
chosen = random.sample(files, pick_n) if pick_n < len(files) else files
|
||||
|
||||
# Copy into target dataset with de-dup by target path
|
||||
for f in chosen:
|
||||
await RatioTaskService.handle_selected_file(existing_paths, f, session, target_ds)
|
||||
for file in chosen:
|
||||
if file.file_path in source_paths:
|
||||
continue
|
||||
await RatioTaskService.handle_selected_file(existing_paths, file, session, target_ds)
|
||||
source_paths.add(file.file_path)
|
||||
added_count += 1
|
||||
added_size += int(f.file_size or 0)
|
||||
added_size += int(file.file_size or 0)
|
||||
|
||||
# Periodically flush to avoid huge transactions
|
||||
await session.flush()
|
||||
@@ -286,8 +290,15 @@ class RatioTaskService:
|
||||
return False
|
||||
try:
|
||||
# tags could be a list of strings or list of objects with 'name'
|
||||
tag_names = RatioTaskService.get_all_tags(tags)
|
||||
return f"{conditions.label.label}@{conditions.label.value}" in tag_names
|
||||
all_tags = RatioTaskService.get_all_tags(tags)
|
||||
for tag in all_tags:
|
||||
if conditions.label.label and tag.get("label") != conditions.label.label:
|
||||
continue
|
||||
if conditions.label.value is not None:
|
||||
return True
|
||||
if tag.get("value") == conditions.label.value:
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to get tags for {file}", e)
|
||||
return False
|
||||
@@ -295,9 +306,9 @@ class RatioTaskService:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_all_tags(tags) -> set[str]:
|
||||
def get_all_tags(tags) -> list[dict]:
|
||||
"""获取所有处理后的标签字符串列表"""
|
||||
all_tags = set()
|
||||
all_tags = list()
|
||||
if not tags:
|
||||
return all_tags
|
||||
|
||||
@@ -314,5 +325,5 @@ class RatioTaskService:
|
||||
|
||||
for file_tag in file_tags:
|
||||
for tag_data in file_tag.get_tags():
|
||||
all_tags.add(tag_data)
|
||||
all_tags.append(tag_data)
|
||||
return all_tags
|
||||
|
||||
Reference in New Issue
Block a user