feature: 数据配比增加通过更新时间来配置 (#95)

* feature: 数据配比增加通过更新时间来配置

* fix: 修复配比时间参数传递的问题
This commit is contained in:
hefanli
2025-11-20 18:50:51 +08:00
committed by GitHub
parent 955ffff6cd
commit cddfe9b149
10 changed files with 458 additions and 595 deletions

View File

@@ -7,7 +7,6 @@ import { useNavigate } from "react-router";
import SelectDataset from "@/pages/RatioTask/Create/components/SelectDataset.tsx";
import BasicInformation from "@/pages/RatioTask/Create/components/BasicInformation.tsx";
import RatioConfig from "@/pages/RatioTask/Create/components/RatioConfig.tsx";
import RatioTransfer from "./components/RatioTransfer";
export default function CreateRatioTask() {
const navigate = useNavigate();
@@ -36,27 +35,12 @@ export default function CreateRatioTask() {
message.error("请配置配比项");
return;
}
// Build request payload
const ratio_method =
ratioTaskForm.ratioType === "dataset" ? "DATASET" : "TAG";
const totals = String(values.totalTargetCount);
const config = ratioTaskForm.ratioConfigs.map((c) => {
if (ratio_method === "DATASET") {
return {
datasetId: String(c.source),
datasetId: c.id,
counts: String(c.quantity ?? 0),
filter_conditions: "",
};
}
// TAG mode: source key like `${datasetId}_${label}`
const source = String(c.source || "");
const idx = source.indexOf("_");
const datasetId = idx > 0 ? source.slice(0, idx) : source;
const label = idx > 0 ? source.slice(idx + 1) : "";
return {
datasetId,
counts: String(c.quantity ?? 0),
filter_conditions: label ? JSON.stringify({ label }) : "",
filterConditions: { label: c.labelFilter, dateRange: String(c.dateRange ?? 0)},
};
});
@@ -65,7 +49,6 @@ export default function CreateRatioTask() {
name: values.name,
description: values.description,
totals,
ratio_method,
config,
});
message.success("配比任务创建成功");
@@ -108,13 +91,6 @@ export default function CreateRatioTask() {
totalTargetCount={ratioTaskForm.totalTargetCount}
/>
{/* <RatioTransfer
ratioTaskForm={ratioTaskForm}
distributions={distributions}
updateRatioConfig={updateRatioConfig}
updateLabelRatioConfig={updateLabelRatioConfig}
/> */}
<div className="flex h-full">
<SelectDataset
selectedDatasets={ratioTaskForm.selectedDatasets}

View File

@@ -27,7 +27,7 @@ const BasicInformation: React.FC<BasicInformationProps> = ({
<Input type="number" placeholder="目标总数量" min={1} />
</Form.Item>
<Form.Item label="任务描述" name="description" className="col-span-2">
<TextArea placeholder="描述配比任务的目的和要求(可选)" rows={2} />
<TextArea placeholder="描述配比任务的目的和要求" rows={2} />
</Form.Item>
</div>
);

View File

@@ -1,7 +1,19 @@
import React, { useMemo, useState } from "react";
import { Badge, Card, Input, Progress, Button, Divider } from "antd";
import { BarChart3 } from "lucide-react";
import { Badge, Card, Input, Progress, Button, DatePicker, Select } from "antd";
import { BarChart3, Filter, Clock } from "lucide-react";
import type { Dataset } from "@/pages/DataManagement/dataset.model.ts";
import dayjs from 'dayjs';
const { RangePicker } = DatePicker;
const { Option } = Select;
const TIME_RANGE_OPTIONS = [
{ label: '最近1天', value: 1 },
{ label: '最近3天', value: 3 },
{ label: '最近7天', value: 7 },
{ label: '最近15天', value: 15 },
{ label: '最近30天', value: 30 },
];
interface RatioConfigItem {
id: string;
@@ -10,6 +22,8 @@ interface RatioConfigItem {
quantity: number;
percentage: number;
source: string;
labelFilter?: string;
dateRange?: string;
}
interface RatioConfigProps {
@@ -28,8 +42,12 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
totalTargetCount,
distributions,
onChange,
}) => {
}) => {
const [ratioConfigs, setRatioConfigs] = useState<RatioConfigItem[]>([]);
const [datasetFilters, setDatasetFilters] = useState<Record<string, {
labelFilter?: string;
dateRange?: string;
}>>({});
// 配比项总数
const totalConfigured = useMemo(
@@ -37,6 +55,36 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
[ratioConfigs]
);
// 获取数据集的标签列表
const getDatasetLabels = (datasetId: string): string[] => {
const dist = distributions[String(datasetId)] || {};
return Object.keys(dist);
};
// 自动平均分配
const generateAutoRatio = () => {
const selectedCount = selectedDatasets.length;
if (selectedCount === 0) return;
const baseQuantity = Math.floor(totalTargetCount / selectedCount);
const remainder = totalTargetCount % selectedCount;
const newConfigs = selectedDatasets.map((datasetId, index) => {
const dataset = datasets.find((d) => String(d.id) === datasetId);
const quantity = baseQuantity + (index < remainder ? 1 : 0);
return {
id: datasetId,
name: dataset?.name || datasetId,
type: ratioType,
quantity,
percentage: Math.round((quantity / totalTargetCount) * 100),
source: datasetId,
labelFilter: datasetFilters[datasetId]?.labelFilter,
dateRange: datasetFilters[datasetId]?.dateRange,
};
});
setRatioConfigs(newConfigs);
onChange?.(newConfigs);
};
// 更新数据集配比项
const updateDatasetQuantity = (datasetId: string, quantity: number) => {
setRatioConfigs((prev) => {
@@ -55,6 +103,8 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
quantity: Math.min(quantity, totalTargetCount - totalOtherQuantity),
percentage: Math.round((quantity / totalTargetCount) * 100),
source: datasetId,
labelFilter: datasetFilters[datasetId]?.labelFilter,
dateRange: datasetFilters[datasetId]?.dateRange,
};
let newConfigs;
@@ -69,78 +119,85 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
});
};
// 自动平均分配
const generateAutoRatio = () => {
const selectedCount = selectedDatasets.length;
if (selectedCount === 0) return;
const baseQuantity = Math.floor(totalTargetCount / selectedCount);
const remainder = totalTargetCount % selectedCount;
const newConfigs = selectedDatasets.map((datasetId, index) => {
const dataset = datasets.find((d) => String(d.id) === datasetId);
const quantity = baseQuantity + (index < remainder ? 1 : 0);
return {
id: datasetId,
name: dataset?.name || datasetId,
type: ratioType,
quantity,
percentage: Math.round((quantity / totalTargetCount) * 100),
source: datasetId,
};
});
setRatioConfigs(newConfigs);
onChange?.(newConfigs);
// 更新筛选条件
const updateFilters = (datasetId: string, updates: {
labelFilter?: string;
dateRange?: [string, string];
}) => {
setDatasetFilters(prev => ({
...prev,
[datasetId]: {
...prev[datasetId],
...updates,
}
}));
};
// 标签模式下,更新某数据集的某个标签的数量
const updateLabelQuantity = (
datasetId: string,
label: string,
quantity: number
) => {
const sourceKey = `${datasetId}_${label}`;
setRatioConfigs((prev) => {
const existingIndex = prev.findIndex((c) => c.source === sourceKey);
const totalOtherQuantity = prev
.filter((c) => c.source !== sourceKey)
.reduce((sum, c) => sum + c.quantity, 0);
const dist = distributions[datasetId] || {};
const labelMax = dist[label] ?? Infinity;
const cappedQuantity = Math.max(
0,
Math.min(quantity, totalTargetCount - totalOtherQuantity, labelMax)
// 渲染筛选器
const renderFilters = (datasetId: string) => {
const labels = getDatasetLabels(datasetId);
const config = ratioConfigs.find(c => c.source === datasetId);
const filters = datasetFilters[datasetId] || {};
return (
<div className="mb-3 p-2 bg-gray-50 rounded">
<div className="flex items-center gap-2 mb-2">
<Filter size={14} className="text-gray-400" />
<span className="text-xs font-medium"></span>
</div>
<div className="grid grid-cols-1 md:grid-cols-2 gap-3">
<div>
<div className="text-xs text-gray-500 mb-1"></div>
<Select
style={{ width: '100%' }}
placeholder="选择标签"
value={filters.labelFilter}
onChange={(value) => updateFilters(datasetId, { labelFilter: value })}
allowClear
onClear={() => updateFilters(datasetId, { labelFilter: undefined })}
>
{labels.map(label => (
<Option key={label} value={label}>{label}</Option>
))}
</Select>
</div>
<div>
<div className="text-xs text-gray-500 mb-1"></div>
<Select
style={{ width: '100%' }}
placeholder="选择标签更新时间"
value={filters.dateRange}
onChange={(dates) => updateFilters(datasetId, { dateRange: dates })}
allowClear
onClear={() => updateFilters(datasetId, { dateRange: undefined })}
>
{TIME_RANGE_OPTIONS.map(option => (
<Option key={option.value} value={option.value}>
{option.label}
</Option>
))}
</Select>
</div>
</div>
</div>
);
const newConfig: RatioConfigItem = {
id: sourceKey,
name: label,
type: "label",
quantity: cappedQuantity,
percentage: Math.round((cappedQuantity / totalTargetCount) * 100),
source: sourceKey,
};
let newConfigs;
if (existingIndex >= 0) {
newConfigs = [...prev];
newConfigs[existingIndex] = newConfig;
} else {
newConfigs = [...prev, newConfig];
}
onChange?.(newConfigs);
return newConfigs;
});
};
// 选中数据集变化时,移除未选中的配比项
// 选中数据集变化时,初始化筛选条件
React.useEffect(() => {
setRatioConfigs((prev) => {
const next = prev.filter((c) => {
const id = String(c.source);
const dsId = id.includes("_") ? id.split("_")[0] : id;
return selectedDatasets.includes(dsId);
const initialFilters: Record<string, any> = {};
selectedDatasets.forEach(datasetId => {
const config = ratioConfigs.find(c => c.source === datasetId);
if (config) {
initialFilters[datasetId] = {
labelFilter: config.labelFilter,
dateRange: config.dateRange,
};
}
});
if (next !== prev) onChange?.(next);
return next;
});
// eslint-disable-next-line
setDatasetFilters(prev => ({ ...prev, ...initialFilters }));
}, [selectedDatasets]);
return (
@@ -148,7 +205,7 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
<div className="flex items-center justify-between p-4 border-bottom">
<span className="text-sm font-bold">
<span className="text-xs text-gray-500">
<span className="text-xs text-gray-500 ml-1">
(:{totalConfigured}/{totalTargetCount})
</span>
</span>
@@ -170,8 +227,7 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
<div className="flex-overflow-auto gap-4 p-4">
{/* 配比预览 */}
{ratioConfigs.length > 0 && (
<div>
<div className="p-3 bg-gray-50 rounded-lg">
<div className="p-3 bg-gray-50 rounded-lg mb-4">
<div className="grid grid-cols-2 gap-4 text-sm">
<div>
<span className="text-gray-500">:</span>
@@ -187,24 +243,20 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
{totalTargetCount.toLocaleString()}
</span>
</div>
<div>
<span className="text-gray-500">:</span>
<span className="ml-2 font-medium">
{ratioConfigs.length}
</span>
</div>
</div>
</div>
</div>
)}
<div className="flex-1 overflow-auto">
<div className="flex-1 overflow-auto space-y-4">
{selectedDatasets.map((datasetId) => {
const dataset = datasets.find((d) => String(d.id) === datasetId);
const config = ratioConfigs.find((c) => c.source === datasetId);
const currentQuantity = config?.quantity || 0;
if (!dataset) return null;
return (
<Card key={datasetId} size="small" className="mb-2">
<Card key={datasetId} size="small" className="mb-4">
<div className="flex items-center justify-between mb-3">
<div className="flex items-center gap-2">
<span className="font-medium text-sm">
@@ -216,8 +268,10 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
{config?.percentage || 0}%
</div>
</div>
{ratioType === "dataset" ? (
<div>
{/* 筛选条件 */}
{renderFilters(datasetId)}
<div className="flex items-center gap-2 mb-2">
<span className="text-xs">:</span>
<Input
@@ -229,7 +283,7 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
Number(e.target.value)
)
}
style={{ width: 80 }}
style={{ width: 100 }}
min={0}
max={Math.min(
dataset.fileCount || 0,
@@ -244,69 +298,6 @@ const RatioConfig: React.FC<RatioConfigProps> = ({
)}
size="small"
/>
</div>
) : (
<div>
{!distributions[String(dataset.id)] ? (
<div className="text-xs text-gray-400">
...
</div>
) : Object.entries(distributions[String(dataset.id)])
.length === 0 ? (
<div className="text-xs text-gray-400">
</div>
) : (
<div className="flex flex-col gap-2">
{Object.entries(
distributions[String(dataset.id)]
).map(([label, count]) => {
const sourceKey = `${datasetId}_${label}`;
const labelConfig = ratioConfigs.find(
(c) => c.source === sourceKey
);
const labelQuantity = labelConfig?.quantity || 0;
return (
<div
key={label}
className="flex items-center justify-between gap-2"
>
<div className="flex items-center gap-2">
<Badge color="gray">{label}</Badge>
<span className="text-xs text-gray-500">
{count}
</span>
</div>
<div className="flex items-center gap-2">
<span className="text-xs">:</span>
<Input
type="number"
value={labelQuantity}
onChange={(e) =>
updateLabelQuantity(
datasetId,
label,
Number(e.target.value)
)
}
style={{ width: 80 }}
min={0}
max={Math.min(
Number(count) || 0,
totalTargetCount
)}
/>
<span className="text-xs text-gray-500">
</span>
</div>
</div>
);
})}
</div>
)}
</div>
)}
</Card>
);
})}

View File

@@ -1,169 +0,0 @@
import React, { useMemo } from "react";
import { Table } from "antd";
import { TransferItem } from "antd/es/transfer";
import RatioConfig from "./RatioConfig";
import useFetchData from "@/hooks/useFetchData";
import { queryDatasetsUsingGet } from "@/pages/DataManagement/dataset.api";
import {
datasetTypeMap,
mapDataset,
} from "@/pages/DataManagement/dataset.const";
import { SearchControls } from "@/components/SearchControls";
const leftColumns = [
{
dataIndex: "name",
title: "名称",
ellipsis: true,
},
{
dataIndex: "datasetType",
title: "类型",
ellipsis: true,
width: 100,
render: (type: string) => datasetTypeMap[type].label,
},
{
dataIndex: "size",
title: "大小",
width: 100,
ellipsis: true,
},
];
export default function RatioTransfer(props: {
distributions: Record<string, Record<string, number>>;
ratioTaskForm: any;
updateRatioConfig: (datasetId: string, quantity: number) => void;
updateLabelRatioConfig: (
datasetId: string,
label: string,
quantity: number
) => void;
}) {
const {
updateLabelRatioConfig,
updateRatioConfig,
ratioTaskForm,
distributions,
} = props;
const {
tableData: datasets,
loading,
pagination,
searchParams,
setSearchParams,
handleFiltersChange,
} = useFetchData(queryDatasetsUsingGet, mapDataset);
const [selectedDatasets, setSelectedDatasets] = React.useState<
TransferItem[]
>([]);
const selectedRowKeys = useMemo(() => {
return selectedDatasets.map((item) => item.key);
}, [selectedDatasets]);
const [listDisabled, setListDisabled] = React.useState(false);
const generateAutoRatio = () => {
const selectedCount = ratioTaskForm.selectedDatasets.length;
if (selectedCount === 0) return;
const baseQuantity = Math.floor(
ratioTaskForm.totalTargetCount / selectedCount
);
const remainder = ratioTaskForm.totalTargetCount % selectedCount;
const newConfigs = ratioTaskForm.selectedDatasets.map(
(datasetId, index) => {
const quantity = baseQuantity + (index < remainder ? 1 : 0);
return {
id: datasetId,
name: datasetId,
type: ratioTaskForm.ratioType,
quantity,
percentage: Math.round(
(quantity / ratioTaskForm.totalTargetCount) * 100
),
source: datasetId,
};
}
);
setRatioTaskForm((prev) => ({ ...prev, ratioConfigs: newConfigs }));
};
return (
<div className="flex">
<div className="border-card flex-1 mr-4">
<h3 className="p-2 border-bottom">{`${selectedDatasets.length} / ${datasets.length}`}</h3>
<SearchControls
searchTerm={searchParams.keyword}
onSearchChange={(keyword) =>
setSearchParams({ ...searchParams, keyword })
}
searchPlaceholder="搜索数据集名称..."
filters={[
{
key: "type",
label: "数据集类型",
options: [
{ value: "dataset", label: "按数据集" },
{ value: "tag", label: "按标签" },
],
},
]}
onFiltersChange={handleFiltersChange}
onClearFilters={() =>
setSearchParams({ ...searchParams, filter: {} })
}
showViewToggle={false}
showReload={false}
className="m-4"
/>
<Table
rowSelection={{
onChange: (_, selectedRows) => {
setSelectedDatasets(selectedRows);
},
selectedRowKeys,
selections: [
Table.SELECTION_ALL,
Table.SELECTION_INVERT,
Table.SELECTION_NONE,
],
}}
columns={leftColumns}
dataSource={datasets}
loading={loading}
pagination={pagination}
size="small"
rowKey="id"
style={{ pointerEvents: listDisabled ? "none" : undefined }}
onRow={(record) => ({
onClick: () => {
if (record.disabled || listDisabled) {
return;
}
setSelectedDatasets((prev) => {
if (prev.includes(record.key)) {
return prev.filter((k) => k !== record.key);
}
return [...prev, record.key];
});
},
})}
/>
</div>
<div className="border-card flex-1">
<RatioConfig
datasets={selectedDatasets}
ratioTaskForm={ratioTaskForm}
distributions={distributions}
onUpdateDatasetQuantity={updateRatioConfig}
onUpdateLabelQuantity={updateLabelRatioConfig}
/>
</div>
</div>
);
}

View File

@@ -1,5 +1,5 @@
import React, { useEffect, useState } from "react";
import { Badge, Button, Card, Checkbox, Input, Pagination, Select } from "antd";
import { Badge, Button, Card, Checkbox, Input, Pagination } from "antd";
import { Search as SearchIcon } from "lucide-react";
import type { Dataset } from "@/pages/DataManagement/dataset.model.ts";
import {
@@ -10,8 +10,6 @@ import {
interface SelectDatasetProps {
selectedDatasets: string[];
ratioType: "dataset" | "label";
onRatioTypeChange: (val: "dataset" | "label") => void;
onSelectedDatasetsChange: (next: string[]) => void;
onDistributionsChange?: (
next: Record<string, Record<string, number>>
@@ -21,8 +19,6 @@ interface SelectDatasetProps {
const SelectDataset: React.FC<SelectDatasetProps> = ({
selectedDatasets,
ratioType,
onRatioTypeChange,
onSelectedDatasetsChange,
onDistributionsChange,
onDatasetsChange,
@@ -62,7 +58,7 @@ const SelectDataset: React.FC<SelectDatasetProps> = ({
// Fetch label distributions when in label mode
useEffect(() => {
const fetchDistributions = async () => {
if (ratioType !== "label" || !datasets?.length) return;
if (!datasets?.length) return;
const idsToFetch = datasets
.map((d) => String(d.id))
.filter((id) => !distributions[id]);
@@ -147,7 +143,7 @@ const SelectDataset: React.FC<SelectDatasetProps> = ({
};
fetchDistributions();
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [ratioType, datasets]);
}, [datasets]);
const onToggleDataset = (datasetId: string, checked: boolean) => {
if (checked) {
@@ -180,18 +176,6 @@ const SelectDataset: React.FC<SelectDatasetProps> = ({
</Button>
</div>
<div className="flex-overflow-auto gap-4 p-4">
<div className="flex items-center gap-4">
<span className="text-sm">:</span>
<Select
className="flex-1 min-w-[120px]"
value={ratioType}
onChange={(v) => onRatioTypeChange(v)}
options={[
{ label: "按数据集", value: "dataset" },
{ label: "按标签", value: "label" },
]}
/>
</div>
<Input
prefix={<SearchIcon className="text-gray-400" />}
placeholder="搜索数据集"
@@ -239,7 +223,6 @@ const SelectDataset: React.FC<SelectDatasetProps> = ({
<span>{dataset.fileCount}</span>
<span>{dataset.size}</span>
</div>
{ratioType === "label" && (
<div className="mt-2">
{distributions[idStr] ? (
Object.entries(distributions[idStr]).length > 0 ? (
@@ -264,7 +247,6 @@ const SelectDataset: React.FC<SelectDatasetProps> = ({
</div>
)}
</div>
)}
</div>
</div>
</Card>

View File

@@ -1,11 +1,13 @@
from .common import (
BaseResponseModel,
StandardResponse,
PaginatedData
PaginatedData,
TaskStatus
)
__all__ = [
"BaseResponseModel",
"StandardResponse",
"PaginatedData"
"PaginatedData",
"TaskStatus"
]

View File

@@ -1,8 +1,9 @@
"""
通用响应模型
"""
from typing import Generic, TypeVar, Optional, List, Type
from typing import Generic, TypeVar, List
from pydantic import BaseModel, Field
from enum import Enum
# 定义泛型类型变量
T = TypeVar('T')
@@ -42,3 +43,9 @@ class PaginatedData(BaseResponseModel, Generic[T]):
total_elements: int = Field(..., description="总条数")
total_pages: int = Field(..., description="总页数")
content: List[T] = Field(..., description="当前页数据")
class TaskStatus(Enum):
PENDING = "PENDING"
RUNNING = "RUNNING"
COMPLETED = "COMPLETED"
FAILED = "FAILED"

View File

@@ -12,7 +12,7 @@ from app.core.logging import get_logger
from app.db.models import Dataset
from app.db.session import get_db
from app.module.dataset import DatasetManagementService
from app.module.shared.schema import StandardResponse
from app.module.shared.schema import StandardResponse, TaskStatus
from app.module.synthesis.schema.ratio_task import (
CreateRatioTaskResponse,
CreateRatioTaskRequest,
@@ -49,28 +49,47 @@ async def create_ratio_task(
await valid_exists(db, req)
# 创建目标数据集:名称使用“<任务名称>-配比生成-时间戳”
target_dataset_name = f"{req.name}-配比生成-{datetime.now().strftime('%Y%m%d%H%M%S')}"
target_dataset = await create_target_dataset(db, req, source_types)
target_type = get_target_dataset_type(source_types)
instance = await create_ratio_instance(db, req, target_dataset)
target_dataset = Dataset(
id=str(uuid.uuid4()),
name=target_dataset_name,
description=req.description or "",
dataset_type=target_type,
status="DRAFT",
asyncio.create_task(RatioTaskService.execute_dataset_ratio_task(instance.id))
response_data = CreateRatioTaskResponse(
id=instance.id,
name=instance.name,
description=instance.description,
totals=instance.totals or 0,
status=instance.status or TaskStatus.PENDING.name,
config=req.config,
targetDataset=TargetDatasetInfo(
id=str(target_dataset.id),
name=str(target_dataset.name),
datasetType=str(target_dataset.dataset_type),
status=str(target_dataset.status),
)
target_dataset.path = f"/dataset/{target_dataset.id}"
db.add(target_dataset)
await db.flush() # 获取 target_dataset.id
)
return StandardResponse(
code=200,
message="success",
data=response_data
)
except HTTPException:
await db.rollback()
raise
except Exception as e:
await db.rollback()
logger.error(f"Failed to create ratio task: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
async def create_ratio_instance(db, req: CreateRatioTaskRequest, target_dataset: Dataset) -> RatioInstance:
service = RatioTaskService(db)
logger.info(f"create_ratio_instance: {req}")
instance = await service.create_task(
name=req.name,
description=req.description,
totals=int(req.totals),
ratio_method=req.ratio_method,
config=[
{
"dataset_id": item.dataset_id,
@@ -81,36 +100,27 @@ async def create_ratio_task(
],
target_dataset_id=target_dataset.id,
)
return instance
# 异步执行配比任务(支持 DATASET / TAG)
asyncio.create_task(RatioTaskService.execute_dataset_ratio_task(instance.id))
return StandardResponse(
code=200,
message="success",
data=CreateRatioTaskResponse(
id=instance.id,
name=instance.name,
description=instance.description,
totals=instance.totals or 0,
ratio_method=instance.ratio_method or req.ratio_method,
status=instance.status or "PENDING",
config=req.config,
targetDataset=TargetDatasetInfo(
id=str(target_dataset.id),
name=str(target_dataset.name),
datasetType=str(target_dataset.dataset_type),
status=str(target_dataset.status),
async def create_target_dataset(db, req: CreateRatioTaskRequest, source_types: set[str]) -> Dataset:
# 创建目标数据集:名称使用“<任务名称>-时间戳”
target_dataset_name = f"{req.name}-{datetime.now().strftime('%Y%m%d%H%M%S')}"
target_type = get_target_dataset_type(source_types)
target_dataset_id = uuid.uuid4()
target_dataset = Dataset(
id=str(target_dataset_id),
name=target_dataset_name,
description=req.description or "",
dataset_type=target_type,
status="DRAFT",
path=f"/dataset/{target_dataset_id}",
)
)
)
except HTTPException:
await db.rollback()
raise
except Exception as e:
await db.rollback()
logger.error(f"Failed to create ratio task: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
db.add(target_dataset)
await db.flush() # 获取 target_dataset.id
return target_dataset
@router.get("", response_model=StandardResponse[PagedRatioTaskResponse], status_code=200)

View File

@@ -2,10 +2,36 @@ from typing import List, Optional, Dict, Any
from datetime import datetime
from pydantic import BaseModel, Field, field_validator
from app.core.logging import get_logger
from app.module.shared.schema.common import TaskStatus
logger = get_logger(__name__)
class FilterCondition(BaseModel):
date_range: Optional[str] = Field(None, description="数据范围", alias="dateRange")
label: Optional[str] = Field(None, description="标签")
@field_validator("date_range")
@classmethod
def validate_date_range(cls, v: Optional[str]) -> Optional[str]:
# ensure it's a numeric string if provided
if not v:
return v
try:
int(v)
return v
except (ValueError, TypeError) as e:
raise ValueError("date_range must be a numeric string")
class Config:
# allow population by field name when constructing model programmatically
validate_by_name = True
class RatioConfigItem(BaseModel):
dataset_id: str = Field(..., alias="datasetId", description="数据集id")
counts: str = Field(..., description="数量")
filter_conditions: str = Field(..., description="过滤条件")
filter_conditions: FilterCondition = Field(..., alias="filterConditions", description="过滤条件")
@field_validator("counts")
@classmethod
@@ -22,17 +48,8 @@ class CreateRatioTaskRequest(BaseModel):
name: str = Field(..., description="名称")
description: Optional[str] = Field(None, description="描述")
totals: str = Field(..., description="目标数量")
ratio_method: str = Field(..., description="配比方式", alias="ratio_method")
config: List[RatioConfigItem] = Field(..., description="配比设置列表")
@field_validator("ratio_method")
@classmethod
def validate_ratio_method(cls, v: str) -> str:
allowed = {"TAG", "DATASET"}
if v not in allowed:
raise ValueError(f"ratio_method must be one of {allowed}")
return v
@field_validator("totals")
@classmethod
def validate_totals(cls, v: str) -> str:
@@ -58,8 +75,7 @@ class CreateRatioTaskResponse(BaseModel):
name: str
description: Optional[str] = None
totals: int
ratio_method: str
status: str
status: TaskStatus
# echoed config
config: List[RatioConfigItem]
# created dataset

View File

@@ -14,6 +14,8 @@ from app.db.models.ratio_task import RatioInstance, RatioRelation
from app.db.models import Dataset, DatasetFiles
from app.db.session import AsyncSessionLocal
from app.module.dataset.schema.dataset_file import DatasetFileTag
from app.module.shared.schema import TaskStatus
from app.module.synthesis.schema.ratio_task import FilterCondition
logger = get_logger(__name__)
@@ -30,7 +32,6 @@ class RatioTaskService:
name: str,
description: Optional[str],
totals: int,
ratio_method: str,
config: List[Dict[str, Any]],
target_dataset_id: Optional[str] = None,
) -> RatioInstance:
@@ -38,12 +39,11 @@ class RatioTaskService:
config item format: {"dataset_id": str, "counts": int, "filter_conditions": str}
"""
logger.info(f"Creating ratio task: name={name}, method={ratio_method}, totals={totals}, items={len(config or [])}")
logger.info(f"Creating ratio task: name={name}, totals={totals}, items={len(config or [])}")
instance = RatioInstance(
name=name,
description=description,
ratio_method=ratio_method,
totals=totals,
target_dataset_id=target_dataset_id,
status="PENDING",
@@ -56,8 +56,12 @@ class RatioTaskService:
ratio_instance_id=instance.id,
source_dataset_id=item.get("dataset_id"),
counts=int(item.get("counts", 0)),
filter_conditions=item.get("filter_conditions"),
filter_conditions=json.dumps({
'date_range': item.get("filter_conditions").date_range,
'label': item.get("filter_conditions").label,
})
)
logger.info(f"Relation created: {relation.id}, {relation}, {item}, {config}")
self.db.add(relation)
await self.db.commit()
@@ -97,21 +101,44 @@ class RatioTaskService:
relations: List[RatioRelation] = list(rel_res.scalars().all())
# Mark running
instance.status = "RUNNING"
if instance.ratio_method not in {"DATASET", "TAG"}:
logger.info(f"Instance {instance_id} ratio_method={instance.ratio_method} not supported yet")
instance.status = "SUCCESS"
return
instance.status = TaskStatus.RUNNING.name
# Load target dataset
ds_res = await session.execute(select(Dataset).where(Dataset.id == instance.target_dataset_id))
target_ds: Optional[Dataset] = ds_res.scalar_one_or_none()
if not target_ds:
logger.error(f"Target dataset not found for instance {instance_id}")
instance.status = "FAILED"
instance.status = TaskStatus.FAILED.name
return
added_count, added_size = await RatioTaskService.handle_ratio_relations(relations,session, target_ds)
# Update target dataset statistics
target_ds.file_count = (target_ds.file_count or 0) + added_count # type: ignore
target_ds.size_bytes = (target_ds.size_bytes or 0) + added_size # type: ignore
# If target dataset has files, mark it ACTIVE
if (target_ds.file_count or 0) > 0: # type: ignore
target_ds.status = "ACTIVE"
# Done
instance.status = TaskStatus.COMPLETED.name
logger.info(f"Dataset ratio execution completed: instance={instance_id}, files={added_count}, size={added_size}, {instance.status}")
except Exception as e:
logger.exception(f"Dataset ratio execution failed for {instance_id}: {e}")
try:
# Try mark failed
inst_res = await session.execute(select(RatioInstance).where(RatioInstance.id == instance_id))
instance = inst_res.scalar_one_or_none()
if instance:
instance.status = TaskStatus.FAILED.name
finally:
pass
finally:
await session.commit()
@staticmethod
async def handle_ratio_relations(relations: list[RatioRelation], session, target_ds: Dataset) -> tuple[int, int]:
# Preload existing target file paths for deduplication
existing_path_rows = await session.execute(
select(DatasetFiles.file_path).where(DatasetFiles.dataset_id == target_ds.id)
@@ -125,20 +152,7 @@ class RatioTaskService:
if not rel.source_dataset_id or not rel.counts or rel.counts <= 0:
continue
# Fetch all files for the source dataset (ACTIVE only)
files_res = await session.execute(
select(DatasetFiles).where(
DatasetFiles.dataset_id == rel.source_dataset_id,
DatasetFiles.status == "ACTIVE",
)
)
files = list(files_res.scalars().all())
# TAG mode: filter by tags according to relation.filter_conditions
if instance.ratio_method == "TAG":
required_tags = RatioTaskService._parse_required_tags(rel.filter_conditions)
if required_tags:
files = [f for f in files if RatioTaskService._file_contains_tags(f, required_tags)]
files = await RatioTaskService.get_files(rel, session)
if not files:
continue
@@ -148,28 +162,28 @@ class RatioTaskService:
# Copy into target dataset with de-dup by target path
for f in chosen:
await RatioTaskService.handle_selected_file(existing_paths, f, session, target_ds)
added_count += 1
added_size += int(f.file_size or 0)
# Periodically flush to avoid huge transactions
await session.flush()
return added_count, added_size
@staticmethod
async def handle_selected_file(existing_paths: set[Any], f, session, target_ds: Dataset):
src_path = f.file_path
new_path = src_path
needs_copy = False
src_prefix = f"/dataset/{rel.source_dataset_id}"
if isinstance(src_path, str) and src_path.startswith(src_prefix):
dst_prefix = f"/dataset/{target_ds.id}"
new_path = src_path.replace(src_prefix, dst_prefix, 1)
needs_copy = True
file_name = RatioTaskService.get_new_file_name(dst_prefix, existing_paths, f)
# De-dup by target path
if new_path in existing_paths:
continue
# Perform copy only when needed
if needs_copy:
new_path = dst_prefix + file_name
dst_dir = os.path.dirname(new_path)
await asyncio.to_thread(os.makedirs, dst_dir, exist_ok=True)
await asyncio.to_thread(shutil.copy2, src_path, new_path)
new_file = DatasetFiles(
dataset_id=target_ds.id, # type: ignore
file_name=f.file_name,
file_name=file_name,
file_path=new_path,
file_type=f.file_type,
file_size=f.file_size,
@@ -180,67 +194,101 @@ class RatioTaskService:
)
session.add(new_file)
existing_paths.add(new_path)
added_count += 1
added_size += int(f.file_size or 0)
# Periodically flush to avoid huge transactions
await session.flush()
@staticmethod
def get_new_file_name(dst_prefix: str, existing_paths: set[Any], f) -> str:
file_name = f.file_name
new_path = dst_prefix + file_name
# Update target dataset statistics
target_ds.file_count = (target_ds.file_count or 0) + added_count # type: ignore
target_ds.size_bytes = (target_ds.size_bytes or 0) + added_size # type: ignore
# If target dataset has files, mark it ACTIVE
if (target_ds.file_count or 0) > 0: # type: ignore
target_ds.status = "ACTIVE"
# Handle file path conflicts by appending a number to the filename
if new_path in existing_paths:
file_name_base, file_ext = os.path.splitext(file_name)
counter = 1
original_file_name = file_name
while new_path in existing_paths:
file_name = f"{file_name_base}_{counter}{file_ext}"
new_path = f"{dst_prefix}{file_name}"
counter += 1
if counter > 1000: # Safety check to prevent infinite loops
logger.error(f"Could not find unique filename for {original_file_name} after 1000 attempts")
break
return file_name
# Done
instance.status = "SUCCESS"
logger.info(f"Dataset ratio execution completed: instance={instance_id}, files={added_count}, size={added_size}")
@staticmethod
async def get_files(rel: RatioRelation, session) -> list[Any]:
# Fetch all files for the source dataset (ACTIVE only)
files_res = await session.execute(
select(DatasetFiles).where(
DatasetFiles.dataset_id == rel.source_dataset_id,
DatasetFiles.status == "ACTIVE",
)
)
files = list(files_res.scalars().all())
except Exception as e:
logger.exception(f"Dataset ratio execution failed for {instance_id}: {e}")
try:
# Try mark failed
inst_res = await session.execute(select(RatioInstance).where(RatioInstance.id == instance_id))
instance = inst_res.scalar_one_or_none()
if instance:
instance.status = "FAILED"
finally:
pass
finally:
await session.commit()
# TAG mode: filter by tags according to relation.filter_conditions
conditions = RatioTaskService._parse_conditions(rel.filter_conditions)
if conditions:
files = [f for f in files if RatioTaskService._filter_file(f, conditions)]
return files
# ------------------------- helpers for TAG filtering ------------------------- #
@staticmethod
def _parse_required_tags(conditions: Optional[str]) -> set[str]:
"""Parse filter_conditions into a set of required tag strings.
def _parse_conditions(conditions: Optional[str]) -> Optional[FilterCondition]:
"""Parse filter_conditions JSON string into a FilterCondition object.
Supports simple separators: comma, semicolon, space. Empty/None -> empty set.
Args:
conditions: JSON string containing filter conditions
Returns:
FilterCondition object if conditions is not None/empty, otherwise None
"""
if not conditions:
return set()
return None
try:
data = json.loads(conditions)
required_tags = set()
if data.get("label"):
required_tags.add(data["label"])
return required_tags
return FilterCondition(**data)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse filter conditions: {e}")
return None
except Exception as e:
logger.error(f"Error creating FilterCondition: {e}")
return None
@staticmethod
def _file_contains_tags(file: DatasetFiles, required: set[str]) -> bool:
if not required:
def _filter_file(file: DatasetFiles, conditions: FilterCondition) -> bool:
if not conditions:
return True
logger.info(f"start filter file: {file}, conditions: {conditions}")
# Check data range condition if provided
if conditions.date_range:
try:
from datetime import datetime, timedelta
data_range_days = int(conditions.date_range)
if data_range_days > 0:
cutoff_date = datetime.now() - timedelta(days=data_range_days)
if file.tags_updated_at and file.tags_updated_at < cutoff_date:
return False
except (ValueError, TypeError) as e:
logger.warning(f"Invalid data_range value: {conditions.date_range}", e)
return False
# Check label condition if provided
if conditions.label:
tags = file.tags
if not tags:
return False
try:
# tags could be a list of strings or list of objects with 'name'
tag_names = RatioTaskService.get_all_tags(tags)
return required.issubset(tag_names)
return conditions.label in tag_names
except Exception as e:
logger.exception(f"Failed to get tags for {file}", e)
return False
return True
@staticmethod
def get_all_tags(tags) -> set[str]:
"""获取所有处理后的标签字符串列表"""