You've already forked DataMate
Fix ratio (#162)
* fix: fixed the issue where an error would be reported when only setting the proportioning quantity when creating a proportioning task * fix: prevent adding the same file multiple times * fix: implement a more flexible matching strategy, allowing only the tag name to be configured for matching
This commit is contained in:
@@ -46,7 +46,7 @@ class DatasetFileTag(BaseModel):
|
|||||||
tags.append(tag_values)
|
tags.append(tag_values)
|
||||||
# 如果 from_name 不为空,添加前缀
|
# 如果 from_name 不为空,添加前缀
|
||||||
if self.from_name:
|
if self.from_name:
|
||||||
tags = [f"{self.from_name}@{tag}" for tag in tags]
|
tags = [{"label": self.from_name, "value": tag} for tag in tags]
|
||||||
return tags
|
return tags
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ class RatioTaskService:
|
|||||||
'label': {
|
'label': {
|
||||||
"label":item.get("filter_conditions").label.label,
|
"label":item.get("filter_conditions").label.label,
|
||||||
"value":item.get("filter_conditions").label.value,
|
"value":item.get("filter_conditions").label.value,
|
||||||
},
|
} if item.get("filter_conditions").label else None,
|
||||||
})
|
})
|
||||||
)
|
)
|
||||||
logger.info(f"Relation created: {relation.id}, {relation}, {item}, {config}")
|
logger.info(f"Relation created: {relation.id}, {relation}, {item}, {config}")
|
||||||
@@ -147,6 +147,7 @@ class RatioTaskService:
|
|||||||
select(DatasetFiles.file_path).where(DatasetFiles.dataset_id == target_ds.id)
|
select(DatasetFiles.file_path).where(DatasetFiles.dataset_id == target_ds.id)
|
||||||
)
|
)
|
||||||
existing_paths = set(p for p in existing_path_rows.scalars().all() if p)
|
existing_paths = set(p for p in existing_path_rows.scalars().all() if p)
|
||||||
|
source_paths = set()
|
||||||
|
|
||||||
added_count = 0
|
added_count = 0
|
||||||
added_size = 0
|
added_size = 0
|
||||||
@@ -164,10 +165,13 @@ class RatioTaskService:
|
|||||||
chosen = random.sample(files, pick_n) if pick_n < len(files) else files
|
chosen = random.sample(files, pick_n) if pick_n < len(files) else files
|
||||||
|
|
||||||
# Copy into target dataset with de-dup by target path
|
# Copy into target dataset with de-dup by target path
|
||||||
for f in chosen:
|
for file in chosen:
|
||||||
await RatioTaskService.handle_selected_file(existing_paths, f, session, target_ds)
|
if file.file_path in source_paths:
|
||||||
|
continue
|
||||||
|
await RatioTaskService.handle_selected_file(existing_paths, file, session, target_ds)
|
||||||
|
source_paths.add(file.file_path)
|
||||||
added_count += 1
|
added_count += 1
|
||||||
added_size += int(f.file_size or 0)
|
added_size += int(file.file_size or 0)
|
||||||
|
|
||||||
# Periodically flush to avoid huge transactions
|
# Periodically flush to avoid huge transactions
|
||||||
await session.flush()
|
await session.flush()
|
||||||
@@ -286,8 +290,15 @@ class RatioTaskService:
|
|||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
# tags could be a list of strings or list of objects with 'name'
|
# tags could be a list of strings or list of objects with 'name'
|
||||||
tag_names = RatioTaskService.get_all_tags(tags)
|
all_tags = RatioTaskService.get_all_tags(tags)
|
||||||
return f"{conditions.label.label}@{conditions.label.value}" in tag_names
|
for tag in all_tags:
|
||||||
|
if conditions.label.label and tag.get("label") != conditions.label.label:
|
||||||
|
continue
|
||||||
|
if conditions.label.value is not None:
|
||||||
|
return True
|
||||||
|
if tag.get("value") == conditions.label.value:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Failed to get tags for {file}", e)
|
logger.exception(f"Failed to get tags for {file}", e)
|
||||||
return False
|
return False
|
||||||
@@ -295,9 +306,9 @@ class RatioTaskService:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_all_tags(tags) -> set[str]:
|
def get_all_tags(tags) -> list[dict]:
|
||||||
"""获取所有处理后的标签字符串列表"""
|
"""获取所有处理后的标签字符串列表"""
|
||||||
all_tags = set()
|
all_tags = list()
|
||||||
if not tags:
|
if not tags:
|
||||||
return all_tags
|
return all_tags
|
||||||
|
|
||||||
@@ -314,5 +325,5 @@ class RatioTaskService:
|
|||||||
|
|
||||||
for file_tag in file_tags:
|
for file_tag in file_tags:
|
||||||
for tag_data in file_tag.get_tags():
|
for tag_data in file_tag.get_tags():
|
||||||
all_tags.add(tag_data)
|
all_tags.append(tag_data)
|
||||||
return all_tags
|
return all_tags
|
||||||
|
|||||||
Reference in New Issue
Block a user