From cddfe9b149413455d60aca980a50888175dd2250 Mon Sep 17 00:00:00 2001
From: hefanli <76611805+hefanli@users.noreply.github.com>
Date: Thu, 20 Nov 2025 18:50:51 +0800
Subject: [PATCH] =?UTF-8?q?feature:=20=E6=95=B0=E6=8D=AE=E9=85=8D=E6=AF=94?=
=?UTF-8?q?=E5=A2=9E=E5=8A=A0=E9=80=9A=E8=BF=87=E6=9B=B4=E6=96=B0=E6=97=B6?=
=?UTF-8?q?=E9=97=B4=E6=9D=A5=E9=85=8D=E7=BD=AE=20(#95)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* feature: 数据配比增加通过更新时间来配置
* fix: 修复配比时间参数传递的问题
---
.../RatioTask/Create/CreateRatioTask.tsx | 28 +-
.../Create/components/BasicInformation.tsx | 2 +-
.../Create/components/RatioConfig.tsx | 375 +++++++++---------
.../Create/components/RatioTransfer.tsx | 169 --------
.../Create/components/SelectDataset.tsx | 66 ++-
.../app/module/shared/schema/__init__.py | 8 +-
.../app/module/shared/schema/common.py | 13 +-
.../module/synthesis/interface/ratio_task.py | 88 ++--
.../app/module/synthesis/schema/ratio_task.py | 40 +-
.../module/synthesis/service/ratio_task.py | 264 +++++++-----
10 files changed, 458 insertions(+), 595 deletions(-)
delete mode 100644 frontend/src/pages/RatioTask/Create/components/RatioTransfer.tsx
diff --git a/frontend/src/pages/RatioTask/Create/CreateRatioTask.tsx b/frontend/src/pages/RatioTask/Create/CreateRatioTask.tsx
index 2c94f0d..e27b4aa 100644
--- a/frontend/src/pages/RatioTask/Create/CreateRatioTask.tsx
+++ b/frontend/src/pages/RatioTask/Create/CreateRatioTask.tsx
@@ -7,7 +7,6 @@ import { useNavigate } from "react-router";
import SelectDataset from "@/pages/RatioTask/Create/components/SelectDataset.tsx";
import BasicInformation from "@/pages/RatioTask/Create/components/BasicInformation.tsx";
import RatioConfig from "@/pages/RatioTask/Create/components/RatioConfig.tsx";
-import RatioTransfer from "./components/RatioTransfer";
export default function CreateRatioTask() {
const navigate = useNavigate();
@@ -36,27 +35,12 @@ export default function CreateRatioTask() {
message.error("请配置配比项");
return;
}
- // Build request payload
- const ratio_method =
- ratioTaskForm.ratioType === "dataset" ? "DATASET" : "TAG";
const totals = String(values.totalTargetCount);
const config = ratioTaskForm.ratioConfigs.map((c) => {
- if (ratio_method === "DATASET") {
- return {
- datasetId: String(c.source),
- counts: String(c.quantity ?? 0),
- filter_conditions: "",
- };
- }
- // TAG mode: source key like `${datasetId}_${label}`
- const source = String(c.source || "");
- const idx = source.indexOf("_");
- const datasetId = idx > 0 ? source.slice(0, idx) : source;
- const label = idx > 0 ? source.slice(idx + 1) : "";
return {
- datasetId,
+ datasetId: c.id,
counts: String(c.quantity ?? 0),
- filter_conditions: label ? JSON.stringify({ label }) : "",
+ filterConditions: { label: c.labelFilter, dateRange: String(c.dateRange ?? 0)},
};
});
@@ -65,7 +49,6 @@ export default function CreateRatioTask() {
name: values.name,
description: values.description,
totals,
- ratio_method,
config,
});
message.success("配比任务创建成功");
@@ -108,13 +91,6 @@ export default function CreateRatioTask() {
totalTargetCount={ratioTaskForm.totalTargetCount}
/>
- {/* */}
-
= ({
-
+
);
diff --git a/frontend/src/pages/RatioTask/Create/components/RatioConfig.tsx b/frontend/src/pages/RatioTask/Create/components/RatioConfig.tsx
index ea0ed32..98e5f78 100644
--- a/frontend/src/pages/RatioTask/Create/components/RatioConfig.tsx
+++ b/frontend/src/pages/RatioTask/Create/components/RatioConfig.tsx
@@ -1,7 +1,19 @@
import React, { useMemo, useState } from "react";
-import { Badge, Card, Input, Progress, Button, Divider } from "antd";
-import { BarChart3 } from "lucide-react";
+import { Badge, Card, Input, Progress, Button, DatePicker, Select } from "antd";
+import { BarChart3, Filter, Clock } from "lucide-react";
import type { Dataset } from "@/pages/DataManagement/dataset.model.ts";
+import dayjs from 'dayjs';
+
+const { RangePicker } = DatePicker;
+const { Option } = Select;
+
+const TIME_RANGE_OPTIONS = [
+ { label: '最近1天', value: 1 },
+ { label: '最近3天', value: 3 },
+ { label: '最近7天', value: 7 },
+ { label: '最近15天', value: 15 },
+ { label: '最近30天', value: 30 },
+];
interface RatioConfigItem {
id: string;
@@ -10,6 +22,8 @@ interface RatioConfigItem {
quantity: number;
percentage: number;
source: string;
+ labelFilter?: string;
+ dateRange?: string;
}
interface RatioConfigProps {
@@ -22,14 +36,18 @@ interface RatioConfigProps {
}
const RatioConfig: React.FC = ({
- ratioType,
- selectedDatasets,
- datasets,
- totalTargetCount,
- distributions,
- onChange,
-}) => {
+ ratioType,
+ selectedDatasets,
+ datasets,
+ totalTargetCount,
+ distributions,
+ onChange,
+ }) => {
const [ratioConfigs, setRatioConfigs] = useState([]);
+ const [datasetFilters, setDatasetFilters] = useState>({});
// 配比项总数
const totalConfigured = useMemo(
@@ -37,6 +55,36 @@ const RatioConfig: React.FC = ({
[ratioConfigs]
);
+ // 获取数据集的标签列表
+ const getDatasetLabels = (datasetId: string): string[] => {
+ const dist = distributions[String(datasetId)] || {};
+ return Object.keys(dist);
+ };
+
+ // 自动平均分配
+ const generateAutoRatio = () => {
+ const selectedCount = selectedDatasets.length;
+ if (selectedCount === 0) return;
+ const baseQuantity = Math.floor(totalTargetCount / selectedCount);
+ const remainder = totalTargetCount % selectedCount;
+ const newConfigs = selectedDatasets.map((datasetId, index) => {
+ const dataset = datasets.find((d) => String(d.id) === datasetId);
+ const quantity = baseQuantity + (index < remainder ? 1 : 0);
+ return {
+ id: datasetId,
+ name: dataset?.name || datasetId,
+ type: ratioType,
+ quantity,
+ percentage: Math.round((quantity / totalTargetCount) * 100),
+ source: datasetId,
+ labelFilter: datasetFilters[datasetId]?.labelFilter,
+ dateRange: datasetFilters[datasetId]?.dateRange,
+ };
+ });
+ setRatioConfigs(newConfigs);
+ onChange?.(newConfigs);
+ };
+
// 更新数据集配比项
const updateDatasetQuantity = (datasetId: string, quantity: number) => {
setRatioConfigs((prev) => {
@@ -55,6 +103,8 @@ const RatioConfig: React.FC = ({
quantity: Math.min(quantity, totalTargetCount - totalOtherQuantity),
percentage: Math.round((quantity / totalTargetCount) * 100),
source: datasetId,
+ labelFilter: datasetFilters[datasetId]?.labelFilter,
+ dateRange: datasetFilters[datasetId]?.dateRange,
};
let newConfigs;
@@ -69,78 +119,85 @@ const RatioConfig: React.FC = ({
});
};
- // 自动平均分配
- const generateAutoRatio = () => {
- const selectedCount = selectedDatasets.length;
- if (selectedCount === 0) return;
- const baseQuantity = Math.floor(totalTargetCount / selectedCount);
- const remainder = totalTargetCount % selectedCount;
- const newConfigs = selectedDatasets.map((datasetId, index) => {
- const dataset = datasets.find((d) => String(d.id) === datasetId);
- const quantity = baseQuantity + (index < remainder ? 1 : 0);
- return {
- id: datasetId,
- name: dataset?.name || datasetId,
- type: ratioType,
- quantity,
- percentage: Math.round((quantity / totalTargetCount) * 100),
- source: datasetId,
- };
- });
- setRatioConfigs(newConfigs);
- onChange?.(newConfigs);
- };
-
- // 标签模式下,更新某数据集的某个标签的数量
- const updateLabelQuantity = (
- datasetId: string,
- label: string,
- quantity: number
- ) => {
- const sourceKey = `${datasetId}_${label}`;
- setRatioConfigs((prev) => {
- const existingIndex = prev.findIndex((c) => c.source === sourceKey);
- const totalOtherQuantity = prev
- .filter((c) => c.source !== sourceKey)
- .reduce((sum, c) => sum + c.quantity, 0);
- const dist = distributions[datasetId] || {};
- const labelMax = dist[label] ?? Infinity;
- const cappedQuantity = Math.max(
- 0,
- Math.min(quantity, totalTargetCount - totalOtherQuantity, labelMax)
- );
- const newConfig: RatioConfigItem = {
- id: sourceKey,
- name: label,
- type: "label",
- quantity: cappedQuantity,
- percentage: Math.round((cappedQuantity / totalTargetCount) * 100),
- source: sourceKey,
- };
- let newConfigs;
- if (existingIndex >= 0) {
- newConfigs = [...prev];
- newConfigs[existingIndex] = newConfig;
- } else {
- newConfigs = [...prev, newConfig];
+ // 更新筛选条件
+ const updateFilters = (datasetId: string, updates: {
+ labelFilter?: string;
+ dateRange?: [string, string];
+ }) => {
+ setDatasetFilters(prev => ({
+ ...prev,
+ [datasetId]: {
+ ...prev[datasetId],
+ ...updates,
}
- onChange?.(newConfigs);
- return newConfigs;
- });
+ }));
};
- // 选中数据集变化时,移除未选中的配比项
+ // 渲染筛选器
+ const renderFilters = (datasetId: string) => {
+ const labels = getDatasetLabels(datasetId);
+ const config = ratioConfigs.find(c => c.source === datasetId);
+ const filters = datasetFilters[datasetId] || {};
+
+ return (
+
+
+
+ 筛选条件
+
+
+
+
+
标签筛选
+
+
+
+
+
标签更新时间
+
+
+
+
+ );
+ };
+
+ // 选中数据集变化时,初始化筛选条件
React.useEffect(() => {
- setRatioConfigs((prev) => {
- const next = prev.filter((c) => {
- const id = String(c.source);
- const dsId = id.includes("_") ? id.split("_")[0] : id;
- return selectedDatasets.includes(dsId);
- });
- if (next !== prev) onChange?.(next);
- return next;
+ const initialFilters: Record = {};
+ selectedDatasets.forEach(datasetId => {
+ const config = ratioConfigs.find(c => c.source === datasetId);
+ if (config) {
+ initialFilters[datasetId] = {
+ labelFilter: config.labelFilter,
+ dateRange: config.dateRange,
+ };
+ }
});
- // eslint-disable-next-line
+ setDatasetFilters(prev => ({ ...prev, ...initialFilters }));
}, [selectedDatasets]);
return (
@@ -148,7 +205,7 @@ const RatioConfig: React.FC = ({
配比配置
-
+
(已配置:{totalConfigured}/{totalTargetCount}条)
@@ -170,41 +227,36 @@ const RatioConfig: React.FC = ({
{/* 配比预览 */}
{ratioConfigs.length > 0 && (
-
-
-
-
- 总配比数量:
-
- {ratioConfigs
- .reduce((sum, config) => sum + config.quantity, 0)
- .toLocaleString()}
-
-
-
- 目标数量:
-
- {totalTargetCount.toLocaleString()}
-
-
-
- 配比项目:
-
- {ratioConfigs.length}个
-
-
+
+
+
+ 总配比数量:
+
+ {ratioConfigs
+ .reduce((sum, config) => sum + config.quantity, 0)
+ .toLocaleString()}
+
+
+
+ 目标数量:
+
+ {totalTargetCount.toLocaleString()}
+
)}
-
+
+
{selectedDatasets.map((datasetId) => {
const dataset = datasets.find((d) => String(d.id) === datasetId);
const config = ratioConfigs.find((c) => c.source === datasetId);
const currentQuantity = config?.quantity || 0;
+
if (!dataset) return null;
+
return (
-
+
@@ -216,97 +268,36 @@ const RatioConfig: React.FC = ({
{config?.percentage || 0}%
- {ratioType === "dataset" ? (
-
-
- 数量:
-
- updateDatasetQuantity(
- datasetId,
- Number(e.target.value)
- )
- }
- style={{ width: 80 }}
- min={0}
- max={Math.min(
- dataset.fileCount || 0,
- totalTargetCount
- )}
- />
- 条
-
-
-
- ) : (
-
- {!distributions[String(dataset.id)] ? (
-
- 加载标签分布...
-
- ) : Object.entries(distributions[String(dataset.id)])
- .length === 0 ? (
-
- 该数据集暂无标签
-
- ) : (
-
- {Object.entries(
- distributions[String(dataset.id)]
- ).map(([label, count]) => {
- const sourceKey = `${datasetId}_${label}`;
- const labelConfig = ratioConfigs.find(
- (c) => c.source === sourceKey
- );
- const labelQuantity = labelConfig?.quantity || 0;
- return (
-
-
- {label}
-
- {count}条
-
-
-
- 数量:
-
- updateLabelQuantity(
- datasetId,
- label,
- Number(e.target.value)
- )
- }
- style={{ width: 80 }}
- min={0}
- max={Math.min(
- Number(count) || 0,
- totalTargetCount
- )}
- />
-
- 条
-
-
-
- );
- })}
-
+
+ {/* 筛选条件 */}
+ {renderFilters(datasetId)}
+
+
+ 数量:
+
+ updateDatasetQuantity(
+ datasetId,
+ Number(e.target.value)
+ )
+ }
+ style={{ width: 100 }}
+ min={0}
+ max={Math.min(
+ dataset.fileCount || 0,
+ totalTargetCount
)}
-
- )}
+ />
+
条
+
+
);
})}
diff --git a/frontend/src/pages/RatioTask/Create/components/RatioTransfer.tsx b/frontend/src/pages/RatioTask/Create/components/RatioTransfer.tsx
deleted file mode 100644
index d182725..0000000
--- a/frontend/src/pages/RatioTask/Create/components/RatioTransfer.tsx
+++ /dev/null
@@ -1,169 +0,0 @@
-import React, { useMemo } from "react";
-import { Table } from "antd";
-import { TransferItem } from "antd/es/transfer";
-import RatioConfig from "./RatioConfig";
-import useFetchData from "@/hooks/useFetchData";
-import { queryDatasetsUsingGet } from "@/pages/DataManagement/dataset.api";
-import {
- datasetTypeMap,
- mapDataset,
-} from "@/pages/DataManagement/dataset.const";
-import { SearchControls } from "@/components/SearchControls";
-
-const leftColumns = [
- {
- dataIndex: "name",
- title: "名称",
- ellipsis: true,
- },
- {
- dataIndex: "datasetType",
- title: "类型",
- ellipsis: true,
- width: 100,
- render: (type: string) => datasetTypeMap[type].label,
- },
- {
- dataIndex: "size",
- title: "大小",
- width: 100,
- ellipsis: true,
- },
-];
-
-export default function RatioTransfer(props: {
- distributions: Record>;
- ratioTaskForm: any;
- updateRatioConfig: (datasetId: string, quantity: number) => void;
- updateLabelRatioConfig: (
- datasetId: string,
- label: string,
- quantity: number
- ) => void;
-}) {
- const {
- updateLabelRatioConfig,
- updateRatioConfig,
- ratioTaskForm,
- distributions,
- } = props;
- const {
- tableData: datasets,
- loading,
- pagination,
- searchParams,
- setSearchParams,
- handleFiltersChange,
- } = useFetchData(queryDatasetsUsingGet, mapDataset);
-
- const [selectedDatasets, setSelectedDatasets] = React.useState<
- TransferItem[]
- >([]);
-
- const selectedRowKeys = useMemo(() => {
- return selectedDatasets.map((item) => item.key);
- }, [selectedDatasets]);
-
- const [listDisabled, setListDisabled] = React.useState(false);
-
- const generateAutoRatio = () => {
- const selectedCount = ratioTaskForm.selectedDatasets.length;
- if (selectedCount === 0) return;
-
- const baseQuantity = Math.floor(
- ratioTaskForm.totalTargetCount / selectedCount
- );
- const remainder = ratioTaskForm.totalTargetCount % selectedCount;
-
- const newConfigs = ratioTaskForm.selectedDatasets.map(
- (datasetId, index) => {
- const quantity = baseQuantity + (index < remainder ? 1 : 0);
- return {
- id: datasetId,
- name: datasetId,
- type: ratioTaskForm.ratioType,
- quantity,
- percentage: Math.round(
- (quantity / ratioTaskForm.totalTargetCount) * 100
- ),
- source: datasetId,
- };
- }
- );
- setRatioTaskForm((prev) => ({ ...prev, ratioConfigs: newConfigs }));
- };
-
- return (
-
-
-
{`${selectedDatasets.length} / ${datasets.length} 项`}
-
- setSearchParams({ ...searchParams, keyword })
- }
- searchPlaceholder="搜索数据集名称..."
- filters={[
- {
- key: "type",
- label: "数据集类型",
- options: [
- { value: "dataset", label: "按数据集" },
- { value: "tag", label: "按标签" },
- ],
- },
- ]}
- onFiltersChange={handleFiltersChange}
- onClearFilters={() =>
- setSearchParams({ ...searchParams, filter: {} })
- }
- showViewToggle={false}
- showReload={false}
- className="m-4"
- />
- {
- setSelectedDatasets(selectedRows);
- },
- selectedRowKeys,
- selections: [
- Table.SELECTION_ALL,
- Table.SELECTION_INVERT,
- Table.SELECTION_NONE,
- ],
- }}
- columns={leftColumns}
- dataSource={datasets}
- loading={loading}
- pagination={pagination}
- size="small"
- rowKey="id"
- style={{ pointerEvents: listDisabled ? "none" : undefined }}
- onRow={(record) => ({
- onClick: () => {
- if (record.disabled || listDisabled) {
- return;
- }
- setSelectedDatasets((prev) => {
- if (prev.includes(record.key)) {
- return prev.filter((k) => k !== record.key);
- }
- return [...prev, record.key];
- });
- },
- })}
- />
-
-
-
-
-
- );
-}
diff --git a/frontend/src/pages/RatioTask/Create/components/SelectDataset.tsx b/frontend/src/pages/RatioTask/Create/components/SelectDataset.tsx
index c99018e..ef04988 100644
--- a/frontend/src/pages/RatioTask/Create/components/SelectDataset.tsx
+++ b/frontend/src/pages/RatioTask/Create/components/SelectDataset.tsx
@@ -1,5 +1,5 @@
import React, { useEffect, useState } from "react";
-import { Badge, Button, Card, Checkbox, Input, Pagination, Select } from "antd";
+import { Badge, Button, Card, Checkbox, Input, Pagination } from "antd";
import { Search as SearchIcon } from "lucide-react";
import type { Dataset } from "@/pages/DataManagement/dataset.model.ts";
import {
@@ -10,8 +10,6 @@ import {
interface SelectDatasetProps {
selectedDatasets: string[];
- ratioType: "dataset" | "label";
- onRatioTypeChange: (val: "dataset" | "label") => void;
onSelectedDatasetsChange: (next: string[]) => void;
onDistributionsChange?: (
next: Record>
@@ -21,8 +19,6 @@ interface SelectDatasetProps {
const SelectDataset: React.FC = ({
selectedDatasets,
- ratioType,
- onRatioTypeChange,
onSelectedDatasetsChange,
onDistributionsChange,
onDatasetsChange,
@@ -62,7 +58,7 @@ const SelectDataset: React.FC = ({
// Fetch label distributions when in label mode
useEffect(() => {
const fetchDistributions = async () => {
- if (ratioType !== "label" || !datasets?.length) return;
+ if (!datasets?.length) return;
const idsToFetch = datasets
.map((d) => String(d.id))
.filter((id) => !distributions[id]);
@@ -147,7 +143,7 @@ const SelectDataset: React.FC = ({
};
fetchDistributions();
// eslint-disable-next-line react-hooks/exhaustive-deps
- }, [ratioType, datasets]);
+ }, [datasets]);
const onToggleDataset = (datasetId: string, checked: boolean) => {
if (checked) {
@@ -180,18 +176,6 @@ const SelectDataset: React.FC = ({
-
- 配比方式:
-
}
placeholder="搜索数据集"
@@ -239,32 +223,30 @@ const SelectDataset: React.FC
= ({
{dataset.fileCount}条
{dataset.size}
- {ratioType === "label" && (
-
- {distributions[idStr] ? (
- Object.entries(distributions[idStr]).length > 0 ? (
-
- {Object.entries(distributions[idStr])
- .slice(0, 8)
- .map(([tag, count]) => (
- {`${tag}: ${count}`}
- ))}
-
- ) : (
-
- 未检测到标签分布
-
- )
+
+ {distributions[idStr] ? (
+ Object.entries(distributions[idStr]).length > 0 ? (
+
+ {Object.entries(distributions[idStr])
+ .slice(0, 8)
+ .map(([tag, count]) => (
+ {`${tag}: ${count}`}
+ ))}
+
) : (
- 加载标签分布...
+ 未检测到标签分布
- )}
-
- )}
+ )
+ ) : (
+
+ 加载标签分布...
+
+ )}
+
diff --git a/runtime/datamate-python/app/module/shared/schema/__init__.py b/runtime/datamate-python/app/module/shared/schema/__init__.py
index 88cf3b0..f0c2765 100644
--- a/runtime/datamate-python/app/module/shared/schema/__init__.py
+++ b/runtime/datamate-python/app/module/shared/schema/__init__.py
@@ -1,11 +1,13 @@
from .common import (
BaseResponseModel,
StandardResponse,
- PaginatedData
+ PaginatedData,
+ TaskStatus
)
__all__ = [
"BaseResponseModel",
"StandardResponse",
- "PaginatedData"
-]
\ No newline at end of file
+ "PaginatedData",
+ "TaskStatus"
+]
diff --git a/runtime/datamate-python/app/module/shared/schema/common.py b/runtime/datamate-python/app/module/shared/schema/common.py
index c79231a..beaf1c7 100644
--- a/runtime/datamate-python/app/module/shared/schema/common.py
+++ b/runtime/datamate-python/app/module/shared/schema/common.py
@@ -1,8 +1,9 @@
"""
通用响应模型
"""
-from typing import Generic, TypeVar, Optional, List, Type
+from typing import Generic, TypeVar, List
from pydantic import BaseModel, Field
+from enum import Enum
# 定义泛型类型变量
T = TypeVar('T')
@@ -16,7 +17,7 @@ def to_camel(string: str) -> str:
class BaseResponseModel(BaseModel):
"""基础响应模型,启用别名生成器"""
-
+
class Config:
populate_by_name = True
alias_generator = to_camel
@@ -24,7 +25,7 @@ class BaseResponseModel(BaseModel):
class StandardResponse(BaseResponseModel, Generic[T]):
"""
标准API响应格式
-
+
所有API端点应返回此格式,确保响应的一致性
"""
code: int = Field(..., description="HTTP状态码")
@@ -42,3 +43,9 @@ class PaginatedData(BaseResponseModel, Generic[T]):
total_elements: int = Field(..., description="总条数")
total_pages: int = Field(..., description="总页数")
content: List[T] = Field(..., description="当前页数据")
+
+class TaskStatus(Enum):
+ PENDING = "PENDING"
+ RUNNING = "RUNNING"
+ COMPLETED = "COMPLETED"
+ FAILED = "FAILED"
diff --git a/runtime/datamate-python/app/module/synthesis/interface/ratio_task.py b/runtime/datamate-python/app/module/synthesis/interface/ratio_task.py
index acc262a..a67c5f7 100644
--- a/runtime/datamate-python/app/module/synthesis/interface/ratio_task.py
+++ b/runtime/datamate-python/app/module/synthesis/interface/ratio_task.py
@@ -12,7 +12,7 @@ from app.core.logging import get_logger
from app.db.models import Dataset
from app.db.session import get_db
from app.module.dataset import DatasetManagementService
-from app.module.shared.schema import StandardResponse
+from app.module.shared.schema import StandardResponse, TaskStatus
from app.module.synthesis.schema.ratio_task import (
CreateRatioTaskResponse,
CreateRatioTaskRequest,
@@ -49,52 +49,18 @@ async def create_ratio_task(
await valid_exists(db, req)
- # 创建目标数据集:名称使用“<任务名称>-配比生成-时间戳”
- target_dataset_name = f"{req.name}-配比生成-{datetime.now().strftime('%Y%m%d%H%M%S')}"
+ target_dataset = await create_target_dataset(db, req, source_types)
- target_type = get_target_dataset_type(source_types)
+ instance = await create_ratio_instance(db, req, target_dataset)
- target_dataset = Dataset(
- id=str(uuid.uuid4()),
- name=target_dataset_name,
- description=req.description or "",
- dataset_type=target_type,
- status="DRAFT",
- )
- target_dataset.path = f"/dataset/{target_dataset.id}"
- db.add(target_dataset)
- await db.flush() # 获取 target_dataset.id
-
- service = RatioTaskService(db)
- instance = await service.create_task(
- name=req.name,
- description=req.description,
- totals=int(req.totals),
- ratio_method=req.ratio_method,
- config=[
- {
- "dataset_id": item.dataset_id,
- "counts": int(item.counts),
- "filter_conditions": item.filter_conditions,
- }
- for item in req.config
- ],
- target_dataset_id=target_dataset.id,
- )
-
- # 异步执行配比任务(支持 DATASET / TAG)
asyncio.create_task(RatioTaskService.execute_dataset_ratio_task(instance.id))
- return StandardResponse(
- code=200,
- message="success",
- data=CreateRatioTaskResponse(
+ response_data = CreateRatioTaskResponse(
id=instance.id,
name=instance.name,
description=instance.description,
totals=instance.totals or 0,
- ratio_method=instance.ratio_method or req.ratio_method,
- status=instance.status or "PENDING",
+ status=instance.status or TaskStatus.PENDING.name,
config=req.config,
targetDataset=TargetDatasetInfo(
id=str(target_dataset.id),
@@ -103,6 +69,10 @@ async def create_ratio_task(
status=str(target_dataset.status),
)
)
+ return StandardResponse(
+ code=200,
+ message="success",
+ data=response_data
)
except HTTPException:
await db.rollback()
@@ -113,6 +83,46 @@ async def create_ratio_task(
raise HTTPException(status_code=500, detail="Internal server error")
+async def create_ratio_instance(db, req: CreateRatioTaskRequest, target_dataset: Dataset) -> RatioInstance:
+ service = RatioTaskService(db)
+ logger.info(f"create_ratio_instance: {req}")
+ instance = await service.create_task(
+ name=req.name,
+ description=req.description,
+ totals=int(req.totals),
+ config=[
+ {
+ "dataset_id": item.dataset_id,
+ "counts": int(item.counts),
+ "filter_conditions": item.filter_conditions,
+ }
+ for item in req.config
+ ],
+ target_dataset_id=target_dataset.id,
+ )
+ return instance
+
+
+async def create_target_dataset(db, req: CreateRatioTaskRequest, source_types: set[str]) -> Dataset:
+ # 创建目标数据集:名称使用“<任务名称>-时间戳”
+ target_dataset_name = f"{req.name}-{datetime.now().strftime('%Y%m%d%H%M%S')}"
+
+ target_type = get_target_dataset_type(source_types)
+ target_dataset_id = uuid.uuid4()
+
+ target_dataset = Dataset(
+ id=str(target_dataset_id),
+ name=target_dataset_name,
+ description=req.description or "",
+ dataset_type=target_type,
+ status="DRAFT",
+ path=f"/dataset/{target_dataset_id}",
+ )
+ db.add(target_dataset)
+ await db.flush() # 获取 target_dataset.id
+ return target_dataset
+
+
@router.get("", response_model=StandardResponse[PagedRatioTaskResponse], status_code=200)
async def list_ratio_tasks(
page: int = 1,
diff --git a/runtime/datamate-python/app/module/synthesis/schema/ratio_task.py b/runtime/datamate-python/app/module/synthesis/schema/ratio_task.py
index 743eb00..a781829 100644
--- a/runtime/datamate-python/app/module/synthesis/schema/ratio_task.py
+++ b/runtime/datamate-python/app/module/synthesis/schema/ratio_task.py
@@ -2,10 +2,36 @@ from typing import List, Optional, Dict, Any
from datetime import datetime
from pydantic import BaseModel, Field, field_validator
+from app.core.logging import get_logger
+from app.module.shared.schema.common import TaskStatus
+
+logger = get_logger(__name__)
+
+class FilterCondition(BaseModel):
+ date_range: Optional[str] = Field(None, description="数据范围", alias="dateRange")
+ label: Optional[str] = Field(None, description="标签")
+
+ @field_validator("date_range")
+ @classmethod
+ def validate_date_range(cls, v: Optional[str]) -> Optional[str]:
+ # ensure it's a numeric string if provided
+ if not v:
+ return v
+ try:
+ int(v)
+ return v
+ except (ValueError, TypeError) as e:
+ raise ValueError("date_range must be a numeric string")
+
+ class Config:
+ # allow population by field name when constructing model programmatically
+ validate_by_name = True
+
+
class RatioConfigItem(BaseModel):
dataset_id: str = Field(..., alias="datasetId", description="数据集id")
counts: str = Field(..., description="数量")
- filter_conditions: str = Field(..., description="过滤条件")
+ filter_conditions: FilterCondition = Field(..., alias="filterConditions", description="过滤条件")
@field_validator("counts")
@classmethod
@@ -22,17 +48,8 @@ class CreateRatioTaskRequest(BaseModel):
name: str = Field(..., description="名称")
description: Optional[str] = Field(None, description="描述")
totals: str = Field(..., description="目标数量")
- ratio_method: str = Field(..., description="配比方式", alias="ratio_method")
config: List[RatioConfigItem] = Field(..., description="配比设置列表")
- @field_validator("ratio_method")
- @classmethod
- def validate_ratio_method(cls, v: str) -> str:
- allowed = {"TAG", "DATASET"}
- if v not in allowed:
- raise ValueError(f"ratio_method must be one of {allowed}")
- return v
-
@field_validator("totals")
@classmethod
def validate_totals(cls, v: str) -> str:
@@ -58,8 +75,7 @@ class CreateRatioTaskResponse(BaseModel):
name: str
description: Optional[str] = None
totals: int
- ratio_method: str
- status: str
+ status: TaskStatus
# echoed config
config: List[RatioConfigItem]
# created dataset
diff --git a/runtime/datamate-python/app/module/synthesis/service/ratio_task.py b/runtime/datamate-python/app/module/synthesis/service/ratio_task.py
index c26c987..bd55061 100644
--- a/runtime/datamate-python/app/module/synthesis/service/ratio_task.py
+++ b/runtime/datamate-python/app/module/synthesis/service/ratio_task.py
@@ -14,6 +14,8 @@ from app.db.models.ratio_task import RatioInstance, RatioRelation
from app.db.models import Dataset, DatasetFiles
from app.db.session import AsyncSessionLocal
from app.module.dataset.schema.dataset_file import DatasetFileTag
+from app.module.shared.schema import TaskStatus
+from app.module.synthesis.schema.ratio_task import FilterCondition
logger = get_logger(__name__)
@@ -30,7 +32,6 @@ class RatioTaskService:
name: str,
description: Optional[str],
totals: int,
- ratio_method: str,
config: List[Dict[str, Any]],
target_dataset_id: Optional[str] = None,
) -> RatioInstance:
@@ -38,12 +39,11 @@ class RatioTaskService:
config item format: {"dataset_id": str, "counts": int, "filter_conditions": str}
"""
- logger.info(f"Creating ratio task: name={name}, method={ratio_method}, totals={totals}, items={len(config or [])}")
+ logger.info(f"Creating ratio task: name={name}, totals={totals}, items={len(config or [])}")
instance = RatioInstance(
name=name,
description=description,
- ratio_method=ratio_method,
totals=totals,
target_dataset_id=target_dataset_id,
status="PENDING",
@@ -56,8 +56,12 @@ class RatioTaskService:
ratio_instance_id=instance.id,
source_dataset_id=item.get("dataset_id"),
counts=int(item.get("counts", 0)),
- filter_conditions=item.get("filter_conditions"),
+ filter_conditions=json.dumps({
+ 'date_range': item.get("filter_conditions").date_range,
+ 'label': item.get("filter_conditions").label,
+ })
)
+ logger.info(f"Relation created: {relation.id}, {relation}, {item}, {config}")
self.db.add(relation)
await self.db.commit()
@@ -97,94 +101,17 @@ class RatioTaskService:
relations: List[RatioRelation] = list(rel_res.scalars().all())
# Mark running
- instance.status = "RUNNING"
-
- if instance.ratio_method not in {"DATASET", "TAG"}:
- logger.info(f"Instance {instance_id} ratio_method={instance.ratio_method} not supported yet")
- instance.status = "SUCCESS"
- return
+ instance.status = TaskStatus.RUNNING.name
# Load target dataset
ds_res = await session.execute(select(Dataset).where(Dataset.id == instance.target_dataset_id))
target_ds: Optional[Dataset] = ds_res.scalar_one_or_none()
if not target_ds:
logger.error(f"Target dataset not found for instance {instance_id}")
- instance.status = "FAILED"
+ instance.status = TaskStatus.FAILED.name
return
- # Preload existing target file paths for deduplication
- existing_path_rows = await session.execute(
- select(DatasetFiles.file_path).where(DatasetFiles.dataset_id == target_ds.id)
- )
- existing_paths = set(p for p in existing_path_rows.scalars().all() if p)
-
- added_count = 0
- added_size = 0
-
- for rel in relations:
- if not rel.source_dataset_id or not rel.counts or rel.counts <= 0:
- continue
-
- # Fetch all files for the source dataset (ACTIVE only)
- files_res = await session.execute(
- select(DatasetFiles).where(
- DatasetFiles.dataset_id == rel.source_dataset_id,
- DatasetFiles.status == "ACTIVE",
- )
- )
- files = list(files_res.scalars().all())
-
- # TAG mode: filter by tags according to relation.filter_conditions
- if instance.ratio_method == "TAG":
- required_tags = RatioTaskService._parse_required_tags(rel.filter_conditions)
- if required_tags:
- files = [f for f in files if RatioTaskService._file_contains_tags(f, required_tags)]
-
- if not files:
- continue
-
- pick_n = min(rel.counts or 0, len(files))
- chosen = random.sample(files, pick_n) if pick_n < len(files) else files
-
- # Copy into target dataset with de-dup by target path
- for f in chosen:
- src_path = f.file_path
- new_path = src_path
- needs_copy = False
- src_prefix = f"/dataset/{rel.source_dataset_id}"
- if isinstance(src_path, str) and src_path.startswith(src_prefix):
- dst_prefix = f"/dataset/{target_ds.id}"
- new_path = src_path.replace(src_prefix, dst_prefix, 1)
- needs_copy = True
-
- # De-dup by target path
- if new_path in existing_paths:
- continue
-
- # Perform copy only when needed
- if needs_copy:
- dst_dir = os.path.dirname(new_path)
- await asyncio.to_thread(os.makedirs, dst_dir, exist_ok=True)
- await asyncio.to_thread(shutil.copy2, src_path, new_path)
-
- new_file = DatasetFiles(
- dataset_id=target_ds.id, # type: ignore
- file_name=f.file_name,
- file_path=new_path,
- file_type=f.file_type,
- file_size=f.file_size,
- check_sum=f.check_sum,
- tags=f.tags,
- dataset_filemetadata=f.dataset_filemetadata,
- status="ACTIVE",
- )
- session.add(new_file)
- existing_paths.add(new_path)
- added_count += 1
- added_size += int(f.file_size or 0)
-
- # Periodically flush to avoid huge transactions
- await session.flush()
+ added_count, added_size = await RatioTaskService.handle_ratio_relations(relations,session, target_ds)
# Update target dataset statistics
target_ds.file_count = (target_ds.file_count or 0) + added_count # type: ignore
@@ -194,8 +121,8 @@ class RatioTaskService:
target_ds.status = "ACTIVE"
# Done
- instance.status = "SUCCESS"
- logger.info(f"Dataset ratio execution completed: instance={instance_id}, files={added_count}, size={added_size}")
+ instance.status = TaskStatus.COMPLETED.name
+ logger.info(f"Dataset ratio execution completed: instance={instance_id}, files={added_count}, size={added_size}, {instance.status}")
except Exception as e:
logger.exception(f"Dataset ratio execution failed for {instance_id}: {e}")
@@ -204,42 +131,163 @@ class RatioTaskService:
inst_res = await session.execute(select(RatioInstance).where(RatioInstance.id == instance_id))
instance = inst_res.scalar_one_or_none()
if instance:
- instance.status = "FAILED"
+ instance.status = TaskStatus.FAILED.name
finally:
pass
finally:
await session.commit()
+ @staticmethod
+ async def handle_ratio_relations(relations: list[RatioRelation], session, target_ds: Dataset) -> tuple[int, int]:
+ # Preload existing target file paths for deduplication
+ existing_path_rows = await session.execute(
+ select(DatasetFiles.file_path).where(DatasetFiles.dataset_id == target_ds.id)
+ )
+ existing_paths = set(p for p in existing_path_rows.scalars().all() if p)
+
+ added_count = 0
+ added_size = 0
+
+ for rel in relations:
+ if not rel.source_dataset_id or not rel.counts or rel.counts <= 0:
+ continue
+
+ files = await RatioTaskService.get_files(rel, session)
+
+ if not files:
+ continue
+
+ pick_n = min(rel.counts or 0, len(files))
+ chosen = random.sample(files, pick_n) if pick_n < len(files) else files
+
+ # Copy into target dataset with de-dup by target path
+ for f in chosen:
+ await RatioTaskService.handle_selected_file(existing_paths, f, session, target_ds)
+ added_count += 1
+ added_size += int(f.file_size or 0)
+
+ # Periodically flush to avoid huge transactions
+ await session.flush()
+ return added_count, added_size
+
+ @staticmethod
+ async def handle_selected_file(existing_paths: set[Any], f, session, target_ds: Dataset):
+ src_path = f.file_path
+ dst_prefix = f"/dataset/{target_ds.id}"
+ file_name = RatioTaskService.get_new_file_name(dst_prefix, existing_paths, f)
+
+ new_path = dst_prefix + file_name
+ dst_dir = os.path.dirname(new_path)
+ await asyncio.to_thread(os.makedirs, dst_dir, exist_ok=True)
+ await asyncio.to_thread(shutil.copy2, src_path, new_path)
+
+ new_file = DatasetFiles(
+ dataset_id=target_ds.id, # type: ignore
+ file_name=file_name,
+ file_path=new_path,
+ file_type=f.file_type,
+ file_size=f.file_size,
+ check_sum=f.check_sum,
+ tags=f.tags,
+ dataset_filemetadata=f.dataset_filemetadata,
+ status="ACTIVE",
+ )
+ session.add(new_file)
+ existing_paths.add(new_path)
+
+ @staticmethod
+ def get_new_file_name(dst_prefix: str, existing_paths: set[Any], f) -> str:
+ file_name = f.file_name
+ new_path = dst_prefix + file_name
+
+ # Handle file path conflicts by appending a number to the filename
+ if new_path in existing_paths:
+ file_name_base, file_ext = os.path.splitext(file_name)
+ counter = 1
+ original_file_name = file_name
+ while new_path in existing_paths:
+ file_name = f"{file_name_base}_{counter}{file_ext}"
+ new_path = f"{dst_prefix}{file_name}"
+ counter += 1
+ if counter > 1000: # Safety check to prevent infinite loops
+ logger.error(f"Could not find unique filename for {original_file_name} after 1000 attempts")
+ break
+ return file_name
+
+ @staticmethod
+ async def get_files(rel: RatioRelation, session) -> list[Any]:
+ # Fetch all files for the source dataset (ACTIVE only)
+ files_res = await session.execute(
+ select(DatasetFiles).where(
+ DatasetFiles.dataset_id == rel.source_dataset_id,
+ DatasetFiles.status == "ACTIVE",
+ )
+ )
+ files = list(files_res.scalars().all())
+
+ # TAG mode: filter by tags according to relation.filter_conditions
+ conditions = RatioTaskService._parse_conditions(rel.filter_conditions)
+ if conditions:
+ files = [f for f in files if RatioTaskService._filter_file(f, conditions)]
+ return files
+
# ------------------------- helpers for TAG filtering ------------------------- #
@staticmethod
- def _parse_required_tags(conditions: Optional[str]) -> set[str]:
- """Parse filter_conditions into a set of required tag strings.
+ def _parse_conditions(conditions: Optional[str]) -> Optional[FilterCondition]:
+ """Parse filter_conditions JSON string into a FilterCondition object.
- Supports simple separators: comma, semicolon, space. Empty/None -> empty set.
+ Args:
+ conditions: JSON string containing filter conditions
+
+ Returns:
+ FilterCondition object if conditions is not None/empty, otherwise None
"""
if not conditions:
- return set()
- data = json.loads(conditions)
- required_tags = set()
- if data.get("label"):
- required_tags.add(data["label"])
- return required_tags
+ return None
+ try:
+ data = json.loads(conditions)
+ return FilterCondition(**data)
+ except json.JSONDecodeError as e:
+ logger.error(f"Failed to parse filter conditions: {e}")
+ return None
+ except Exception as e:
+ logger.error(f"Error creating FilterCondition: {e}")
+ return None
@staticmethod
- def _file_contains_tags(file: DatasetFiles, required: set[str]) -> bool:
- if not required:
+ def _filter_file(file: DatasetFiles, conditions: FilterCondition) -> bool:
+ if not conditions:
return True
- tags = file.tags
- if not tags:
- return False
- try:
- # tags could be a list of strings or list of objects with 'name'
- tag_names = RatioTaskService.get_all_tags(tags)
- return required.issubset(tag_names)
- except Exception as e:
- logger.exception(f"Failed to get tags for {file}", e)
- return False
+ logger.info(f"start filter file: {file}, conditions: {conditions}")
+
+ # Check data range condition if provided
+ if conditions.date_range:
+ try:
+ from datetime import datetime, timedelta
+ data_range_days = int(conditions.date_range)
+ if data_range_days > 0:
+ cutoff_date = datetime.now() - timedelta(days=data_range_days)
+ if file.tags_updated_at and file.tags_updated_at < cutoff_date:
+ return False
+ except (ValueError, TypeError) as e:
+ logger.warning(f"Invalid data_range value: {conditions.date_range}", e)
+ return False
+
+ # Check label condition if provided
+ if conditions.label:
+ tags = file.tags
+ if not tags:
+ return False
+ try:
+ # tags could be a list of strings or list of objects with 'name'
+ tag_names = RatioTaskService.get_all_tags(tags)
+ return conditions.label in tag_names
+ except Exception as e:
+ logger.exception(f"Failed to get tags for {file}", e)
+ return False
+
+ return True
@staticmethod
def get_all_tags(tags) -> set[str]: