feature: 数据配比增加通过更新时间来配置 (#95)

* feature: 数据配比增加通过更新时间来配置

* fix: 修复配比时间参数传递的问题
This commit is contained in:
hefanli
2025-11-20 18:50:51 +08:00
committed by GitHub
parent 955ffff6cd
commit cddfe9b149
10 changed files with 458 additions and 595 deletions

View File

@@ -7,7 +7,6 @@ import { useNavigate } from "react-router";
import SelectDataset from "@/pages/RatioTask/Create/components/SelectDataset.tsx"; import SelectDataset from "@/pages/RatioTask/Create/components/SelectDataset.tsx";
import BasicInformation from "@/pages/RatioTask/Create/components/BasicInformation.tsx"; import BasicInformation from "@/pages/RatioTask/Create/components/BasicInformation.tsx";
import RatioConfig from "@/pages/RatioTask/Create/components/RatioConfig.tsx"; import RatioConfig from "@/pages/RatioTask/Create/components/RatioConfig.tsx";
import RatioTransfer from "./components/RatioTransfer";
export default function CreateRatioTask() { export default function CreateRatioTask() {
const navigate = useNavigate(); const navigate = useNavigate();
@@ -36,27 +35,12 @@ export default function CreateRatioTask() {
message.error("请配置配比项"); message.error("请配置配比项");
return; return;
} }
// Build request payload
const ratio_method =
ratioTaskForm.ratioType === "dataset" ? "DATASET" : "TAG";
const totals = String(values.totalTargetCount); const totals = String(values.totalTargetCount);
const config = ratioTaskForm.ratioConfigs.map((c) => { const config = ratioTaskForm.ratioConfigs.map((c) => {
if (ratio_method === "DATASET") {
return {
datasetId: String(c.source),
counts: String(c.quantity ?? 0),
filter_conditions: "",
};
}
// TAG mode: source key like `${datasetId}_${label}`
const source = String(c.source || "");
const idx = source.indexOf("_");
const datasetId = idx > 0 ? source.slice(0, idx) : source;
const label = idx > 0 ? source.slice(idx + 1) : "";
return { return {
datasetId, datasetId: c.id,
counts: String(c.quantity ?? 0), counts: String(c.quantity ?? 0),
filter_conditions: label ? JSON.stringify({ label }) : "", filterConditions: { label: c.labelFilter, dateRange: String(c.dateRange ?? 0)},
}; };
}); });
@@ -65,7 +49,6 @@ export default function CreateRatioTask() {
name: values.name, name: values.name,
description: values.description, description: values.description,
totals, totals,
ratio_method,
config, config,
}); });
message.success("配比任务创建成功"); message.success("配比任务创建成功");
@@ -108,13 +91,6 @@ export default function CreateRatioTask() {
totalTargetCount={ratioTaskForm.totalTargetCount} totalTargetCount={ratioTaskForm.totalTargetCount}
/> />
{/* <RatioTransfer
ratioTaskForm={ratioTaskForm}
distributions={distributions}
updateRatioConfig={updateRatioConfig}
updateLabelRatioConfig={updateLabelRatioConfig}
/> */}
<div className="flex h-full"> <div className="flex h-full">
<SelectDataset <SelectDataset
selectedDatasets={ratioTaskForm.selectedDatasets} selectedDatasets={ratioTaskForm.selectedDatasets}

View File

@@ -27,7 +27,7 @@ const BasicInformation: React.FC<BasicInformationProps> = ({
<Input type="number" placeholder="目标总数量" min={1} /> <Input type="number" placeholder="目标总数量" min={1} />
</Form.Item> </Form.Item>
<Form.Item label="任务描述" name="description" className="col-span-2"> <Form.Item label="任务描述" name="description" className="col-span-2">
<TextArea placeholder="描述配比任务的目的和要求(可选)" rows={2} /> <TextArea placeholder="描述配比任务的目的和要求" rows={2} />
</Form.Item> </Form.Item>
</div> </div>
); );

View File

@@ -1,7 +1,19 @@
import React, { useMemo, useState } from "react"; import React, { useMemo, useState } from "react";
import { Badge, Card, Input, Progress, Button, Divider } from "antd"; import { Badge, Card, Input, Progress, Button, DatePicker, Select } from "antd";
import { BarChart3 } from "lucide-react"; import { BarChart3, Filter, Clock } from "lucide-react";
import type { Dataset } from "@/pages/DataManagement/dataset.model.ts"; import type { Dataset } from "@/pages/DataManagement/dataset.model.ts";
import dayjs from 'dayjs';
const { RangePicker } = DatePicker;
const { Option } = Select;
const TIME_RANGE_OPTIONS = [
{ label: '最近1天', value: 1 },
{ label: '最近3天', value: 3 },
{ label: '最近7天', value: 7 },
{ label: '最近15天', value: 15 },
{ label: '最近30天', value: 30 },
];
interface RatioConfigItem { interface RatioConfigItem {
id: string; id: string;
@@ -10,6 +22,8 @@ interface RatioConfigItem {
quantity: number; quantity: number;
percentage: number; percentage: number;
source: string; source: string;
labelFilter?: string;
dateRange?: string;
} }
interface RatioConfigProps { interface RatioConfigProps {
@@ -22,14 +36,18 @@ interface RatioConfigProps {
} }
const RatioConfig: React.FC<RatioConfigProps> = ({ const RatioConfig: React.FC<RatioConfigProps> = ({
ratioType, ratioType,
selectedDatasets, selectedDatasets,
datasets, datasets,
totalTargetCount, totalTargetCount,
distributions, distributions,
onChange, onChange,
}) => { }) => {
const [ratioConfigs, setRatioConfigs] = useState<RatioConfigItem[]>([]); const [ratioConfigs, setRatioConfigs] = useState<RatioConfigItem[]>([]);
const [datasetFilters, setDatasetFilters] = useState<Record<string, {
labelFilter?: string;
dateRange?: string;
}>>({});
// 配比项总数 // 配比项总数
const totalConfigured = useMemo( const totalConfigured = useMemo(
@@ -37,6 +55,36 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
[ratioConfigs] [ratioConfigs]
); );
// 获取数据集的标签列表
const getDatasetLabels = (datasetId: string): string[] => {
const dist = distributions[String(datasetId)] || {};
return Object.keys(dist);
};
// 自动平均分配
const generateAutoRatio = () => {
const selectedCount = selectedDatasets.length;
if (selectedCount === 0) return;
const baseQuantity = Math.floor(totalTargetCount / selectedCount);
const remainder = totalTargetCount % selectedCount;
const newConfigs = selectedDatasets.map((datasetId, index) => {
const dataset = datasets.find((d) => String(d.id) === datasetId);
const quantity = baseQuantity + (index < remainder ? 1 : 0);
return {
id: datasetId,
name: dataset?.name || datasetId,
type: ratioType,
quantity,
percentage: Math.round((quantity / totalTargetCount) * 100),
source: datasetId,
labelFilter: datasetFilters[datasetId]?.labelFilter,
dateRange: datasetFilters[datasetId]?.dateRange,
};
});
setRatioConfigs(newConfigs);
onChange?.(newConfigs);
};
// 更新数据集配比项 // 更新数据集配比项
const updateDatasetQuantity = (datasetId: string, quantity: number) => { const updateDatasetQuantity = (datasetId: string, quantity: number) => {
setRatioConfigs((prev) => { setRatioConfigs((prev) => {
@@ -55,6 +103,8 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
quantity: Math.min(quantity, totalTargetCount - totalOtherQuantity), quantity: Math.min(quantity, totalTargetCount - totalOtherQuantity),
percentage: Math.round((quantity / totalTargetCount) * 100), percentage: Math.round((quantity / totalTargetCount) * 100),
source: datasetId, source: datasetId,
labelFilter: datasetFilters[datasetId]?.labelFilter,
dateRange: datasetFilters[datasetId]?.dateRange,
}; };
let newConfigs; let newConfigs;
@@ -69,78 +119,85 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
}); });
}; };
// 自动平均分配 // 更新筛选条件
const generateAutoRatio = () => { const updateFilters = (datasetId: string, updates: {
const selectedCount = selectedDatasets.length; labelFilter?: string;
if (selectedCount === 0) return; dateRange?: [string, string];
const baseQuantity = Math.floor(totalTargetCount / selectedCount); }) => {
const remainder = totalTargetCount % selectedCount; setDatasetFilters(prev => ({
const newConfigs = selectedDatasets.map((datasetId, index) => { ...prev,
const dataset = datasets.find((d) => String(d.id) === datasetId); [datasetId]: {
const quantity = baseQuantity + (index < remainder ? 1 : 0); ...prev[datasetId],
return { ...updates,
id: datasetId,
name: dataset?.name || datasetId,
type: ratioType,
quantity,
percentage: Math.round((quantity / totalTargetCount) * 100),
source: datasetId,
};
});
setRatioConfigs(newConfigs);
onChange?.(newConfigs);
};
// 标签模式下,更新某数据集的某个标签的数量
const updateLabelQuantity = (
datasetId: string,
label: string,
quantity: number
) => {
const sourceKey = `${datasetId}_${label}`;
setRatioConfigs((prev) => {
const existingIndex = prev.findIndex((c) => c.source === sourceKey);
const totalOtherQuantity = prev
.filter((c) => c.source !== sourceKey)
.reduce((sum, c) => sum + c.quantity, 0);
const dist = distributions[datasetId] || {};
const labelMax = dist[label] ?? Infinity;
const cappedQuantity = Math.max(
0,
Math.min(quantity, totalTargetCount - totalOtherQuantity, labelMax)
);
const newConfig: RatioConfigItem = {
id: sourceKey,
name: label,
type: "label",
quantity: cappedQuantity,
percentage: Math.round((cappedQuantity / totalTargetCount) * 100),
source: sourceKey,
};
let newConfigs;
if (existingIndex >= 0) {
newConfigs = [...prev];
newConfigs[existingIndex] = newConfig;
} else {
newConfigs = [...prev, newConfig];
} }
onChange?.(newConfigs); }));
return newConfigs;
});
}; };
// 选中数据集变化时,移除未选中的配比项 // 渲染筛选器
const renderFilters = (datasetId: string) => {
const labels = getDatasetLabels(datasetId);
const config = ratioConfigs.find(c => c.source === datasetId);
const filters = datasetFilters[datasetId] || {};
return (
<div className="mb-3 p-2 bg-gray-50 rounded">
<div className="flex items-center gap-2 mb-2">
<Filter size={14} className="text-gray-400" />
<span className="text-xs font-medium"></span>
</div>
<div className="grid grid-cols-1 md:grid-cols-2 gap-3">
<div>
<div className="text-xs text-gray-500 mb-1"></div>
<Select
style={{ width: '100%' }}
placeholder="选择标签"
value={filters.labelFilter}
onChange={(value) => updateFilters(datasetId, { labelFilter: value })}
allowClear
onClear={() => updateFilters(datasetId, { labelFilter: undefined })}
>
{labels.map(label => (
<Option key={label} value={label}>{label}</Option>
))}
</Select>
</div>
<div>
<div className="text-xs text-gray-500 mb-1"></div>
<Select
style={{ width: '100%' }}
placeholder="选择标签更新时间"
value={filters.dateRange}
onChange={(dates) => updateFilters(datasetId, { dateRange: dates })}
allowClear
onClear={() => updateFilters(datasetId, { dateRange: undefined })}
>
{TIME_RANGE_OPTIONS.map(option => (
<Option key={option.value} value={option.value}>
{option.label}
</Option>
))}
</Select>
</div>
</div>
</div>
);
};
// 选中数据集变化时,初始化筛选条件
React.useEffect(() => { React.useEffect(() => {
setRatioConfigs((prev) => { const initialFilters: Record<string, any> = {};
const next = prev.filter((c) => { selectedDatasets.forEach(datasetId => {
const id = String(c.source); const config = ratioConfigs.find(c => c.source === datasetId);
const dsId = id.includes("_") ? id.split("_")[0] : id; if (config) {
return selectedDatasets.includes(dsId); initialFilters[datasetId] = {
}); labelFilter: config.labelFilter,
if (next !== prev) onChange?.(next); dateRange: config.dateRange,
return next; };
}
}); });
// eslint-disable-next-line setDatasetFilters(prev => ({ ...prev, ...initialFilters }));
}, [selectedDatasets]); }, [selectedDatasets]);
return ( return (
@@ -148,7 +205,7 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
<div className="flex items-center justify-between p-4 border-bottom"> <div className="flex items-center justify-between p-4 border-bottom">
<span className="text-sm font-bold"> <span className="text-sm font-bold">
<span className="text-xs text-gray-500"> <span className="text-xs text-gray-500 ml-1">
(:{totalConfigured}/{totalTargetCount}) (:{totalConfigured}/{totalTargetCount})
</span> </span>
</span> </span>
@@ -170,41 +227,36 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
<div className="flex-overflow-auto gap-4 p-4"> <div className="flex-overflow-auto gap-4 p-4">
{/* 配比预览 */} {/* 配比预览 */}
{ratioConfigs.length > 0 && ( {ratioConfigs.length > 0 && (
<div> <div className="p-3 bg-gray-50 rounded-lg mb-4">
<div className="p-3 bg-gray-50 rounded-lg"> <div className="grid grid-cols-2 gap-4 text-sm">
<div className="grid grid-cols-2 gap-4 text-sm"> <div>
<div> <span className="text-gray-500">:</span>
<span className="text-gray-500">:</span> <span className="ml-2 font-medium">
<span className="ml-2 font-medium"> {ratioConfigs
{ratioConfigs .reduce((sum, config) => sum + config.quantity, 0)
.reduce((sum, config) => sum + config.quantity, 0) .toLocaleString()}
.toLocaleString()} </span>
</span> </div>
</div> <div>
<div> <span className="text-gray-500">:</span>
<span className="text-gray-500">:</span> <span className="ml-2 font-medium">
<span className="ml-2 font-medium"> {totalTargetCount.toLocaleString()}
{totalTargetCount.toLocaleString()} </span>
</span>
</div>
<div>
<span className="text-gray-500">:</span>
<span className="ml-2 font-medium">
{ratioConfigs.length}
</span>
</div>
</div> </div>
</div> </div>
</div> </div>
)} )}
<div className="flex-1 overflow-auto">
<div className="flex-1 overflow-auto space-y-4">
{selectedDatasets.map((datasetId) => { {selectedDatasets.map((datasetId) => {
const dataset = datasets.find((d) => String(d.id) === datasetId); const dataset = datasets.find((d) => String(d.id) === datasetId);
const config = ratioConfigs.find((c) => c.source === datasetId); const config = ratioConfigs.find((c) => c.source === datasetId);
const currentQuantity = config?.quantity || 0; const currentQuantity = config?.quantity || 0;
if (!dataset) return null; if (!dataset) return null;
return ( return (
<Card key={datasetId} size="small" className="mb-2"> <Card key={datasetId} size="small" className="mb-4">
<div className="flex items-center justify-between mb-3"> <div className="flex items-center justify-between mb-3">
<div className="flex items-center gap-2"> <div className="flex items-center gap-2">
<span className="font-medium text-sm"> <span className="font-medium text-sm">
@@ -216,97 +268,36 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
{config?.percentage || 0}% {config?.percentage || 0}%
</div> </div>
</div> </div>
{ratioType === "dataset" ? (
<div> {/* 筛选条件 */}
<div className="flex items-center gap-2 mb-2"> {renderFilters(datasetId)}
<span className="text-xs">:</span>
<Input <div className="flex items-center gap-2 mb-2">
type="number" <span className="text-xs">:</span>
value={currentQuantity} <Input
onChange={(e) => type="number"
updateDatasetQuantity( value={currentQuantity}
datasetId, onChange={(e) =>
Number(e.target.value) updateDatasetQuantity(
) datasetId,
} Number(e.target.value)
style={{ width: 80 }} )
min={0} }
max={Math.min( style={{ width: 100 }}
dataset.fileCount || 0, min={0}
totalTargetCount max={Math.min(
)} dataset.fileCount || 0,
/> totalTargetCount
<span className="text-xs text-gray-500"></span>
</div>
<Progress
percent={Math.round(
(currentQuantity / totalTargetCount) * 100
)}
size="small"
/>
</div>
) : (
<div>
{!distributions[String(dataset.id)] ? (
<div className="text-xs text-gray-400">
...
</div>
) : Object.entries(distributions[String(dataset.id)])
.length === 0 ? (
<div className="text-xs text-gray-400">
</div>
) : (
<div className="flex flex-col gap-2">
{Object.entries(
distributions[String(dataset.id)]
).map(([label, count]) => {
const sourceKey = `${datasetId}_${label}`;
const labelConfig = ratioConfigs.find(
(c) => c.source === sourceKey
);
const labelQuantity = labelConfig?.quantity || 0;
return (
<div
key={label}
className="flex items-center justify-between gap-2"
>
<div className="flex items-center gap-2">
<Badge color="gray">{label}</Badge>
<span className="text-xs text-gray-500">
{count}
</span>
</div>
<div className="flex items-center gap-2">
<span className="text-xs">:</span>
<Input
type="number"
value={labelQuantity}
onChange={(e) =>
updateLabelQuantity(
datasetId,
label,
Number(e.target.value)
)
}
style={{ width: 80 }}
min={0}
max={Math.min(
Number(count) || 0,
totalTargetCount
)}
/>
<span className="text-xs text-gray-500">
</span>
</div>
</div>
);
})}
</div>
)} )}
</div> />
)} <span className="text-xs text-gray-500"></span>
</div>
<Progress
percent={Math.round(
(currentQuantity / totalTargetCount) * 100
)}
size="small"
/>
</Card> </Card>
); );
})} })}

View File

@@ -1,169 +0,0 @@
import React, { useMemo } from "react";
import { Table } from "antd";
import { TransferItem } from "antd/es/transfer";
import RatioConfig from "./RatioConfig";
import useFetchData from "@/hooks/useFetchData";
import { queryDatasetsUsingGet } from "@/pages/DataManagement/dataset.api";
import {
datasetTypeMap,
mapDataset,
} from "@/pages/DataManagement/dataset.const";
import { SearchControls } from "@/components/SearchControls";
const leftColumns = [
{
dataIndex: "name",
title: "名称",
ellipsis: true,
},
{
dataIndex: "datasetType",
title: "类型",
ellipsis: true,
width: 100,
render: (type: string) => datasetTypeMap[type].label,
},
{
dataIndex: "size",
title: "大小",
width: 100,
ellipsis: true,
},
];
export default function RatioTransfer(props: {
distributions: Record<string, Record<string, number>>;
ratioTaskForm: any;
updateRatioConfig: (datasetId: string, quantity: number) => void;
updateLabelRatioConfig: (
datasetId: string,
label: string,
quantity: number
) => void;
}) {
const {
updateLabelRatioConfig,
updateRatioConfig,
ratioTaskForm,
distributions,
} = props;
const {
tableData: datasets,
loading,
pagination,
searchParams,
setSearchParams,
handleFiltersChange,
} = useFetchData(queryDatasetsUsingGet, mapDataset);
const [selectedDatasets, setSelectedDatasets] = React.useState<
TransferItem[]
>([]);
const selectedRowKeys = useMemo(() => {
return selectedDatasets.map((item) => item.key);
}, [selectedDatasets]);
const [listDisabled, setListDisabled] = React.useState(false);
const generateAutoRatio = () => {
const selectedCount = ratioTaskForm.selectedDatasets.length;
if (selectedCount === 0) return;
const baseQuantity = Math.floor(
ratioTaskForm.totalTargetCount / selectedCount
);
const remainder = ratioTaskForm.totalTargetCount % selectedCount;
const newConfigs = ratioTaskForm.selectedDatasets.map(
(datasetId, index) => {
const quantity = baseQuantity + (index < remainder ? 1 : 0);
return {
id: datasetId,
name: datasetId,
type: ratioTaskForm.ratioType,
quantity,
percentage: Math.round(
(quantity / ratioTaskForm.totalTargetCount) * 100
),
source: datasetId,
};
}
);
setRatioTaskForm((prev) => ({ ...prev, ratioConfigs: newConfigs }));
};
return (
<div className="flex">
<div className="border-card flex-1 mr-4">
<h3 className="p-2 border-bottom">{`${selectedDatasets.length} / ${datasets.length}`}</h3>
<SearchControls
searchTerm={searchParams.keyword}
onSearchChange={(keyword) =>
setSearchParams({ ...searchParams, keyword })
}
searchPlaceholder="搜索数据集名称..."
filters={[
{
key: "type",
label: "数据集类型",
options: [
{ value: "dataset", label: "按数据集" },
{ value: "tag", label: "按标签" },
],
},
]}
onFiltersChange={handleFiltersChange}
onClearFilters={() =>
setSearchParams({ ...searchParams, filter: {} })
}
showViewToggle={false}
showReload={false}
className="m-4"
/>
<Table
rowSelection={{
onChange: (_, selectedRows) => {
setSelectedDatasets(selectedRows);
},
selectedRowKeys,
selections: [
Table.SELECTION_ALL,
Table.SELECTION_INVERT,
Table.SELECTION_NONE,
],
}}
columns={leftColumns}
dataSource={datasets}
loading={loading}
pagination={pagination}
size="small"
rowKey="id"
style={{ pointerEvents: listDisabled ? "none" : undefined }}
onRow={(record) => ({
onClick: () => {
if (record.disabled || listDisabled) {
return;
}
setSelectedDatasets((prev) => {
if (prev.includes(record.key)) {
return prev.filter((k) => k !== record.key);
}
return [...prev, record.key];
});
},
})}
/>
</div>
<div className="border-card flex-1">
<RatioConfig
datasets={selectedDatasets}
ratioTaskForm={ratioTaskForm}
distributions={distributions}
onUpdateDatasetQuantity={updateRatioConfig}
onUpdateLabelQuantity={updateLabelRatioConfig}
/>
</div>
</div>
);
}

View File

@@ -1,5 +1,5 @@
import React, { useEffect, useState } from "react"; import React, { useEffect, useState } from "react";
import { Badge, Button, Card, Checkbox, Input, Pagination, Select } from "antd"; import { Badge, Button, Card, Checkbox, Input, Pagination } from "antd";
import { Search as SearchIcon } from "lucide-react"; import { Search as SearchIcon } from "lucide-react";
import type { Dataset } from "@/pages/DataManagement/dataset.model.ts"; import type { Dataset } from "@/pages/DataManagement/dataset.model.ts";
import { import {
@@ -10,8 +10,6 @@ import {
interface SelectDatasetProps { interface SelectDatasetProps {
selectedDatasets: string[]; selectedDatasets: string[];
ratioType: "dataset" | "label";
onRatioTypeChange: (val: "dataset" | "label") => void;
onSelectedDatasetsChange: (next: string[]) => void; onSelectedDatasetsChange: (next: string[]) => void;
onDistributionsChange?: ( onDistributionsChange?: (
next: Record<string, Record<string, number>> next: Record<string, Record<string, number>>
@@ -21,8 +19,6 @@ interface SelectDatasetProps {
const SelectDataset: React.FC<SelectDatasetProps> = ({ const SelectDataset: React.FC<SelectDatasetProps> = ({
selectedDatasets, selectedDatasets,
ratioType,
onRatioTypeChange,
onSelectedDatasetsChange, onSelectedDatasetsChange,
onDistributionsChange, onDistributionsChange,
onDatasetsChange, onDatasetsChange,
@@ -62,7 +58,7 @@ const SelectDataset: React.FC<SelectDatasetProps> = ({
// Fetch label distributions when in label mode // Fetch label distributions when in label mode
useEffect(() => { useEffect(() => {
const fetchDistributions = async () => { const fetchDistributions = async () => {
if (ratioType !== "label" || !datasets?.length) return; if (!datasets?.length) return;
const idsToFetch = datasets const idsToFetch = datasets
.map((d) => String(d.id)) .map((d) => String(d.id))
.filter((id) => !distributions[id]); .filter((id) => !distributions[id]);
@@ -147,7 +143,7 @@ const SelectDataset: React.FC<SelectDatasetProps> = ({
}; };
fetchDistributions(); fetchDistributions();
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
}, [ratioType, datasets]); }, [datasets]);
const onToggleDataset = (datasetId: string, checked: boolean) => { const onToggleDataset = (datasetId: string, checked: boolean) => {
if (checked) { if (checked) {
@@ -180,18 +176,6 @@ const SelectDataset: React.FC<SelectDatasetProps> = ({
</Button> </Button>
</div> </div>
<div className="flex-overflow-auto gap-4 p-4"> <div className="flex-overflow-auto gap-4 p-4">
<div className="flex items-center gap-4">
<span className="text-sm">:</span>
<Select
className="flex-1 min-w-[120px]"
value={ratioType}
onChange={(v) => onRatioTypeChange(v)}
options={[
{ label: "按数据集", value: "dataset" },
{ label: "按标签", value: "label" },
]}
/>
</div>
<Input <Input
prefix={<SearchIcon className="text-gray-400" />} prefix={<SearchIcon className="text-gray-400" />}
placeholder="搜索数据集" placeholder="搜索数据集"
@@ -239,32 +223,30 @@ const SelectDataset: React.FC<SelectDatasetProps> = ({
<span>{dataset.fileCount}</span> <span>{dataset.fileCount}</span>
<span>{dataset.size}</span> <span>{dataset.size}</span>
</div> </div>
{ratioType === "label" && ( <div className="mt-2">
<div className="mt-2"> {distributions[idStr] ? (
{distributions[idStr] ? ( Object.entries(distributions[idStr]).length > 0 ? (
Object.entries(distributions[idStr]).length > 0 ? ( <div className="flex flex-wrap gap-2 text-xs">
<div className="flex flex-wrap gap-2 text-xs"> {Object.entries(distributions[idStr])
{Object.entries(distributions[idStr]) .slice(0, 8)
.slice(0, 8) .map(([tag, count]) => (
.map(([tag, count]) => ( <Badge
<Badge key={tag}
key={tag} color="gray"
color="gray" >{`${tag}: ${count}`}</Badge>
>{`${tag}: ${count}`}</Badge> ))}
))} </div>
</div>
) : (
<div className="text-xs text-gray-400">
</div>
)
) : ( ) : (
<div className="text-xs text-gray-400"> <div className="text-xs text-gray-400">
...
</div> </div>
)} )
</div> ) : (
)} <div className="text-xs text-gray-400">
...
</div>
)}
</div>
</div> </div>
</div> </div>
</Card> </Card>

View File

@@ -1,11 +1,13 @@
from .common import ( from .common import (
BaseResponseModel, BaseResponseModel,
StandardResponse, StandardResponse,
PaginatedData PaginatedData,
TaskStatus
) )
__all__ = [ __all__ = [
"BaseResponseModel", "BaseResponseModel",
"StandardResponse", "StandardResponse",
"PaginatedData" "PaginatedData",
"TaskStatus"
] ]

View File

@@ -1,8 +1,9 @@
""" """
通用响应模型 通用响应模型
""" """
from typing import Generic, TypeVar, Optional, List, Type from typing import Generic, TypeVar, List
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from enum import Enum
# 定义泛型类型变量 # 定义泛型类型变量
T = TypeVar('T') T = TypeVar('T')
@@ -42,3 +43,9 @@ class PaginatedData(BaseResponseModel, Generic[T]):
total_elements: int = Field(..., description="总条数") total_elements: int = Field(..., description="总条数")
total_pages: int = Field(..., description="总页数") total_pages: int = Field(..., description="总页数")
content: List[T] = Field(..., description="当前页数据") content: List[T] = Field(..., description="当前页数据")
class TaskStatus(Enum):
PENDING = "PENDING"
RUNNING = "RUNNING"
COMPLETED = "COMPLETED"
FAILED = "FAILED"

View File

@@ -12,7 +12,7 @@ from app.core.logging import get_logger
from app.db.models import Dataset from app.db.models import Dataset
from app.db.session import get_db from app.db.session import get_db
from app.module.dataset import DatasetManagementService from app.module.dataset import DatasetManagementService
from app.module.shared.schema import StandardResponse from app.module.shared.schema import StandardResponse, TaskStatus
from app.module.synthesis.schema.ratio_task import ( from app.module.synthesis.schema.ratio_task import (
CreateRatioTaskResponse, CreateRatioTaskResponse,
CreateRatioTaskRequest, CreateRatioTaskRequest,
@@ -49,52 +49,18 @@ async def create_ratio_task(
await valid_exists(db, req) await valid_exists(db, req)
# 创建目标数据集:名称使用“<任务名称>-配比生成-时间戳” target_dataset = await create_target_dataset(db, req, source_types)
target_dataset_name = f"{req.name}-配比生成-{datetime.now().strftime('%Y%m%d%H%M%S')}"
target_type = get_target_dataset_type(source_types) instance = await create_ratio_instance(db, req, target_dataset)
target_dataset = Dataset(
id=str(uuid.uuid4()),
name=target_dataset_name,
description=req.description or "",
dataset_type=target_type,
status="DRAFT",
)
target_dataset.path = f"/dataset/{target_dataset.id}"
db.add(target_dataset)
await db.flush() # 获取 target_dataset.id
service = RatioTaskService(db)
instance = await service.create_task(
name=req.name,
description=req.description,
totals=int(req.totals),
ratio_method=req.ratio_method,
config=[
{
"dataset_id": item.dataset_id,
"counts": int(item.counts),
"filter_conditions": item.filter_conditions,
}
for item in req.config
],
target_dataset_id=target_dataset.id,
)
# 异步执行配比任务(支持 DATASET / TAG)
asyncio.create_task(RatioTaskService.execute_dataset_ratio_task(instance.id)) asyncio.create_task(RatioTaskService.execute_dataset_ratio_task(instance.id))
return StandardResponse( response_data = CreateRatioTaskResponse(
code=200,
message="success",
data=CreateRatioTaskResponse(
id=instance.id, id=instance.id,
name=instance.name, name=instance.name,
description=instance.description, description=instance.description,
totals=instance.totals or 0, totals=instance.totals or 0,
ratio_method=instance.ratio_method or req.ratio_method, status=instance.status or TaskStatus.PENDING.name,
status=instance.status or "PENDING",
config=req.config, config=req.config,
targetDataset=TargetDatasetInfo( targetDataset=TargetDatasetInfo(
id=str(target_dataset.id), id=str(target_dataset.id),
@@ -103,6 +69,10 @@ async def create_ratio_task(
status=str(target_dataset.status), status=str(target_dataset.status),
) )
) )
return StandardResponse(
code=200,
message="success",
data=response_data
) )
except HTTPException: except HTTPException:
await db.rollback() await db.rollback()
@@ -113,6 +83,46 @@ async def create_ratio_task(
raise HTTPException(status_code=500, detail="Internal server error") raise HTTPException(status_code=500, detail="Internal server error")
async def create_ratio_instance(db, req: CreateRatioTaskRequest, target_dataset: Dataset) -> RatioInstance:
service = RatioTaskService(db)
logger.info(f"create_ratio_instance: {req}")
instance = await service.create_task(
name=req.name,
description=req.description,
totals=int(req.totals),
config=[
{
"dataset_id": item.dataset_id,
"counts": int(item.counts),
"filter_conditions": item.filter_conditions,
}
for item in req.config
],
target_dataset_id=target_dataset.id,
)
return instance
async def create_target_dataset(db, req: CreateRatioTaskRequest, source_types: set[str]) -> Dataset:
# 创建目标数据集:名称使用“<任务名称>-时间戳”
target_dataset_name = f"{req.name}-{datetime.now().strftime('%Y%m%d%H%M%S')}"
target_type = get_target_dataset_type(source_types)
target_dataset_id = uuid.uuid4()
target_dataset = Dataset(
id=str(target_dataset_id),
name=target_dataset_name,
description=req.description or "",
dataset_type=target_type,
status="DRAFT",
path=f"/dataset/{target_dataset_id}",
)
db.add(target_dataset)
await db.flush() # 获取 target_dataset.id
return target_dataset
@router.get("", response_model=StandardResponse[PagedRatioTaskResponse], status_code=200) @router.get("", response_model=StandardResponse[PagedRatioTaskResponse], status_code=200)
async def list_ratio_tasks( async def list_ratio_tasks(
page: int = 1, page: int = 1,

View File

@@ -2,10 +2,36 @@ from typing import List, Optional, Dict, Any
from datetime import datetime from datetime import datetime
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from app.core.logging import get_logger
from app.module.shared.schema.common import TaskStatus
logger = get_logger(__name__)
class FilterCondition(BaseModel):
date_range: Optional[str] = Field(None, description="数据范围", alias="dateRange")
label: Optional[str] = Field(None, description="标签")
@field_validator("date_range")
@classmethod
def validate_date_range(cls, v: Optional[str]) -> Optional[str]:
# ensure it's a numeric string if provided
if not v:
return v
try:
int(v)
return v
except (ValueError, TypeError) as e:
raise ValueError("date_range must be a numeric string")
class Config:
# allow population by field name when constructing model programmatically
validate_by_name = True
class RatioConfigItem(BaseModel): class RatioConfigItem(BaseModel):
dataset_id: str = Field(..., alias="datasetId", description="数据集id") dataset_id: str = Field(..., alias="datasetId", description="数据集id")
counts: str = Field(..., description="数量") counts: str = Field(..., description="数量")
filter_conditions: str = Field(..., description="过滤条件") filter_conditions: FilterCondition = Field(..., alias="filterConditions", description="过滤条件")
@field_validator("counts") @field_validator("counts")
@classmethod @classmethod
@@ -22,17 +48,8 @@ class CreateRatioTaskRequest(BaseModel):
name: str = Field(..., description="名称") name: str = Field(..., description="名称")
description: Optional[str] = Field(None, description="描述") description: Optional[str] = Field(None, description="描述")
totals: str = Field(..., description="目标数量") totals: str = Field(..., description="目标数量")
ratio_method: str = Field(..., description="配比方式", alias="ratio_method")
config: List[RatioConfigItem] = Field(..., description="配比设置列表") config: List[RatioConfigItem] = Field(..., description="配比设置列表")
@field_validator("ratio_method")
@classmethod
def validate_ratio_method(cls, v: str) -> str:
allowed = {"TAG", "DATASET"}
if v not in allowed:
raise ValueError(f"ratio_method must be one of {allowed}")
return v
@field_validator("totals") @field_validator("totals")
@classmethod @classmethod
def validate_totals(cls, v: str) -> str: def validate_totals(cls, v: str) -> str:
@@ -58,8 +75,7 @@ class CreateRatioTaskResponse(BaseModel):
name: str name: str
description: Optional[str] = None description: Optional[str] = None
totals: int totals: int
ratio_method: str status: TaskStatus
status: str
# echoed config # echoed config
config: List[RatioConfigItem] config: List[RatioConfigItem]
# created dataset # created dataset

View File

@@ -14,6 +14,8 @@ from app.db.models.ratio_task import RatioInstance, RatioRelation
from app.db.models import Dataset, DatasetFiles from app.db.models import Dataset, DatasetFiles
from app.db.session import AsyncSessionLocal from app.db.session import AsyncSessionLocal
from app.module.dataset.schema.dataset_file import DatasetFileTag from app.module.dataset.schema.dataset_file import DatasetFileTag
from app.module.shared.schema import TaskStatus
from app.module.synthesis.schema.ratio_task import FilterCondition
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -30,7 +32,6 @@ class RatioTaskService:
name: str, name: str,
description: Optional[str], description: Optional[str],
totals: int, totals: int,
ratio_method: str,
config: List[Dict[str, Any]], config: List[Dict[str, Any]],
target_dataset_id: Optional[str] = None, target_dataset_id: Optional[str] = None,
) -> RatioInstance: ) -> RatioInstance:
@@ -38,12 +39,11 @@ class RatioTaskService:
config item format: {"dataset_id": str, "counts": int, "filter_conditions": str} config item format: {"dataset_id": str, "counts": int, "filter_conditions": str}
""" """
logger.info(f"Creating ratio task: name={name}, method={ratio_method}, totals={totals}, items={len(config or [])}") logger.info(f"Creating ratio task: name={name}, totals={totals}, items={len(config or [])}")
instance = RatioInstance( instance = RatioInstance(
name=name, name=name,
description=description, description=description,
ratio_method=ratio_method,
totals=totals, totals=totals,
target_dataset_id=target_dataset_id, target_dataset_id=target_dataset_id,
status="PENDING", status="PENDING",
@@ -56,8 +56,12 @@ class RatioTaskService:
ratio_instance_id=instance.id, ratio_instance_id=instance.id,
source_dataset_id=item.get("dataset_id"), source_dataset_id=item.get("dataset_id"),
counts=int(item.get("counts", 0)), counts=int(item.get("counts", 0)),
filter_conditions=item.get("filter_conditions"), filter_conditions=json.dumps({
'date_range': item.get("filter_conditions").date_range,
'label': item.get("filter_conditions").label,
})
) )
logger.info(f"Relation created: {relation.id}, {relation}, {item}, {config}")
self.db.add(relation) self.db.add(relation)
await self.db.commit() await self.db.commit()
@@ -97,94 +101,17 @@ class RatioTaskService:
relations: List[RatioRelation] = list(rel_res.scalars().all()) relations: List[RatioRelation] = list(rel_res.scalars().all())
# Mark running # Mark running
instance.status = "RUNNING" instance.status = TaskStatus.RUNNING.name
if instance.ratio_method not in {"DATASET", "TAG"}:
logger.info(f"Instance {instance_id} ratio_method={instance.ratio_method} not supported yet")
instance.status = "SUCCESS"
return
# Load target dataset # Load target dataset
ds_res = await session.execute(select(Dataset).where(Dataset.id == instance.target_dataset_id)) ds_res = await session.execute(select(Dataset).where(Dataset.id == instance.target_dataset_id))
target_ds: Optional[Dataset] = ds_res.scalar_one_or_none() target_ds: Optional[Dataset] = ds_res.scalar_one_or_none()
if not target_ds: if not target_ds:
logger.error(f"Target dataset not found for instance {instance_id}") logger.error(f"Target dataset not found for instance {instance_id}")
instance.status = "FAILED" instance.status = TaskStatus.FAILED.name
return return
# Preload existing target file paths for deduplication added_count, added_size = await RatioTaskService.handle_ratio_relations(relations,session, target_ds)
existing_path_rows = await session.execute(
select(DatasetFiles.file_path).where(DatasetFiles.dataset_id == target_ds.id)
)
existing_paths = set(p for p in existing_path_rows.scalars().all() if p)
added_count = 0
added_size = 0
for rel in relations:
if not rel.source_dataset_id or not rel.counts or rel.counts <= 0:
continue
# Fetch all files for the source dataset (ACTIVE only)
files_res = await session.execute(
select(DatasetFiles).where(
DatasetFiles.dataset_id == rel.source_dataset_id,
DatasetFiles.status == "ACTIVE",
)
)
files = list(files_res.scalars().all())
# TAG mode: filter by tags according to relation.filter_conditions
if instance.ratio_method == "TAG":
required_tags = RatioTaskService._parse_required_tags(rel.filter_conditions)
if required_tags:
files = [f for f in files if RatioTaskService._file_contains_tags(f, required_tags)]
if not files:
continue
pick_n = min(rel.counts or 0, len(files))
chosen = random.sample(files, pick_n) if pick_n < len(files) else files
# Copy into target dataset with de-dup by target path
for f in chosen:
src_path = f.file_path
new_path = src_path
needs_copy = False
src_prefix = f"/dataset/{rel.source_dataset_id}"
if isinstance(src_path, str) and src_path.startswith(src_prefix):
dst_prefix = f"/dataset/{target_ds.id}"
new_path = src_path.replace(src_prefix, dst_prefix, 1)
needs_copy = True
# De-dup by target path
if new_path in existing_paths:
continue
# Perform copy only when needed
if needs_copy:
dst_dir = os.path.dirname(new_path)
await asyncio.to_thread(os.makedirs, dst_dir, exist_ok=True)
await asyncio.to_thread(shutil.copy2, src_path, new_path)
new_file = DatasetFiles(
dataset_id=target_ds.id, # type: ignore
file_name=f.file_name,
file_path=new_path,
file_type=f.file_type,
file_size=f.file_size,
check_sum=f.check_sum,
tags=f.tags,
dataset_filemetadata=f.dataset_filemetadata,
status="ACTIVE",
)
session.add(new_file)
existing_paths.add(new_path)
added_count += 1
added_size += int(f.file_size or 0)
# Periodically flush to avoid huge transactions
await session.flush()
# Update target dataset statistics # Update target dataset statistics
target_ds.file_count = (target_ds.file_count or 0) + added_count # type: ignore target_ds.file_count = (target_ds.file_count or 0) + added_count # type: ignore
@@ -194,8 +121,8 @@ class RatioTaskService:
target_ds.status = "ACTIVE" target_ds.status = "ACTIVE"
# Done # Done
instance.status = "SUCCESS" instance.status = TaskStatus.COMPLETED.name
logger.info(f"Dataset ratio execution completed: instance={instance_id}, files={added_count}, size={added_size}") logger.info(f"Dataset ratio execution completed: instance={instance_id}, files={added_count}, size={added_size}, {instance.status}")
except Exception as e: except Exception as e:
logger.exception(f"Dataset ratio execution failed for {instance_id}: {e}") logger.exception(f"Dataset ratio execution failed for {instance_id}: {e}")
@@ -204,42 +131,163 @@ class RatioTaskService:
inst_res = await session.execute(select(RatioInstance).where(RatioInstance.id == instance_id)) inst_res = await session.execute(select(RatioInstance).where(RatioInstance.id == instance_id))
instance = inst_res.scalar_one_or_none() instance = inst_res.scalar_one_or_none()
if instance: if instance:
instance.status = "FAILED" instance.status = TaskStatus.FAILED.name
finally: finally:
pass pass
finally: finally:
await session.commit() await session.commit()
@staticmethod
async def handle_ratio_relations(relations: list[RatioRelation], session, target_ds: Dataset) -> tuple[int, int]:
# Preload existing target file paths for deduplication
existing_path_rows = await session.execute(
select(DatasetFiles.file_path).where(DatasetFiles.dataset_id == target_ds.id)
)
existing_paths = set(p for p in existing_path_rows.scalars().all() if p)
added_count = 0
added_size = 0
for rel in relations:
if not rel.source_dataset_id or not rel.counts or rel.counts <= 0:
continue
files = await RatioTaskService.get_files(rel, session)
if not files:
continue
pick_n = min(rel.counts or 0, len(files))
chosen = random.sample(files, pick_n) if pick_n < len(files) else files
# Copy into target dataset with de-dup by target path
for f in chosen:
await RatioTaskService.handle_selected_file(existing_paths, f, session, target_ds)
added_count += 1
added_size += int(f.file_size or 0)
# Periodically flush to avoid huge transactions
await session.flush()
return added_count, added_size
@staticmethod
async def handle_selected_file(existing_paths: set[Any], f, session, target_ds: Dataset):
src_path = f.file_path
dst_prefix = f"/dataset/{target_ds.id}"
file_name = RatioTaskService.get_new_file_name(dst_prefix, existing_paths, f)
new_path = dst_prefix + file_name
dst_dir = os.path.dirname(new_path)
await asyncio.to_thread(os.makedirs, dst_dir, exist_ok=True)
await asyncio.to_thread(shutil.copy2, src_path, new_path)
new_file = DatasetFiles(
dataset_id=target_ds.id, # type: ignore
file_name=file_name,
file_path=new_path,
file_type=f.file_type,
file_size=f.file_size,
check_sum=f.check_sum,
tags=f.tags,
dataset_filemetadata=f.dataset_filemetadata,
status="ACTIVE",
)
session.add(new_file)
existing_paths.add(new_path)
@staticmethod
def get_new_file_name(dst_prefix: str, existing_paths: set[Any], f) -> str:
file_name = f.file_name
new_path = dst_prefix + file_name
# Handle file path conflicts by appending a number to the filename
if new_path in existing_paths:
file_name_base, file_ext = os.path.splitext(file_name)
counter = 1
original_file_name = file_name
while new_path in existing_paths:
file_name = f"{file_name_base}_{counter}{file_ext}"
new_path = f"{dst_prefix}{file_name}"
counter += 1
if counter > 1000: # Safety check to prevent infinite loops
logger.error(f"Could not find unique filename for {original_file_name} after 1000 attempts")
break
return file_name
@staticmethod
async def get_files(rel: RatioRelation, session) -> list[Any]:
# Fetch all files for the source dataset (ACTIVE only)
files_res = await session.execute(
select(DatasetFiles).where(
DatasetFiles.dataset_id == rel.source_dataset_id,
DatasetFiles.status == "ACTIVE",
)
)
files = list(files_res.scalars().all())
# TAG mode: filter by tags according to relation.filter_conditions
conditions = RatioTaskService._parse_conditions(rel.filter_conditions)
if conditions:
files = [f for f in files if RatioTaskService._filter_file(f, conditions)]
return files
# ------------------------- helpers for TAG filtering ------------------------- # # ------------------------- helpers for TAG filtering ------------------------- #
@staticmethod @staticmethod
def _parse_required_tags(conditions: Optional[str]) -> set[str]: def _parse_conditions(conditions: Optional[str]) -> Optional[FilterCondition]:
"""Parse filter_conditions into a set of required tag strings. """Parse filter_conditions JSON string into a FilterCondition object.
Supports simple separators: comma, semicolon, space. Empty/None -> empty set. Args:
conditions: JSON string containing filter conditions
Returns:
FilterCondition object if conditions is not None/empty, otherwise None
""" """
if not conditions: if not conditions:
return set() return None
data = json.loads(conditions) try:
required_tags = set() data = json.loads(conditions)
if data.get("label"): return FilterCondition(**data)
required_tags.add(data["label"]) except json.JSONDecodeError as e:
return required_tags logger.error(f"Failed to parse filter conditions: {e}")
return None
except Exception as e:
logger.error(f"Error creating FilterCondition: {e}")
return None
@staticmethod @staticmethod
def _file_contains_tags(file: DatasetFiles, required: set[str]) -> bool: def _filter_file(file: DatasetFiles, conditions: FilterCondition) -> bool:
if not required: if not conditions:
return True return True
tags = file.tags logger.info(f"start filter file: {file}, conditions: {conditions}")
if not tags:
return False # Check data range condition if provided
try: if conditions.date_range:
# tags could be a list of strings or list of objects with 'name' try:
tag_names = RatioTaskService.get_all_tags(tags) from datetime import datetime, timedelta
return required.issubset(tag_names) data_range_days = int(conditions.date_range)
except Exception as e: if data_range_days > 0:
logger.exception(f"Failed to get tags for {file}", e) cutoff_date = datetime.now() - timedelta(days=data_range_days)
return False if file.tags_updated_at and file.tags_updated_at < cutoff_date:
return False
except (ValueError, TypeError) as e:
logger.warning(f"Invalid data_range value: {conditions.date_range}", e)
return False
# Check label condition if provided
if conditions.label:
tags = file.tags
if not tags:
return False
try:
# tags could be a list of strings or list of objects with 'name'
tag_names = RatioTaskService.get_all_tags(tags)
return conditions.label in tag_names
except Exception as e:
logger.exception(f"Failed to get tags for {file}", e)
return False
return True
@staticmethod @staticmethod
def get_all_tags(tags) -> set[str]: def get_all_tags(tags) -> set[str]: