You've already forked DataMate
feature: 数据配比增加通过更新时间来配置 (#95)
* feature: 数据配比增加通过更新时间来配置 * fix: 修复配比时间参数传递的问题
This commit is contained in:
@@ -7,7 +7,6 @@ import { useNavigate } from "react-router";
|
|||||||
import SelectDataset from "@/pages/RatioTask/Create/components/SelectDataset.tsx";
|
import SelectDataset from "@/pages/RatioTask/Create/components/SelectDataset.tsx";
|
||||||
import BasicInformation from "@/pages/RatioTask/Create/components/BasicInformation.tsx";
|
import BasicInformation from "@/pages/RatioTask/Create/components/BasicInformation.tsx";
|
||||||
import RatioConfig from "@/pages/RatioTask/Create/components/RatioConfig.tsx";
|
import RatioConfig from "@/pages/RatioTask/Create/components/RatioConfig.tsx";
|
||||||
import RatioTransfer from "./components/RatioTransfer";
|
|
||||||
|
|
||||||
export default function CreateRatioTask() {
|
export default function CreateRatioTask() {
|
||||||
const navigate = useNavigate();
|
const navigate = useNavigate();
|
||||||
@@ -36,27 +35,12 @@ export default function CreateRatioTask() {
|
|||||||
message.error("请配置配比项");
|
message.error("请配置配比项");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// Build request payload
|
|
||||||
const ratio_method =
|
|
||||||
ratioTaskForm.ratioType === "dataset" ? "DATASET" : "TAG";
|
|
||||||
const totals = String(values.totalTargetCount);
|
const totals = String(values.totalTargetCount);
|
||||||
const config = ratioTaskForm.ratioConfigs.map((c) => {
|
const config = ratioTaskForm.ratioConfigs.map((c) => {
|
||||||
if (ratio_method === "DATASET") {
|
|
||||||
return {
|
|
||||||
datasetId: String(c.source),
|
|
||||||
counts: String(c.quantity ?? 0),
|
|
||||||
filter_conditions: "",
|
|
||||||
};
|
|
||||||
}
|
|
||||||
// TAG mode: source key like `${datasetId}_${label}`
|
|
||||||
const source = String(c.source || "");
|
|
||||||
const idx = source.indexOf("_");
|
|
||||||
const datasetId = idx > 0 ? source.slice(0, idx) : source;
|
|
||||||
const label = idx > 0 ? source.slice(idx + 1) : "";
|
|
||||||
return {
|
return {
|
||||||
datasetId,
|
datasetId: c.id,
|
||||||
counts: String(c.quantity ?? 0),
|
counts: String(c.quantity ?? 0),
|
||||||
filter_conditions: label ? JSON.stringify({ label }) : "",
|
filterConditions: { label: c.labelFilter, dateRange: String(c.dateRange ?? 0)},
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -65,7 +49,6 @@ export default function CreateRatioTask() {
|
|||||||
name: values.name,
|
name: values.name,
|
||||||
description: values.description,
|
description: values.description,
|
||||||
totals,
|
totals,
|
||||||
ratio_method,
|
|
||||||
config,
|
config,
|
||||||
});
|
});
|
||||||
message.success("配比任务创建成功");
|
message.success("配比任务创建成功");
|
||||||
@@ -108,13 +91,6 @@ export default function CreateRatioTask() {
|
|||||||
totalTargetCount={ratioTaskForm.totalTargetCount}
|
totalTargetCount={ratioTaskForm.totalTargetCount}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
{/* <RatioTransfer
|
|
||||||
ratioTaskForm={ratioTaskForm}
|
|
||||||
distributions={distributions}
|
|
||||||
updateRatioConfig={updateRatioConfig}
|
|
||||||
updateLabelRatioConfig={updateLabelRatioConfig}
|
|
||||||
/> */}
|
|
||||||
|
|
||||||
<div className="flex h-full">
|
<div className="flex h-full">
|
||||||
<SelectDataset
|
<SelectDataset
|
||||||
selectedDatasets={ratioTaskForm.selectedDatasets}
|
selectedDatasets={ratioTaskForm.selectedDatasets}
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ const BasicInformation: React.FC<BasicInformationProps> = ({
|
|||||||
<Input type="number" placeholder="目标总数量" min={1} />
|
<Input type="number" placeholder="目标总数量" min={1} />
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
<Form.Item label="任务描述" name="description" className="col-span-2">
|
<Form.Item label="任务描述" name="description" className="col-span-2">
|
||||||
<TextArea placeholder="描述配比任务的目的和要求(可选)" rows={2} />
|
<TextArea placeholder="描述配比任务的目的和要求" rows={2} />
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -1,7 +1,19 @@
|
|||||||
import React, { useMemo, useState } from "react";
|
import React, { useMemo, useState } from "react";
|
||||||
import { Badge, Card, Input, Progress, Button, Divider } from "antd";
|
import { Badge, Card, Input, Progress, Button, DatePicker, Select } from "antd";
|
||||||
import { BarChart3 } from "lucide-react";
|
import { BarChart3, Filter, Clock } from "lucide-react";
|
||||||
import type { Dataset } from "@/pages/DataManagement/dataset.model.ts";
|
import type { Dataset } from "@/pages/DataManagement/dataset.model.ts";
|
||||||
|
import dayjs from 'dayjs';
|
||||||
|
|
||||||
|
const { RangePicker } = DatePicker;
|
||||||
|
const { Option } = Select;
|
||||||
|
|
||||||
|
const TIME_RANGE_OPTIONS = [
|
||||||
|
{ label: '最近1天', value: 1 },
|
||||||
|
{ label: '最近3天', value: 3 },
|
||||||
|
{ label: '最近7天', value: 7 },
|
||||||
|
{ label: '最近15天', value: 15 },
|
||||||
|
{ label: '最近30天', value: 30 },
|
||||||
|
];
|
||||||
|
|
||||||
interface RatioConfigItem {
|
interface RatioConfigItem {
|
||||||
id: string;
|
id: string;
|
||||||
@@ -10,6 +22,8 @@ interface RatioConfigItem {
|
|||||||
quantity: number;
|
quantity: number;
|
||||||
percentage: number;
|
percentage: number;
|
||||||
source: string;
|
source: string;
|
||||||
|
labelFilter?: string;
|
||||||
|
dateRange?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface RatioConfigProps {
|
interface RatioConfigProps {
|
||||||
@@ -22,14 +36,18 @@ interface RatioConfigProps {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const RatioConfig: React.FC<RatioConfigProps> = ({
|
const RatioConfig: React.FC<RatioConfigProps> = ({
|
||||||
ratioType,
|
ratioType,
|
||||||
selectedDatasets,
|
selectedDatasets,
|
||||||
datasets,
|
datasets,
|
||||||
totalTargetCount,
|
totalTargetCount,
|
||||||
distributions,
|
distributions,
|
||||||
onChange,
|
onChange,
|
||||||
}) => {
|
}) => {
|
||||||
const [ratioConfigs, setRatioConfigs] = useState<RatioConfigItem[]>([]);
|
const [ratioConfigs, setRatioConfigs] = useState<RatioConfigItem[]>([]);
|
||||||
|
const [datasetFilters, setDatasetFilters] = useState<Record<string, {
|
||||||
|
labelFilter?: string;
|
||||||
|
dateRange?: string;
|
||||||
|
}>>({});
|
||||||
|
|
||||||
// 配比项总数
|
// 配比项总数
|
||||||
const totalConfigured = useMemo(
|
const totalConfigured = useMemo(
|
||||||
@@ -37,6 +55,36 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
|
|||||||
[ratioConfigs]
|
[ratioConfigs]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// 获取数据集的标签列表
|
||||||
|
const getDatasetLabels = (datasetId: string): string[] => {
|
||||||
|
const dist = distributions[String(datasetId)] || {};
|
||||||
|
return Object.keys(dist);
|
||||||
|
};
|
||||||
|
|
||||||
|
// 自动平均分配
|
||||||
|
const generateAutoRatio = () => {
|
||||||
|
const selectedCount = selectedDatasets.length;
|
||||||
|
if (selectedCount === 0) return;
|
||||||
|
const baseQuantity = Math.floor(totalTargetCount / selectedCount);
|
||||||
|
const remainder = totalTargetCount % selectedCount;
|
||||||
|
const newConfigs = selectedDatasets.map((datasetId, index) => {
|
||||||
|
const dataset = datasets.find((d) => String(d.id) === datasetId);
|
||||||
|
const quantity = baseQuantity + (index < remainder ? 1 : 0);
|
||||||
|
return {
|
||||||
|
id: datasetId,
|
||||||
|
name: dataset?.name || datasetId,
|
||||||
|
type: ratioType,
|
||||||
|
quantity,
|
||||||
|
percentage: Math.round((quantity / totalTargetCount) * 100),
|
||||||
|
source: datasetId,
|
||||||
|
labelFilter: datasetFilters[datasetId]?.labelFilter,
|
||||||
|
dateRange: datasetFilters[datasetId]?.dateRange,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
setRatioConfigs(newConfigs);
|
||||||
|
onChange?.(newConfigs);
|
||||||
|
};
|
||||||
|
|
||||||
// 更新数据集配比项
|
// 更新数据集配比项
|
||||||
const updateDatasetQuantity = (datasetId: string, quantity: number) => {
|
const updateDatasetQuantity = (datasetId: string, quantity: number) => {
|
||||||
setRatioConfigs((prev) => {
|
setRatioConfigs((prev) => {
|
||||||
@@ -55,6 +103,8 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
|
|||||||
quantity: Math.min(quantity, totalTargetCount - totalOtherQuantity),
|
quantity: Math.min(quantity, totalTargetCount - totalOtherQuantity),
|
||||||
percentage: Math.round((quantity / totalTargetCount) * 100),
|
percentage: Math.round((quantity / totalTargetCount) * 100),
|
||||||
source: datasetId,
|
source: datasetId,
|
||||||
|
labelFilter: datasetFilters[datasetId]?.labelFilter,
|
||||||
|
dateRange: datasetFilters[datasetId]?.dateRange,
|
||||||
};
|
};
|
||||||
|
|
||||||
let newConfigs;
|
let newConfigs;
|
||||||
@@ -69,78 +119,85 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
|
|||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
// 自动平均分配
|
// 更新筛选条件
|
||||||
const generateAutoRatio = () => {
|
const updateFilters = (datasetId: string, updates: {
|
||||||
const selectedCount = selectedDatasets.length;
|
labelFilter?: string;
|
||||||
if (selectedCount === 0) return;
|
dateRange?: [string, string];
|
||||||
const baseQuantity = Math.floor(totalTargetCount / selectedCount);
|
}) => {
|
||||||
const remainder = totalTargetCount % selectedCount;
|
setDatasetFilters(prev => ({
|
||||||
const newConfigs = selectedDatasets.map((datasetId, index) => {
|
...prev,
|
||||||
const dataset = datasets.find((d) => String(d.id) === datasetId);
|
[datasetId]: {
|
||||||
const quantity = baseQuantity + (index < remainder ? 1 : 0);
|
...prev[datasetId],
|
||||||
return {
|
...updates,
|
||||||
id: datasetId,
|
|
||||||
name: dataset?.name || datasetId,
|
|
||||||
type: ratioType,
|
|
||||||
quantity,
|
|
||||||
percentage: Math.round((quantity / totalTargetCount) * 100),
|
|
||||||
source: datasetId,
|
|
||||||
};
|
|
||||||
});
|
|
||||||
setRatioConfigs(newConfigs);
|
|
||||||
onChange?.(newConfigs);
|
|
||||||
};
|
|
||||||
|
|
||||||
// 标签模式下,更新某数据集的某个标签的数量
|
|
||||||
const updateLabelQuantity = (
|
|
||||||
datasetId: string,
|
|
||||||
label: string,
|
|
||||||
quantity: number
|
|
||||||
) => {
|
|
||||||
const sourceKey = `${datasetId}_${label}`;
|
|
||||||
setRatioConfigs((prev) => {
|
|
||||||
const existingIndex = prev.findIndex((c) => c.source === sourceKey);
|
|
||||||
const totalOtherQuantity = prev
|
|
||||||
.filter((c) => c.source !== sourceKey)
|
|
||||||
.reduce((sum, c) => sum + c.quantity, 0);
|
|
||||||
const dist = distributions[datasetId] || {};
|
|
||||||
const labelMax = dist[label] ?? Infinity;
|
|
||||||
const cappedQuantity = Math.max(
|
|
||||||
0,
|
|
||||||
Math.min(quantity, totalTargetCount - totalOtherQuantity, labelMax)
|
|
||||||
);
|
|
||||||
const newConfig: RatioConfigItem = {
|
|
||||||
id: sourceKey,
|
|
||||||
name: label,
|
|
||||||
type: "label",
|
|
||||||
quantity: cappedQuantity,
|
|
||||||
percentage: Math.round((cappedQuantity / totalTargetCount) * 100),
|
|
||||||
source: sourceKey,
|
|
||||||
};
|
|
||||||
let newConfigs;
|
|
||||||
if (existingIndex >= 0) {
|
|
||||||
newConfigs = [...prev];
|
|
||||||
newConfigs[existingIndex] = newConfig;
|
|
||||||
} else {
|
|
||||||
newConfigs = [...prev, newConfig];
|
|
||||||
}
|
}
|
||||||
onChange?.(newConfigs);
|
}));
|
||||||
return newConfigs;
|
|
||||||
});
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// 选中数据集变化时,移除未选中的配比项
|
// 渲染筛选器
|
||||||
|
const renderFilters = (datasetId: string) => {
|
||||||
|
const labels = getDatasetLabels(datasetId);
|
||||||
|
const config = ratioConfigs.find(c => c.source === datasetId);
|
||||||
|
const filters = datasetFilters[datasetId] || {};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="mb-3 p-2 bg-gray-50 rounded">
|
||||||
|
<div className="flex items-center gap-2 mb-2">
|
||||||
|
<Filter size={14} className="text-gray-400" />
|
||||||
|
<span className="text-xs font-medium">筛选条件</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="grid grid-cols-1 md:grid-cols-2 gap-3">
|
||||||
|
<div>
|
||||||
|
<div className="text-xs text-gray-500 mb-1">标签筛选</div>
|
||||||
|
<Select
|
||||||
|
style={{ width: '100%' }}
|
||||||
|
placeholder="选择标签"
|
||||||
|
value={filters.labelFilter}
|
||||||
|
onChange={(value) => updateFilters(datasetId, { labelFilter: value })}
|
||||||
|
allowClear
|
||||||
|
onClear={() => updateFilters(datasetId, { labelFilter: undefined })}
|
||||||
|
>
|
||||||
|
{labels.map(label => (
|
||||||
|
<Option key={label} value={label}>{label}</Option>
|
||||||
|
))}
|
||||||
|
</Select>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<div className="text-xs text-gray-500 mb-1">标签更新时间</div>
|
||||||
|
<Select
|
||||||
|
style={{ width: '100%' }}
|
||||||
|
placeholder="选择标签更新时间"
|
||||||
|
value={filters.dateRange}
|
||||||
|
onChange={(dates) => updateFilters(datasetId, { dateRange: dates })}
|
||||||
|
allowClear
|
||||||
|
onClear={() => updateFilters(datasetId, { dateRange: undefined })}
|
||||||
|
>
|
||||||
|
{TIME_RANGE_OPTIONS.map(option => (
|
||||||
|
<Option key={option.value} value={option.value}>
|
||||||
|
{option.label}
|
||||||
|
</Option>
|
||||||
|
))}
|
||||||
|
</Select>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
// 选中数据集变化时,初始化筛选条件
|
||||||
React.useEffect(() => {
|
React.useEffect(() => {
|
||||||
setRatioConfigs((prev) => {
|
const initialFilters: Record<string, any> = {};
|
||||||
const next = prev.filter((c) => {
|
selectedDatasets.forEach(datasetId => {
|
||||||
const id = String(c.source);
|
const config = ratioConfigs.find(c => c.source === datasetId);
|
||||||
const dsId = id.includes("_") ? id.split("_")[0] : id;
|
if (config) {
|
||||||
return selectedDatasets.includes(dsId);
|
initialFilters[datasetId] = {
|
||||||
});
|
labelFilter: config.labelFilter,
|
||||||
if (next !== prev) onChange?.(next);
|
dateRange: config.dateRange,
|
||||||
return next;
|
};
|
||||||
|
}
|
||||||
});
|
});
|
||||||
// eslint-disable-next-line
|
setDatasetFilters(prev => ({ ...prev, ...initialFilters }));
|
||||||
}, [selectedDatasets]);
|
}, [selectedDatasets]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@@ -148,7 +205,7 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
|
|||||||
<div className="flex items-center justify-between p-4 border-bottom">
|
<div className="flex items-center justify-between p-4 border-bottom">
|
||||||
<span className="text-sm font-bold">
|
<span className="text-sm font-bold">
|
||||||
配比配置
|
配比配置
|
||||||
<span className="text-xs text-gray-500">
|
<span className="text-xs text-gray-500 ml-1">
|
||||||
(已配置:{totalConfigured}/{totalTargetCount}条)
|
(已配置:{totalConfigured}/{totalTargetCount}条)
|
||||||
</span>
|
</span>
|
||||||
</span>
|
</span>
|
||||||
@@ -170,41 +227,36 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
|
|||||||
<div className="flex-overflow-auto gap-4 p-4">
|
<div className="flex-overflow-auto gap-4 p-4">
|
||||||
{/* 配比预览 */}
|
{/* 配比预览 */}
|
||||||
{ratioConfigs.length > 0 && (
|
{ratioConfigs.length > 0 && (
|
||||||
<div>
|
<div className="p-3 bg-gray-50 rounded-lg mb-4">
|
||||||
<div className="p-3 bg-gray-50 rounded-lg">
|
<div className="grid grid-cols-2 gap-4 text-sm">
|
||||||
<div className="grid grid-cols-2 gap-4 text-sm">
|
<div>
|
||||||
<div>
|
<span className="text-gray-500">总配比数量:</span>
|
||||||
<span className="text-gray-500">总配比数量:</span>
|
<span className="ml-2 font-medium">
|
||||||
<span className="ml-2 font-medium">
|
{ratioConfigs
|
||||||
{ratioConfigs
|
.reduce((sum, config) => sum + config.quantity, 0)
|
||||||
.reduce((sum, config) => sum + config.quantity, 0)
|
.toLocaleString()}
|
||||||
.toLocaleString()}
|
</span>
|
||||||
</span>
|
</div>
|
||||||
</div>
|
<div>
|
||||||
<div>
|
<span className="text-gray-500">目标数量:</span>
|
||||||
<span className="text-gray-500">目标数量:</span>
|
<span className="ml-2 font-medium">
|
||||||
<span className="ml-2 font-medium">
|
{totalTargetCount.toLocaleString()}
|
||||||
{totalTargetCount.toLocaleString()}
|
</span>
|
||||||
</span>
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<span className="text-gray-500">配比项目:</span>
|
|
||||||
<span className="ml-2 font-medium">
|
|
||||||
{ratioConfigs.length}个
|
|
||||||
</span>
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
<div className="flex-1 overflow-auto">
|
|
||||||
|
<div className="flex-1 overflow-auto space-y-4">
|
||||||
{selectedDatasets.map((datasetId) => {
|
{selectedDatasets.map((datasetId) => {
|
||||||
const dataset = datasets.find((d) => String(d.id) === datasetId);
|
const dataset = datasets.find((d) => String(d.id) === datasetId);
|
||||||
const config = ratioConfigs.find((c) => c.source === datasetId);
|
const config = ratioConfigs.find((c) => c.source === datasetId);
|
||||||
const currentQuantity = config?.quantity || 0;
|
const currentQuantity = config?.quantity || 0;
|
||||||
|
|
||||||
if (!dataset) return null;
|
if (!dataset) return null;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Card key={datasetId} size="small" className="mb-2">
|
<Card key={datasetId} size="small" className="mb-4">
|
||||||
<div className="flex items-center justify-between mb-3">
|
<div className="flex items-center justify-between mb-3">
|
||||||
<div className="flex items-center gap-2">
|
<div className="flex items-center gap-2">
|
||||||
<span className="font-medium text-sm">
|
<span className="font-medium text-sm">
|
||||||
@@ -216,97 +268,36 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
|
|||||||
{config?.percentage || 0}%
|
{config?.percentage || 0}%
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
{ratioType === "dataset" ? (
|
|
||||||
<div>
|
{/* 筛选条件 */}
|
||||||
<div className="flex items-center gap-2 mb-2">
|
{renderFilters(datasetId)}
|
||||||
<span className="text-xs">数量:</span>
|
|
||||||
<Input
|
<div className="flex items-center gap-2 mb-2">
|
||||||
type="number"
|
<span className="text-xs">数量:</span>
|
||||||
value={currentQuantity}
|
<Input
|
||||||
onChange={(e) =>
|
type="number"
|
||||||
updateDatasetQuantity(
|
value={currentQuantity}
|
||||||
datasetId,
|
onChange={(e) =>
|
||||||
Number(e.target.value)
|
updateDatasetQuantity(
|
||||||
)
|
datasetId,
|
||||||
}
|
Number(e.target.value)
|
||||||
style={{ width: 80 }}
|
)
|
||||||
min={0}
|
}
|
||||||
max={Math.min(
|
style={{ width: 100 }}
|
||||||
dataset.fileCount || 0,
|
min={0}
|
||||||
totalTargetCount
|
max={Math.min(
|
||||||
)}
|
dataset.fileCount || 0,
|
||||||
/>
|
totalTargetCount
|
||||||
<span className="text-xs text-gray-500">条</span>
|
|
||||||
</div>
|
|
||||||
<Progress
|
|
||||||
percent={Math.round(
|
|
||||||
(currentQuantity / totalTargetCount) * 100
|
|
||||||
)}
|
|
||||||
size="small"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
) : (
|
|
||||||
<div>
|
|
||||||
{!distributions[String(dataset.id)] ? (
|
|
||||||
<div className="text-xs text-gray-400">
|
|
||||||
加载标签分布...
|
|
||||||
</div>
|
|
||||||
) : Object.entries(distributions[String(dataset.id)])
|
|
||||||
.length === 0 ? (
|
|
||||||
<div className="text-xs text-gray-400">
|
|
||||||
该数据集暂无标签
|
|
||||||
</div>
|
|
||||||
) : (
|
|
||||||
<div className="flex flex-col gap-2">
|
|
||||||
{Object.entries(
|
|
||||||
distributions[String(dataset.id)]
|
|
||||||
).map(([label, count]) => {
|
|
||||||
const sourceKey = `${datasetId}_${label}`;
|
|
||||||
const labelConfig = ratioConfigs.find(
|
|
||||||
(c) => c.source === sourceKey
|
|
||||||
);
|
|
||||||
const labelQuantity = labelConfig?.quantity || 0;
|
|
||||||
return (
|
|
||||||
<div
|
|
||||||
key={label}
|
|
||||||
className="flex items-center justify-between gap-2"
|
|
||||||
>
|
|
||||||
<div className="flex items-center gap-2">
|
|
||||||
<Badge color="gray">{label}</Badge>
|
|
||||||
<span className="text-xs text-gray-500">
|
|
||||||
{count}条
|
|
||||||
</span>
|
|
||||||
</div>
|
|
||||||
<div className="flex items-center gap-2">
|
|
||||||
<span className="text-xs">数量:</span>
|
|
||||||
<Input
|
|
||||||
type="number"
|
|
||||||
value={labelQuantity}
|
|
||||||
onChange={(e) =>
|
|
||||||
updateLabelQuantity(
|
|
||||||
datasetId,
|
|
||||||
label,
|
|
||||||
Number(e.target.value)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
style={{ width: 80 }}
|
|
||||||
min={0}
|
|
||||||
max={Math.min(
|
|
||||||
Number(count) || 0,
|
|
||||||
totalTargetCount
|
|
||||||
)}
|
|
||||||
/>
|
|
||||||
<span className="text-xs text-gray-500">
|
|
||||||
条
|
|
||||||
</span>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
})}
|
|
||||||
</div>
|
|
||||||
)}
|
)}
|
||||||
</div>
|
/>
|
||||||
)}
|
<span className="text-xs text-gray-500">条</span>
|
||||||
|
</div>
|
||||||
|
<Progress
|
||||||
|
percent={Math.round(
|
||||||
|
(currentQuantity / totalTargetCount) * 100
|
||||||
|
)}
|
||||||
|
size="small"
|
||||||
|
/>
|
||||||
</Card>
|
</Card>
|
||||||
);
|
);
|
||||||
})}
|
})}
|
||||||
|
|||||||
@@ -1,169 +0,0 @@
|
|||||||
import React, { useMemo } from "react";
|
|
||||||
import { Table } from "antd";
|
|
||||||
import { TransferItem } from "antd/es/transfer";
|
|
||||||
import RatioConfig from "./RatioConfig";
|
|
||||||
import useFetchData from "@/hooks/useFetchData";
|
|
||||||
import { queryDatasetsUsingGet } from "@/pages/DataManagement/dataset.api";
|
|
||||||
import {
|
|
||||||
datasetTypeMap,
|
|
||||||
mapDataset,
|
|
||||||
} from "@/pages/DataManagement/dataset.const";
|
|
||||||
import { SearchControls } from "@/components/SearchControls";
|
|
||||||
|
|
||||||
const leftColumns = [
|
|
||||||
{
|
|
||||||
dataIndex: "name",
|
|
||||||
title: "名称",
|
|
||||||
ellipsis: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
dataIndex: "datasetType",
|
|
||||||
title: "类型",
|
|
||||||
ellipsis: true,
|
|
||||||
width: 100,
|
|
||||||
render: (type: string) => datasetTypeMap[type].label,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
dataIndex: "size",
|
|
||||||
title: "大小",
|
|
||||||
width: 100,
|
|
||||||
ellipsis: true,
|
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
export default function RatioTransfer(props: {
|
|
||||||
distributions: Record<string, Record<string, number>>;
|
|
||||||
ratioTaskForm: any;
|
|
||||||
updateRatioConfig: (datasetId: string, quantity: number) => void;
|
|
||||||
updateLabelRatioConfig: (
|
|
||||||
datasetId: string,
|
|
||||||
label: string,
|
|
||||||
quantity: number
|
|
||||||
) => void;
|
|
||||||
}) {
|
|
||||||
const {
|
|
||||||
updateLabelRatioConfig,
|
|
||||||
updateRatioConfig,
|
|
||||||
ratioTaskForm,
|
|
||||||
distributions,
|
|
||||||
} = props;
|
|
||||||
const {
|
|
||||||
tableData: datasets,
|
|
||||||
loading,
|
|
||||||
pagination,
|
|
||||||
searchParams,
|
|
||||||
setSearchParams,
|
|
||||||
handleFiltersChange,
|
|
||||||
} = useFetchData(queryDatasetsUsingGet, mapDataset);
|
|
||||||
|
|
||||||
const [selectedDatasets, setSelectedDatasets] = React.useState<
|
|
||||||
TransferItem[]
|
|
||||||
>([]);
|
|
||||||
|
|
||||||
const selectedRowKeys = useMemo(() => {
|
|
||||||
return selectedDatasets.map((item) => item.key);
|
|
||||||
}, [selectedDatasets]);
|
|
||||||
|
|
||||||
const [listDisabled, setListDisabled] = React.useState(false);
|
|
||||||
|
|
||||||
const generateAutoRatio = () => {
|
|
||||||
const selectedCount = ratioTaskForm.selectedDatasets.length;
|
|
||||||
if (selectedCount === 0) return;
|
|
||||||
|
|
||||||
const baseQuantity = Math.floor(
|
|
||||||
ratioTaskForm.totalTargetCount / selectedCount
|
|
||||||
);
|
|
||||||
const remainder = ratioTaskForm.totalTargetCount % selectedCount;
|
|
||||||
|
|
||||||
const newConfigs = ratioTaskForm.selectedDatasets.map(
|
|
||||||
(datasetId, index) => {
|
|
||||||
const quantity = baseQuantity + (index < remainder ? 1 : 0);
|
|
||||||
return {
|
|
||||||
id: datasetId,
|
|
||||||
name: datasetId,
|
|
||||||
type: ratioTaskForm.ratioType,
|
|
||||||
quantity,
|
|
||||||
percentage: Math.round(
|
|
||||||
(quantity / ratioTaskForm.totalTargetCount) * 100
|
|
||||||
),
|
|
||||||
source: datasetId,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
);
|
|
||||||
setRatioTaskForm((prev) => ({ ...prev, ratioConfigs: newConfigs }));
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className="flex">
|
|
||||||
<div className="border-card flex-1 mr-4">
|
|
||||||
<h3 className="p-2 border-bottom">{`${selectedDatasets.length} / ${datasets.length} 项`}</h3>
|
|
||||||
<SearchControls
|
|
||||||
searchTerm={searchParams.keyword}
|
|
||||||
onSearchChange={(keyword) =>
|
|
||||||
setSearchParams({ ...searchParams, keyword })
|
|
||||||
}
|
|
||||||
searchPlaceholder="搜索数据集名称..."
|
|
||||||
filters={[
|
|
||||||
{
|
|
||||||
key: "type",
|
|
||||||
label: "数据集类型",
|
|
||||||
options: [
|
|
||||||
{ value: "dataset", label: "按数据集" },
|
|
||||||
{ value: "tag", label: "按标签" },
|
|
||||||
],
|
|
||||||
},
|
|
||||||
]}
|
|
||||||
onFiltersChange={handleFiltersChange}
|
|
||||||
onClearFilters={() =>
|
|
||||||
setSearchParams({ ...searchParams, filter: {} })
|
|
||||||
}
|
|
||||||
showViewToggle={false}
|
|
||||||
showReload={false}
|
|
||||||
className="m-4"
|
|
||||||
/>
|
|
||||||
<Table
|
|
||||||
rowSelection={{
|
|
||||||
onChange: (_, selectedRows) => {
|
|
||||||
setSelectedDatasets(selectedRows);
|
|
||||||
},
|
|
||||||
selectedRowKeys,
|
|
||||||
selections: [
|
|
||||||
Table.SELECTION_ALL,
|
|
||||||
Table.SELECTION_INVERT,
|
|
||||||
Table.SELECTION_NONE,
|
|
||||||
],
|
|
||||||
}}
|
|
||||||
columns={leftColumns}
|
|
||||||
dataSource={datasets}
|
|
||||||
loading={loading}
|
|
||||||
pagination={pagination}
|
|
||||||
size="small"
|
|
||||||
rowKey="id"
|
|
||||||
style={{ pointerEvents: listDisabled ? "none" : undefined }}
|
|
||||||
onRow={(record) => ({
|
|
||||||
onClick: () => {
|
|
||||||
if (record.disabled || listDisabled) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
setSelectedDatasets((prev) => {
|
|
||||||
if (prev.includes(record.key)) {
|
|
||||||
return prev.filter((k) => k !== record.key);
|
|
||||||
}
|
|
||||||
return [...prev, record.key];
|
|
||||||
});
|
|
||||||
},
|
|
||||||
})}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
<div className="border-card flex-1">
|
|
||||||
<RatioConfig
|
|
||||||
datasets={selectedDatasets}
|
|
||||||
ratioTaskForm={ratioTaskForm}
|
|
||||||
distributions={distributions}
|
|
||||||
onUpdateDatasetQuantity={updateRatioConfig}
|
|
||||||
onUpdateLabelQuantity={updateLabelRatioConfig}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
import React, { useEffect, useState } from "react";
|
import React, { useEffect, useState } from "react";
|
||||||
import { Badge, Button, Card, Checkbox, Input, Pagination, Select } from "antd";
|
import { Badge, Button, Card, Checkbox, Input, Pagination } from "antd";
|
||||||
import { Search as SearchIcon } from "lucide-react";
|
import { Search as SearchIcon } from "lucide-react";
|
||||||
import type { Dataset } from "@/pages/DataManagement/dataset.model.ts";
|
import type { Dataset } from "@/pages/DataManagement/dataset.model.ts";
|
||||||
import {
|
import {
|
||||||
@@ -10,8 +10,6 @@ import {
|
|||||||
|
|
||||||
interface SelectDatasetProps {
|
interface SelectDatasetProps {
|
||||||
selectedDatasets: string[];
|
selectedDatasets: string[];
|
||||||
ratioType: "dataset" | "label";
|
|
||||||
onRatioTypeChange: (val: "dataset" | "label") => void;
|
|
||||||
onSelectedDatasetsChange: (next: string[]) => void;
|
onSelectedDatasetsChange: (next: string[]) => void;
|
||||||
onDistributionsChange?: (
|
onDistributionsChange?: (
|
||||||
next: Record<string, Record<string, number>>
|
next: Record<string, Record<string, number>>
|
||||||
@@ -21,8 +19,6 @@ interface SelectDatasetProps {
|
|||||||
|
|
||||||
const SelectDataset: React.FC<SelectDatasetProps> = ({
|
const SelectDataset: React.FC<SelectDatasetProps> = ({
|
||||||
selectedDatasets,
|
selectedDatasets,
|
||||||
ratioType,
|
|
||||||
onRatioTypeChange,
|
|
||||||
onSelectedDatasetsChange,
|
onSelectedDatasetsChange,
|
||||||
onDistributionsChange,
|
onDistributionsChange,
|
||||||
onDatasetsChange,
|
onDatasetsChange,
|
||||||
@@ -62,7 +58,7 @@ const SelectDataset: React.FC<SelectDatasetProps> = ({
|
|||||||
// Fetch label distributions when in label mode
|
// Fetch label distributions when in label mode
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const fetchDistributions = async () => {
|
const fetchDistributions = async () => {
|
||||||
if (ratioType !== "label" || !datasets?.length) return;
|
if (!datasets?.length) return;
|
||||||
const idsToFetch = datasets
|
const idsToFetch = datasets
|
||||||
.map((d) => String(d.id))
|
.map((d) => String(d.id))
|
||||||
.filter((id) => !distributions[id]);
|
.filter((id) => !distributions[id]);
|
||||||
@@ -147,7 +143,7 @@ const SelectDataset: React.FC<SelectDatasetProps> = ({
|
|||||||
};
|
};
|
||||||
fetchDistributions();
|
fetchDistributions();
|
||||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
}, [ratioType, datasets]);
|
}, [datasets]);
|
||||||
|
|
||||||
const onToggleDataset = (datasetId: string, checked: boolean) => {
|
const onToggleDataset = (datasetId: string, checked: boolean) => {
|
||||||
if (checked) {
|
if (checked) {
|
||||||
@@ -180,18 +176,6 @@ const SelectDataset: React.FC<SelectDatasetProps> = ({
|
|||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
<div className="flex-overflow-auto gap-4 p-4">
|
<div className="flex-overflow-auto gap-4 p-4">
|
||||||
<div className="flex items-center gap-4">
|
|
||||||
<span className="text-sm">配比方式:</span>
|
|
||||||
<Select
|
|
||||||
className="flex-1 min-w-[120px]"
|
|
||||||
value={ratioType}
|
|
||||||
onChange={(v) => onRatioTypeChange(v)}
|
|
||||||
options={[
|
|
||||||
{ label: "按数据集", value: "dataset" },
|
|
||||||
{ label: "按标签", value: "label" },
|
|
||||||
]}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
<Input
|
<Input
|
||||||
prefix={<SearchIcon className="text-gray-400" />}
|
prefix={<SearchIcon className="text-gray-400" />}
|
||||||
placeholder="搜索数据集"
|
placeholder="搜索数据集"
|
||||||
@@ -239,32 +223,30 @@ const SelectDataset: React.FC<SelectDatasetProps> = ({
|
|||||||
<span>{dataset.fileCount}条</span>
|
<span>{dataset.fileCount}条</span>
|
||||||
<span>{dataset.size}</span>
|
<span>{dataset.size}</span>
|
||||||
</div>
|
</div>
|
||||||
{ratioType === "label" && (
|
<div className="mt-2">
|
||||||
<div className="mt-2">
|
{distributions[idStr] ? (
|
||||||
{distributions[idStr] ? (
|
Object.entries(distributions[idStr]).length > 0 ? (
|
||||||
Object.entries(distributions[idStr]).length > 0 ? (
|
<div className="flex flex-wrap gap-2 text-xs">
|
||||||
<div className="flex flex-wrap gap-2 text-xs">
|
{Object.entries(distributions[idStr])
|
||||||
{Object.entries(distributions[idStr])
|
.slice(0, 8)
|
||||||
.slice(0, 8)
|
.map(([tag, count]) => (
|
||||||
.map(([tag, count]) => (
|
<Badge
|
||||||
<Badge
|
key={tag}
|
||||||
key={tag}
|
color="gray"
|
||||||
color="gray"
|
>{`${tag}: ${count}`}</Badge>
|
||||||
>{`${tag}: ${count}`}</Badge>
|
))}
|
||||||
))}
|
</div>
|
||||||
</div>
|
|
||||||
) : (
|
|
||||||
<div className="text-xs text-gray-400">
|
|
||||||
未检测到标签分布
|
|
||||||
</div>
|
|
||||||
)
|
|
||||||
) : (
|
) : (
|
||||||
<div className="text-xs text-gray-400">
|
<div className="text-xs text-gray-400">
|
||||||
加载标签分布...
|
未检测到标签分布
|
||||||
</div>
|
</div>
|
||||||
)}
|
)
|
||||||
</div>
|
) : (
|
||||||
)}
|
<div className="text-xs text-gray-400">
|
||||||
|
加载标签分布...
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</Card>
|
</Card>
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
from .common import (
|
from .common import (
|
||||||
BaseResponseModel,
|
BaseResponseModel,
|
||||||
StandardResponse,
|
StandardResponse,
|
||||||
PaginatedData
|
PaginatedData,
|
||||||
|
TaskStatus
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseResponseModel",
|
"BaseResponseModel",
|
||||||
"StandardResponse",
|
"StandardResponse",
|
||||||
"PaginatedData"
|
"PaginatedData",
|
||||||
|
"TaskStatus"
|
||||||
]
|
]
|
||||||
@@ -1,8 +1,9 @@
|
|||||||
"""
|
"""
|
||||||
通用响应模型
|
通用响应模型
|
||||||
"""
|
"""
|
||||||
from typing import Generic, TypeVar, Optional, List, Type
|
from typing import Generic, TypeVar, List
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
# 定义泛型类型变量
|
# 定义泛型类型变量
|
||||||
T = TypeVar('T')
|
T = TypeVar('T')
|
||||||
@@ -42,3 +43,9 @@ class PaginatedData(BaseResponseModel, Generic[T]):
|
|||||||
total_elements: int = Field(..., description="总条数")
|
total_elements: int = Field(..., description="总条数")
|
||||||
total_pages: int = Field(..., description="总页数")
|
total_pages: int = Field(..., description="总页数")
|
||||||
content: List[T] = Field(..., description="当前页数据")
|
content: List[T] = Field(..., description="当前页数据")
|
||||||
|
|
||||||
|
class TaskStatus(Enum):
|
||||||
|
PENDING = "PENDING"
|
||||||
|
RUNNING = "RUNNING"
|
||||||
|
COMPLETED = "COMPLETED"
|
||||||
|
FAILED = "FAILED"
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from app.core.logging import get_logger
|
|||||||
from app.db.models import Dataset
|
from app.db.models import Dataset
|
||||||
from app.db.session import get_db
|
from app.db.session import get_db
|
||||||
from app.module.dataset import DatasetManagementService
|
from app.module.dataset import DatasetManagementService
|
||||||
from app.module.shared.schema import StandardResponse
|
from app.module.shared.schema import StandardResponse, TaskStatus
|
||||||
from app.module.synthesis.schema.ratio_task import (
|
from app.module.synthesis.schema.ratio_task import (
|
||||||
CreateRatioTaskResponse,
|
CreateRatioTaskResponse,
|
||||||
CreateRatioTaskRequest,
|
CreateRatioTaskRequest,
|
||||||
@@ -49,52 +49,18 @@ async def create_ratio_task(
|
|||||||
|
|
||||||
await valid_exists(db, req)
|
await valid_exists(db, req)
|
||||||
|
|
||||||
# 创建目标数据集:名称使用“<任务名称>-配比生成-时间戳”
|
target_dataset = await create_target_dataset(db, req, source_types)
|
||||||
target_dataset_name = f"{req.name}-配比生成-{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
|
||||||
|
|
||||||
target_type = get_target_dataset_type(source_types)
|
instance = await create_ratio_instance(db, req, target_dataset)
|
||||||
|
|
||||||
target_dataset = Dataset(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
name=target_dataset_name,
|
|
||||||
description=req.description or "",
|
|
||||||
dataset_type=target_type,
|
|
||||||
status="DRAFT",
|
|
||||||
)
|
|
||||||
target_dataset.path = f"/dataset/{target_dataset.id}"
|
|
||||||
db.add(target_dataset)
|
|
||||||
await db.flush() # 获取 target_dataset.id
|
|
||||||
|
|
||||||
service = RatioTaskService(db)
|
|
||||||
instance = await service.create_task(
|
|
||||||
name=req.name,
|
|
||||||
description=req.description,
|
|
||||||
totals=int(req.totals),
|
|
||||||
ratio_method=req.ratio_method,
|
|
||||||
config=[
|
|
||||||
{
|
|
||||||
"dataset_id": item.dataset_id,
|
|
||||||
"counts": int(item.counts),
|
|
||||||
"filter_conditions": item.filter_conditions,
|
|
||||||
}
|
|
||||||
for item in req.config
|
|
||||||
],
|
|
||||||
target_dataset_id=target_dataset.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 异步执行配比任务(支持 DATASET / TAG)
|
|
||||||
asyncio.create_task(RatioTaskService.execute_dataset_ratio_task(instance.id))
|
asyncio.create_task(RatioTaskService.execute_dataset_ratio_task(instance.id))
|
||||||
|
|
||||||
return StandardResponse(
|
response_data = CreateRatioTaskResponse(
|
||||||
code=200,
|
|
||||||
message="success",
|
|
||||||
data=CreateRatioTaskResponse(
|
|
||||||
id=instance.id,
|
id=instance.id,
|
||||||
name=instance.name,
|
name=instance.name,
|
||||||
description=instance.description,
|
description=instance.description,
|
||||||
totals=instance.totals or 0,
|
totals=instance.totals or 0,
|
||||||
ratio_method=instance.ratio_method or req.ratio_method,
|
status=instance.status or TaskStatus.PENDING.name,
|
||||||
status=instance.status or "PENDING",
|
|
||||||
config=req.config,
|
config=req.config,
|
||||||
targetDataset=TargetDatasetInfo(
|
targetDataset=TargetDatasetInfo(
|
||||||
id=str(target_dataset.id),
|
id=str(target_dataset.id),
|
||||||
@@ -103,6 +69,10 @@ async def create_ratio_task(
|
|||||||
status=str(target_dataset.status),
|
status=str(target_dataset.status),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
return StandardResponse(
|
||||||
|
code=200,
|
||||||
|
message="success",
|
||||||
|
data=response_data
|
||||||
)
|
)
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
@@ -113,6 +83,46 @@ async def create_ratio_task(
|
|||||||
raise HTTPException(status_code=500, detail="Internal server error")
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
|
|
||||||
|
|
||||||
|
async def create_ratio_instance(db, req: CreateRatioTaskRequest, target_dataset: Dataset) -> RatioInstance:
|
||||||
|
service = RatioTaskService(db)
|
||||||
|
logger.info(f"create_ratio_instance: {req}")
|
||||||
|
instance = await service.create_task(
|
||||||
|
name=req.name,
|
||||||
|
description=req.description,
|
||||||
|
totals=int(req.totals),
|
||||||
|
config=[
|
||||||
|
{
|
||||||
|
"dataset_id": item.dataset_id,
|
||||||
|
"counts": int(item.counts),
|
||||||
|
"filter_conditions": item.filter_conditions,
|
||||||
|
}
|
||||||
|
for item in req.config
|
||||||
|
],
|
||||||
|
target_dataset_id=target_dataset.id,
|
||||||
|
)
|
||||||
|
return instance
|
||||||
|
|
||||||
|
|
||||||
|
async def create_target_dataset(db, req: CreateRatioTaskRequest, source_types: set[str]) -> Dataset:
|
||||||
|
# 创建目标数据集:名称使用“<任务名称>-时间戳”
|
||||||
|
target_dataset_name = f"{req.name}-{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||||
|
|
||||||
|
target_type = get_target_dataset_type(source_types)
|
||||||
|
target_dataset_id = uuid.uuid4()
|
||||||
|
|
||||||
|
target_dataset = Dataset(
|
||||||
|
id=str(target_dataset_id),
|
||||||
|
name=target_dataset_name,
|
||||||
|
description=req.description or "",
|
||||||
|
dataset_type=target_type,
|
||||||
|
status="DRAFT",
|
||||||
|
path=f"/dataset/{target_dataset_id}",
|
||||||
|
)
|
||||||
|
db.add(target_dataset)
|
||||||
|
await db.flush() # 获取 target_dataset.id
|
||||||
|
return target_dataset
|
||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=StandardResponse[PagedRatioTaskResponse], status_code=200)
|
@router.get("", response_model=StandardResponse[PagedRatioTaskResponse], status_code=200)
|
||||||
async def list_ratio_tasks(
|
async def list_ratio_tasks(
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
|
|||||||
@@ -2,10 +2,36 @@ from typing import List, Optional, Dict, Any
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.module.shared.schema.common import TaskStatus
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
class FilterCondition(BaseModel):
|
||||||
|
date_range: Optional[str] = Field(None, description="数据范围", alias="dateRange")
|
||||||
|
label: Optional[str] = Field(None, description="标签")
|
||||||
|
|
||||||
|
@field_validator("date_range")
|
||||||
|
@classmethod
|
||||||
|
def validate_date_range(cls, v: Optional[str]) -> Optional[str]:
|
||||||
|
# ensure it's a numeric string if provided
|
||||||
|
if not v:
|
||||||
|
return v
|
||||||
|
try:
|
||||||
|
int(v)
|
||||||
|
return v
|
||||||
|
except (ValueError, TypeError) as e:
|
||||||
|
raise ValueError("date_range must be a numeric string")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
# allow population by field name when constructing model programmatically
|
||||||
|
validate_by_name = True
|
||||||
|
|
||||||
|
|
||||||
class RatioConfigItem(BaseModel):
|
class RatioConfigItem(BaseModel):
|
||||||
dataset_id: str = Field(..., alias="datasetId", description="数据集id")
|
dataset_id: str = Field(..., alias="datasetId", description="数据集id")
|
||||||
counts: str = Field(..., description="数量")
|
counts: str = Field(..., description="数量")
|
||||||
filter_conditions: str = Field(..., description="过滤条件")
|
filter_conditions: FilterCondition = Field(..., alias="filterConditions", description="过滤条件")
|
||||||
|
|
||||||
@field_validator("counts")
|
@field_validator("counts")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -22,17 +48,8 @@ class CreateRatioTaskRequest(BaseModel):
|
|||||||
name: str = Field(..., description="名称")
|
name: str = Field(..., description="名称")
|
||||||
description: Optional[str] = Field(None, description="描述")
|
description: Optional[str] = Field(None, description="描述")
|
||||||
totals: str = Field(..., description="目标数量")
|
totals: str = Field(..., description="目标数量")
|
||||||
ratio_method: str = Field(..., description="配比方式", alias="ratio_method")
|
|
||||||
config: List[RatioConfigItem] = Field(..., description="配比设置列表")
|
config: List[RatioConfigItem] = Field(..., description="配比设置列表")
|
||||||
|
|
||||||
@field_validator("ratio_method")
|
|
||||||
@classmethod
|
|
||||||
def validate_ratio_method(cls, v: str) -> str:
|
|
||||||
allowed = {"TAG", "DATASET"}
|
|
||||||
if v not in allowed:
|
|
||||||
raise ValueError(f"ratio_method must be one of {allowed}")
|
|
||||||
return v
|
|
||||||
|
|
||||||
@field_validator("totals")
|
@field_validator("totals")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_totals(cls, v: str) -> str:
|
def validate_totals(cls, v: str) -> str:
|
||||||
@@ -58,8 +75,7 @@ class CreateRatioTaskResponse(BaseModel):
|
|||||||
name: str
|
name: str
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
totals: int
|
totals: int
|
||||||
ratio_method: str
|
status: TaskStatus
|
||||||
status: str
|
|
||||||
# echoed config
|
# echoed config
|
||||||
config: List[RatioConfigItem]
|
config: List[RatioConfigItem]
|
||||||
# created dataset
|
# created dataset
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ from app.db.models.ratio_task import RatioInstance, RatioRelation
|
|||||||
from app.db.models import Dataset, DatasetFiles
|
from app.db.models import Dataset, DatasetFiles
|
||||||
from app.db.session import AsyncSessionLocal
|
from app.db.session import AsyncSessionLocal
|
||||||
from app.module.dataset.schema.dataset_file import DatasetFileTag
|
from app.module.dataset.schema.dataset_file import DatasetFileTag
|
||||||
|
from app.module.shared.schema import TaskStatus
|
||||||
|
from app.module.synthesis.schema.ratio_task import FilterCondition
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -30,7 +32,6 @@ class RatioTaskService:
|
|||||||
name: str,
|
name: str,
|
||||||
description: Optional[str],
|
description: Optional[str],
|
||||||
totals: int,
|
totals: int,
|
||||||
ratio_method: str,
|
|
||||||
config: List[Dict[str, Any]],
|
config: List[Dict[str, Any]],
|
||||||
target_dataset_id: Optional[str] = None,
|
target_dataset_id: Optional[str] = None,
|
||||||
) -> RatioInstance:
|
) -> RatioInstance:
|
||||||
@@ -38,12 +39,11 @@ class RatioTaskService:
|
|||||||
|
|
||||||
config item format: {"dataset_id": str, "counts": int, "filter_conditions": str}
|
config item format: {"dataset_id": str, "counts": int, "filter_conditions": str}
|
||||||
"""
|
"""
|
||||||
logger.info(f"Creating ratio task: name={name}, method={ratio_method}, totals={totals}, items={len(config or [])}")
|
logger.info(f"Creating ratio task: name={name}, totals={totals}, items={len(config or [])}")
|
||||||
|
|
||||||
instance = RatioInstance(
|
instance = RatioInstance(
|
||||||
name=name,
|
name=name,
|
||||||
description=description,
|
description=description,
|
||||||
ratio_method=ratio_method,
|
|
||||||
totals=totals,
|
totals=totals,
|
||||||
target_dataset_id=target_dataset_id,
|
target_dataset_id=target_dataset_id,
|
||||||
status="PENDING",
|
status="PENDING",
|
||||||
@@ -56,8 +56,12 @@ class RatioTaskService:
|
|||||||
ratio_instance_id=instance.id,
|
ratio_instance_id=instance.id,
|
||||||
source_dataset_id=item.get("dataset_id"),
|
source_dataset_id=item.get("dataset_id"),
|
||||||
counts=int(item.get("counts", 0)),
|
counts=int(item.get("counts", 0)),
|
||||||
filter_conditions=item.get("filter_conditions"),
|
filter_conditions=json.dumps({
|
||||||
|
'date_range': item.get("filter_conditions").date_range,
|
||||||
|
'label': item.get("filter_conditions").label,
|
||||||
|
})
|
||||||
)
|
)
|
||||||
|
logger.info(f"Relation created: {relation.id}, {relation}, {item}, {config}")
|
||||||
self.db.add(relation)
|
self.db.add(relation)
|
||||||
|
|
||||||
await self.db.commit()
|
await self.db.commit()
|
||||||
@@ -97,94 +101,17 @@ class RatioTaskService:
|
|||||||
relations: List[RatioRelation] = list(rel_res.scalars().all())
|
relations: List[RatioRelation] = list(rel_res.scalars().all())
|
||||||
|
|
||||||
# Mark running
|
# Mark running
|
||||||
instance.status = "RUNNING"
|
instance.status = TaskStatus.RUNNING.name
|
||||||
|
|
||||||
if instance.ratio_method not in {"DATASET", "TAG"}:
|
|
||||||
logger.info(f"Instance {instance_id} ratio_method={instance.ratio_method} not supported yet")
|
|
||||||
instance.status = "SUCCESS"
|
|
||||||
return
|
|
||||||
|
|
||||||
# Load target dataset
|
# Load target dataset
|
||||||
ds_res = await session.execute(select(Dataset).where(Dataset.id == instance.target_dataset_id))
|
ds_res = await session.execute(select(Dataset).where(Dataset.id == instance.target_dataset_id))
|
||||||
target_ds: Optional[Dataset] = ds_res.scalar_one_or_none()
|
target_ds: Optional[Dataset] = ds_res.scalar_one_or_none()
|
||||||
if not target_ds:
|
if not target_ds:
|
||||||
logger.error(f"Target dataset not found for instance {instance_id}")
|
logger.error(f"Target dataset not found for instance {instance_id}")
|
||||||
instance.status = "FAILED"
|
instance.status = TaskStatus.FAILED.name
|
||||||
return
|
return
|
||||||
|
|
||||||
# Preload existing target file paths for deduplication
|
added_count, added_size = await RatioTaskService.handle_ratio_relations(relations,session, target_ds)
|
||||||
existing_path_rows = await session.execute(
|
|
||||||
select(DatasetFiles.file_path).where(DatasetFiles.dataset_id == target_ds.id)
|
|
||||||
)
|
|
||||||
existing_paths = set(p for p in existing_path_rows.scalars().all() if p)
|
|
||||||
|
|
||||||
added_count = 0
|
|
||||||
added_size = 0
|
|
||||||
|
|
||||||
for rel in relations:
|
|
||||||
if not rel.source_dataset_id or not rel.counts or rel.counts <= 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Fetch all files for the source dataset (ACTIVE only)
|
|
||||||
files_res = await session.execute(
|
|
||||||
select(DatasetFiles).where(
|
|
||||||
DatasetFiles.dataset_id == rel.source_dataset_id,
|
|
||||||
DatasetFiles.status == "ACTIVE",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
files = list(files_res.scalars().all())
|
|
||||||
|
|
||||||
# TAG mode: filter by tags according to relation.filter_conditions
|
|
||||||
if instance.ratio_method == "TAG":
|
|
||||||
required_tags = RatioTaskService._parse_required_tags(rel.filter_conditions)
|
|
||||||
if required_tags:
|
|
||||||
files = [f for f in files if RatioTaskService._file_contains_tags(f, required_tags)]
|
|
||||||
|
|
||||||
if not files:
|
|
||||||
continue
|
|
||||||
|
|
||||||
pick_n = min(rel.counts or 0, len(files))
|
|
||||||
chosen = random.sample(files, pick_n) if pick_n < len(files) else files
|
|
||||||
|
|
||||||
# Copy into target dataset with de-dup by target path
|
|
||||||
for f in chosen:
|
|
||||||
src_path = f.file_path
|
|
||||||
new_path = src_path
|
|
||||||
needs_copy = False
|
|
||||||
src_prefix = f"/dataset/{rel.source_dataset_id}"
|
|
||||||
if isinstance(src_path, str) and src_path.startswith(src_prefix):
|
|
||||||
dst_prefix = f"/dataset/{target_ds.id}"
|
|
||||||
new_path = src_path.replace(src_prefix, dst_prefix, 1)
|
|
||||||
needs_copy = True
|
|
||||||
|
|
||||||
# De-dup by target path
|
|
||||||
if new_path in existing_paths:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Perform copy only when needed
|
|
||||||
if needs_copy:
|
|
||||||
dst_dir = os.path.dirname(new_path)
|
|
||||||
await asyncio.to_thread(os.makedirs, dst_dir, exist_ok=True)
|
|
||||||
await asyncio.to_thread(shutil.copy2, src_path, new_path)
|
|
||||||
|
|
||||||
new_file = DatasetFiles(
|
|
||||||
dataset_id=target_ds.id, # type: ignore
|
|
||||||
file_name=f.file_name,
|
|
||||||
file_path=new_path,
|
|
||||||
file_type=f.file_type,
|
|
||||||
file_size=f.file_size,
|
|
||||||
check_sum=f.check_sum,
|
|
||||||
tags=f.tags,
|
|
||||||
dataset_filemetadata=f.dataset_filemetadata,
|
|
||||||
status="ACTIVE",
|
|
||||||
)
|
|
||||||
session.add(new_file)
|
|
||||||
existing_paths.add(new_path)
|
|
||||||
added_count += 1
|
|
||||||
added_size += int(f.file_size or 0)
|
|
||||||
|
|
||||||
# Periodically flush to avoid huge transactions
|
|
||||||
await session.flush()
|
|
||||||
|
|
||||||
# Update target dataset statistics
|
# Update target dataset statistics
|
||||||
target_ds.file_count = (target_ds.file_count or 0) + added_count # type: ignore
|
target_ds.file_count = (target_ds.file_count or 0) + added_count # type: ignore
|
||||||
@@ -194,8 +121,8 @@ class RatioTaskService:
|
|||||||
target_ds.status = "ACTIVE"
|
target_ds.status = "ACTIVE"
|
||||||
|
|
||||||
# Done
|
# Done
|
||||||
instance.status = "SUCCESS"
|
instance.status = TaskStatus.COMPLETED.name
|
||||||
logger.info(f"Dataset ratio execution completed: instance={instance_id}, files={added_count}, size={added_size}")
|
logger.info(f"Dataset ratio execution completed: instance={instance_id}, files={added_count}, size={added_size}, {instance.status}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Dataset ratio execution failed for {instance_id}: {e}")
|
logger.exception(f"Dataset ratio execution failed for {instance_id}: {e}")
|
||||||
@@ -204,42 +131,163 @@ class RatioTaskService:
|
|||||||
inst_res = await session.execute(select(RatioInstance).where(RatioInstance.id == instance_id))
|
inst_res = await session.execute(select(RatioInstance).where(RatioInstance.id == instance_id))
|
||||||
instance = inst_res.scalar_one_or_none()
|
instance = inst_res.scalar_one_or_none()
|
||||||
if instance:
|
if instance:
|
||||||
instance.status = "FAILED"
|
instance.status = TaskStatus.FAILED.name
|
||||||
finally:
|
finally:
|
||||||
pass
|
pass
|
||||||
finally:
|
finally:
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def handle_ratio_relations(relations: list[RatioRelation], session, target_ds: Dataset) -> tuple[int, int]:
|
||||||
|
# Preload existing target file paths for deduplication
|
||||||
|
existing_path_rows = await session.execute(
|
||||||
|
select(DatasetFiles.file_path).where(DatasetFiles.dataset_id == target_ds.id)
|
||||||
|
)
|
||||||
|
existing_paths = set(p for p in existing_path_rows.scalars().all() if p)
|
||||||
|
|
||||||
|
added_count = 0
|
||||||
|
added_size = 0
|
||||||
|
|
||||||
|
for rel in relations:
|
||||||
|
if not rel.source_dataset_id or not rel.counts or rel.counts <= 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
files = await RatioTaskService.get_files(rel, session)
|
||||||
|
|
||||||
|
if not files:
|
||||||
|
continue
|
||||||
|
|
||||||
|
pick_n = min(rel.counts or 0, len(files))
|
||||||
|
chosen = random.sample(files, pick_n) if pick_n < len(files) else files
|
||||||
|
|
||||||
|
# Copy into target dataset with de-dup by target path
|
||||||
|
for f in chosen:
|
||||||
|
await RatioTaskService.handle_selected_file(existing_paths, f, session, target_ds)
|
||||||
|
added_count += 1
|
||||||
|
added_size += int(f.file_size or 0)
|
||||||
|
|
||||||
|
# Periodically flush to avoid huge transactions
|
||||||
|
await session.flush()
|
||||||
|
return added_count, added_size
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def handle_selected_file(existing_paths: set[Any], f, session, target_ds: Dataset):
|
||||||
|
src_path = f.file_path
|
||||||
|
dst_prefix = f"/dataset/{target_ds.id}"
|
||||||
|
file_name = RatioTaskService.get_new_file_name(dst_prefix, existing_paths, f)
|
||||||
|
|
||||||
|
new_path = dst_prefix + file_name
|
||||||
|
dst_dir = os.path.dirname(new_path)
|
||||||
|
await asyncio.to_thread(os.makedirs, dst_dir, exist_ok=True)
|
||||||
|
await asyncio.to_thread(shutil.copy2, src_path, new_path)
|
||||||
|
|
||||||
|
new_file = DatasetFiles(
|
||||||
|
dataset_id=target_ds.id, # type: ignore
|
||||||
|
file_name=file_name,
|
||||||
|
file_path=new_path,
|
||||||
|
file_type=f.file_type,
|
||||||
|
file_size=f.file_size,
|
||||||
|
check_sum=f.check_sum,
|
||||||
|
tags=f.tags,
|
||||||
|
dataset_filemetadata=f.dataset_filemetadata,
|
||||||
|
status="ACTIVE",
|
||||||
|
)
|
||||||
|
session.add(new_file)
|
||||||
|
existing_paths.add(new_path)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_new_file_name(dst_prefix: str, existing_paths: set[Any], f) -> str:
|
||||||
|
file_name = f.file_name
|
||||||
|
new_path = dst_prefix + file_name
|
||||||
|
|
||||||
|
# Handle file path conflicts by appending a number to the filename
|
||||||
|
if new_path in existing_paths:
|
||||||
|
file_name_base, file_ext = os.path.splitext(file_name)
|
||||||
|
counter = 1
|
||||||
|
original_file_name = file_name
|
||||||
|
while new_path in existing_paths:
|
||||||
|
file_name = f"{file_name_base}_{counter}{file_ext}"
|
||||||
|
new_path = f"{dst_prefix}{file_name}"
|
||||||
|
counter += 1
|
||||||
|
if counter > 1000: # Safety check to prevent infinite loops
|
||||||
|
logger.error(f"Could not find unique filename for {original_file_name} after 1000 attempts")
|
||||||
|
break
|
||||||
|
return file_name
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_files(rel: RatioRelation, session) -> list[Any]:
|
||||||
|
# Fetch all files for the source dataset (ACTIVE only)
|
||||||
|
files_res = await session.execute(
|
||||||
|
select(DatasetFiles).where(
|
||||||
|
DatasetFiles.dataset_id == rel.source_dataset_id,
|
||||||
|
DatasetFiles.status == "ACTIVE",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
files = list(files_res.scalars().all())
|
||||||
|
|
||||||
|
# TAG mode: filter by tags according to relation.filter_conditions
|
||||||
|
conditions = RatioTaskService._parse_conditions(rel.filter_conditions)
|
||||||
|
if conditions:
|
||||||
|
files = [f for f in files if RatioTaskService._filter_file(f, conditions)]
|
||||||
|
return files
|
||||||
|
|
||||||
# ------------------------- helpers for TAG filtering ------------------------- #
|
# ------------------------- helpers for TAG filtering ------------------------- #
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _parse_required_tags(conditions: Optional[str]) -> set[str]:
|
def _parse_conditions(conditions: Optional[str]) -> Optional[FilterCondition]:
|
||||||
"""Parse filter_conditions into a set of required tag strings.
|
"""Parse filter_conditions JSON string into a FilterCondition object.
|
||||||
|
|
||||||
Supports simple separators: comma, semicolon, space. Empty/None -> empty set.
|
Args:
|
||||||
|
conditions: JSON string containing filter conditions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FilterCondition object if conditions is not None/empty, otherwise None
|
||||||
"""
|
"""
|
||||||
if not conditions:
|
if not conditions:
|
||||||
return set()
|
return None
|
||||||
data = json.loads(conditions)
|
try:
|
||||||
required_tags = set()
|
data = json.loads(conditions)
|
||||||
if data.get("label"):
|
return FilterCondition(**data)
|
||||||
required_tags.add(data["label"])
|
except json.JSONDecodeError as e:
|
||||||
return required_tags
|
logger.error(f"Failed to parse filter conditions: {e}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating FilterCondition: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _file_contains_tags(file: DatasetFiles, required: set[str]) -> bool:
|
def _filter_file(file: DatasetFiles, conditions: FilterCondition) -> bool:
|
||||||
if not required:
|
if not conditions:
|
||||||
return True
|
return True
|
||||||
tags = file.tags
|
logger.info(f"start filter file: {file}, conditions: {conditions}")
|
||||||
if not tags:
|
|
||||||
return False
|
# Check data range condition if provided
|
||||||
try:
|
if conditions.date_range:
|
||||||
# tags could be a list of strings or list of objects with 'name'
|
try:
|
||||||
tag_names = RatioTaskService.get_all_tags(tags)
|
from datetime import datetime, timedelta
|
||||||
return required.issubset(tag_names)
|
data_range_days = int(conditions.date_range)
|
||||||
except Exception as e:
|
if data_range_days > 0:
|
||||||
logger.exception(f"Failed to get tags for {file}", e)
|
cutoff_date = datetime.now() - timedelta(days=data_range_days)
|
||||||
return False
|
if file.tags_updated_at and file.tags_updated_at < cutoff_date:
|
||||||
|
return False
|
||||||
|
except (ValueError, TypeError) as e:
|
||||||
|
logger.warning(f"Invalid data_range value: {conditions.date_range}", e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check label condition if provided
|
||||||
|
if conditions.label:
|
||||||
|
tags = file.tags
|
||||||
|
if not tags:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
# tags could be a list of strings or list of objects with 'name'
|
||||||
|
tag_names = RatioTaskService.get_all_tags(tags)
|
||||||
|
return conditions.label in tag_names
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Failed to get tags for {file}", e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_all_tags(tags) -> set[str]:
|
def get_all_tags(tags) -> set[str]:
|
||||||
|
|||||||
Reference in New Issue
Block a user