You've already forked DataMate
feature: 数据配比增加通过更新时间来配置 (#95)
* feature: 数据配比增加通过更新时间来配置 * fix: 修复配比时间参数传递的问题
This commit is contained in:
@@ -7,7 +7,6 @@ import { useNavigate } from "react-router";
|
||||
import SelectDataset from "@/pages/RatioTask/Create/components/SelectDataset.tsx";
|
||||
import BasicInformation from "@/pages/RatioTask/Create/components/BasicInformation.tsx";
|
||||
import RatioConfig from "@/pages/RatioTask/Create/components/RatioConfig.tsx";
|
||||
import RatioTransfer from "./components/RatioTransfer";
|
||||
|
||||
export default function CreateRatioTask() {
|
||||
const navigate = useNavigate();
|
||||
@@ -36,27 +35,12 @@ export default function CreateRatioTask() {
|
||||
message.error("请配置配比项");
|
||||
return;
|
||||
}
|
||||
// Build request payload
|
||||
const ratio_method =
|
||||
ratioTaskForm.ratioType === "dataset" ? "DATASET" : "TAG";
|
||||
const totals = String(values.totalTargetCount);
|
||||
const config = ratioTaskForm.ratioConfigs.map((c) => {
|
||||
if (ratio_method === "DATASET") {
|
||||
return {
|
||||
datasetId: String(c.source),
|
||||
counts: String(c.quantity ?? 0),
|
||||
filter_conditions: "",
|
||||
};
|
||||
}
|
||||
// TAG mode: source key like `${datasetId}_${label}`
|
||||
const source = String(c.source || "");
|
||||
const idx = source.indexOf("_");
|
||||
const datasetId = idx > 0 ? source.slice(0, idx) : source;
|
||||
const label = idx > 0 ? source.slice(idx + 1) : "";
|
||||
return {
|
||||
datasetId,
|
||||
datasetId: c.id,
|
||||
counts: String(c.quantity ?? 0),
|
||||
filter_conditions: label ? JSON.stringify({ label }) : "",
|
||||
filterConditions: { label: c.labelFilter, dateRange: String(c.dateRange ?? 0)},
|
||||
};
|
||||
});
|
||||
|
||||
@@ -65,7 +49,6 @@ export default function CreateRatioTask() {
|
||||
name: values.name,
|
||||
description: values.description,
|
||||
totals,
|
||||
ratio_method,
|
||||
config,
|
||||
});
|
||||
message.success("配比任务创建成功");
|
||||
@@ -108,13 +91,6 @@ export default function CreateRatioTask() {
|
||||
totalTargetCount={ratioTaskForm.totalTargetCount}
|
||||
/>
|
||||
|
||||
{/* <RatioTransfer
|
||||
ratioTaskForm={ratioTaskForm}
|
||||
distributions={distributions}
|
||||
updateRatioConfig={updateRatioConfig}
|
||||
updateLabelRatioConfig={updateLabelRatioConfig}
|
||||
/> */}
|
||||
|
||||
<div className="flex h-full">
|
||||
<SelectDataset
|
||||
selectedDatasets={ratioTaskForm.selectedDatasets}
|
||||
|
||||
@@ -27,7 +27,7 @@ const BasicInformation: React.FC<BasicInformationProps> = ({
|
||||
<Input type="number" placeholder="目标总数量" min={1} />
|
||||
</Form.Item>
|
||||
<Form.Item label="任务描述" name="description" className="col-span-2">
|
||||
<TextArea placeholder="描述配比任务的目的和要求(可选)" rows={2} />
|
||||
<TextArea placeholder="描述配比任务的目的和要求" rows={2} />
|
||||
</Form.Item>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -1,7 +1,19 @@
|
||||
import React, { useMemo, useState } from "react";
|
||||
import { Badge, Card, Input, Progress, Button, Divider } from "antd";
|
||||
import { BarChart3 } from "lucide-react";
|
||||
import { Badge, Card, Input, Progress, Button, DatePicker, Select } from "antd";
|
||||
import { BarChart3, Filter, Clock } from "lucide-react";
|
||||
import type { Dataset } from "@/pages/DataManagement/dataset.model.ts";
|
||||
import dayjs from 'dayjs';
|
||||
|
||||
const { RangePicker } = DatePicker;
|
||||
const { Option } = Select;
|
||||
|
||||
const TIME_RANGE_OPTIONS = [
|
||||
{ label: '最近1天', value: 1 },
|
||||
{ label: '最近3天', value: 3 },
|
||||
{ label: '最近7天', value: 7 },
|
||||
{ label: '最近15天', value: 15 },
|
||||
{ label: '最近30天', value: 30 },
|
||||
];
|
||||
|
||||
interface RatioConfigItem {
|
||||
id: string;
|
||||
@@ -10,6 +22,8 @@ interface RatioConfigItem {
|
||||
quantity: number;
|
||||
percentage: number;
|
||||
source: string;
|
||||
labelFilter?: string;
|
||||
dateRange?: string;
|
||||
}
|
||||
|
||||
interface RatioConfigProps {
|
||||
@@ -22,14 +36,18 @@ interface RatioConfigProps {
|
||||
}
|
||||
|
||||
const RatioConfig: React.FC<RatioConfigProps> = ({
|
||||
ratioType,
|
||||
selectedDatasets,
|
||||
datasets,
|
||||
totalTargetCount,
|
||||
distributions,
|
||||
onChange,
|
||||
}) => {
|
||||
ratioType,
|
||||
selectedDatasets,
|
||||
datasets,
|
||||
totalTargetCount,
|
||||
distributions,
|
||||
onChange,
|
||||
}) => {
|
||||
const [ratioConfigs, setRatioConfigs] = useState<RatioConfigItem[]>([]);
|
||||
const [datasetFilters, setDatasetFilters] = useState<Record<string, {
|
||||
labelFilter?: string;
|
||||
dateRange?: string;
|
||||
}>>({});
|
||||
|
||||
// 配比项总数
|
||||
const totalConfigured = useMemo(
|
||||
@@ -37,6 +55,36 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
|
||||
[ratioConfigs]
|
||||
);
|
||||
|
||||
// 获取数据集的标签列表
|
||||
const getDatasetLabels = (datasetId: string): string[] => {
|
||||
const dist = distributions[String(datasetId)] || {};
|
||||
return Object.keys(dist);
|
||||
};
|
||||
|
||||
// 自动平均分配
|
||||
const generateAutoRatio = () => {
|
||||
const selectedCount = selectedDatasets.length;
|
||||
if (selectedCount === 0) return;
|
||||
const baseQuantity = Math.floor(totalTargetCount / selectedCount);
|
||||
const remainder = totalTargetCount % selectedCount;
|
||||
const newConfigs = selectedDatasets.map((datasetId, index) => {
|
||||
const dataset = datasets.find((d) => String(d.id) === datasetId);
|
||||
const quantity = baseQuantity + (index < remainder ? 1 : 0);
|
||||
return {
|
||||
id: datasetId,
|
||||
name: dataset?.name || datasetId,
|
||||
type: ratioType,
|
||||
quantity,
|
||||
percentage: Math.round((quantity / totalTargetCount) * 100),
|
||||
source: datasetId,
|
||||
labelFilter: datasetFilters[datasetId]?.labelFilter,
|
||||
dateRange: datasetFilters[datasetId]?.dateRange,
|
||||
};
|
||||
});
|
||||
setRatioConfigs(newConfigs);
|
||||
onChange?.(newConfigs);
|
||||
};
|
||||
|
||||
// 更新数据集配比项
|
||||
const updateDatasetQuantity = (datasetId: string, quantity: number) => {
|
||||
setRatioConfigs((prev) => {
|
||||
@@ -55,6 +103,8 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
|
||||
quantity: Math.min(quantity, totalTargetCount - totalOtherQuantity),
|
||||
percentage: Math.round((quantity / totalTargetCount) * 100),
|
||||
source: datasetId,
|
||||
labelFilter: datasetFilters[datasetId]?.labelFilter,
|
||||
dateRange: datasetFilters[datasetId]?.dateRange,
|
||||
};
|
||||
|
||||
let newConfigs;
|
||||
@@ -69,78 +119,85 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
|
||||
});
|
||||
};
|
||||
|
||||
// 自动平均分配
|
||||
const generateAutoRatio = () => {
|
||||
const selectedCount = selectedDatasets.length;
|
||||
if (selectedCount === 0) return;
|
||||
const baseQuantity = Math.floor(totalTargetCount / selectedCount);
|
||||
const remainder = totalTargetCount % selectedCount;
|
||||
const newConfigs = selectedDatasets.map((datasetId, index) => {
|
||||
const dataset = datasets.find((d) => String(d.id) === datasetId);
|
||||
const quantity = baseQuantity + (index < remainder ? 1 : 0);
|
||||
return {
|
||||
id: datasetId,
|
||||
name: dataset?.name || datasetId,
|
||||
type: ratioType,
|
||||
quantity,
|
||||
percentage: Math.round((quantity / totalTargetCount) * 100),
|
||||
source: datasetId,
|
||||
};
|
||||
});
|
||||
setRatioConfigs(newConfigs);
|
||||
onChange?.(newConfigs);
|
||||
};
|
||||
|
||||
// 标签模式下,更新某数据集的某个标签的数量
|
||||
const updateLabelQuantity = (
|
||||
datasetId: string,
|
||||
label: string,
|
||||
quantity: number
|
||||
) => {
|
||||
const sourceKey = `${datasetId}_${label}`;
|
||||
setRatioConfigs((prev) => {
|
||||
const existingIndex = prev.findIndex((c) => c.source === sourceKey);
|
||||
const totalOtherQuantity = prev
|
||||
.filter((c) => c.source !== sourceKey)
|
||||
.reduce((sum, c) => sum + c.quantity, 0);
|
||||
const dist = distributions[datasetId] || {};
|
||||
const labelMax = dist[label] ?? Infinity;
|
||||
const cappedQuantity = Math.max(
|
||||
0,
|
||||
Math.min(quantity, totalTargetCount - totalOtherQuantity, labelMax)
|
||||
);
|
||||
const newConfig: RatioConfigItem = {
|
||||
id: sourceKey,
|
||||
name: label,
|
||||
type: "label",
|
||||
quantity: cappedQuantity,
|
||||
percentage: Math.round((cappedQuantity / totalTargetCount) * 100),
|
||||
source: sourceKey,
|
||||
};
|
||||
let newConfigs;
|
||||
if (existingIndex >= 0) {
|
||||
newConfigs = [...prev];
|
||||
newConfigs[existingIndex] = newConfig;
|
||||
} else {
|
||||
newConfigs = [...prev, newConfig];
|
||||
// 更新筛选条件
|
||||
const updateFilters = (datasetId: string, updates: {
|
||||
labelFilter?: string;
|
||||
dateRange?: [string, string];
|
||||
}) => {
|
||||
setDatasetFilters(prev => ({
|
||||
...prev,
|
||||
[datasetId]: {
|
||||
...prev[datasetId],
|
||||
...updates,
|
||||
}
|
||||
onChange?.(newConfigs);
|
||||
return newConfigs;
|
||||
});
|
||||
}));
|
||||
};
|
||||
|
||||
// 选中数据集变化时,移除未选中的配比项
|
||||
// 渲染筛选器
|
||||
const renderFilters = (datasetId: string) => {
|
||||
const labels = getDatasetLabels(datasetId);
|
||||
const config = ratioConfigs.find(c => c.source === datasetId);
|
||||
const filters = datasetFilters[datasetId] || {};
|
||||
|
||||
return (
|
||||
<div className="mb-3 p-2 bg-gray-50 rounded">
|
||||
<div className="flex items-center gap-2 mb-2">
|
||||
<Filter size={14} className="text-gray-400" />
|
||||
<span className="text-xs font-medium">筛选条件</span>
|
||||
</div>
|
||||
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 gap-3">
|
||||
<div>
|
||||
<div className="text-xs text-gray-500 mb-1">标签筛选</div>
|
||||
<Select
|
||||
style={{ width: '100%' }}
|
||||
placeholder="选择标签"
|
||||
value={filters.labelFilter}
|
||||
onChange={(value) => updateFilters(datasetId, { labelFilter: value })}
|
||||
allowClear
|
||||
onClear={() => updateFilters(datasetId, { labelFilter: undefined })}
|
||||
>
|
||||
{labels.map(label => (
|
||||
<Option key={label} value={label}>{label}</Option>
|
||||
))}
|
||||
</Select>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<div className="text-xs text-gray-500 mb-1">标签更新时间</div>
|
||||
<Select
|
||||
style={{ width: '100%' }}
|
||||
placeholder="选择标签更新时间"
|
||||
value={filters.dateRange}
|
||||
onChange={(dates) => updateFilters(datasetId, { dateRange: dates })}
|
||||
allowClear
|
||||
onClear={() => updateFilters(datasetId, { dateRange: undefined })}
|
||||
>
|
||||
{TIME_RANGE_OPTIONS.map(option => (
|
||||
<Option key={option.value} value={option.value}>
|
||||
{option.label}
|
||||
</Option>
|
||||
))}
|
||||
</Select>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
// 选中数据集变化时,初始化筛选条件
|
||||
React.useEffect(() => {
|
||||
setRatioConfigs((prev) => {
|
||||
const next = prev.filter((c) => {
|
||||
const id = String(c.source);
|
||||
const dsId = id.includes("_") ? id.split("_")[0] : id;
|
||||
return selectedDatasets.includes(dsId);
|
||||
});
|
||||
if (next !== prev) onChange?.(next);
|
||||
return next;
|
||||
const initialFilters: Record<string, any> = {};
|
||||
selectedDatasets.forEach(datasetId => {
|
||||
const config = ratioConfigs.find(c => c.source === datasetId);
|
||||
if (config) {
|
||||
initialFilters[datasetId] = {
|
||||
labelFilter: config.labelFilter,
|
||||
dateRange: config.dateRange,
|
||||
};
|
||||
}
|
||||
});
|
||||
// eslint-disable-next-line
|
||||
setDatasetFilters(prev => ({ ...prev, ...initialFilters }));
|
||||
}, [selectedDatasets]);
|
||||
|
||||
return (
|
||||
@@ -148,7 +205,7 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
|
||||
<div className="flex items-center justify-between p-4 border-bottom">
|
||||
<span className="text-sm font-bold">
|
||||
配比配置
|
||||
<span className="text-xs text-gray-500">
|
||||
<span className="text-xs text-gray-500 ml-1">
|
||||
(已配置:{totalConfigured}/{totalTargetCount}条)
|
||||
</span>
|
||||
</span>
|
||||
@@ -170,41 +227,36 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
|
||||
<div className="flex-overflow-auto gap-4 p-4">
|
||||
{/* 配比预览 */}
|
||||
{ratioConfigs.length > 0 && (
|
||||
<div>
|
||||
<div className="p-3 bg-gray-50 rounded-lg">
|
||||
<div className="grid grid-cols-2 gap-4 text-sm">
|
||||
<div>
|
||||
<span className="text-gray-500">总配比数量:</span>
|
||||
<span className="ml-2 font-medium">
|
||||
{ratioConfigs
|
||||
.reduce((sum, config) => sum + config.quantity, 0)
|
||||
.toLocaleString()}
|
||||
</span>
|
||||
</div>
|
||||
<div>
|
||||
<span className="text-gray-500">目标数量:</span>
|
||||
<span className="ml-2 font-medium">
|
||||
{totalTargetCount.toLocaleString()}
|
||||
</span>
|
||||
</div>
|
||||
<div>
|
||||
<span className="text-gray-500">配比项目:</span>
|
||||
<span className="ml-2 font-medium">
|
||||
{ratioConfigs.length}个
|
||||
</span>
|
||||
</div>
|
||||
<div className="p-3 bg-gray-50 rounded-lg mb-4">
|
||||
<div className="grid grid-cols-2 gap-4 text-sm">
|
||||
<div>
|
||||
<span className="text-gray-500">总配比数量:</span>
|
||||
<span className="ml-2 font-medium">
|
||||
{ratioConfigs
|
||||
.reduce((sum, config) => sum + config.quantity, 0)
|
||||
.toLocaleString()}
|
||||
</span>
|
||||
</div>
|
||||
<div>
|
||||
<span className="text-gray-500">目标数量:</span>
|
||||
<span className="ml-2 font-medium">
|
||||
{totalTargetCount.toLocaleString()}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
<div className="flex-1 overflow-auto">
|
||||
|
||||
<div className="flex-1 overflow-auto space-y-4">
|
||||
{selectedDatasets.map((datasetId) => {
|
||||
const dataset = datasets.find((d) => String(d.id) === datasetId);
|
||||
const config = ratioConfigs.find((c) => c.source === datasetId);
|
||||
const currentQuantity = config?.quantity || 0;
|
||||
|
||||
if (!dataset) return null;
|
||||
|
||||
return (
|
||||
<Card key={datasetId} size="small" className="mb-2">
|
||||
<Card key={datasetId} size="small" className="mb-4">
|
||||
<div className="flex items-center justify-between mb-3">
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="font-medium text-sm">
|
||||
@@ -216,97 +268,36 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
|
||||
{config?.percentage || 0}%
|
||||
</div>
|
||||
</div>
|
||||
{ratioType === "dataset" ? (
|
||||
<div>
|
||||
<div className="flex items-center gap-2 mb-2">
|
||||
<span className="text-xs">数量:</span>
|
||||
<Input
|
||||
type="number"
|
||||
value={currentQuantity}
|
||||
onChange={(e) =>
|
||||
updateDatasetQuantity(
|
||||
datasetId,
|
||||
Number(e.target.value)
|
||||
)
|
||||
}
|
||||
style={{ width: 80 }}
|
||||
min={0}
|
||||
max={Math.min(
|
||||
dataset.fileCount || 0,
|
||||
totalTargetCount
|
||||
)}
|
||||
/>
|
||||
<span className="text-xs text-gray-500">条</span>
|
||||
</div>
|
||||
<Progress
|
||||
percent={Math.round(
|
||||
(currentQuantity / totalTargetCount) * 100
|
||||
)}
|
||||
size="small"
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<div>
|
||||
{!distributions[String(dataset.id)] ? (
|
||||
<div className="text-xs text-gray-400">
|
||||
加载标签分布...
|
||||
</div>
|
||||
) : Object.entries(distributions[String(dataset.id)])
|
||||
.length === 0 ? (
|
||||
<div className="text-xs text-gray-400">
|
||||
该数据集暂无标签
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex flex-col gap-2">
|
||||
{Object.entries(
|
||||
distributions[String(dataset.id)]
|
||||
).map(([label, count]) => {
|
||||
const sourceKey = `${datasetId}_${label}`;
|
||||
const labelConfig = ratioConfigs.find(
|
||||
(c) => c.source === sourceKey
|
||||
);
|
||||
const labelQuantity = labelConfig?.quantity || 0;
|
||||
return (
|
||||
<div
|
||||
key={label}
|
||||
className="flex items-center justify-between gap-2"
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<Badge color="gray">{label}</Badge>
|
||||
<span className="text-xs text-gray-500">
|
||||
{count}条
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-xs">数量:</span>
|
||||
<Input
|
||||
type="number"
|
||||
value={labelQuantity}
|
||||
onChange={(e) =>
|
||||
updateLabelQuantity(
|
||||
datasetId,
|
||||
label,
|
||||
Number(e.target.value)
|
||||
)
|
||||
}
|
||||
style={{ width: 80 }}
|
||||
min={0}
|
||||
max={Math.min(
|
||||
Number(count) || 0,
|
||||
totalTargetCount
|
||||
)}
|
||||
/>
|
||||
<span className="text-xs text-gray-500">
|
||||
条
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
|
||||
{/* 筛选条件 */}
|
||||
{renderFilters(datasetId)}
|
||||
|
||||
<div className="flex items-center gap-2 mb-2">
|
||||
<span className="text-xs">数量:</span>
|
||||
<Input
|
||||
type="number"
|
||||
value={currentQuantity}
|
||||
onChange={(e) =>
|
||||
updateDatasetQuantity(
|
||||
datasetId,
|
||||
Number(e.target.value)
|
||||
)
|
||||
}
|
||||
style={{ width: 100 }}
|
||||
min={0}
|
||||
max={Math.min(
|
||||
dataset.fileCount || 0,
|
||||
totalTargetCount
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
/>
|
||||
<span className="text-xs text-gray-500">条</span>
|
||||
</div>
|
||||
<Progress
|
||||
percent={Math.round(
|
||||
(currentQuantity / totalTargetCount) * 100
|
||||
)}
|
||||
size="small"
|
||||
/>
|
||||
</Card>
|
||||
);
|
||||
})}
|
||||
|
||||
@@ -1,169 +0,0 @@
|
||||
import React, { useMemo } from "react";
|
||||
import { Table } from "antd";
|
||||
import { TransferItem } from "antd/es/transfer";
|
||||
import RatioConfig from "./RatioConfig";
|
||||
import useFetchData from "@/hooks/useFetchData";
|
||||
import { queryDatasetsUsingGet } from "@/pages/DataManagement/dataset.api";
|
||||
import {
|
||||
datasetTypeMap,
|
||||
mapDataset,
|
||||
} from "@/pages/DataManagement/dataset.const";
|
||||
import { SearchControls } from "@/components/SearchControls";
|
||||
|
||||
const leftColumns = [
|
||||
{
|
||||
dataIndex: "name",
|
||||
title: "名称",
|
||||
ellipsis: true,
|
||||
},
|
||||
{
|
||||
dataIndex: "datasetType",
|
||||
title: "类型",
|
||||
ellipsis: true,
|
||||
width: 100,
|
||||
render: (type: string) => datasetTypeMap[type].label,
|
||||
},
|
||||
{
|
||||
dataIndex: "size",
|
||||
title: "大小",
|
||||
width: 100,
|
||||
ellipsis: true,
|
||||
},
|
||||
];
|
||||
|
||||
export default function RatioTransfer(props: {
|
||||
distributions: Record<string, Record<string, number>>;
|
||||
ratioTaskForm: any;
|
||||
updateRatioConfig: (datasetId: string, quantity: number) => void;
|
||||
updateLabelRatioConfig: (
|
||||
datasetId: string,
|
||||
label: string,
|
||||
quantity: number
|
||||
) => void;
|
||||
}) {
|
||||
const {
|
||||
updateLabelRatioConfig,
|
||||
updateRatioConfig,
|
||||
ratioTaskForm,
|
||||
distributions,
|
||||
} = props;
|
||||
const {
|
||||
tableData: datasets,
|
||||
loading,
|
||||
pagination,
|
||||
searchParams,
|
||||
setSearchParams,
|
||||
handleFiltersChange,
|
||||
} = useFetchData(queryDatasetsUsingGet, mapDataset);
|
||||
|
||||
const [selectedDatasets, setSelectedDatasets] = React.useState<
|
||||
TransferItem[]
|
||||
>([]);
|
||||
|
||||
const selectedRowKeys = useMemo(() => {
|
||||
return selectedDatasets.map((item) => item.key);
|
||||
}, [selectedDatasets]);
|
||||
|
||||
const [listDisabled, setListDisabled] = React.useState(false);
|
||||
|
||||
const generateAutoRatio = () => {
|
||||
const selectedCount = ratioTaskForm.selectedDatasets.length;
|
||||
if (selectedCount === 0) return;
|
||||
|
||||
const baseQuantity = Math.floor(
|
||||
ratioTaskForm.totalTargetCount / selectedCount
|
||||
);
|
||||
const remainder = ratioTaskForm.totalTargetCount % selectedCount;
|
||||
|
||||
const newConfigs = ratioTaskForm.selectedDatasets.map(
|
||||
(datasetId, index) => {
|
||||
const quantity = baseQuantity + (index < remainder ? 1 : 0);
|
||||
return {
|
||||
id: datasetId,
|
||||
name: datasetId,
|
||||
type: ratioTaskForm.ratioType,
|
||||
quantity,
|
||||
percentage: Math.round(
|
||||
(quantity / ratioTaskForm.totalTargetCount) * 100
|
||||
),
|
||||
source: datasetId,
|
||||
};
|
||||
}
|
||||
);
|
||||
setRatioTaskForm((prev) => ({ ...prev, ratioConfigs: newConfigs }));
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex">
|
||||
<div className="border-card flex-1 mr-4">
|
||||
<h3 className="p-2 border-bottom">{`${selectedDatasets.length} / ${datasets.length} 项`}</h3>
|
||||
<SearchControls
|
||||
searchTerm={searchParams.keyword}
|
||||
onSearchChange={(keyword) =>
|
||||
setSearchParams({ ...searchParams, keyword })
|
||||
}
|
||||
searchPlaceholder="搜索数据集名称..."
|
||||
filters={[
|
||||
{
|
||||
key: "type",
|
||||
label: "数据集类型",
|
||||
options: [
|
||||
{ value: "dataset", label: "按数据集" },
|
||||
{ value: "tag", label: "按标签" },
|
||||
],
|
||||
},
|
||||
]}
|
||||
onFiltersChange={handleFiltersChange}
|
||||
onClearFilters={() =>
|
||||
setSearchParams({ ...searchParams, filter: {} })
|
||||
}
|
||||
showViewToggle={false}
|
||||
showReload={false}
|
||||
className="m-4"
|
||||
/>
|
||||
<Table
|
||||
rowSelection={{
|
||||
onChange: (_, selectedRows) => {
|
||||
setSelectedDatasets(selectedRows);
|
||||
},
|
||||
selectedRowKeys,
|
||||
selections: [
|
||||
Table.SELECTION_ALL,
|
||||
Table.SELECTION_INVERT,
|
||||
Table.SELECTION_NONE,
|
||||
],
|
||||
}}
|
||||
columns={leftColumns}
|
||||
dataSource={datasets}
|
||||
loading={loading}
|
||||
pagination={pagination}
|
||||
size="small"
|
||||
rowKey="id"
|
||||
style={{ pointerEvents: listDisabled ? "none" : undefined }}
|
||||
onRow={(record) => ({
|
||||
onClick: () => {
|
||||
if (record.disabled || listDisabled) {
|
||||
return;
|
||||
}
|
||||
setSelectedDatasets((prev) => {
|
||||
if (prev.includes(record.key)) {
|
||||
return prev.filter((k) => k !== record.key);
|
||||
}
|
||||
return [...prev, record.key];
|
||||
});
|
||||
},
|
||||
})}
|
||||
/>
|
||||
</div>
|
||||
<div className="border-card flex-1">
|
||||
<RatioConfig
|
||||
datasets={selectedDatasets}
|
||||
ratioTaskForm={ratioTaskForm}
|
||||
distributions={distributions}
|
||||
onUpdateDatasetQuantity={updateRatioConfig}
|
||||
onUpdateLabelQuantity={updateLabelRatioConfig}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
import React, { useEffect, useState } from "react";
|
||||
import { Badge, Button, Card, Checkbox, Input, Pagination, Select } from "antd";
|
||||
import { Badge, Button, Card, Checkbox, Input, Pagination } from "antd";
|
||||
import { Search as SearchIcon } from "lucide-react";
|
||||
import type { Dataset } from "@/pages/DataManagement/dataset.model.ts";
|
||||
import {
|
||||
@@ -10,8 +10,6 @@ import {
|
||||
|
||||
interface SelectDatasetProps {
|
||||
selectedDatasets: string[];
|
||||
ratioType: "dataset" | "label";
|
||||
onRatioTypeChange: (val: "dataset" | "label") => void;
|
||||
onSelectedDatasetsChange: (next: string[]) => void;
|
||||
onDistributionsChange?: (
|
||||
next: Record<string, Record<string, number>>
|
||||
@@ -21,8 +19,6 @@ interface SelectDatasetProps {
|
||||
|
||||
const SelectDataset: React.FC<SelectDatasetProps> = ({
|
||||
selectedDatasets,
|
||||
ratioType,
|
||||
onRatioTypeChange,
|
||||
onSelectedDatasetsChange,
|
||||
onDistributionsChange,
|
||||
onDatasetsChange,
|
||||
@@ -62,7 +58,7 @@ const SelectDataset: React.FC<SelectDatasetProps> = ({
|
||||
// Fetch label distributions when in label mode
|
||||
useEffect(() => {
|
||||
const fetchDistributions = async () => {
|
||||
if (ratioType !== "label" || !datasets?.length) return;
|
||||
if (!datasets?.length) return;
|
||||
const idsToFetch = datasets
|
||||
.map((d) => String(d.id))
|
||||
.filter((id) => !distributions[id]);
|
||||
@@ -147,7 +143,7 @@ const SelectDataset: React.FC<SelectDatasetProps> = ({
|
||||
};
|
||||
fetchDistributions();
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [ratioType, datasets]);
|
||||
}, [datasets]);
|
||||
|
||||
const onToggleDataset = (datasetId: string, checked: boolean) => {
|
||||
if (checked) {
|
||||
@@ -180,18 +176,6 @@ const SelectDataset: React.FC<SelectDatasetProps> = ({
|
||||
</Button>
|
||||
</div>
|
||||
<div className="flex-overflow-auto gap-4 p-4">
|
||||
<div className="flex items-center gap-4">
|
||||
<span className="text-sm">配比方式:</span>
|
||||
<Select
|
||||
className="flex-1 min-w-[120px]"
|
||||
value={ratioType}
|
||||
onChange={(v) => onRatioTypeChange(v)}
|
||||
options={[
|
||||
{ label: "按数据集", value: "dataset" },
|
||||
{ label: "按标签", value: "label" },
|
||||
]}
|
||||
/>
|
||||
</div>
|
||||
<Input
|
||||
prefix={<SearchIcon className="text-gray-400" />}
|
||||
placeholder="搜索数据集"
|
||||
@@ -239,32 +223,30 @@ const SelectDataset: React.FC<SelectDatasetProps> = ({
|
||||
<span>{dataset.fileCount}条</span>
|
||||
<span>{dataset.size}</span>
|
||||
</div>
|
||||
{ratioType === "label" && (
|
||||
<div className="mt-2">
|
||||
{distributions[idStr] ? (
|
||||
Object.entries(distributions[idStr]).length > 0 ? (
|
||||
<div className="flex flex-wrap gap-2 text-xs">
|
||||
{Object.entries(distributions[idStr])
|
||||
.slice(0, 8)
|
||||
.map(([tag, count]) => (
|
||||
<Badge
|
||||
key={tag}
|
||||
color="gray"
|
||||
>{`${tag}: ${count}`}</Badge>
|
||||
))}
|
||||
</div>
|
||||
) : (
|
||||
<div className="text-xs text-gray-400">
|
||||
未检测到标签分布
|
||||
</div>
|
||||
)
|
||||
<div className="mt-2">
|
||||
{distributions[idStr] ? (
|
||||
Object.entries(distributions[idStr]).length > 0 ? (
|
||||
<div className="flex flex-wrap gap-2 text-xs">
|
||||
{Object.entries(distributions[idStr])
|
||||
.slice(0, 8)
|
||||
.map(([tag, count]) => (
|
||||
<Badge
|
||||
key={tag}
|
||||
color="gray"
|
||||
>{`${tag}: ${count}`}</Badge>
|
||||
))}
|
||||
</div>
|
||||
) : (
|
||||
<div className="text-xs text-gray-400">
|
||||
加载标签分布...
|
||||
未检测到标签分布
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
)
|
||||
) : (
|
||||
<div className="text-xs text-gray-400">
|
||||
加载标签分布...
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
from .common import (
|
||||
BaseResponseModel,
|
||||
StandardResponse,
|
||||
PaginatedData
|
||||
PaginatedData,
|
||||
TaskStatus
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseResponseModel",
|
||||
"StandardResponse",
|
||||
"PaginatedData"
|
||||
]
|
||||
"PaginatedData",
|
||||
"TaskStatus"
|
||||
]
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""
|
||||
通用响应模型
|
||||
"""
|
||||
from typing import Generic, TypeVar, Optional, List, Type
|
||||
from typing import Generic, TypeVar, List
|
||||
from pydantic import BaseModel, Field
|
||||
from enum import Enum
|
||||
|
||||
# 定义泛型类型变量
|
||||
T = TypeVar('T')
|
||||
@@ -16,7 +17,7 @@ def to_camel(string: str) -> str:
|
||||
|
||||
class BaseResponseModel(BaseModel):
|
||||
"""基础响应模型,启用别名生成器"""
|
||||
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
alias_generator = to_camel
|
||||
@@ -24,7 +25,7 @@ class BaseResponseModel(BaseModel):
|
||||
class StandardResponse(BaseResponseModel, Generic[T]):
|
||||
"""
|
||||
标准API响应格式
|
||||
|
||||
|
||||
所有API端点应返回此格式,确保响应的一致性
|
||||
"""
|
||||
code: int = Field(..., description="HTTP状态码")
|
||||
@@ -42,3 +43,9 @@ class PaginatedData(BaseResponseModel, Generic[T]):
|
||||
total_elements: int = Field(..., description="总条数")
|
||||
total_pages: int = Field(..., description="总页数")
|
||||
content: List[T] = Field(..., description="当前页数据")
|
||||
|
||||
class TaskStatus(Enum):
|
||||
PENDING = "PENDING"
|
||||
RUNNING = "RUNNING"
|
||||
COMPLETED = "COMPLETED"
|
||||
FAILED = "FAILED"
|
||||
|
||||
@@ -12,7 +12,7 @@ from app.core.logging import get_logger
|
||||
from app.db.models import Dataset
|
||||
from app.db.session import get_db
|
||||
from app.module.dataset import DatasetManagementService
|
||||
from app.module.shared.schema import StandardResponse
|
||||
from app.module.shared.schema import StandardResponse, TaskStatus
|
||||
from app.module.synthesis.schema.ratio_task import (
|
||||
CreateRatioTaskResponse,
|
||||
CreateRatioTaskRequest,
|
||||
@@ -49,52 +49,18 @@ async def create_ratio_task(
|
||||
|
||||
await valid_exists(db, req)
|
||||
|
||||
# 创建目标数据集:名称使用“<任务名称>-配比生成-时间戳”
|
||||
target_dataset_name = f"{req.name}-配比生成-{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||
target_dataset = await create_target_dataset(db, req, source_types)
|
||||
|
||||
target_type = get_target_dataset_type(source_types)
|
||||
instance = await create_ratio_instance(db, req, target_dataset)
|
||||
|
||||
target_dataset = Dataset(
|
||||
id=str(uuid.uuid4()),
|
||||
name=target_dataset_name,
|
||||
description=req.description or "",
|
||||
dataset_type=target_type,
|
||||
status="DRAFT",
|
||||
)
|
||||
target_dataset.path = f"/dataset/{target_dataset.id}"
|
||||
db.add(target_dataset)
|
||||
await db.flush() # 获取 target_dataset.id
|
||||
|
||||
service = RatioTaskService(db)
|
||||
instance = await service.create_task(
|
||||
name=req.name,
|
||||
description=req.description,
|
||||
totals=int(req.totals),
|
||||
ratio_method=req.ratio_method,
|
||||
config=[
|
||||
{
|
||||
"dataset_id": item.dataset_id,
|
||||
"counts": int(item.counts),
|
||||
"filter_conditions": item.filter_conditions,
|
||||
}
|
||||
for item in req.config
|
||||
],
|
||||
target_dataset_id=target_dataset.id,
|
||||
)
|
||||
|
||||
# 异步执行配比任务(支持 DATASET / TAG)
|
||||
asyncio.create_task(RatioTaskService.execute_dataset_ratio_task(instance.id))
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=CreateRatioTaskResponse(
|
||||
response_data = CreateRatioTaskResponse(
|
||||
id=instance.id,
|
||||
name=instance.name,
|
||||
description=instance.description,
|
||||
totals=instance.totals or 0,
|
||||
ratio_method=instance.ratio_method or req.ratio_method,
|
||||
status=instance.status or "PENDING",
|
||||
status=instance.status or TaskStatus.PENDING.name,
|
||||
config=req.config,
|
||||
targetDataset=TargetDatasetInfo(
|
||||
id=str(target_dataset.id),
|
||||
@@ -103,6 +69,10 @@ async def create_ratio_task(
|
||||
status=str(target_dataset.status),
|
||||
)
|
||||
)
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=response_data
|
||||
)
|
||||
except HTTPException:
|
||||
await db.rollback()
|
||||
@@ -113,6 +83,46 @@ async def create_ratio_task(
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
async def create_ratio_instance(db, req: CreateRatioTaskRequest, target_dataset: Dataset) -> RatioInstance:
|
||||
service = RatioTaskService(db)
|
||||
logger.info(f"create_ratio_instance: {req}")
|
||||
instance = await service.create_task(
|
||||
name=req.name,
|
||||
description=req.description,
|
||||
totals=int(req.totals),
|
||||
config=[
|
||||
{
|
||||
"dataset_id": item.dataset_id,
|
||||
"counts": int(item.counts),
|
||||
"filter_conditions": item.filter_conditions,
|
||||
}
|
||||
for item in req.config
|
||||
],
|
||||
target_dataset_id=target_dataset.id,
|
||||
)
|
||||
return instance
|
||||
|
||||
|
||||
async def create_target_dataset(db, req: CreateRatioTaskRequest, source_types: set[str]) -> Dataset:
|
||||
# 创建目标数据集:名称使用“<任务名称>-时间戳”
|
||||
target_dataset_name = f"{req.name}-{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||
|
||||
target_type = get_target_dataset_type(source_types)
|
||||
target_dataset_id = uuid.uuid4()
|
||||
|
||||
target_dataset = Dataset(
|
||||
id=str(target_dataset_id),
|
||||
name=target_dataset_name,
|
||||
description=req.description or "",
|
||||
dataset_type=target_type,
|
||||
status="DRAFT",
|
||||
path=f"/dataset/{target_dataset_id}",
|
||||
)
|
||||
db.add(target_dataset)
|
||||
await db.flush() # 获取 target_dataset.id
|
||||
return target_dataset
|
||||
|
||||
|
||||
@router.get("", response_model=StandardResponse[PagedRatioTaskResponse], status_code=200)
|
||||
async def list_ratio_tasks(
|
||||
page: int = 1,
|
||||
|
||||
@@ -2,10 +2,36 @@ from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.module.shared.schema.common import TaskStatus
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class FilterCondition(BaseModel):
|
||||
date_range: Optional[str] = Field(None, description="数据范围", alias="dateRange")
|
||||
label: Optional[str] = Field(None, description="标签")
|
||||
|
||||
@field_validator("date_range")
|
||||
@classmethod
|
||||
def validate_date_range(cls, v: Optional[str]) -> Optional[str]:
|
||||
# ensure it's a numeric string if provided
|
||||
if not v:
|
||||
return v
|
||||
try:
|
||||
int(v)
|
||||
return v
|
||||
except (ValueError, TypeError) as e:
|
||||
raise ValueError("date_range must be a numeric string")
|
||||
|
||||
class Config:
|
||||
# allow population by field name when constructing model programmatically
|
||||
validate_by_name = True
|
||||
|
||||
|
||||
class RatioConfigItem(BaseModel):
|
||||
dataset_id: str = Field(..., alias="datasetId", description="数据集id")
|
||||
counts: str = Field(..., description="数量")
|
||||
filter_conditions: str = Field(..., description="过滤条件")
|
||||
filter_conditions: FilterCondition = Field(..., alias="filterConditions", description="过滤条件")
|
||||
|
||||
@field_validator("counts")
|
||||
@classmethod
|
||||
@@ -22,17 +48,8 @@ class CreateRatioTaskRequest(BaseModel):
|
||||
name: str = Field(..., description="名称")
|
||||
description: Optional[str] = Field(None, description="描述")
|
||||
totals: str = Field(..., description="目标数量")
|
||||
ratio_method: str = Field(..., description="配比方式", alias="ratio_method")
|
||||
config: List[RatioConfigItem] = Field(..., description="配比设置列表")
|
||||
|
||||
@field_validator("ratio_method")
|
||||
@classmethod
|
||||
def validate_ratio_method(cls, v: str) -> str:
|
||||
allowed = {"TAG", "DATASET"}
|
||||
if v not in allowed:
|
||||
raise ValueError(f"ratio_method must be one of {allowed}")
|
||||
return v
|
||||
|
||||
@field_validator("totals")
|
||||
@classmethod
|
||||
def validate_totals(cls, v: str) -> str:
|
||||
@@ -58,8 +75,7 @@ class CreateRatioTaskResponse(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
totals: int
|
||||
ratio_method: str
|
||||
status: str
|
||||
status: TaskStatus
|
||||
# echoed config
|
||||
config: List[RatioConfigItem]
|
||||
# created dataset
|
||||
|
||||
@@ -14,6 +14,8 @@ from app.db.models.ratio_task import RatioInstance, RatioRelation
|
||||
from app.db.models import Dataset, DatasetFiles
|
||||
from app.db.session import AsyncSessionLocal
|
||||
from app.module.dataset.schema.dataset_file import DatasetFileTag
|
||||
from app.module.shared.schema import TaskStatus
|
||||
from app.module.synthesis.schema.ratio_task import FilterCondition
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -30,7 +32,6 @@ class RatioTaskService:
|
||||
name: str,
|
||||
description: Optional[str],
|
||||
totals: int,
|
||||
ratio_method: str,
|
||||
config: List[Dict[str, Any]],
|
||||
target_dataset_id: Optional[str] = None,
|
||||
) -> RatioInstance:
|
||||
@@ -38,12 +39,11 @@ class RatioTaskService:
|
||||
|
||||
config item format: {"dataset_id": str, "counts": int, "filter_conditions": str}
|
||||
"""
|
||||
logger.info(f"Creating ratio task: name={name}, method={ratio_method}, totals={totals}, items={len(config or [])}")
|
||||
logger.info(f"Creating ratio task: name={name}, totals={totals}, items={len(config or [])}")
|
||||
|
||||
instance = RatioInstance(
|
||||
name=name,
|
||||
description=description,
|
||||
ratio_method=ratio_method,
|
||||
totals=totals,
|
||||
target_dataset_id=target_dataset_id,
|
||||
status="PENDING",
|
||||
@@ -56,8 +56,12 @@ class RatioTaskService:
|
||||
ratio_instance_id=instance.id,
|
||||
source_dataset_id=item.get("dataset_id"),
|
||||
counts=int(item.get("counts", 0)),
|
||||
filter_conditions=item.get("filter_conditions"),
|
||||
filter_conditions=json.dumps({
|
||||
'date_range': item.get("filter_conditions").date_range,
|
||||
'label': item.get("filter_conditions").label,
|
||||
})
|
||||
)
|
||||
logger.info(f"Relation created: {relation.id}, {relation}, {item}, {config}")
|
||||
self.db.add(relation)
|
||||
|
||||
await self.db.commit()
|
||||
@@ -97,94 +101,17 @@ class RatioTaskService:
|
||||
relations: List[RatioRelation] = list(rel_res.scalars().all())
|
||||
|
||||
# Mark running
|
||||
instance.status = "RUNNING"
|
||||
|
||||
if instance.ratio_method not in {"DATASET", "TAG"}:
|
||||
logger.info(f"Instance {instance_id} ratio_method={instance.ratio_method} not supported yet")
|
||||
instance.status = "SUCCESS"
|
||||
return
|
||||
instance.status = TaskStatus.RUNNING.name
|
||||
|
||||
# Load target dataset
|
||||
ds_res = await session.execute(select(Dataset).where(Dataset.id == instance.target_dataset_id))
|
||||
target_ds: Optional[Dataset] = ds_res.scalar_one_or_none()
|
||||
if not target_ds:
|
||||
logger.error(f"Target dataset not found for instance {instance_id}")
|
||||
instance.status = "FAILED"
|
||||
instance.status = TaskStatus.FAILED.name
|
||||
return
|
||||
|
||||
# Preload existing target file paths for deduplication
|
||||
existing_path_rows = await session.execute(
|
||||
select(DatasetFiles.file_path).where(DatasetFiles.dataset_id == target_ds.id)
|
||||
)
|
||||
existing_paths = set(p for p in existing_path_rows.scalars().all() if p)
|
||||
|
||||
added_count = 0
|
||||
added_size = 0
|
||||
|
||||
for rel in relations:
|
||||
if not rel.source_dataset_id or not rel.counts or rel.counts <= 0:
|
||||
continue
|
||||
|
||||
# Fetch all files for the source dataset (ACTIVE only)
|
||||
files_res = await session.execute(
|
||||
select(DatasetFiles).where(
|
||||
DatasetFiles.dataset_id == rel.source_dataset_id,
|
||||
DatasetFiles.status == "ACTIVE",
|
||||
)
|
||||
)
|
||||
files = list(files_res.scalars().all())
|
||||
|
||||
# TAG mode: filter by tags according to relation.filter_conditions
|
||||
if instance.ratio_method == "TAG":
|
||||
required_tags = RatioTaskService._parse_required_tags(rel.filter_conditions)
|
||||
if required_tags:
|
||||
files = [f for f in files if RatioTaskService._file_contains_tags(f, required_tags)]
|
||||
|
||||
if not files:
|
||||
continue
|
||||
|
||||
pick_n = min(rel.counts or 0, len(files))
|
||||
chosen = random.sample(files, pick_n) if pick_n < len(files) else files
|
||||
|
||||
# Copy into target dataset with de-dup by target path
|
||||
for f in chosen:
|
||||
src_path = f.file_path
|
||||
new_path = src_path
|
||||
needs_copy = False
|
||||
src_prefix = f"/dataset/{rel.source_dataset_id}"
|
||||
if isinstance(src_path, str) and src_path.startswith(src_prefix):
|
||||
dst_prefix = f"/dataset/{target_ds.id}"
|
||||
new_path = src_path.replace(src_prefix, dst_prefix, 1)
|
||||
needs_copy = True
|
||||
|
||||
# De-dup by target path
|
||||
if new_path in existing_paths:
|
||||
continue
|
||||
|
||||
# Perform copy only when needed
|
||||
if needs_copy:
|
||||
dst_dir = os.path.dirname(new_path)
|
||||
await asyncio.to_thread(os.makedirs, dst_dir, exist_ok=True)
|
||||
await asyncio.to_thread(shutil.copy2, src_path, new_path)
|
||||
|
||||
new_file = DatasetFiles(
|
||||
dataset_id=target_ds.id, # type: ignore
|
||||
file_name=f.file_name,
|
||||
file_path=new_path,
|
||||
file_type=f.file_type,
|
||||
file_size=f.file_size,
|
||||
check_sum=f.check_sum,
|
||||
tags=f.tags,
|
||||
dataset_filemetadata=f.dataset_filemetadata,
|
||||
status="ACTIVE",
|
||||
)
|
||||
session.add(new_file)
|
||||
existing_paths.add(new_path)
|
||||
added_count += 1
|
||||
added_size += int(f.file_size or 0)
|
||||
|
||||
# Periodically flush to avoid huge transactions
|
||||
await session.flush()
|
||||
added_count, added_size = await RatioTaskService.handle_ratio_relations(relations,session, target_ds)
|
||||
|
||||
# Update target dataset statistics
|
||||
target_ds.file_count = (target_ds.file_count or 0) + added_count # type: ignore
|
||||
@@ -194,8 +121,8 @@ class RatioTaskService:
|
||||
target_ds.status = "ACTIVE"
|
||||
|
||||
# Done
|
||||
instance.status = "SUCCESS"
|
||||
logger.info(f"Dataset ratio execution completed: instance={instance_id}, files={added_count}, size={added_size}")
|
||||
instance.status = TaskStatus.COMPLETED.name
|
||||
logger.info(f"Dataset ratio execution completed: instance={instance_id}, files={added_count}, size={added_size}, {instance.status}")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Dataset ratio execution failed for {instance_id}: {e}")
|
||||
@@ -204,42 +131,163 @@ class RatioTaskService:
|
||||
inst_res = await session.execute(select(RatioInstance).where(RatioInstance.id == instance_id))
|
||||
instance = inst_res.scalar_one_or_none()
|
||||
if instance:
|
||||
instance.status = "FAILED"
|
||||
instance.status = TaskStatus.FAILED.name
|
||||
finally:
|
||||
pass
|
||||
finally:
|
||||
await session.commit()
|
||||
|
||||
@staticmethod
|
||||
async def handle_ratio_relations(relations: list[RatioRelation], session, target_ds: Dataset) -> tuple[int, int]:
|
||||
# Preload existing target file paths for deduplication
|
||||
existing_path_rows = await session.execute(
|
||||
select(DatasetFiles.file_path).where(DatasetFiles.dataset_id == target_ds.id)
|
||||
)
|
||||
existing_paths = set(p for p in existing_path_rows.scalars().all() if p)
|
||||
|
||||
added_count = 0
|
||||
added_size = 0
|
||||
|
||||
for rel in relations:
|
||||
if not rel.source_dataset_id or not rel.counts or rel.counts <= 0:
|
||||
continue
|
||||
|
||||
files = await RatioTaskService.get_files(rel, session)
|
||||
|
||||
if not files:
|
||||
continue
|
||||
|
||||
pick_n = min(rel.counts or 0, len(files))
|
||||
chosen = random.sample(files, pick_n) if pick_n < len(files) else files
|
||||
|
||||
# Copy into target dataset with de-dup by target path
|
||||
for f in chosen:
|
||||
await RatioTaskService.handle_selected_file(existing_paths, f, session, target_ds)
|
||||
added_count += 1
|
||||
added_size += int(f.file_size or 0)
|
||||
|
||||
# Periodically flush to avoid huge transactions
|
||||
await session.flush()
|
||||
return added_count, added_size
|
||||
|
||||
@staticmethod
|
||||
async def handle_selected_file(existing_paths: set[Any], f, session, target_ds: Dataset):
|
||||
src_path = f.file_path
|
||||
dst_prefix = f"/dataset/{target_ds.id}"
|
||||
file_name = RatioTaskService.get_new_file_name(dst_prefix, existing_paths, f)
|
||||
|
||||
new_path = dst_prefix + file_name
|
||||
dst_dir = os.path.dirname(new_path)
|
||||
await asyncio.to_thread(os.makedirs, dst_dir, exist_ok=True)
|
||||
await asyncio.to_thread(shutil.copy2, src_path, new_path)
|
||||
|
||||
new_file = DatasetFiles(
|
||||
dataset_id=target_ds.id, # type: ignore
|
||||
file_name=file_name,
|
||||
file_path=new_path,
|
||||
file_type=f.file_type,
|
||||
file_size=f.file_size,
|
||||
check_sum=f.check_sum,
|
||||
tags=f.tags,
|
||||
dataset_filemetadata=f.dataset_filemetadata,
|
||||
status="ACTIVE",
|
||||
)
|
||||
session.add(new_file)
|
||||
existing_paths.add(new_path)
|
||||
|
||||
@staticmethod
|
||||
def get_new_file_name(dst_prefix: str, existing_paths: set[Any], f) -> str:
|
||||
file_name = f.file_name
|
||||
new_path = dst_prefix + file_name
|
||||
|
||||
# Handle file path conflicts by appending a number to the filename
|
||||
if new_path in existing_paths:
|
||||
file_name_base, file_ext = os.path.splitext(file_name)
|
||||
counter = 1
|
||||
original_file_name = file_name
|
||||
while new_path in existing_paths:
|
||||
file_name = f"{file_name_base}_{counter}{file_ext}"
|
||||
new_path = f"{dst_prefix}{file_name}"
|
||||
counter += 1
|
||||
if counter > 1000: # Safety check to prevent infinite loops
|
||||
logger.error(f"Could not find unique filename for {original_file_name} after 1000 attempts")
|
||||
break
|
||||
return file_name
|
||||
|
||||
@staticmethod
|
||||
async def get_files(rel: RatioRelation, session) -> list[Any]:
|
||||
# Fetch all files for the source dataset (ACTIVE only)
|
||||
files_res = await session.execute(
|
||||
select(DatasetFiles).where(
|
||||
DatasetFiles.dataset_id == rel.source_dataset_id,
|
||||
DatasetFiles.status == "ACTIVE",
|
||||
)
|
||||
)
|
||||
files = list(files_res.scalars().all())
|
||||
|
||||
# TAG mode: filter by tags according to relation.filter_conditions
|
||||
conditions = RatioTaskService._parse_conditions(rel.filter_conditions)
|
||||
if conditions:
|
||||
files = [f for f in files if RatioTaskService._filter_file(f, conditions)]
|
||||
return files
|
||||
|
||||
# ------------------------- helpers for TAG filtering ------------------------- #
|
||||
|
||||
@staticmethod
|
||||
def _parse_required_tags(conditions: Optional[str]) -> set[str]:
|
||||
"""Parse filter_conditions into a set of required tag strings.
|
||||
def _parse_conditions(conditions: Optional[str]) -> Optional[FilterCondition]:
|
||||
"""Parse filter_conditions JSON string into a FilterCondition object.
|
||||
|
||||
Supports simple separators: comma, semicolon, space. Empty/None -> empty set.
|
||||
Args:
|
||||
conditions: JSON string containing filter conditions
|
||||
|
||||
Returns:
|
||||
FilterCondition object if conditions is not None/empty, otherwise None
|
||||
"""
|
||||
if not conditions:
|
||||
return set()
|
||||
data = json.loads(conditions)
|
||||
required_tags = set()
|
||||
if data.get("label"):
|
||||
required_tags.add(data["label"])
|
||||
return required_tags
|
||||
return None
|
||||
try:
|
||||
data = json.loads(conditions)
|
||||
return FilterCondition(**data)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse filter conditions: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating FilterCondition: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _file_contains_tags(file: DatasetFiles, required: set[str]) -> bool:
|
||||
if not required:
|
||||
def _filter_file(file: DatasetFiles, conditions: FilterCondition) -> bool:
|
||||
if not conditions:
|
||||
return True
|
||||
tags = file.tags
|
||||
if not tags:
|
||||
return False
|
||||
try:
|
||||
# tags could be a list of strings or list of objects with 'name'
|
||||
tag_names = RatioTaskService.get_all_tags(tags)
|
||||
return required.issubset(tag_names)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to get tags for {file}", e)
|
||||
return False
|
||||
logger.info(f"start filter file: {file}, conditions: {conditions}")
|
||||
|
||||
# Check data range condition if provided
|
||||
if conditions.date_range:
|
||||
try:
|
||||
from datetime import datetime, timedelta
|
||||
data_range_days = int(conditions.date_range)
|
||||
if data_range_days > 0:
|
||||
cutoff_date = datetime.now() - timedelta(days=data_range_days)
|
||||
if file.tags_updated_at and file.tags_updated_at < cutoff_date:
|
||||
return False
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning(f"Invalid data_range value: {conditions.date_range}", e)
|
||||
return False
|
||||
|
||||
# Check label condition if provided
|
||||
if conditions.label:
|
||||
tags = file.tags
|
||||
if not tags:
|
||||
return False
|
||||
try:
|
||||
# tags could be a list of strings or list of objects with 'name'
|
||||
tag_names = RatioTaskService.get_all_tags(tags)
|
||||
return conditions.label in tag_names
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to get tags for {file}", e)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_all_tags(tags) -> set[str]:
|
||||
|
||||
Reference in New Issue
Block a user