feature: 数据配比增加通过更新时间来配置 (#95)

* feature: 数据配比增加通过更新时间来配置

* fix: 修复配比时间参数传递的问题
This commit is contained in:
hefanli
2025-11-20 18:50:51 +08:00
committed by GitHub
parent 955ffff6cd
commit cddfe9b149
10 changed files with 458 additions and 595 deletions

View File

@@ -7,7 +7,6 @@ import { useNavigate } from "react-router";
import SelectDataset from "@/pages/RatioTask/Create/components/SelectDataset.tsx"; import SelectDataset from "@/pages/RatioTask/Create/components/SelectDataset.tsx";
import BasicInformation from "@/pages/RatioTask/Create/components/BasicInformation.tsx"; import BasicInformation from "@/pages/RatioTask/Create/components/BasicInformation.tsx";
import RatioConfig from "@/pages/RatioTask/Create/components/RatioConfig.tsx"; import RatioConfig from "@/pages/RatioTask/Create/components/RatioConfig.tsx";
import RatioTransfer from "./components/RatioTransfer";
export default function CreateRatioTask() { export default function CreateRatioTask() {
const navigate = useNavigate(); const navigate = useNavigate();
@@ -36,27 +35,12 @@ export default function CreateRatioTask() {
message.error("请配置配比项"); message.error("请配置配比项");
return; return;
} }
// Build request payload
const ratio_method =
ratioTaskForm.ratioType === "dataset" ? "DATASET" : "TAG";
const totals = String(values.totalTargetCount); const totals = String(values.totalTargetCount);
const config = ratioTaskForm.ratioConfigs.map((c) => { const config = ratioTaskForm.ratioConfigs.map((c) => {
if (ratio_method === "DATASET") {
return { return {
datasetId: String(c.source), datasetId: c.id,
counts: String(c.quantity ?? 0), counts: String(c.quantity ?? 0),
filter_conditions: "", filterConditions: { label: c.labelFilter, dateRange: String(c.dateRange ?? 0)},
};
}
// TAG mode: source key like `${datasetId}_${label}`
const source = String(c.source || "");
const idx = source.indexOf("_");
const datasetId = idx > 0 ? source.slice(0, idx) : source;
const label = idx > 0 ? source.slice(idx + 1) : "";
return {
datasetId,
counts: String(c.quantity ?? 0),
filter_conditions: label ? JSON.stringify({ label }) : "",
}; };
}); });
@@ -65,7 +49,6 @@ export default function CreateRatioTask() {
name: values.name, name: values.name,
description: values.description, description: values.description,
totals, totals,
ratio_method,
config, config,
}); });
message.success("配比任务创建成功"); message.success("配比任务创建成功");
@@ -108,13 +91,6 @@ export default function CreateRatioTask() {
totalTargetCount={ratioTaskForm.totalTargetCount} totalTargetCount={ratioTaskForm.totalTargetCount}
/> />
{/* <RatioTransfer
ratioTaskForm={ratioTaskForm}
distributions={distributions}
updateRatioConfig={updateRatioConfig}
updateLabelRatioConfig={updateLabelRatioConfig}
/> */}
<div className="flex h-full"> <div className="flex h-full">
<SelectDataset <SelectDataset
selectedDatasets={ratioTaskForm.selectedDatasets} selectedDatasets={ratioTaskForm.selectedDatasets}

View File

@@ -27,7 +27,7 @@ const BasicInformation: React.FC<BasicInformationProps> = ({
<Input type="number" placeholder="目标总数量" min={1} /> <Input type="number" placeholder="目标总数量" min={1} />
</Form.Item> </Form.Item>
<Form.Item label="任务描述" name="description" className="col-span-2"> <Form.Item label="任务描述" name="description" className="col-span-2">
<TextArea placeholder="描述配比任务的目的和要求(可选)" rows={2} /> <TextArea placeholder="描述配比任务的目的和要求" rows={2} />
</Form.Item> </Form.Item>
</div> </div>
); );

View File

@@ -1,7 +1,19 @@
import React, { useMemo, useState } from "react"; import React, { useMemo, useState } from "react";
import { Badge, Card, Input, Progress, Button, Divider } from "antd"; import { Badge, Card, Input, Progress, Button, DatePicker, Select } from "antd";
import { BarChart3 } from "lucide-react"; import { BarChart3, Filter, Clock } from "lucide-react";
import type { Dataset } from "@/pages/DataManagement/dataset.model.ts"; import type { Dataset } from "@/pages/DataManagement/dataset.model.ts";
import dayjs from 'dayjs';
const { RangePicker } = DatePicker;
const { Option } = Select;
const TIME_RANGE_OPTIONS = [
{ label: '最近1天', value: 1 },
{ label: '最近3天', value: 3 },
{ label: '最近7天', value: 7 },
{ label: '最近15天', value: 15 },
{ label: '最近30天', value: 30 },
];
interface RatioConfigItem { interface RatioConfigItem {
id: string; id: string;
@@ -10,6 +22,8 @@ interface RatioConfigItem {
quantity: number; quantity: number;
percentage: number; percentage: number;
source: string; source: string;
labelFilter?: string;
dateRange?: string;
} }
interface RatioConfigProps { interface RatioConfigProps {
@@ -30,6 +44,10 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
onChange, onChange,
}) => { }) => {
const [ratioConfigs, setRatioConfigs] = useState<RatioConfigItem[]>([]); const [ratioConfigs, setRatioConfigs] = useState<RatioConfigItem[]>([]);
const [datasetFilters, setDatasetFilters] = useState<Record<string, {
labelFilter?: string;
dateRange?: string;
}>>({});
// 配比项总数 // 配比项总数
const totalConfigured = useMemo( const totalConfigured = useMemo(
@@ -37,6 +55,36 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
[ratioConfigs] [ratioConfigs]
); );
// 获取数据集的标签列表
const getDatasetLabels = (datasetId: string): string[] => {
const dist = distributions[String(datasetId)] || {};
return Object.keys(dist);
};
// 自动平均分配
const generateAutoRatio = () => {
const selectedCount = selectedDatasets.length;
if (selectedCount === 0) return;
const baseQuantity = Math.floor(totalTargetCount / selectedCount);
const remainder = totalTargetCount % selectedCount;
const newConfigs = selectedDatasets.map((datasetId, index) => {
const dataset = datasets.find((d) => String(d.id) === datasetId);
const quantity = baseQuantity + (index < remainder ? 1 : 0);
return {
id: datasetId,
name: dataset?.name || datasetId,
type: ratioType,
quantity,
percentage: Math.round((quantity / totalTargetCount) * 100),
source: datasetId,
labelFilter: datasetFilters[datasetId]?.labelFilter,
dateRange: datasetFilters[datasetId]?.dateRange,
};
});
setRatioConfigs(newConfigs);
onChange?.(newConfigs);
};
// 更新数据集配比项 // 更新数据集配比项
const updateDatasetQuantity = (datasetId: string, quantity: number) => { const updateDatasetQuantity = (datasetId: string, quantity: number) => {
setRatioConfigs((prev) => { setRatioConfigs((prev) => {
@@ -55,6 +103,8 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
quantity: Math.min(quantity, totalTargetCount - totalOtherQuantity), quantity: Math.min(quantity, totalTargetCount - totalOtherQuantity),
percentage: Math.round((quantity / totalTargetCount) * 100), percentage: Math.round((quantity / totalTargetCount) * 100),
source: datasetId, source: datasetId,
labelFilter: datasetFilters[datasetId]?.labelFilter,
dateRange: datasetFilters[datasetId]?.dateRange,
}; };
let newConfigs; let newConfigs;
@@ -69,78 +119,85 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
}); });
}; };
// 自动平均分配 // 更新筛选条件
const generateAutoRatio = () => { const updateFilters = (datasetId: string, updates: {
const selectedCount = selectedDatasets.length; labelFilter?: string;
if (selectedCount === 0) return; dateRange?: [string, string];
const baseQuantity = Math.floor(totalTargetCount / selectedCount); }) => {
const remainder = totalTargetCount % selectedCount; setDatasetFilters(prev => ({
const newConfigs = selectedDatasets.map((datasetId, index) => { ...prev,
const dataset = datasets.find((d) => String(d.id) === datasetId); [datasetId]: {
const quantity = baseQuantity + (index < remainder ? 1 : 0); ...prev[datasetId],
return { ...updates,
id: datasetId, }
name: dataset?.name || datasetId, }));
type: ratioType,
quantity,
percentage: Math.round((quantity / totalTargetCount) * 100),
source: datasetId,
};
});
setRatioConfigs(newConfigs);
onChange?.(newConfigs);
}; };
// 标签模式下,更新某数据集的某个标签的数量 // 渲染筛选器
const updateLabelQuantity = ( const renderFilters = (datasetId: string) => {
datasetId: string, const labels = getDatasetLabels(datasetId);
label: string, const config = ratioConfigs.find(c => c.source === datasetId);
quantity: number const filters = datasetFilters[datasetId] || {};
) => {
const sourceKey = `${datasetId}_${label}`; return (
setRatioConfigs((prev) => { <div className="mb-3 p-2 bg-gray-50 rounded">
const existingIndex = prev.findIndex((c) => c.source === sourceKey); <div className="flex items-center gap-2 mb-2">
const totalOtherQuantity = prev <Filter size={14} className="text-gray-400" />
.filter((c) => c.source !== sourceKey) <span className="text-xs font-medium"></span>
.reduce((sum, c) => sum + c.quantity, 0); </div>
const dist = distributions[datasetId] || {};
const labelMax = dist[label] ?? Infinity; <div className="grid grid-cols-1 md:grid-cols-2 gap-3">
const cappedQuantity = Math.max( <div>
0, <div className="text-xs text-gray-500 mb-1"></div>
Math.min(quantity, totalTargetCount - totalOtherQuantity, labelMax) <Select
style={{ width: '100%' }}
placeholder="选择标签"
value={filters.labelFilter}
onChange={(value) => updateFilters(datasetId, { labelFilter: value })}
allowClear
onClear={() => updateFilters(datasetId, { labelFilter: undefined })}
>
{labels.map(label => (
<Option key={label} value={label}>{label}</Option>
))}
</Select>
</div>
<div>
<div className="text-xs text-gray-500 mb-1"></div>
<Select
style={{ width: '100%' }}
placeholder="选择标签更新时间"
value={filters.dateRange}
onChange={(dates) => updateFilters(datasetId, { dateRange: dates })}
allowClear
onClear={() => updateFilters(datasetId, { dateRange: undefined })}
>
{TIME_RANGE_OPTIONS.map(option => (
<Option key={option.value} value={option.value}>
{option.label}
</Option>
))}
</Select>
</div>
</div>
</div>
); );
const newConfig: RatioConfigItem = {
id: sourceKey,
name: label,
type: "label",
quantity: cappedQuantity,
percentage: Math.round((cappedQuantity / totalTargetCount) * 100),
source: sourceKey,
};
let newConfigs;
if (existingIndex >= 0) {
newConfigs = [...prev];
newConfigs[existingIndex] = newConfig;
} else {
newConfigs = [...prev, newConfig];
}
onChange?.(newConfigs);
return newConfigs;
});
}; };
// 选中数据集变化时,移除未选中的配比项 // 选中数据集变化时,初始化筛选条件
React.useEffect(() => { React.useEffect(() => {
setRatioConfigs((prev) => { const initialFilters: Record<string, any> = {};
const next = prev.filter((c) => { selectedDatasets.forEach(datasetId => {
const id = String(c.source); const config = ratioConfigs.find(c => c.source === datasetId);
const dsId = id.includes("_") ? id.split("_")[0] : id; if (config) {
return selectedDatasets.includes(dsId); initialFilters[datasetId] = {
labelFilter: config.labelFilter,
dateRange: config.dateRange,
};
}
}); });
if (next !== prev) onChange?.(next); setDatasetFilters(prev => ({ ...prev, ...initialFilters }));
return next;
});
// eslint-disable-next-line
}, [selectedDatasets]); }, [selectedDatasets]);
return ( return (
@@ -148,7 +205,7 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
<div className="flex items-center justify-between p-4 border-bottom"> <div className="flex items-center justify-between p-4 border-bottom">
<span className="text-sm font-bold"> <span className="text-sm font-bold">
<span className="text-xs text-gray-500"> <span className="text-xs text-gray-500 ml-1">
(:{totalConfigured}/{totalTargetCount}) (:{totalConfigured}/{totalTargetCount})
</span> </span>
</span> </span>
@@ -170,8 +227,7 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
<div className="flex-overflow-auto gap-4 p-4"> <div className="flex-overflow-auto gap-4 p-4">
{/* 配比预览 */} {/* 配比预览 */}
{ratioConfigs.length > 0 && ( {ratioConfigs.length > 0 && (
<div> <div className="p-3 bg-gray-50 rounded-lg mb-4">
<div className="p-3 bg-gray-50 rounded-lg">
<div className="grid grid-cols-2 gap-4 text-sm"> <div className="grid grid-cols-2 gap-4 text-sm">
<div> <div>
<span className="text-gray-500">:</span> <span className="text-gray-500">:</span>
@@ -187,24 +243,20 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
{totalTargetCount.toLocaleString()} {totalTargetCount.toLocaleString()}
</span> </span>
</div> </div>
<div>
<span className="text-gray-500">:</span>
<span className="ml-2 font-medium">
{ratioConfigs.length}
</span>
</div>
</div>
</div> </div>
</div> </div>
)} )}
<div className="flex-1 overflow-auto">
<div className="flex-1 overflow-auto space-y-4">
{selectedDatasets.map((datasetId) => { {selectedDatasets.map((datasetId) => {
const dataset = datasets.find((d) => String(d.id) === datasetId); const dataset = datasets.find((d) => String(d.id) === datasetId);
const config = ratioConfigs.find((c) => c.source === datasetId); const config = ratioConfigs.find((c) => c.source === datasetId);
const currentQuantity = config?.quantity || 0; const currentQuantity = config?.quantity || 0;
if (!dataset) return null; if (!dataset) return null;
return ( return (
<Card key={datasetId} size="small" className="mb-2"> <Card key={datasetId} size="small" className="mb-4">
<div className="flex items-center justify-between mb-3"> <div className="flex items-center justify-between mb-3">
<div className="flex items-center gap-2"> <div className="flex items-center gap-2">
<span className="font-medium text-sm"> <span className="font-medium text-sm">
@@ -216,8 +268,10 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
{config?.percentage || 0}% {config?.percentage || 0}%
</div> </div>
</div> </div>
{ratioType === "dataset" ? (
<div> {/* 筛选条件 */}
{renderFilters(datasetId)}
<div className="flex items-center gap-2 mb-2"> <div className="flex items-center gap-2 mb-2">
<span className="text-xs">:</span> <span className="text-xs">:</span>
<Input <Input
@@ -229,7 +283,7 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
Number(e.target.value) Number(e.target.value)
) )
} }
style={{ width: 80 }} style={{ width: 100 }}
min={0} min={0}
max={Math.min( max={Math.min(
dataset.fileCount || 0, dataset.fileCount || 0,
@@ -244,69 +298,6 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
)} )}
size="small" size="small"
/> />
</div>
) : (
<div>
{!distributions[String(dataset.id)] ? (
<div className="text-xs text-gray-400">
...
</div>
) : Object.entries(distributions[String(dataset.id)])
.length === 0 ? (
<div className="text-xs text-gray-400">
</div>
) : (
<div className="flex flex-col gap-2">
{Object.entries(
distributions[String(dataset.id)]
).map(([label, count]) => {
const sourceKey = `${datasetId}_${label}`;
const labelConfig = ratioConfigs.find(
(c) => c.source === sourceKey
);
const labelQuantity = labelConfig?.quantity || 0;
return (
<div
key={label}
className="flex items-center justify-between gap-2"
>
<div className="flex items-center gap-2">
<Badge color="gray">{label}</Badge>
<span className="text-xs text-gray-500">
{count}
</span>
</div>
<div className="flex items-center gap-2">
<span className="text-xs">:</span>
<Input
type="number"
value={labelQuantity}
onChange={(e) =>
updateLabelQuantity(
datasetId,
label,
Number(e.target.value)
)
}
style={{ width: 80 }}
min={0}
max={Math.min(
Number(count) || 0,
totalTargetCount
)}
/>
<span className="text-xs text-gray-500">
</span>
</div>
</div>
);
})}
</div>
)}
</div>
)}
</Card> </Card>
); );
})} })}

View File

@@ -1,169 +0,0 @@
import React, { useMemo } from "react";
import { Table } from "antd";
import { TransferItem } from "antd/es/transfer";
import RatioConfig from "./RatioConfig";
import useFetchData from "@/hooks/useFetchData";
import { queryDatasetsUsingGet } from "@/pages/DataManagement/dataset.api";
import {
datasetTypeMap,
mapDataset,
} from "@/pages/DataManagement/dataset.const";
import { SearchControls } from "@/components/SearchControls";
const leftColumns = [
{
dataIndex: "name",
title: "名称",
ellipsis: true,
},
{
dataIndex: "datasetType",
title: "类型",
ellipsis: true,
width: 100,
render: (type: string) => datasetTypeMap[type].label,
},
{
dataIndex: "size",
title: "大小",
width: 100,
ellipsis: true,
},
];
export default function RatioTransfer(props: {
distributions: Record<string, Record<string, number>>;
ratioTaskForm: any;
updateRatioConfig: (datasetId: string, quantity: number) => void;
updateLabelRatioConfig: (
datasetId: string,
label: string,
quantity: number
) => void;
}) {
const {
updateLabelRatioConfig,
updateRatioConfig,
ratioTaskForm,
distributions,
} = props;
const {
tableData: datasets,
loading,
pagination,
searchParams,
setSearchParams,
handleFiltersChange,
} = useFetchData(queryDatasetsUsingGet, mapDataset);
const [selectedDatasets, setSelectedDatasets] = React.useState<
TransferItem[]
>([]);
const selectedRowKeys = useMemo(() => {
return selectedDatasets.map((item) => item.key);
}, [selectedDatasets]);
const [listDisabled, setListDisabled] = React.useState(false);
const generateAutoRatio = () => {
const selectedCount = ratioTaskForm.selectedDatasets.length;
if (selectedCount === 0) return;
const baseQuantity = Math.floor(
ratioTaskForm.totalTargetCount / selectedCount
);
const remainder = ratioTaskForm.totalTargetCount % selectedCount;
const newConfigs = ratioTaskForm.selectedDatasets.map(
(datasetId, index) => {
const quantity = baseQuantity + (index < remainder ? 1 : 0);
return {
id: datasetId,
name: datasetId,
type: ratioTaskForm.ratioType,
quantity,
percentage: Math.round(
(quantity / ratioTaskForm.totalTargetCount) * 100
),
source: datasetId,
};
}
);
setRatioTaskForm((prev) => ({ ...prev, ratioConfigs: newConfigs }));
};
return (
<div className="flex">
<div className="border-card flex-1 mr-4">
<h3 className="p-2 border-bottom">{`${selectedDatasets.length} / ${datasets.length}`}</h3>
<SearchControls
searchTerm={searchParams.keyword}
onSearchChange={(keyword) =>
setSearchParams({ ...searchParams, keyword })
}
searchPlaceholder="搜索数据集名称..."
filters={[
{
key: "type",
label: "数据集类型",
options: [
{ value: "dataset", label: "按数据集" },
{ value: "tag", label: "按标签" },
],
},
]}
onFiltersChange={handleFiltersChange}
onClearFilters={() =>
setSearchParams({ ...searchParams, filter: {} })
}
showViewToggle={false}
showReload={false}
className="m-4"
/>
<Table
rowSelection={{
onChange: (_, selectedRows) => {
setSelectedDatasets(selectedRows);
},
selectedRowKeys,
selections: [
Table.SELECTION_ALL,
Table.SELECTION_INVERT,
Table.SELECTION_NONE,
],
}}
columns={leftColumns}
dataSource={datasets}
loading={loading}
pagination={pagination}
size="small"
rowKey="id"
style={{ pointerEvents: listDisabled ? "none" : undefined }}
onRow={(record) => ({
onClick: () => {
if (record.disabled || listDisabled) {
return;
}
setSelectedDatasets((prev) => {
if (prev.includes(record.key)) {
return prev.filter((k) => k !== record.key);
}
return [...prev, record.key];
});
},
})}
/>
</div>
<div className="border-card flex-1">
<RatioConfig
datasets={selectedDatasets}
ratioTaskForm={ratioTaskForm}
distributions={distributions}
onUpdateDatasetQuantity={updateRatioConfig}
onUpdateLabelQuantity={updateLabelRatioConfig}
/>
</div>
</div>
);
}

View File

@@ -1,5 +1,5 @@
import React, { useEffect, useState } from "react"; import React, { useEffect, useState } from "react";
import { Badge, Button, Card, Checkbox, Input, Pagination, Select } from "antd"; import { Badge, Button, Card, Checkbox, Input, Pagination } from "antd";
import { Search as SearchIcon } from "lucide-react"; import { Search as SearchIcon } from "lucide-react";
import type { Dataset } from "@/pages/DataManagement/dataset.model.ts"; import type { Dataset } from "@/pages/DataManagement/dataset.model.ts";
import { import {
@@ -10,8 +10,6 @@ import {
interface SelectDatasetProps { interface SelectDatasetProps {
selectedDatasets: string[]; selectedDatasets: string[];
ratioType: "dataset" | "label";
onRatioTypeChange: (val: "dataset" | "label") => void;
onSelectedDatasetsChange: (next: string[]) => void; onSelectedDatasetsChange: (next: string[]) => void;
onDistributionsChange?: ( onDistributionsChange?: (
next: Record<string, Record<string, number>> next: Record<string, Record<string, number>>
@@ -21,8 +19,6 @@ interface SelectDatasetProps {
const SelectDataset: React.FC<SelectDatasetProps> = ({ const SelectDataset: React.FC<SelectDatasetProps> = ({
selectedDatasets, selectedDatasets,
ratioType,
onRatioTypeChange,
onSelectedDatasetsChange, onSelectedDatasetsChange,
onDistributionsChange, onDistributionsChange,
onDatasetsChange, onDatasetsChange,
@@ -62,7 +58,7 @@ const SelectDataset: React.FC<SelectDatasetProps> = ({
// Fetch label distributions when in label mode // Fetch label distributions when in label mode
useEffect(() => { useEffect(() => {
const fetchDistributions = async () => { const fetchDistributions = async () => {
if (ratioType !== "label" || !datasets?.length) return; if (!datasets?.length) return;
const idsToFetch = datasets const idsToFetch = datasets
.map((d) => String(d.id)) .map((d) => String(d.id))
.filter((id) => !distributions[id]); .filter((id) => !distributions[id]);
@@ -147,7 +143,7 @@ const SelectDataset: React.FC<SelectDatasetProps> = ({
}; };
fetchDistributions(); fetchDistributions();
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
}, [ratioType, datasets]); }, [datasets]);
const onToggleDataset = (datasetId: string, checked: boolean) => { const onToggleDataset = (datasetId: string, checked: boolean) => {
if (checked) { if (checked) {
@@ -180,18 +176,6 @@ const SelectDataset: React.FC<SelectDatasetProps> = ({
</Button> </Button>
</div> </div>
<div className="flex-overflow-auto gap-4 p-4"> <div className="flex-overflow-auto gap-4 p-4">
<div className="flex items-center gap-4">
<span className="text-sm">:</span>
<Select
className="flex-1 min-w-[120px]"
value={ratioType}
onChange={(v) => onRatioTypeChange(v)}
options={[
{ label: "按数据集", value: "dataset" },
{ label: "按标签", value: "label" },
]}
/>
</div>
<Input <Input
prefix={<SearchIcon className="text-gray-400" />} prefix={<SearchIcon className="text-gray-400" />}
placeholder="搜索数据集" placeholder="搜索数据集"
@@ -239,7 +223,6 @@ const SelectDataset: React.FC<SelectDatasetProps> = ({
<span>{dataset.fileCount}</span> <span>{dataset.fileCount}</span>
<span>{dataset.size}</span> <span>{dataset.size}</span>
</div> </div>
{ratioType === "label" && (
<div className="mt-2"> <div className="mt-2">
{distributions[idStr] ? ( {distributions[idStr] ? (
Object.entries(distributions[idStr]).length > 0 ? ( Object.entries(distributions[idStr]).length > 0 ? (
@@ -264,7 +247,6 @@ const SelectDataset: React.FC<SelectDatasetProps> = ({
</div> </div>
)} )}
</div> </div>
)}
</div> </div>
</div> </div>
</Card> </Card>

View File

@@ -1,11 +1,13 @@
from .common import ( from .common import (
BaseResponseModel, BaseResponseModel,
StandardResponse, StandardResponse,
PaginatedData PaginatedData,
TaskStatus
) )
__all__ = [ __all__ = [
"BaseResponseModel", "BaseResponseModel",
"StandardResponse", "StandardResponse",
"PaginatedData" "PaginatedData",
"TaskStatus"
] ]

View File

@@ -1,8 +1,9 @@
""" """
通用响应模型 通用响应模型
""" """
from typing import Generic, TypeVar, Optional, List, Type from typing import Generic, TypeVar, List
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from enum import Enum
# 定义泛型类型变量 # 定义泛型类型变量
T = TypeVar('T') T = TypeVar('T')
@@ -42,3 +43,9 @@ class PaginatedData(BaseResponseModel, Generic[T]):
total_elements: int = Field(..., description="总条数") total_elements: int = Field(..., description="总条数")
total_pages: int = Field(..., description="总页数") total_pages: int = Field(..., description="总页数")
content: List[T] = Field(..., description="当前页数据") content: List[T] = Field(..., description="当前页数据")
class TaskStatus(Enum):
PENDING = "PENDING"
RUNNING = "RUNNING"
COMPLETED = "COMPLETED"
FAILED = "FAILED"

View File

@@ -12,7 +12,7 @@ from app.core.logging import get_logger
from app.db.models import Dataset from app.db.models import Dataset
from app.db.session import get_db from app.db.session import get_db
from app.module.dataset import DatasetManagementService from app.module.dataset import DatasetManagementService
from app.module.shared.schema import StandardResponse from app.module.shared.schema import StandardResponse, TaskStatus
from app.module.synthesis.schema.ratio_task import ( from app.module.synthesis.schema.ratio_task import (
CreateRatioTaskResponse, CreateRatioTaskResponse,
CreateRatioTaskRequest, CreateRatioTaskRequest,
@@ -49,28 +49,47 @@ async def create_ratio_task(
await valid_exists(db, req) await valid_exists(db, req)
# 创建目标数据集:名称使用“<任务名称>-配比生成-时间戳” target_dataset = await create_target_dataset(db, req, source_types)
target_dataset_name = f"{req.name}-配比生成-{datetime.now().strftime('%Y%m%d%H%M%S')}"
target_type = get_target_dataset_type(source_types) instance = await create_ratio_instance(db, req, target_dataset)
target_dataset = Dataset( asyncio.create_task(RatioTaskService.execute_dataset_ratio_task(instance.id))
id=str(uuid.uuid4()),
name=target_dataset_name, response_data = CreateRatioTaskResponse(
description=req.description or "", id=instance.id,
dataset_type=target_type, name=instance.name,
status="DRAFT", description=instance.description,
totals=instance.totals or 0,
status=instance.status or TaskStatus.PENDING.name,
config=req.config,
targetDataset=TargetDatasetInfo(
id=str(target_dataset.id),
name=str(target_dataset.name),
datasetType=str(target_dataset.dataset_type),
status=str(target_dataset.status),
) )
target_dataset.path = f"/dataset/{target_dataset.id}" )
db.add(target_dataset) return StandardResponse(
await db.flush() # 获取 target_dataset.id code=200,
message="success",
data=response_data
)
except HTTPException:
await db.rollback()
raise
except Exception as e:
await db.rollback()
logger.error(f"Failed to create ratio task: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
async def create_ratio_instance(db, req: CreateRatioTaskRequest, target_dataset: Dataset) -> RatioInstance:
service = RatioTaskService(db) service = RatioTaskService(db)
logger.info(f"create_ratio_instance: {req}")
instance = await service.create_task( instance = await service.create_task(
name=req.name, name=req.name,
description=req.description, description=req.description,
totals=int(req.totals), totals=int(req.totals),
ratio_method=req.ratio_method,
config=[ config=[
{ {
"dataset_id": item.dataset_id, "dataset_id": item.dataset_id,
@@ -81,36 +100,27 @@ async def create_ratio_task(
], ],
target_dataset_id=target_dataset.id, target_dataset_id=target_dataset.id,
) )
return instance
# 异步执行配比任务(支持 DATASET / TAG)
asyncio.create_task(RatioTaskService.execute_dataset_ratio_task(instance.id))
return StandardResponse( async def create_target_dataset(db, req: CreateRatioTaskRequest, source_types: set[str]) -> Dataset:
code=200, # 创建目标数据集:名称使用“<任务名称>-时间戳”
message="success", target_dataset_name = f"{req.name}-{datetime.now().strftime('%Y%m%d%H%M%S')}"
data=CreateRatioTaskResponse(
id=instance.id, target_type = get_target_dataset_type(source_types)
name=instance.name, target_dataset_id = uuid.uuid4()
description=instance.description,
totals=instance.totals or 0, target_dataset = Dataset(
ratio_method=instance.ratio_method or req.ratio_method, id=str(target_dataset_id),
status=instance.status or "PENDING", name=target_dataset_name,
config=req.config, description=req.description or "",
targetDataset=TargetDatasetInfo( dataset_type=target_type,
id=str(target_dataset.id), status="DRAFT",
name=str(target_dataset.name), path=f"/dataset/{target_dataset_id}",
datasetType=str(target_dataset.dataset_type),
status=str(target_dataset.status),
) )
) db.add(target_dataset)
) await db.flush() # 获取 target_dataset.id
except HTTPException: return target_dataset
await db.rollback()
raise
except Exception as e:
await db.rollback()
logger.error(f"Failed to create ratio task: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("", response_model=StandardResponse[PagedRatioTaskResponse], status_code=200) @router.get("", response_model=StandardResponse[PagedRatioTaskResponse], status_code=200)

View File

@@ -2,10 +2,36 @@ from typing import List, Optional, Dict, Any
from datetime import datetime from datetime import datetime
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from app.core.logging import get_logger
from app.module.shared.schema.common import TaskStatus
logger = get_logger(__name__)
class FilterCondition(BaseModel):
date_range: Optional[str] = Field(None, description="数据范围", alias="dateRange")
label: Optional[str] = Field(None, description="标签")
@field_validator("date_range")
@classmethod
def validate_date_range(cls, v: Optional[str]) -> Optional[str]:
# ensure it's a numeric string if provided
if not v:
return v
try:
int(v)
return v
except (ValueError, TypeError) as e:
raise ValueError("date_range must be a numeric string")
class Config:
# allow population by field name when constructing model programmatically
validate_by_name = True
class RatioConfigItem(BaseModel): class RatioConfigItem(BaseModel):
dataset_id: str = Field(..., alias="datasetId", description="数据集id") dataset_id: str = Field(..., alias="datasetId", description="数据集id")
counts: str = Field(..., description="数量") counts: str = Field(..., description="数量")
filter_conditions: str = Field(..., description="过滤条件") filter_conditions: FilterCondition = Field(..., alias="filterConditions", description="过滤条件")
@field_validator("counts") @field_validator("counts")
@classmethod @classmethod
@@ -22,17 +48,8 @@ class CreateRatioTaskRequest(BaseModel):
name: str = Field(..., description="名称") name: str = Field(..., description="名称")
description: Optional[str] = Field(None, description="描述") description: Optional[str] = Field(None, description="描述")
totals: str = Field(..., description="目标数量") totals: str = Field(..., description="目标数量")
ratio_method: str = Field(..., description="配比方式", alias="ratio_method")
config: List[RatioConfigItem] = Field(..., description="配比设置列表") config: List[RatioConfigItem] = Field(..., description="配比设置列表")
@field_validator("ratio_method")
@classmethod
def validate_ratio_method(cls, v: str) -> str:
allowed = {"TAG", "DATASET"}
if v not in allowed:
raise ValueError(f"ratio_method must be one of {allowed}")
return v
@field_validator("totals") @field_validator("totals")
@classmethod @classmethod
def validate_totals(cls, v: str) -> str: def validate_totals(cls, v: str) -> str:
@@ -58,8 +75,7 @@ class CreateRatioTaskResponse(BaseModel):
name: str name: str
description: Optional[str] = None description: Optional[str] = None
totals: int totals: int
ratio_method: str status: TaskStatus
status: str
# echoed config # echoed config
config: List[RatioConfigItem] config: List[RatioConfigItem]
# created dataset # created dataset

View File

@@ -14,6 +14,8 @@ from app.db.models.ratio_task import RatioInstance, RatioRelation
from app.db.models import Dataset, DatasetFiles from app.db.models import Dataset, DatasetFiles
from app.db.session import AsyncSessionLocal from app.db.session import AsyncSessionLocal
from app.module.dataset.schema.dataset_file import DatasetFileTag from app.module.dataset.schema.dataset_file import DatasetFileTag
from app.module.shared.schema import TaskStatus
from app.module.synthesis.schema.ratio_task import FilterCondition
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -30,7 +32,6 @@ class RatioTaskService:
name: str, name: str,
description: Optional[str], description: Optional[str],
totals: int, totals: int,
ratio_method: str,
config: List[Dict[str, Any]], config: List[Dict[str, Any]],
target_dataset_id: Optional[str] = None, target_dataset_id: Optional[str] = None,
) -> RatioInstance: ) -> RatioInstance:
@@ -38,12 +39,11 @@ class RatioTaskService:
config item format: {"dataset_id": str, "counts": int, "filter_conditions": str} config item format: {"dataset_id": str, "counts": int, "filter_conditions": str}
""" """
logger.info(f"Creating ratio task: name={name}, method={ratio_method}, totals={totals}, items={len(config or [])}") logger.info(f"Creating ratio task: name={name}, totals={totals}, items={len(config or [])}")
instance = RatioInstance( instance = RatioInstance(
name=name, name=name,
description=description, description=description,
ratio_method=ratio_method,
totals=totals, totals=totals,
target_dataset_id=target_dataset_id, target_dataset_id=target_dataset_id,
status="PENDING", status="PENDING",
@@ -56,8 +56,12 @@ class RatioTaskService:
ratio_instance_id=instance.id, ratio_instance_id=instance.id,
source_dataset_id=item.get("dataset_id"), source_dataset_id=item.get("dataset_id"),
counts=int(item.get("counts", 0)), counts=int(item.get("counts", 0)),
filter_conditions=item.get("filter_conditions"), filter_conditions=json.dumps({
'date_range': item.get("filter_conditions").date_range,
'label': item.get("filter_conditions").label,
})
) )
logger.info(f"Relation created: {relation.id}, {relation}, {item}, {config}")
self.db.add(relation) self.db.add(relation)
await self.db.commit() await self.db.commit()
@@ -97,21 +101,44 @@ class RatioTaskService:
relations: List[RatioRelation] = list(rel_res.scalars().all()) relations: List[RatioRelation] = list(rel_res.scalars().all())
# Mark running # Mark running
instance.status = "RUNNING" instance.status = TaskStatus.RUNNING.name
if instance.ratio_method not in {"DATASET", "TAG"}:
logger.info(f"Instance {instance_id} ratio_method={instance.ratio_method} not supported yet")
instance.status = "SUCCESS"
return
# Load target dataset # Load target dataset
ds_res = await session.execute(select(Dataset).where(Dataset.id == instance.target_dataset_id)) ds_res = await session.execute(select(Dataset).where(Dataset.id == instance.target_dataset_id))
target_ds: Optional[Dataset] = ds_res.scalar_one_or_none() target_ds: Optional[Dataset] = ds_res.scalar_one_or_none()
if not target_ds: if not target_ds:
logger.error(f"Target dataset not found for instance {instance_id}") logger.error(f"Target dataset not found for instance {instance_id}")
instance.status = "FAILED" instance.status = TaskStatus.FAILED.name
return return
added_count, added_size = await RatioTaskService.handle_ratio_relations(relations,session, target_ds)
# Update target dataset statistics
target_ds.file_count = (target_ds.file_count or 0) + added_count # type: ignore
target_ds.size_bytes = (target_ds.size_bytes or 0) + added_size # type: ignore
# If target dataset has files, mark it ACTIVE
if (target_ds.file_count or 0) > 0: # type: ignore
target_ds.status = "ACTIVE"
# Done
instance.status = TaskStatus.COMPLETED.name
logger.info(f"Dataset ratio execution completed: instance={instance_id}, files={added_count}, size={added_size}, {instance.status}")
except Exception as e:
logger.exception(f"Dataset ratio execution failed for {instance_id}: {e}")
try:
# Try mark failed
inst_res = await session.execute(select(RatioInstance).where(RatioInstance.id == instance_id))
instance = inst_res.scalar_one_or_none()
if instance:
instance.status = TaskStatus.FAILED.name
finally:
pass
finally:
await session.commit()
@staticmethod
async def handle_ratio_relations(relations: list[RatioRelation], session, target_ds: Dataset) -> tuple[int, int]:
# Preload existing target file paths for deduplication # Preload existing target file paths for deduplication
existing_path_rows = await session.execute( existing_path_rows = await session.execute(
select(DatasetFiles.file_path).where(DatasetFiles.dataset_id == target_ds.id) select(DatasetFiles.file_path).where(DatasetFiles.dataset_id == target_ds.id)
@@ -125,20 +152,7 @@ class RatioTaskService:
if not rel.source_dataset_id or not rel.counts or rel.counts <= 0: if not rel.source_dataset_id or not rel.counts or rel.counts <= 0:
continue continue
# Fetch all files for the source dataset (ACTIVE only) files = await RatioTaskService.get_files(rel, session)
files_res = await session.execute(
select(DatasetFiles).where(
DatasetFiles.dataset_id == rel.source_dataset_id,
DatasetFiles.status == "ACTIVE",
)
)
files = list(files_res.scalars().all())
# TAG mode: filter by tags according to relation.filter_conditions
if instance.ratio_method == "TAG":
required_tags = RatioTaskService._parse_required_tags(rel.filter_conditions)
if required_tags:
files = [f for f in files if RatioTaskService._file_contains_tags(f, required_tags)]
if not files: if not files:
continue continue
@@ -148,28 +162,28 @@ class RatioTaskService:
# Copy into target dataset with de-dup by target path # Copy into target dataset with de-dup by target path
for f in chosen: for f in chosen:
await RatioTaskService.handle_selected_file(existing_paths, f, session, target_ds)
added_count += 1
added_size += int(f.file_size or 0)
# Periodically flush to avoid huge transactions
await session.flush()
return added_count, added_size
@staticmethod
async def handle_selected_file(existing_paths: set[Any], f, session, target_ds: Dataset):
src_path = f.file_path src_path = f.file_path
new_path = src_path
needs_copy = False
src_prefix = f"/dataset/{rel.source_dataset_id}"
if isinstance(src_path, str) and src_path.startswith(src_prefix):
dst_prefix = f"/dataset/{target_ds.id}" dst_prefix = f"/dataset/{target_ds.id}"
new_path = src_path.replace(src_prefix, dst_prefix, 1) file_name = RatioTaskService.get_new_file_name(dst_prefix, existing_paths, f)
needs_copy = True
# De-dup by target path new_path = dst_prefix + file_name
if new_path in existing_paths:
continue
# Perform copy only when needed
if needs_copy:
dst_dir = os.path.dirname(new_path) dst_dir = os.path.dirname(new_path)
await asyncio.to_thread(os.makedirs, dst_dir, exist_ok=True) await asyncio.to_thread(os.makedirs, dst_dir, exist_ok=True)
await asyncio.to_thread(shutil.copy2, src_path, new_path) await asyncio.to_thread(shutil.copy2, src_path, new_path)
new_file = DatasetFiles( new_file = DatasetFiles(
dataset_id=target_ds.id, # type: ignore dataset_id=target_ds.id, # type: ignore
file_name=f.file_name, file_name=file_name,
file_path=new_path, file_path=new_path,
file_type=f.file_type, file_type=f.file_type,
file_size=f.file_size, file_size=f.file_size,
@@ -180,67 +194,101 @@ class RatioTaskService:
) )
session.add(new_file) session.add(new_file)
existing_paths.add(new_path) existing_paths.add(new_path)
added_count += 1
added_size += int(f.file_size or 0)
# Periodically flush to avoid huge transactions @staticmethod
await session.flush() def get_new_file_name(dst_prefix: str, existing_paths: set[Any], f) -> str:
file_name = f.file_name
new_path = dst_prefix + file_name
# Update target dataset statistics # Handle file path conflicts by appending a number to the filename
target_ds.file_count = (target_ds.file_count or 0) + added_count # type: ignore if new_path in existing_paths:
target_ds.size_bytes = (target_ds.size_bytes or 0) + added_size # type: ignore file_name_base, file_ext = os.path.splitext(file_name)
# If target dataset has files, mark it ACTIVE counter = 1
if (target_ds.file_count or 0) > 0: # type: ignore original_file_name = file_name
target_ds.status = "ACTIVE" while new_path in existing_paths:
file_name = f"{file_name_base}_{counter}{file_ext}"
new_path = f"{dst_prefix}{file_name}"
counter += 1
if counter > 1000: # Safety check to prevent infinite loops
logger.error(f"Could not find unique filename for {original_file_name} after 1000 attempts")
break
return file_name
# Done @staticmethod
instance.status = "SUCCESS" async def get_files(rel: RatioRelation, session) -> list[Any]:
logger.info(f"Dataset ratio execution completed: instance={instance_id}, files={added_count}, size={added_size}") # Fetch all files for the source dataset (ACTIVE only)
files_res = await session.execute(
select(DatasetFiles).where(
DatasetFiles.dataset_id == rel.source_dataset_id,
DatasetFiles.status == "ACTIVE",
)
)
files = list(files_res.scalars().all())
except Exception as e: # TAG mode: filter by tags according to relation.filter_conditions
logger.exception(f"Dataset ratio execution failed for {instance_id}: {e}") conditions = RatioTaskService._parse_conditions(rel.filter_conditions)
try: if conditions:
# Try mark failed files = [f for f in files if RatioTaskService._filter_file(f, conditions)]
inst_res = await session.execute(select(RatioInstance).where(RatioInstance.id == instance_id)) return files
instance = inst_res.scalar_one_or_none()
if instance:
instance.status = "FAILED"
finally:
pass
finally:
await session.commit()
# ------------------------- helpers for TAG filtering ------------------------- # # ------------------------- helpers for TAG filtering ------------------------- #
@staticmethod @staticmethod
def _parse_required_tags(conditions: Optional[str]) -> set[str]: def _parse_conditions(conditions: Optional[str]) -> Optional[FilterCondition]:
"""Parse filter_conditions into a set of required tag strings. """Parse filter_conditions JSON string into a FilterCondition object.
Supports simple separators: comma, semicolon, space. Empty/None -> empty set. Args:
conditions: JSON string containing filter conditions
Returns:
FilterCondition object if conditions is not None/empty, otherwise None
""" """
if not conditions: if not conditions:
return set() return None
try:
data = json.loads(conditions) data = json.loads(conditions)
required_tags = set() return FilterCondition(**data)
if data.get("label"): except json.JSONDecodeError as e:
required_tags.add(data["label"]) logger.error(f"Failed to parse filter conditions: {e}")
return required_tags return None
except Exception as e:
logger.error(f"Error creating FilterCondition: {e}")
return None
@staticmethod @staticmethod
def _file_contains_tags(file: DatasetFiles, required: set[str]) -> bool: def _filter_file(file: DatasetFiles, conditions: FilterCondition) -> bool:
if not required: if not conditions:
return True return True
logger.info(f"start filter file: {file}, conditions: {conditions}")
# Check data range condition if provided
if conditions.date_range:
try:
from datetime import datetime, timedelta
data_range_days = int(conditions.date_range)
if data_range_days > 0:
cutoff_date = datetime.now() - timedelta(days=data_range_days)
if file.tags_updated_at and file.tags_updated_at < cutoff_date:
return False
except (ValueError, TypeError) as e:
logger.warning(f"Invalid data_range value: {conditions.date_range}", e)
return False
# Check label condition if provided
if conditions.label:
tags = file.tags tags = file.tags
if not tags: if not tags:
return False return False
try: try:
# tags could be a list of strings or list of objects with 'name' # tags could be a list of strings or list of objects with 'name'
tag_names = RatioTaskService.get_all_tags(tags) tag_names = RatioTaskService.get_all_tags(tags)
return required.issubset(tag_names) return conditions.label in tag_names
except Exception as e: except Exception as e:
logger.exception(f"Failed to get tags for {file}", e) logger.exception(f"Failed to get tags for {file}", e)
return False return False
return True
@staticmethod @staticmethod
def get_all_tags(tags) -> set[str]: def get_all_tags(tags) -> set[str]:
"""获取所有处理后的标签字符串列表""" """获取所有处理后的标签字符串列表"""