feat:问题生成过程优化及COT数据生成优化 (#169)

* fix(chart): update Helm chart helpers and values for improved configuration

* feat(SynthesisTaskTab): enhance task table with tooltip support and improved column widths

* feat(CreateTask, SynthFileTask): improve task creation and detail view with enhanced payload handling and UI updates

* feat(SynthFileTask): enhance file display with progress tracking and delete action

* feat(SynthFileTask): enhance file display with progress tracking and delete action

* feat(SynthDataDetail): add delete action for chunks with confirmation prompt

* feat(SynthDataDetail): update edit and delete buttons to icon-only format

* feat(SynthDataDetail): add confirmation modals for chunk and synthesis data deletion

* feat(DocumentSplitter): add enhanced document splitting functionality with CJK support and metadata detection

* feat(DataSynthesis): refactor data synthesis models and update task handling logic

* feat(DataSynthesis): streamline synthesis task handling and enhance chunk processing logic

* feat(DataSynthesis): refactor data synthesis models and update task handling logic

* fix(generation_service): ensure processed chunks are incremented regardless of question generation success

* feat(CreateTask): enhance task creation with new synthesis templates and improved configuration options

* feat(CreateTask): enhance task creation with new synthesis templates and improved configuration options

* feat(CreateTask): enhance task creation with new synthesis templates and improved configuration options

* feat(CreateTask): enhance task creation with new synthesis templates and improved configuration options
This commit is contained in:
Dallas98
2025-12-18 16:51:18 +08:00
committed by GitHub
parent 761f7f6a51
commit e0e9b1d94d
14 changed files with 1362 additions and 571 deletions

View File

@@ -1,7 +1,7 @@
import { useEffect, useState } from "react"; import { useEffect, useState } from "react";
import type { Dataset, DatasetFile } from "@/pages/DataManagement/dataset.model"; import type { Dataset, DatasetFile } from "@/pages/DataManagement/dataset.model";
import { Steps, Card, Select, Input, Checkbox, Button, Form, message } from "antd"; import { Steps, Card, Select, Input, Button, Form, message, Tag, Tooltip, InputNumber } from "antd";
import { Eye, ArrowLeft, ArrowRight, Play, Search, MoreHorizontal } from "lucide-react"; import { Eye, ArrowLeft, ArrowRight, Play, Search, Sparkles, Brain, Layers } from "lucide-react";
import { Link, useNavigate } from "react-router"; import { Link, useNavigate } from "react-router";
import { queryDatasetsUsingGet } from "../DataManagement/dataset.api"; import { queryDatasetsUsingGet } from "../DataManagement/dataset.api";
import DatasetFileTransfer from "@/components/business/DatasetFileTransfer"; import DatasetFileTransfer from "@/components/business/DatasetFileTransfer";
@@ -31,13 +31,18 @@ export default function SynthesisTaskCreate() {
const [selectedFiles, setSelectedFiles] = useState<string[]>([]); const [selectedFiles, setSelectedFiles] = useState<string[]>([]);
const [selectedMap, setSelectedMap] = useState<Record<string, DatasetFile>>({}); const [selectedMap, setSelectedMap] = useState<Record<string, DatasetFile>>({});
const [selectedDataset, setSelectedDataset] = useState<Dataset | null>(null); const [selectedDataset, setSelectedDataset] = useState<Dataset | null>(null);
// 当前选中的模板类型(QA / COT),用于高亮展示
const [selectedSynthesisTypes, setSelectedSynthesisTypes] = useState<string[]>(["qa"]); const [selectedSynthesisTypes, setSelectedSynthesisTypes] = useState<string[]>(["qa"]);
const [taskType, setTaskType] = useState<"qa" | "cot">("qa"); const [taskType, setTaskType] = useState<"qa" | "cot">("qa");
const [promptTemplate, setPromptTemplate] = useState<string>(""); const [questionPrompt, setQuestionPrompt] = useState<string>("");
const [answerPrompt, setAnswerPrompt] = useState<string>("");
const [submitting, setSubmitting] = useState(false); const [submitting, setSubmitting] = useState(false);
const [modelOptions, setModelOptions] = useState<{ label: string; value: string }[]>([]); const [modelOptions, setModelOptions] = useState<{ label: string; value: string }[]>([]);
const [modelsLoading, setModelsLoading] = useState(false); const [modelsLoading, setModelsLoading] = useState(false);
const [selectedModel, setSelectedModel] = useState<string | undefined>(undefined); const [questionModelId, setQuestionModelId] = useState<string | undefined>(undefined);
const [answerModelId, setAnswerModelId] = useState<string | undefined>(undefined);
// 文本切片配置
const [sliceConfig, setSliceConfig] = useState({ const [sliceConfig, setSliceConfig] = useState({
processType: "DEFAULT_CHUNK" as processType: "DEFAULT_CHUNK" as
| "DEFAULT_CHUNK" | "DEFAULT_CHUNK"
@@ -45,10 +50,23 @@ export default function SynthesisTaskCreate() {
| "PARAGRAPH_CHUNK" | "PARAGRAPH_CHUNK"
| "FIXED_LENGTH_CHUNK" | "FIXED_LENGTH_CHUNK"
| "CUSTOM_SEPARATOR_CHUNK", | "CUSTOM_SEPARATOR_CHUNK",
chunkSize: 500, chunkSize: 3000,
overlapSize: 50, overlapSize: 100,
delimiter: "", delimiter: "",
}); });
// 问题/答案合成配置(与后端 question_synth_config / answer_synth_config 对齐)
const [questionConfig, setQuestionConfig] = useState({
number: 1,
temperature: 0.7,
});
const [answerConfig, setAnswerConfig] = useState({
// 答案侧不再需要 number,只保留温度
temperature: 0.7,
});
// 合成总数上限,默认 5000
const [maxQaPairs, setMaxQaPairs] = useState<number | undefined>(5000);
const sliceOptions = [ const sliceOptions = [
{ label: "默认分块", value: "DEFAULT_CHUNK" }, { label: "默认分块", value: "DEFAULT_CHUNK" },
{ label: "按章节分块", value: "CHAPTER_CHUNK" }, { label: "按章节分块", value: "CHAPTER_CHUNK" },
@@ -62,33 +80,43 @@ export default function SynthesisTaskCreate() {
return data; return data;
}; };
const fetchPrompt = async (type: "qa" | "cot") => { // 问题 Prompt:固定使用 QUESTION 类型获取
const fetchQuestionPrompt = async () => {
try { try {
const synthTypeParam = type.toUpperCase(); const res = await getPromptByTypeUsingGet("QUESTION");
const res = await getPromptByTypeUsingGet(synthTypeParam);
const prompt = typeof res === "string" ? res : (res as { data?: string })?.data ?? ""; const prompt = typeof res === "string" ? res : (res as { data?: string })?.data ?? "";
setPromptTemplate(prompt || ""); setQuestionPrompt(prompt || "");
} catch (e) { } catch (e) {
console.error(e); console.error(e);
message.error("获取提示词模板失败"); message.error("获取问题 Prompt 模板失败");
setPromptTemplate(""); setQuestionPrompt("");
} }
}; };
useEffect(() => { // 答案 Prompt:根据当前任务类型获取 QA/COT 模板
fetchDatasets(); const fetchAnswerPrompt = async (type: "qa" | "cot") => {
}, []); try {
const synthTypeParam = type === "qa" ? "QA" : "COT";
useEffect(() => { const res = await getPromptByTypeUsingGet(synthTypeParam);
fetchPrompt(taskType); const prompt = typeof res === "string" ? res : (res as { data?: string })?.data ?? "";
}, [taskType]); setAnswerPrompt(prompt || "");
} catch (e) {
console.error(e);
message.error("获取答案 Prompt 模板失败");
setAnswerPrompt("");
}
};
// 拉取模型列表,仅保留 CHAT 模型
useEffect(() => { useEffect(() => {
const loadModels = async () => { const loadModels = async () => {
setModelsLoading(true); setModelsLoading(true);
try { try {
const { data } = await queryModelListUsingGet({ page: 0, size: 1000 }); const { data } = await queryModelListUsingGet({ page: 0, size: 1000 });
const options = (data?.content || []).map((model: ModelI) => ({ const chatModels: ModelI[] = (data?.content || []).filter(
(model: ModelI) => model.type === "CHAT"
);
const options = chatModels.map((model) => ({
label: `${model.modelName} (${model.provider})`, label: `${model.modelName} (${model.provider})`,
value: model.id, value: model.id,
})); }));
@@ -102,11 +130,22 @@ export default function SynthesisTaskCreate() {
loadModels(); loadModels();
}, []); }, []);
// 默认选中第一个 CHAT 模型作为问题/答案模型
useEffect(() => { useEffect(() => {
if (!selectedModel && modelOptions.length > 0) { if (modelOptions.length > 0) {
setSelectedModel(modelOptions[0].value); setQuestionModelId((prev) => prev ?? modelOptions[0].value);
setAnswerModelId((prev) => prev ?? modelOptions[0].value);
} }
}, [modelOptions, selectedModel]); }, [modelOptions]);
useEffect(() => {
fetchDatasets();
}, []);
useEffect(() => {
fetchQuestionPrompt();
fetchAnswerPrompt(taskType);
}, [taskType]);
// 表单数据 // 表单数据
const [formValues, setFormValues] = useState<CreateTaskFormValues>({ const [formValues, setFormValues] = useState<CreateTaskFormValues>({
@@ -131,13 +170,12 @@ export default function SynthesisTaskCreate() {
const handleCreateTask = async () => { const handleCreateTask = async () => {
try { try {
const values = (await form.validateFields()) as CreateTaskFormValues; const values = (await form.validateFields()) as CreateTaskFormValues;
// precise validation
if (!(taskType === "qa" || taskType === "cot")) { if (!(taskType === "qa" || taskType === "cot")) {
message.error("请选择一个合成类型"); message.error("请选择一个合成类型");
return; return;
} }
if (!selectedModel) { if (!questionModelId || !answerModelId) {
message.error("请选择模型"); message.error("请选择问题和答案使用的模型");
return; return;
} }
if (selectedFiles.length === 0) { if (selectedFiles.length === 0) {
@@ -145,25 +183,42 @@ export default function SynthesisTaskCreate() {
return; return;
} }
// 构造后端要求的参数格式 const synthConfig: Record<string, unknown> = {
const payload: Record<string, unknown> = {
name: values.name || form.getFieldValue("name"),
model_id: selectedModel,
source_file_id: selectedFiles,
text_split_config: { text_split_config: {
chunk_size: sliceConfig.chunkSize, chunk_size: sliceConfig.chunkSize,
chunk_overlap: sliceConfig.overlapSize, chunk_overlap: sliceConfig.overlapSize,
}, },
synthesis_config: { question_synth_config: {
prompt_template: promptTemplate, model_id: questionModelId,
prompt_template: questionPrompt,
number: questionConfig.number,
temperature: questionConfig.temperature,
}, },
synthesis_type: taskType === "qa" ? "QA" : "COT", answer_synth_config: {
model_id: answerModelId,
prompt_template: answerPrompt,
temperature: answerConfig.temperature,
},
max_qa_pairs: typeof maxQaPairs === "number" && maxQaPairs > 0 ? maxQaPairs : undefined,
}; };
// 只有在有真实内容时携带 description,避免强制传空字符串 const payload: Record<string, unknown> = {
const desc = values.description ?? form.getFieldValue("description"); name: values.name || form.getFieldValue("name"),
if (typeof desc === "string" && desc.trim().length > 0) { description: values.description ?? form.getFieldValue("description"),
payload.description = desc.trim(); synthesis_type: taskType === "qa" ? "QA" : "COT",
source_file_id: selectedFiles,
synth_config: synthConfig,
};
// 清洗 description:空字符串转为 undefined,让后端用 validator 处理为 None
const desc = payload.description;
if (typeof desc === "string" && desc.trim().length === 0) {
delete payload.description;
}
// 如果未设置 max_qa_pairs,则从 synth_config 中移除该字段,避免传递 undefined
if (synthConfig.max_qa_pairs === undefined) {
delete (synthConfig as { max_qa_pairs?: number }).max_qa_pairs;
} }
setSubmitting(true); setSubmitting(true);
@@ -187,25 +242,43 @@ export default function SynthesisTaskCreate() {
return; return;
} }
console.error(error); console.error(error);
message.error((error instanceof Error ? error.message : "合成任务创建失败")); message.error(error instanceof Error ? error.message : "合成任务创建失败");
} finally { } finally {
setSubmitting(false); setSubmitting(false);
} }
}; };
// 仅两个一级类型,无二级目录 // 仅两个一级类型,无二级目录 -> 扩展为模板配置
const synthesisTypes = [ const synthesisTemplates = [
{ id: "qa", name: "生成问答对" }, {
{ id: "cot", name: "生成COT链式推理" }, id: "sft-qa",
] as const; type: "qa" as const,
title: "SFT 问答数据合成",
subtitle: "从长文档自动生成高质量问答样本",
badge: "推荐",
description:
"适用于构建监督微调(SFT)问答数据集,支持从知识库或长文档中抽取关键问答对。",
colorClass: "from-sky-500/10 via-sky-400/5 to-transparent",
borderClass: "border-sky-100 hover:border-sky-300",
icon: Sparkles,
},
{
id: "cot-reasoning",
type: "cot" as const,
title: "COT 链式推理合成",
subtitle: "一步步推理过程与最终答案",
badge: "推理增强",
description:
"生成包含模型推理中间过程的 COT 数据,用于提升模型的复杂推理和解释能力。",
colorClass: "from-violet-500/10 via-violet-400/5 to-transparent",
borderClass: "border-violet-100 hover:border-violet-300",
icon: Brain,
},
];
const handleSynthesisTypeSelect = (typeId: "qa" | "cot") => { const handleTemplateClick = (tpl: (typeof synthesisTemplates)[number]) => {
setSelectedSynthesisTypes((prev) => { setTaskType(tpl.type);
const next = prev.includes(typeId) ? [] : [typeId]; setSelectedSynthesisTypes([tpl.type]);
if (next[0] === "qa") setTaskType("qa");
if (next[0] === "cot") setTaskType("cot");
return next;
});
}; };
useEffect(() => { useEffect(() => {
@@ -247,120 +320,374 @@ export default function SynthesisTaskCreate() {
if (createStep === 2) { if (createStep === 2) {
return ( return (
<div className=""> <div className="px-1 pb-2 pt-1">
<div className="grid grid-cols-12 gap-6 min-h-[500px]"> <div className="grid grid-cols-12 gap-5 min-h-[520px]">
{/* 左侧合成指令(仅两个一级类型,单选) */} {/* 左侧合成指令模板区:占 1/3 宽度 */}
<div className="col-span-4 space-y-4"> <div className="col-span-4 space-y-4">
<Card className="shadow-sm border-0 bg-white"> <Card className="shadow-sm border border-slate-100/80 bg-gradient-to-b from-slate-50/70 via-white to-white">
<h1 className="text-base"></h1> <div className="flex items-center justify-between mb-3">
<div className="space-y-3 mb-4"> <div>
<h1 className="text-sm font-semibold text-slate-900 flex items-center gap-1.5">
<Sparkles className="w-4 h-4 text-amber-500" />
</h1>
<p className="text-[11px] text-slate-500 mt-0.5">
Prompt
</p>
</div>
<Tag color="blue" className="text-[10px] px-2 py-0.5 rounded-full">
</Tag>
</div>
<div className="space-y-3">
<div className="relative"> <div className="relative">
<Search className="w-3 h-3 absolute left-2 top-1/2 transform -translate-y-1/2 text-gray-400" /> <Search className="w-3 h-3 absolute left-2 top-1/2 -translate-y-1/2 text-gray-400" />
<Input placeholder="搜索名称" className="pl-7 text-xs h-8" /> <Input
</div> placeholder="搜索模板名称,如:SFT 问答 / COT 推理"
</div> className="pl-6 text-[11px] h-7 rounded-full bg-slate-50/80 border-slate-100 focus:bg-white"
<div className="space-y-2"> disabled
{synthesisTypes.map((type) => (
<div
key={type.id}
className={`flex items-center gap-2 p-2 rounded-lg cursor-pointer text-xs transition-colors ${
selectedSynthesisTypes.includes(type.id)
? "bg-blue-50 text-blue-700 border border-blue-200"
: "hover:bg-gray-50"
}`}
onClick={() => handleSynthesisTypeSelect(type.id)}
>
<Checkbox
checked={selectedSynthesisTypes.includes(type.id)}
onChange={() => handleSynthesisTypeSelect(type.id)}
/> />
<span className="flex-1">{type.name}</span>
<MoreHorizontal className="w-3 h-3 text-gray-400" />
</div> </div>
))}
<div className="space-y-2 max-h-[420px] overflow-auto pr-1 custom-scrollbar-thin">
{synthesisTemplates.map((tpl) => {
const Icon = tpl.icon;
const active = selectedSynthesisTypes.includes(tpl.type);
return (
<div
key={tpl.id}
onClick={() => handleTemplateClick(tpl)}
className={`group relative rounded-xl border p-2.5 text-xs transition-all duration-200 cursor-pointer bg-white/80 hover:bg-white/100 ${
tpl.borderClass
} ${
active
? "ring-1 ring-offset-1 ring-blue-500/60 border-blue-400/70 shadow-sm bg-gradient-to-r " +
tpl.colorClass
: "border-slate-100 hover:shadow-sm"
}`}
>
<div className="flex items-start gap-2.5">
<div
className={`mt-0.5 flex h-7 w-7 items-center justify-center rounded-full bg-white/60 shadow-sm border ${
active ? "border-blue-200" : "border-slate-100"
}`}
>
<Icon
className={`h-3.5 w-3.5 ${
active
? "text-blue-500 drop-shadow-[0_0_6px_rgba(59,130,246,0.45)]"
: "text-slate-400 group-hover:text-slate-500"
}`}
/>
</div>
<div className="flex-1 min-w-0">
<div className="flex items-center gap-1.5 mb-0.5">
<span
className={`truncate text-[12px] font-medium ${
active ? "text-slate-900" : "text-slate-800"
}`}
>
{tpl.title}
</span>
{tpl.badge && (
<Tag
color={tpl.type === "qa" ? "processing" : "purple"}
className="text-[10px] px-1.5 py-0 h-4 flex items-center rounded-full"
>
{tpl.badge}
</Tag>
)}
</div>
<p className="text-[11px] text-slate-500 leading-snug truncate">
{tpl.subtitle}
</p>
<p className="mt-1 text-[11px] text-slate-400 leading-snug line-clamp-2">
{tpl.description}
</p>
</div>
</div>
<div className="absolute inset-y-2 right-1 flex items-center">
<Tooltip title={active ? "当前已选模板" : "点击应用此模板"}>
<div
className={`flex h-5 w-5 items-center justify-center rounded-full border text-[10px] transition-colors ${
active
? "bg-blue-500 text-white border-blue-500 shadow-sm"
: "bg-white/70 text-slate-300 border-slate-100 group-hover:text-slate-400"
}`}
>
{active ? "✓" : ""}
</div>
</Tooltip>
</div>
</div>
);
})}
</div>
</div> </div>
</Card> </Card>
</div> </div>
{/* 右侧合成配置 */} {/* 右侧合成配置:占 2/3 宽度 */}
<div className="col-span-8"> <div className="col-span-8">
<Card className="h-full shadow-sm border-0 bg-white"> <Card className="h-full shadow-sm border border-slate-100/80 bg-gradient-to-b from-white via-slate-50/60 to-white">
<div className="flex items-center justify-between"> <div className="flex items-center justify-between mb-3">
<h1></h1> <div>
<h1 className="text-sm font-semibold text-slate-900 flex items-center gap-1.5">
<Layers className="w-4 h-4 text-indigo-500" />
</h1>
<p className="text-[11px] text-slate-500 mt-0.5">
</p>
</div>
<div className="flex items-center gap-2"> <div className="flex items-center gap-2">
<Button className="hover:bg-white text-xs" type="default"> <Tooltip title="在正式创建任务前,先小批量运行验证效果">
<Button size="small" className="hover:bg-white text-[11px]" type="default">
<Eye className="w-3 h-3 mr-1" /> <Eye className="w-3 h-3 mr-1" />
</Button> </Button>
</Tooltip>
</div> </div>
</div> </div>
<div className="space-y-4"> <div className="space-y-4">
{/* 切片配置 */} {/* 步骤说明条 */}
<Card className="shadow-sm border"> <div className="flex items-center gap-3 px-3 py-2 rounded-lg bg-slate-50 border border-slate-100 text-[11px] text-slate-500">
<span className="inline-flex items-center justify-center w-5 h-5 rounded-full bg-indigo-600 text-white text-[10px] font-semibold">1</span>
<span></span>
<span className="text-slate-300">/</span>
<span className="inline-flex items-center justify-center w-5 h-5 rounded-full bg-indigo-600 text-white text-[10px] font-semibold">2</span>
<span></span>
<span className="text-slate-300">/</span>
<span className="inline-flex items-center justify-center w-5 h-5 rounded-full bg-indigo-600 text-white text-[10px] font-semibold">3</span>
<span></span>
<span className="text-slate-300">/</span>
<span className="inline-flex items-center justify-center w-5 h-5 rounded-full bg-indigo-600 text-white text-[10px] font-semibold">4</span>
<span></span>
</div>
{/* 1. 合成总数配置 */}
<div className="rounded-xl bg-white/90 border border-slate-100 px-4 py-3 shadow-[0_0_0_1px_rgba(148,163,184,0.12)]">
<div className="flex items-center justify-between mb-2">
<div className="flex items-center gap-2">
<span className="inline-flex items-center justify-center w-5 h-5 rounded-full bg-indigo-600 text-white text-[10px] font-semibold">1</span>
<span className="text-[12px] font-medium text-slate-800"></span>
</div>
<span className="text-[10px] text-slate-400"> QA </span>
</div>
<div className="flex items-center gap-3">
<InputNumber
className="w-40"
min={1}
max={100000}
size="small"
value={maxQaPairs}
placeholder="不填则不限制"
onChange={(v) => setMaxQaPairs(typeof v === "number" ? v : undefined)}
/>
<span className="text-[11px] text-slate-400"></span>
</div>
</div>
{/* 2. 文本切片配置 */}
<div className="rounded-xl bg-white/90 border border-slate-100 px-4 py-3 shadow-[0_0_0_1px_rgba(148,163,184,0.12)]">
<div className="flex items-center justify-between mb-2">
<div className="flex items-center gap-2">
<span className="inline-flex items-center justify-center w-5 h-5 rounded-full bg-indigo-600 text-white text-[10px] font-semibold">2</span>
<span className="text-[12px] font-medium text-slate-800"></span>
</div>
<span className="text-[10px] text-slate-400"></span>
</div>
<div className="grid grid-cols-3 gap-3"> <div className="grid grid-cols-3 gap-3">
<div> <div>
<span className="text-xs font-medium text-gray-600"></span> <span className="text-[11px] font-medium text-gray-600"></span>
<Select <Select
className="mt-1 w-full"
options={sliceOptions} options={sliceOptions}
value={sliceConfig.processType} value={sliceConfig.processType}
onChange={(v) => setSliceConfig((p) => ({ ...p, processType: v }))} onChange={(v) => setSliceConfig((p) => ({ ...p, processType: v }))}
size="small"
/> />
</div> </div>
<div> <div>
<span className="text-xs font-medium text-gray-600"></span> <span className="text-[11px] font-medium text-gray-600"></span>
<Input <Input
className="mt-1"
type="number" type="number"
min={1} min={1}
value={sliceConfig.chunkSize} value={sliceConfig.chunkSize}
onChange={(e) => setSliceConfig((p) => ({ ...p, chunkSize: Number(e.target.value) }))} onChange={(e) => setSliceConfig((p) => ({ ...p, chunkSize: Number(e.target.value) }))}
size="small"
/> />
</div> </div>
<div> <div>
<span className="text-xs font-medium text-gray-600"></span> <span className="text-[11px] font-medium text-gray-600"></span>
<Input <Input
className="mt-1"
type="number" type="number"
min={0} min={0}
value={sliceConfig.overlapSize} value={sliceConfig.overlapSize}
onChange={(e) => setSliceConfig((p) => ({ ...p, overlapSize: Number(e.target.value) }))} onChange={(e) => setSliceConfig((p) => ({ ...p, overlapSize: Number(e.target.value) }))}
size="small"
/> />
</div> </div>
</div> </div>
{sliceConfig.processType === "CUSTOM_SEPARATOR_CHUNK" && ( {sliceConfig.processType === "CUSTOM_SEPARATOR_CHUNK" && (
<div className="mt-3"> <div className="mt-3">
<span className="text-xs font-medium text-gray-600"></span> <span className="text-[11px] font-medium text-gray-600"></span>
<Input <Input
className="mt-1"
placeholder={"例如:\\n\\n 或 ###"} placeholder={"例如:\\n\\n 或 ###"}
value={sliceConfig.delimiter} value={sliceConfig.delimiter}
onChange={(e) => setSliceConfig((p) => ({ ...p, delimiter: e.target.value }))} onChange={(e) => setSliceConfig((p) => ({ ...p, delimiter: e.target.value }))}
size="small"
/> />
</div> </div>
)} )}
</Card> </div>
{/* 模型选择 */} {/* 3. 问题合成配置 */}
<Card className="shadow-sm border"> <div className="rounded-xl bg-white/90 border border-slate-100 px-4 py-3 shadow-[0_0_0_1px_rgba(148,163,184,0.12)]">
<span className="text-xs font-medium text-gray-600"></span> <div className="flex items-center justify-between mb-2">
<div className="flex items-center gap-2">
<span className="inline-flex items-center justify-center w-5 h-5 rounded-full bg-indigo-600 text-white text-[10px] font-semibold">3</span>
<span className="text-[12px] font-medium text-slate-800"></span>
</div>
<span className="text-[10px] text-slate-400"> chunk </span>
</div>
<div className="grid grid-cols-12 gap-3 mb-3">
<div className="col-span-4">
<div className="flex flex-col gap-0.5">
<span className="text-[11px] font-medium text-gray-600"></span>
<span className="text-[10px] text-slate-400">tokens生成的问题条数</span>
</div>
<InputNumber
className="mt-1 w-full"
min={1}
max={20}
size="small"
value={questionConfig.number}
onChange={(v) =>
setQuestionConfig((p) => ({ ...p, number: typeof v === "number" ? v : 1 }))
}
/>
</div>
<div className="col-span-4">
<div className="flex flex-col gap-0.5">
<span className="text-[11px] font-medium text-gray-600"> (Temperature)</span>
<span className="text-[10px] text-slate-400"></span>
</div>
<InputNumber
className="mt-1 w-full"
min={0}
max={2}
step={0.1}
size="small"
value={questionConfig.temperature}
onChange={(v) =>
setQuestionConfig((p) => ({
...p,
temperature: typeof v === "number" ? v : 0.7,
}))
}
/>
</div>
<div className="col-span-4">
<div className="flex flex-col gap-0.5">
<span className="text-[11px] font-medium text-gray-600">使</span>
<span className="text-[10px] text-slate-400"></span>
</div>
<Select <Select
placeholder="选择模型" className="mt-1 w-full"
size="small"
options={modelOptions} options={modelOptions}
loading={modelsLoading} loading={modelsLoading}
value={selectedModel} value={questionModelId}
onChange={(value) => setSelectedModel(value)} onChange={(v) => setQuestionModelId(v)}
/> />
</Card> </div>
</div>
{/* Prompt 配置 */} <span className="text-[11px] font-medium text-gray-600"> Prompt </span>
<Card className="shadow-sm border"> <p className="mt-0.5 text-[10px] text-slate-400">
<span className="text-xs font-medium text-gray-600">Prompt </span>
</p>
<TextArea <TextArea
value={promptTemplate} value={questionPrompt}
onChange={(e) => setPromptTemplate(e.target.value)} onChange={(e) => setQuestionPrompt(e.target.value)}
rows={8} rows={6}
className="resize-none text-xs font-mono" className="mt-1 resize-none text-[11px] font-mono rounded-lg border-slate-200 bg-slate-50/60 hover:bg-slate-50 focus:bg-white"
placeholder={taskType === "qa" ? "正在加载 QA 提示词模板..." : "正在加载 COT 提示词模板..."} placeholder={
taskType === "qa"
? "将根据 SFT 问答合成场景预填问题生成 Prompt,可按需微调"
: "将根据 COT 推理合成场景预填问题生成 Prompt,可按需微调"
}
/> />
</Card> </div>
{/* 4. 答案合成配置 */}
<div className="rounded-xl bg-white/90 border border-slate-100 px-4 py-3 shadow-[0_0_0_1px_rgba(148,163,184,0.12)]">
<div className="flex items-center justify-between mb-2">
<div className="flex items-center gap-2">
<span className="inline-flex items-center justify-center w-5 h-5 rounded-full bg-indigo-600 text-white text-[10px] font-semibold">4</span>
<span className="text-[12px] font-medium text-slate-800"></span>
</div>
<span className="text-[10px] text-slate-400"></span>
</div>
<div className="grid grid-cols-12 gap-3 mb-3">
<div className="col-span-4">
<div className="flex flex-col gap-0.5">
<span className="text-[11px] font-medium text-gray-600"> (Temperature)</span>
<span className="text-[10px] text-slate-400"></span>
</div>
<InputNumber
className="mt-1 w-full"
min={0}
max={2}
step={0.1}
size="small"
value={answerConfig.temperature}
onChange={(v) =>
setAnswerConfig((p) => ({
...p,
temperature: typeof v === "number" ? v : 0.7,
}))
}
/>
</div>
<div className="col-span-4">
<div className="flex flex-col gap-0.5">
<span className="text-[11px] font-medium text-gray-600">使</span>
<span className="text-[10px] text-slate-400"></span>
</div>
<Select
className="mt-1 w-full"
size="small"
options={modelOptions}
loading={modelsLoading}
value={answerModelId}
onChange={(v) => setAnswerModelId(v)}
/>
</div>
</div>
<span className="text-[11px] font-medium text-gray-600"> Prompt </span>
<p className="mt-0.5 text-[10px] text-slate-400">
</p>
<TextArea
value={answerPrompt}
onChange={(e) => setAnswerPrompt(e.target.value)}
rows={6}
className="mt-1 resize-none text-[11px] font-mono rounded-lg border-slate-200 bg-slate-50/60 hover:bg-slate-50 focus:bg-white"
placeholder={
taskType === "qa"
? "将根据 SFT 问答合成场景预填答案生成 Prompt,可按需微调"
: "将根据 COT 推理合成场景预填答案生成 Prompt,可按需微调"
}
/>
</div>
</div> </div>
{/* 页面底部统一操作条渲染,不在此处放置按钮 */} {/* 页面底部统一操作条渲染,不在此处放置按钮 */}
@@ -384,7 +711,7 @@ export default function SynthesisTaskCreate() {
</Link> </Link>
<h1 className="text-xl font-bold bg-clip-text"></h1> <h1 className="text-xl font-bold bg-clip-text"></h1>
</div> </div>
<Steps current={createStep - 1} size="small" items={[{ title: "基本信息" }, { title: "算子编排" }]} style={{ width: "50%", marginLeft: "auto" }} /> <Steps current={createStep - 1} size="small" items={[{ title: "基本信息" }, { title: "合成编排" }]} style={{ width: "50%", marginLeft: "auto" }} />
</div> </div>
<div className="border-card flex-overflow-auto"> <div className="border-card flex-overflow-auto">
{renderCreateTaskPage()} {renderCreateTaskPage()}
@@ -419,7 +746,8 @@ export default function SynthesisTaskCreate() {
!form.getFieldValue("name") || !form.getFieldValue("name") ||
!selectedDataset || !selectedDataset ||
selectedFiles.length === 0 || selectedFiles.length === 0 ||
!selectedModel !questionModelId ||
!answerModelId
} }
loading={submitting} loading={submitting}
className="px-6 py-2 text-sm font-semibold bg-purple-600 hover:bg-purple-700 shadow-lg" className="px-6 py-2 text-sm font-semibold bg-purple-600 hover:bg-purple-700 shadow-lg"

View File

@@ -1,66 +1,65 @@
import uuid import uuid
from xml.etree.ElementTree import tostring
from sqlalchemy import Column, String, Text, Integer, JSON, TIMESTAMP, ForeignKey, func from sqlalchemy import Column, String, Text, Integer, JSON, TIMESTAMP, func
from sqlalchemy.orm import relationship
from app.db.session import Base from app.db.session import Base
from app.module.generation.schema.generation import CreateSynthesisTaskRequest from app.module.generation.schema.generation import CreateSynthesisTaskRequest
async def save_synthesis_task(db_session, synthesis_task: CreateSynthesisTaskRequest): async def save_synthesis_task(db_session, synthesis_task: CreateSynthesisTaskRequest):
"""保存数据合成任务。""" """保存数据合成任务。
# 转换为模型实例
注意:当前 MySQL 表 `t_data_synth_instances` 结构中只包含 synth_type / synth_config 等字段,
没有 model_id、text_split_config、source_file_id、result_data_location 等列,因此这里只保存
与表结构一致的字段,其他信息由上层逻辑或其它表负责管理。
"""
gid = str(uuid.uuid4()) gid = str(uuid.uuid4())
synthesis_task_instance = DataSynthesisInstance(
# 兼容旧请求结构:从请求对象中提取必要字段,
# - 合成类型:synthesis_type -> synth_type
# - 合成配置:text_split_config + synthesis_config 合并后写入 synth_config
synth_task_instance = DataSynthInstance(
id=gid, id=gid,
name=synthesis_task.name, name=synthesis_task.name,
description=synthesis_task.description, description=synthesis_task.description,
status="pending", status="pending",
model_id=synthesis_task.model_id, synth_type=synthesis_task.synthesis_type.value,
synthesis_type=synthesis_task.synthesis_type.value,
progress=0, progress=0,
result_data_location=f"/dataset/synthesis_results/{gid}/", synth_config=synthesis_task.synth_config.model_dump(),
text_split_config=synthesis_task.text_split_config.model_dump(), total_files=len(synthesis_task.source_file_id or []),
synthesis_config=synthesis_task.synthesis_config.model_dump(),
source_file_id=synthesis_task.source_file_id,
total_files=len(synthesis_task.source_file_id),
processed_files=0, processed_files=0,
total_chunks=0, total_chunks=0,
processed_chunks=0, processed_chunks=0,
total_synthesis_data=0, total_synth_data=0,
created_at=func.now(), created_at=func.now(),
updated_at=func.now(), updated_at=func.now(),
created_by="system", created_by="system",
updated_by="system" updated_by="system",
) )
db_session.add(synthesis_task_instance) db_session.add(synth_task_instance)
await db_session.commit() await db_session.commit()
await db_session.refresh(synthesis_task_instance) await db_session.refresh(synth_task_instance)
return synthesis_task_instance return synth_task_instance
class DataSynthesisInstance(Base): class DataSynthInstance(Base):
"""数据合成任务表,对应表 t_data_synthesis_instances """数据合成任务表,对应表 t_data_synth_instances
create table if not exists t_data_synthesis_instances create table if not exists t_data_synth_instances
( (
id VARCHAR(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci PRIMARY KEY COMMENT 'UUID', id VARCHAR(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci PRIMARY KEY COMMENT 'UUID',
name VARCHAR(255) NOT NULL COMMENT '任务名称', name VARCHAR(255) NOT NULL COMMENT '任务名称',
description TEXT COMMENT '任务描述', description TEXT COMMENT '任务描述',
status VARCHAR(20) COMMENT '任务状态', status VARCHAR(20) COMMENT '任务状态',
synthesis_type VARCHAR(20) NOT NULL COMMENT '合成类型', synth_type VARCHAR(20) NOT NULL COMMENT '合成类型',
model_id VARCHAR(255) NOT NULL COMMENT '模型ID',
progress INT DEFAULT 0 COMMENT '任务进度(百分比)', progress INT DEFAULT 0 COMMENT '任务进度(百分比)',
result_data_location VARCHAR(1000) COMMENT '结果数据存储位', synth_config JSON NOT NULL COMMENT '合成配',
text_split_config JSON NOT NULL COMMENT '文本切片配置',
synthesis_config JSON NOT NULL COMMENT '合成配置',
source_file_id JSON NOT NULL COMMENT '原始文件ID列表',
total_files INT DEFAULT 0 COMMENT '总文件数', total_files INT DEFAULT 0 COMMENT '总文件数',
processed_files INT DEFAULT 0 COMMENT '已处理文件数', processed_files INT DEFAULT 0 COMMENT '已处理文件数',
total_chunks INT DEFAULT 0 COMMENT '总文本块数', total_chunks INT DEFAULT 0 COMMENT '总文本块数',
processed_chunks INT DEFAULT 0 COMMENT '已处理文本块数', processed_chunks INT DEFAULT 0 COMMENT '已处理文本块数',
total_synthesis_data INT DEFAULT 0 COMMENT '总合成数据量', total_synth_data INT DEFAULT 0 COMMENT '总合成数据量',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间', updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
created_by VARCHAR(255) COMMENT '创建者', created_by VARCHAR(255) COMMENT '创建者',
@@ -68,27 +67,29 @@ class DataSynthesisInstance(Base):
) COMMENT='数据合成任务表(UUID 主键)'; ) COMMENT='数据合成任务表(UUID 主键)';
""" """
__tablename__ = "t_data_synthesis_instances" __tablename__ = "t_data_synth_instances"
id = Column(String(36), primary_key=True, index=True, comment="UUID") id = Column(String(36), primary_key=True, index=True, comment="UUID")
name = Column(String(255), nullable=False, comment="任务名称") name = Column(String(255), nullable=False, comment="任务名称")
description = Column(Text, nullable=True, comment="任务描述") description = Column(Text, nullable=True, comment="任务描述")
status = Column(String(20), nullable=True, comment="任务状态") status = Column(String(20), nullable=True, comment="任务状态")
synthesis_type = Column(String(20), nullable=False, comment="合成类型") # 与数据库字段保持一致:synth_type / synth_config
model_id = Column(String(255), nullable=False, comment="模型ID") synth_type = Column(String(20), nullable=False, comment="合成类型")
progress = Column(Integer, nullable=False, default=0, comment="任务进度(百分比)") progress = Column(Integer, nullable=False, default=0, comment="任务进度(百分比)")
result_data_location = Column(String(1000), nullable=True, comment="结果数据存储位") synth_config = Column(JSON, nullable=False, comment="合成配")
text_split_config = Column(JSON, nullable=False, comment="文本切片配置")
synthesis_config = Column(JSON, nullable=False, comment="合成配置")
source_file_id = Column(JSON, nullable=False, comment="原始文件ID列表")
total_files = Column(Integer, nullable=False, default=0, comment="总文件数") total_files = Column(Integer, nullable=False, default=0, comment="总文件数")
processed_files = Column(Integer, nullable=False, default=0, comment="已处理文件数") processed_files = Column(Integer, nullable=False, default=0, comment="已处理文件数")
total_chunks = Column(Integer, nullable=False, default=0, comment="总文本块数") total_chunks = Column(Integer, nullable=False, default=0, comment="总文本块数")
processed_chunks = Column(Integer, nullable=False, default=0, comment="已处理文本块数") processed_chunks = Column(Integer, nullable=False, default=0, comment="已处理文本块数")
total_synthesis_data = Column(Integer, nullable=False, default=0, comment="总合成数据量") total_synth_data = Column(Integer, nullable=False, default=0, comment="总合成数据量")
created_at = Column(TIMESTAMP, nullable=False, default=func.now(), comment="创建时间")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), nullable=True, comment="创建时间") updated_at = Column(
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), nullable=True, comment="更新时间") TIMESTAMP,
nullable=False,
default=func.now(),
onupdate=func.now(),
comment="更新时间",
)
created_by = Column(String(255), nullable=True, comment="创建者") created_by = Column(String(255), nullable=True, comment="创建者")
updated_by = Column(String(255), nullable=True, comment="更新者") updated_by = Column(String(255), nullable=True, comment="更新者")
@@ -123,7 +124,7 @@ class DataSynthesisFileInstance(Base):
) )
file_name = Column(String(255), nullable=False, comment="文件名") file_name = Column(String(255), nullable=False, comment="文件名")
source_file_id = Column(String(255), nullable=False, comment="原始文件ID") source_file_id = Column(String(255), nullable=False, comment="原始文件ID")
target_file_location = Column(String(1000), nullable=False, comment="目标文件存储位置") target_file_location = Column(String(1000), nullable=True, comment="目标文件存储位置")
status = Column(String(20), nullable=True, comment="任务状态") status = Column(String(20), nullable=True, comment="任务状态")
total_chunks = Column(Integer, nullable=False, default=0, comment="总文本块数") total_chunks = Column(Integer, nullable=False, default=0, comment="总文本块数")
processed_chunks = Column(Integer, nullable=False, default=0, comment="已处理文本块数") processed_chunks = Column(Integer, nullable=False, default=0, comment="已处理文本块数")

View File

@@ -13,7 +13,7 @@ from app.db.models.data_synthesis import DataSynthesisFileInstance, SynthesisDat
from app.db.session import AsyncSessionLocal from app.db.session import AsyncSessionLocal
from app.module.evaluation.schema.evaluation import SourceType from app.module.evaluation.schema.evaluation import SourceType
from app.module.shared.schema import TaskStatus from app.module.shared.schema import TaskStatus
from app.module.shared.util.model_chat import call_openai_style_model, _extract_json_substring from app.module.shared.util.model_chat import call_openai_style_model, extract_json_substring
from app.module.evaluation.schema.prompt import get_prompt from app.module.evaluation.schema.prompt import get_prompt
from app.module.shared.util.structured_file import StructuredFileHandlerFactory from app.module.shared.util.structured_file import StructuredFileHandlerFactory
from app.module.system.service.common_service import get_model_by_id from app.module.system.service.common_service import get_model_by_id
@@ -36,8 +36,8 @@ class EvaluationExecutor:
.replace("{question}", eval_content.get("instruction"))) .replace("{question}", eval_content.get("instruction")))
.replace("{answer}", eval_content.get("output"))) .replace("{answer}", eval_content.get("output")))
if self.task.task_type == "COT": if self.task.task_type == "COT":
prompt_text = ((prompt_text.replace("{question}", eval_content.get("question")) prompt_text = ((prompt_text.replace("{question}", eval_content.get("instruction"))
.replace("{conclusion}", eval_content.get("conclusion"))) .replace("{conclusion}", eval_content.get("output")))
.replace("{chain_of_thought}", eval_content.get("chain_of_thought"))) .replace("{chain_of_thought}", eval_content.get("chain_of_thought")))
return prompt_text return prompt_text
@@ -73,7 +73,7 @@ class EvaluationExecutor:
call_openai_style_model, model_config.base_url, model_config.api_key, model_config.model_name, call_openai_style_model, model_config.base_url, model_config.api_key, model_config.model_name,
prompt_text, prompt_text,
) )
resp_text = _extract_json_substring(resp_text) resp_text = extract_json_substring(resp_text)
try: try:
json.loads(resp_text) json.loads(resp_text)
except Exception as e: except Exception as e:

View File

@@ -1,4 +1,5 @@
import uuid import uuid
from typing import cast
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
from sqlalchemy import select, func, delete from sqlalchemy import select, func, delete
@@ -7,13 +8,12 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger from app.core.logging import get_logger
from app.db.models.data_synthesis import ( from app.db.models.data_synthesis import (
save_synthesis_task, save_synthesis_task,
DataSynthesisInstance, DataSynthInstance,
DataSynthesisFileInstance, DataSynthesisFileInstance,
DataSynthesisChunkInstance, DataSynthesisChunkInstance,
SynthesisData, SynthesisData,
) )
from app.db.models.dataset_management import DatasetFiles from app.db.models.dataset_management import DatasetFiles
from app.db.models.model_config import get_model_by_id
from app.db.session import get_db from app.db.session import get_db
from app.module.generation.schema.generation import ( from app.module.generation.schema.generation import (
CreateSynthesisTaskRequest, CreateSynthesisTaskRequest,
@@ -28,9 +28,9 @@ from app.module.generation.schema.generation import (
SynthesisDataUpdateRequest, SynthesisDataUpdateRequest,
BatchDeleteSynthesisDataRequest, BatchDeleteSynthesisDataRequest,
) )
from app.module.generation.service.export_service import SynthesisDatasetExporter, SynthesisExportError
from app.module.generation.service.generation_service import GenerationService from app.module.generation.service.generation_service import GenerationService
from app.module.generation.service.prompt import get_prompt from app.module.generation.service.prompt import get_prompt
from app.module.generation.service.export_service import SynthesisDatasetExporter, SynthesisExportError
from app.module.shared.schema import StandardResponse from app.module.shared.schema import StandardResponse
router = APIRouter( router = APIRouter(
@@ -47,10 +47,6 @@ async def create_synthesis_task(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
"""创建数据合成任务""" """创建数据合成任务"""
result = await get_model_by_id(db, request.model_id)
if not result:
raise HTTPException(status_code=404, detail="Model not found")
# 先根据 source_file_id 在 DatasetFiles 中查出已有文件信息 # 先根据 source_file_id 在 DatasetFiles 中查出已有文件信息
file_ids = request.source_file_id or [] file_ids = request.source_file_id or []
dataset_files = [] dataset_files = []
@@ -65,32 +61,48 @@ async def create_synthesis_task(
synthesis_task = await save_synthesis_task(db, request) synthesis_task = await save_synthesis_task(db, request)
# 将已有的 DatasetFiles 记录保存到 t_data_synthesis_file_instances # 将已有的 DatasetFiles 记录保存到 t_data_synthesis_file_instances
synth_files = []
for f in dataset_files: for f in dataset_files:
file_instance = DataSynthesisFileInstance( file_instance = DataSynthesisFileInstance(
id=str(uuid.uuid4()), # 使用新的 UUID 作为文件任务记录的主键,避免与 DatasetFiles 主键冲突 id=str(uuid.uuid4()), # 使用新的 UUID 作为文件任务记录的主键,避免与 DatasetFiles 主键冲突
synthesis_instance_id=synthesis_task.id, synthesis_instance_id=synthesis_task.id,
file_name=f.file_name, file_name=f.file_name,
source_file_id=str(f.id), source_file_id=str(f.id),
target_file_location=synthesis_task.result_data_location or "",
status="pending", status="pending",
total_chunks=0, total_chunks=0,
processed_chunks=0, processed_chunks=0,
created_by="system", created_by="system",
updated_by="system", updated_by="system",
) )
db.add(file_instance) synth_files.append(file_instance)
if dataset_files: if dataset_files:
db.add_all(synth_files)
await db.commit() await db.commit()
generation_service = GenerationService(db) generation_service = GenerationService(db)
# 异步处理任务:只传任务 ID,后台任务中使用新的 DB 会话重新加载任务对象 # 异步处理任务:只传任务 ID,后台任务中使用新的 DB 会话重新加载任务对象
background_tasks.add_task(generation_service.process_task, synthesis_task.id) background_tasks.add_task(generation_service.process_task, synthesis_task.id)
# 将 ORM 对象包装成 DataSynthesisTaskItem,兼容新字段从 synth_config 还原
task_item = DataSynthesisTaskItem(
id=synthesis_task.id,
name=synthesis_task.name,
description=synthesis_task.description,
status=synthesis_task.status,
synthesis_type=synthesis_task.synth_type,
total_files=synthesis_task.total_files,
created_at=synthesis_task.created_at,
updated_at=synthesis_task.updated_at,
created_by=synthesis_task.created_by,
updated_by=synthesis_task.updated_by,
)
return StandardResponse( return StandardResponse(
code=200, code=200,
message="success", message="success",
data=synthesis_task, data=task_item,
) )
@@ -100,14 +112,26 @@ async def get_synthesis_task(
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
"""获取数据合成任务详情""" """获取数据合成任务详情"""
result = await db.get(DataSynthesisInstance, task_id) synthesis_task = await db.get(DataSynthInstance, task_id)
if not result: if not synthesis_task:
raise HTTPException(status_code=404, detail="Synthesis task not found") raise HTTPException(status_code=404, detail="Synthesis task not found")
task_item = DataSynthesisTaskItem(
id=synthesis_task.id,
name=synthesis_task.name,
description=synthesis_task.description,
status=synthesis_task.status,
synthesis_type=synthesis_task.synth_type,
total_files=synthesis_task.total_files,
created_at=synthesis_task.created_at,
updated_at=synthesis_task.updated_at,
created_by=synthesis_task.created_by,
updated_by=synthesis_task.updated_by,
)
return StandardResponse( return StandardResponse(
code=200, code=200,
message="success", message="success",
data=result, data=task_item,
) )
@@ -121,16 +145,16 @@ async def list_synthesis_tasks(
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
"""分页列出所有数据合成任务,默认按创建时间倒序""" """分页列出所有数据合成任务,默认按创建时间倒序"""
query = select(DataSynthesisInstance) query = select(DataSynthInstance)
if synthesis_type: if synthesis_type:
query = query.filter(DataSynthesisInstance.synthesis_type == synthesis_type) query = query.filter(DataSynthInstance.synth_type == synthesis_type)
if status: if status:
query = query.filter(DataSynthesisInstance.status == status) query = query.filter(DataSynthInstance.status == status)
if name: if name:
query = query.filter(DataSynthesisInstance.name.like(f"%{name}%")) query = query.filter(DataSynthInstance.name.like(f"%{name}%"))
# 默认按创建时间倒序排列 # 默认按创建时间倒序排列
query = query.order_by(DataSynthesisInstance.created_at.desc()) query = query.order_by(DataSynthInstance.created_at.desc())
count_q = select(func.count()).select_from(query.subquery()) count_q = select(func.count()).select_from(query.subquery())
total = (await db.execute(count_q)).scalar_one() total = (await db.execute(count_q)).scalar_one()
@@ -143,31 +167,39 @@ async def list_synthesis_tasks(
result = await db.execute(query.offset((page - 1) * page_size).limit(page_size)) result = await db.execute(query.offset((page - 1) * page_size).limit(page_size))
rows = result.scalars().all() rows = result.scalars().all()
task_items = [ task_items: list[DataSynthesisTaskItem] = []
for row in rows:
synth_cfg = getattr(row, "synth_config", {}) or {}
text_split_cfg = synth_cfg.get("text_split_config") or {}
synthesis_cfg = synth_cfg.get("synthesis_config") or {}
source_file_ids = synth_cfg.get("source_file_id") or []
model_id = synth_cfg.get("model_id")
result_location = synth_cfg.get("result_data_location")
task_items.append(
DataSynthesisTaskItem( DataSynthesisTaskItem(
id=row.id, id=str(row.id),
name=row.name, name=str(row.name),
description=row.description, description=cast(str | None, row.description),
status=row.status, status=cast(str | None, row.status),
synthesis_type=row.synthesis_type, synthesis_type=str(row.synth_type),
model_id=row.model_id, model_id=model_id or "",
progress=row.progress, progress=int(cast(int, row.progress)),
result_data_location=row.result_data_location, result_data_location=result_location,
text_split_config=row.text_split_config, text_split_config=text_split_cfg,
synthesis_config=row.synthesis_config, synthesis_config=synthesis_cfg,
source_file_id=row.source_file_id, source_file_id=list(source_file_ids),
total_files=row.total_files, total_files=int(cast(int, row.total_files)),
processed_files=row.processed_files, processed_files=int(cast(int, row.processed_files)),
total_chunks=row.total_chunks, total_chunks=int(cast(int, row.total_chunks)),
processed_chunks=row.processed_chunks, processed_chunks=int(cast(int, row.processed_chunks)),
total_synthesis_data=row.total_synthesis_data, total_synthesis_data=int(cast(int, row.total_synth_data)),
created_at=row.created_at, created_at=row.created_at,
updated_at=row.updated_at, updated_at=row.updated_at,
created_by=row.created_by, created_by=row.created_by,
updated_by=row.updated_by, updated_by=row.updated_by,
) )
for row in rows )
]
paged = PagedDataSynthesisTaskResponse( paged = PagedDataSynthesisTaskResponse(
content=task_items, content=task_items,
@@ -190,7 +222,7 @@ async def delete_synthesis_task(
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
"""删除数据合成任务""" """删除数据合成任务"""
task = await db.get(DataSynthesisInstance, task_id) task = await db.get(DataSynthInstance, task_id)
if not task: if not task:
raise HTTPException(status_code=404, detail="Synthesis task not found") raise HTTPException(status_code=404, detail="Synthesis task not found")
@@ -241,7 +273,7 @@ async def delete_synthesis_file_task(
): ):
"""删除数据合成任务中的文件任务,同时刷新任务表中的文件/切片数量""" """删除数据合成任务中的文件任务,同时刷新任务表中的文件/切片数量"""
# 先获取任务和文件任务记录 # 先获取任务和文件任务记录
task = await db.get(DataSynthesisInstance, task_id) task = await db.get(DataSynthInstance, task_id)
if not task: if not task:
raise HTTPException(status_code=404, detail="Synthesis task not found") raise HTTPException(status_code=404, detail="Synthesis task not found")
@@ -306,7 +338,7 @@ async def list_synthesis_file_tasks(
): ):
"""分页获取某个数据合成任务下的文件任务列表""" """分页获取某个数据合成任务下的文件任务列表"""
# 先校验任务是否存在 # 先校验任务是否存在
task = await db.get(DataSynthesisInstance, task_id) task = await db.get(DataSynthInstance, task_id)
if not task: if not task:
raise HTTPException(status_code=404, detail="Synthesis task not found") raise HTTPException(status_code=404, detail="Synthesis task not found")
@@ -333,7 +365,6 @@ async def list_synthesis_file_tasks(
synthesis_instance_id=row.synthesis_instance_id, synthesis_instance_id=row.synthesis_instance_id,
file_name=row.file_name, file_name=row.file_name,
source_file_id=row.source_file_id, source_file_id=row.source_file_id,
target_file_location=row.target_file_location,
status=row.status, status=row.status,
total_chunks=row.total_chunks, total_chunks=row.total_chunks,
processed_chunks=row.processed_chunks, processed_chunks=row.processed_chunks,
@@ -523,7 +554,7 @@ async def delete_synthesis_data_by_chunk(
result = await db.execute( result = await db.execute(
delete(SynthesisData).where(SynthesisData.chunk_instance_id == chunk_id) delete(SynthesisData).where(SynthesisData.chunk_instance_id == chunk_id)
) )
deleted = result.rowcount or 0 deleted = int(getattr(result, "rowcount", 0) or 0)
await db.commit() await db.commit()
@@ -542,7 +573,7 @@ async def batch_delete_synthesis_data(
result = await db.execute( result = await db.execute(
delete(SynthesisData).where(SynthesisData.id.in_(request.ids)) delete(SynthesisData).where(SynthesisData.id.in_(request.ids))
) )
deleted = result.rowcount or 0 deleted = int(getattr(result, "rowcount", 0) or 0)
await db.commit() await db.commit()
return StandardResponse(code=200, message="success", data=deleted) return StandardResponse(code=200, message="success", data=deleted)

View File

@@ -11,33 +11,45 @@ class TextSplitConfig(BaseModel):
chunk_overlap: int = Field(..., description="重叠令牌数") chunk_overlap: int = Field(..., description="重叠令牌数")
class SynthesisConfig(BaseModel): class SyntheConfig(BaseModel):
"""合成配置""" """合成配置"""
prompt_template: str = Field(..., description="合成提示模板") model_id: str = Field(..., description="模型ID")
synthesis_count: int = Field(None, description="单个chunk合成的数据数量") prompt_template: str = Field(None, description="合成提示模板")
number: Optional[int] = Field(None, description="单个chunk合成的数据数量")
temperature: Optional[float] = Field(None, description="温度参数") temperature: Optional[float] = Field(None, description="温度参数")
class Config(BaseModel):
"""配置"""
text_split_config: TextSplitConfig = Field(None, description="文本切片配置")
question_synth_config: SyntheConfig = Field(None, description="问题合成配置")
answer_synth_config: SyntheConfig = Field(None, description="答案合成配置")
# 新增:整个任务允许生成的 QA 总上限(问题/答案对数量)
max_qa_pairs: Optional[int] = Field(
default=None,
description="整个任务允许生成的 QA 对总量上限;为 None 或 <=0 表示不限制",
)
class SynthesisType(Enum): class SynthesisType(Enum):
"""合成类型""" """合成类型"""
QA = "QA" QA = "QA"
COT = "COT" COT = "COT"
QUESTION = "QUESTION"
class CreateSynthesisTaskRequest(BaseModel): class CreateSynthesisTaskRequest(BaseModel):
"""创建数据合成任务请求""" """创建数据合成任务请求"""
name: str = Field(..., description="合成任务名称") name: str = Field(..., description="合成任务名称")
description: Optional[str] = Field(None, description="合成任务描述") description: Optional[str] = Field(None, description="合成任务描述")
model_id: str = Field(..., description="模型ID")
source_file_id: list[str] = Field(..., description="原始文件ID列表")
text_split_config: TextSplitConfig = Field(None, description="文本切片配置")
synthesis_config: SynthesisConfig = Field(..., description="合成配置")
synthesis_type: SynthesisType = Field(..., description="合成类型") synthesis_type: SynthesisType = Field(..., description="合成类型")
source_file_id: list[str] = Field(..., description="原始文件ID列表")
synth_config: Config = Field(..., description="合成配置")
@field_validator("description") @field_validator("description")
@classmethod @classmethod
def empty_string_to_none(cls, v: Optional[str]) -> Optional[str]: def empty_string_to_none(cls, v: Optional[str]) -> Optional[str]:
"""前端如果传入空字符串,将其统一转为 None,避免存库时看起来像有描述但实际上为空。""" """前端如果传入空字符串,将其统一转为 None,避免存库时看起来像有描述但实际上为空。"""
if isinstance(v, str) and v.strip() == "": if isinstance(v, str) and v.strip() == "":
return None return None
return v return v
@@ -50,17 +62,7 @@ class DataSynthesisTaskItem(BaseModel):
description: Optional[str] = None description: Optional[str] = None
status: Optional[str] = None status: Optional[str] = None
synthesis_type: str synthesis_type: str
model_id: str
progress: int
result_data_location: Optional[str] = None
text_split_config: Dict[str, Any]
synthesis_config: Dict[str, Any]
source_file_id: list[str]
total_files: int total_files: int
processed_files: int
total_chunks: int
processed_chunks: int
total_synthesis_data: int
created_at: Optional[datetime] = None created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None updated_at: Optional[datetime] = None
created_by: Optional[str] = None created_by: Optional[str] = None
@@ -85,7 +87,6 @@ class DataSynthesisFileTaskItem(BaseModel):
synthesis_instance_id: str synthesis_instance_id: str
file_name: str file_name: str
source_file_id: str source_file_id: str
target_file_location: str
status: Optional[str] = None status: Optional[str] = None
total_chunks: int total_chunks: int
processed_chunks: int processed_chunks: int
@@ -108,7 +109,7 @@ class PagedDataSynthesisFileTaskResponse(BaseModel):
class DataSynthesisChunkItem(BaseModel): class DataSynthesisChunkItem(BaseModel):
"""数据合成文件下的 chunk 记录""" """数据合成任务下的 chunk 记录"""
id: str id: str
synthesis_file_instance_id: str synthesis_file_instance_id: str
chunk_index: Optional[int] = None chunk_index: Optional[int] = None

View File

@@ -9,7 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger from app.core.logging import get_logger
from app.db.models.data_synthesis import ( from app.db.models.data_synthesis import (
DataSynthesisInstance, DataSynthInstance,
DataSynthesisFileInstance, DataSynthesisFileInstance,
SynthesisData, SynthesisData,
) )
@@ -43,7 +43,7 @@ class SynthesisDatasetExporter:
Optimized to process one file at a time to reduce memory usage. Optimized to process one file at a time to reduce memory usage.
""" """
task = await self._db.get(DataSynthesisInstance, task_id) task = await self._db.get(DataSynthInstance, task_id)
if not task: if not task:
raise SynthesisExportError(f"Synthesis task {task_id} not found") raise SynthesisExportError(f"Synthesis task {task_id} not found")

View File

@@ -1,138 +1,477 @@
import asyncio import asyncio
import json import json
import uuid import uuid
from pathlib import Path
from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_core.language_models import BaseChatModel
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models.data_synthesis import ( from app.db.models.data_synthesis import (
DataSynthesisInstance, DataSynthInstance,
DataSynthesisFileInstance, DataSynthesisFileInstance,
DataSynthesisChunkInstance, DataSynthesisChunkInstance,
SynthesisData, SynthesisData,
) )
from app.db.models.dataset_management import DatasetFiles from app.db.models.dataset_management import DatasetFiles
from app.db.models.model_config import get_model_by_id
from app.db.session import logger from app.db.session import logger
from app.module.shared.util.model_chat import _extract_json_substring from app.module.generation.schema.generation import Config, SyntheConfig
from app.module.system.service.common_service import get_chat_client, chat from app.module.generation.service.prompt import (
from app.common.document_loaders import load_documents QUESTION_GENERATOR_PROMPT,
ANSWER_GENERATOR_PROMPT,
)
from app.module.shared.common.document_loaders import load_documents
from app.module.shared.common.text_split import DocumentSplitter
from app.module.shared.util.model_chat import extract_json_substring
from app.module.system.service.common_service import chat, get_model_by_id, get_chat_client
class GenerationService: class GenerationService:
def __init__(self, db: AsyncSession): def __init__(self, db: AsyncSession):
self.db = db self.db = db
# 全局并发信号量:保证任意时刻最多 10 次模型调用
self.question_semaphore = asyncio.Semaphore(10)
self.answer_semaphore = asyncio.Semaphore(100)
async def process_task(self, task_id: str): async def process_task(self, task_id: str):
"""处理数据合成任务入口:根据任务ID加载任务并逐个处理源文件。""" """处理数据合成任务入口:根据任务ID加载任务并逐个处理源文件。"""
synthesis_task: DataSynthesisInstance | None = await self.db.get(DataSynthesisInstance, task_id) synth_task: DataSynthInstance | None = await self.db.get(DataSynthInstance, task_id)
if not synthesis_task: if not synth_task:
logger.error(f"Synthesis task {task_id} not found, abort processing") logger.error(f"Synthesis task {task_id} not found, abort processing")
return return
logger.info(f"Processing synthesis task {task_id}") logger.info(f"Start processing synthe task {task_id}")
file_ids = synthesis_task.source_file_id or []
# 获取模型客户端 # 从 synth_config 中读取 max_qa_pairs,全局控制 QA 总量上限;<=0 或异常则视为不限制
model_result = await get_model_by_id(self.db, str(synthesis_task.model_id)) try:
if not model_result: cfg = Config(**(synth_task.synth_config or {}))
logger.error( max_qa_pairs = cfg.max_qa_pairs if (cfg and cfg.max_qa_pairs and cfg.max_qa_pairs > 0) else None
f"Model config not found for id={synthesis_task.model_id}, abort task {synthesis_task.id}" except Exception:
) max_qa_pairs = None
# 获取任务关联的文件原始ID列表
file_ids = await self._get_file_ids_for_task(task_id)
if not file_ids:
logger.warning(f"No files associated with task {task_id}, abort processing")
return return
chat_client = get_chat_client(model_result)
# 控制并发度的信号量(限制全任务范围内最多 10 个并发调用)
semaphore = asyncio.Semaphore(10)
# 逐个文件处理 # 逐个文件处理
for file_id in file_ids: for file_id in file_ids:
try: try:
success = await self._process_single_file( success = await self._process_single_file(synth_task, file_id, max_qa_pairs=max_qa_pairs)
synthesis_task=synthesis_task,
file_id=file_id,
chat_client=chat_client,
semaphore=semaphore,
)
except Exception as e: except Exception as e:
logger.exception(f"Unexpected error when processing file {file_id} for task {task_id}: {e}") logger.exception(f"Unexpected error when processing file {file_id} for task {task_id}: {e}")
# 确保对应文件任务状态标记为失败 # 确保对应文件任务状态标记为失败
await self._mark_file_failed(str(synthesis_task.id), file_id, str(e)) await self._mark_file_failed(str(synth_task.id), file_id, str(e))
success = False success = False
if success: if success:
# 每处理完一个文件,简单增加 processed_files 计数 # 每处理完一个文件,简单增加 processed_files 计数
synthesis_task.processed_files = (synthesis_task.processed_files or 0) + 1 synth_task.processed_files = (synth_task.processed_files or 0) + 1
await self.db.commit() await self.db.commit()
await self.db.refresh(synthesis_task) await self.db.refresh(synth_task)
logger.info(f"Finished processing synthesis task {synthesis_task.id}") logger.info(f"Finished processing synthesis task {synth_task.id}")
# ==================== 高层文件处理流程 ====================
async def _process_single_file( async def _process_single_file(
self, self,
synthesis_task: DataSynthesisInstance, synth_task: DataSynthInstance,
file_id: str, file_id: str,
chat_client, max_qa_pairs: int | None = None,
semaphore: asyncio.Semaphore,
) -> bool: ) -> bool:
"""处理单个源文件:解析路径、切片、保存分块并触发 LLM 调用。""" """按 chunk 批量流式处理单个源文件。
流程:
1. 切片并将所有 chunk 持久化到 DB 后释放内存;
2. 从 DB 按 chunk_index 升序批量读取 chunk;
3. 对批次中的每个 chunk:先生成指定数量的问题,再基于这些问题生成答案;
4. 每成功处理完一个 chunk(即该 chunk 至少生成一条 QA)就更新一次 processed_chunks;
5. 全部完成后将文件实例标记为 completed。
"""
# 解析文件路径与配置
file_path = await self._resolve_file_path(file_id) file_path = await self._resolve_file_path(file_id)
if not file_path: if not file_path:
logger.warning(f"File path not found for file_id={file_id}, skip") logger.warning(f"File path not found for file_id={file_id}, skip")
await self._mark_file_failed(str(synthesis_task.id), file_id, "file_path_not_found") await self._mark_file_failed(str(synth_task.id), file_id, "file_path_not_found")
return False return False
logger.info(f"Processing file_id={file_id}, path={file_path}") logger.info(f"Processing file_id={file_id}, path={file_path}")
split_cfg = synthesis_task.text_split_config or {}
synthesis_cfg = synthesis_task.synthesis_config or {}
chunk_size = int(split_cfg.get("chunk_size", 800))
chunk_overlap = int(split_cfg.get("chunk_overlap", 50))
# 加载并切片
try: try:
chunks = self._load_and_split(file_path, chunk_size, chunk_overlap) config = Config(**(synth_task.synth_config or {}))
except Exception as e: except Exception as e:
logger.error(f"Failed to load/split file {file_path}: {e}") logger.error(f"Invalid synth_config for task={synth_task.id}: {e}")
await self._mark_file_failed(str(synthesis_task.id), file_id, f"load_split_error: {e}") await self._mark_file_failed(str(synth_task.id), file_id, "invalid_synth_config")
return False return False
# 1. 加载并切片(仅在此处占用内存)
chunks = self._load_and_split(
file_path,
config.text_split_config.chunk_size,
config.text_split_config.chunk_overlap,
)
if not chunks: if not chunks:
logger.warning(f"No chunks generated for file_id={file_id}") logger.warning(f"No chunks generated for file_id={file_id}")
await self._mark_file_failed(str(synthesis_task.id), file_id, "no_chunks_generated") await self._mark_file_failed(str(synth_task.id), file_id, "no_chunks_generated")
return False return False
logger.info(f"File {file_id} split into {len(chunks)} chunks by LangChain") logger.info(f"File {file_id} split into {len(chunks)} chunks by LangChain")
# 保存文件任务记录 + 分块记录 # 2. 获取文件实例并持久化 chunk 记录
file_task = await self._get_or_create_file_instance( file_task = await self._get_or_create_file_instance(
synthesis_task_id=str(synthesis_task.id), synthesis_task_id=str(synth_task.id),
source_file_id=file_id, source_file_id=file_id,
file_path=file_path,
) )
await self._persist_chunks(synthesis_task, file_task, file_id, chunks) if not file_task:
logger.error(
f"DataSynthesisFileInstance not found for task={synth_task.id}, file_id={file_id}"
)
await self._mark_file_failed(str(synth_task.id), file_id, "file_instance_not_found")
return False
# 针对每个切片并发调用大模型 await self._persist_chunks(synth_task, file_task, file_id, chunks)
await self._invoke_llm_for_chunks( total_chunks = len(chunks)
synthesis_task=synthesis_task, # 释放内存中的切片
file_id=file_id, del chunks
chunks=chunks,
synthesis_cfg=synthesis_cfg, # 3. 读取问答配置
chat_client=chat_client, question_cfg: SyntheConfig | None = config.question_synth_config
semaphore=semaphore, answer_cfg: SyntheConfig | None = config.answer_synth_config
if not question_cfg or not answer_cfg:
logger.error(
f"Question/Answer synth config missing for task={synth_task.id}, file={file_id}"
)
await self._mark_file_failed(str(synth_task.id), file_id, "qa_config_missing")
return False
logger.info(
f"Start QA generation for task={synth_task.id}, file={file_id}, total_chunks={total_chunks}"
) )
# 如果执行到此处,说明该文件的切片与 LLM 调用流程均未抛出异常,标记为完成 # 为本文件构建模型 client
question_model = await get_model_by_id(self.db, question_cfg.model_id)
answer_model = await get_model_by_id(self.db, answer_cfg.model_id)
question_chat = get_chat_client(question_model)
answer_chat = get_chat_client(answer_model)
# 分批次从 DB 读取并处理 chunk
batch_size = 20
current_index = 1
while current_index <= total_chunks:
end_index = min(current_index + batch_size - 1, total_chunks)
chunk_batch = await self._load_chunk_batch(
file_task_id=file_task.id,
start_index=current_index,
end_index=end_index,
)
if not chunk_batch:
logger.warning(
f"Empty chunk batch loaded for file={file_id}, range=[{current_index}, {end_index}]"
)
current_index = end_index + 1
continue
# 对本批中的每个 chunk 并发处理(内部受 semaphore 限流)
async def process_one(chunk: DataSynthesisChunkInstance) -> bool:
return await self._process_single_chunk_qa(
file_task=file_task,
chunk=chunk,
question_cfg=question_cfg,
answer_cfg=answer_cfg,
question_chat=question_chat,
answer_chat=answer_chat,
synth_task_id=str(synth_task.id),
max_qa_pairs=max_qa_pairs,
)
tasks = [process_one(chunk) for chunk in chunk_batch]
await asyncio.gather(*tasks, return_exceptions=True)
current_index = end_index + 1
# 全部完成
file_task.status = "completed" file_task.status = "completed"
await self.db.commit() await self.db.commit()
await self.db.refresh(file_task) await self.db.refresh(file_task)
return True return True
async def _process_single_chunk_qa(
self,
file_task: DataSynthesisFileInstance,
chunk: DataSynthesisChunkInstance,
question_cfg: SyntheConfig,
answer_cfg: SyntheConfig,
question_chat: BaseChatModel,
answer_chat: BaseChatModel,
synth_task_id: str,
max_qa_pairs: int | None = None,
) -> bool:
"""处理单个 chunk:生成问题列表,然后为每个问题生成答案并落库。
为了全局控制 QA 总量:在本方法开始处,根据 synth_task_id 查询当前已落盘的
SynthesisData 条数,如果 >= max_qa_pairs,则不再对当前 chunk 做任何 QA 生成,
并将当前文件任务标记为 completed,processed_chunks = total_chunks。
已经进入后续流程的任务(例如其它协程正在生成答案)允许自然执行完。
"""
# 如果没有全局上限配置,维持原有行为
if max_qa_pairs is not None and max_qa_pairs > 0:
from sqlalchemy import func
# 统计当前整个任务下已生成的 QA 总数
result = await self.db.execute(
select(func.count(SynthesisData.id)).where(
SynthesisData.synthesis_file_instance_id.in_(
select(DataSynthesisFileInstance.id).where(
DataSynthesisFileInstance.synthesis_instance_id == synth_task_id
)
)
)
)
current_qa_count = int(result.scalar() or 0)
if current_qa_count >= max_qa_pairs:
logger.info(
"max_qa_pairs reached: current=%s, max=%s, task_id=%s, file_task_id=%s, skip new QA generation for this chunk.",
current_qa_count,
max_qa_pairs,
synth_task_id,
file_task.id,
)
# 将文件任务标记为已完成,并认为所有 chunk 均已处理
file_task.status = "completed"
if file_task.total_chunks is not None:
file_task.processed_chunks = file_task.total_chunks
await self.db.commit()
await self.db.refresh(file_task)
return False
# ---- 下面保持原有逻辑不变 ----
chunk_index = chunk.chunk_index
chunk_text = chunk.chunk_content or ""
if not chunk_text.strip():
logger.warning(
f"Empty chunk text for file_task={file_task.id}, chunk_index={chunk_index}"
)
# 无论成功或失败,均视为该 chunk 已处理完成
try:
await self._increment_processed_chunks(file_task.id, 1)
except Exception as e:
logger.exception(
f"Failed to increment processed_chunks for file_task={file_task.id}, chunk_index={chunk_index}: {e}"
)
return False
success_any = False
# 1. 生成问题
try:
questions = await self._generate_questions_for_one_chunk(
chunk_text=chunk_text,
question_cfg=question_cfg,
question_chat=question_chat,
)
except Exception as e:
logger.error(
f"Generate questions failed for file_task={file_task.id}, chunk_index={chunk_index}: {e}"
)
questions = []
if not questions:
logger.info(
f"No questions generated for file_task={file_task.id}, chunk_index={chunk_index}"
)
else:
# 2. 针对每个问题生成答案并入库
qa_success = await self._generate_answers_for_one_chunk(
file_task=file_task,
chunk=chunk,
questions=questions,
answer_cfg=answer_cfg,
answer_chat=answer_chat,
)
success_any = bool(qa_success)
# 无论本 chunk 处理是否成功,都增加 processed_chunks 计数,避免任务长时间卡住
try:
await self._increment_processed_chunks(file_task.id, 1)
except Exception as e:
logger.exception(
f"Failed to increment processed_chunks for file_task={file_task.id}, chunk_index={chunk_index}: {e}"
)
return success_any
async def _generate_questions_for_one_chunk(
self,
chunk_text: str,
question_cfg: SyntheConfig,
question_chat: BaseChatModel,
) -> list[str]:
"""针对单个 chunk 文本,调用 question_chat 生成问题列表。"""
number = question_cfg.number or 5
number = number if number is not None else 5
number = max(int(len(chunk_text) / 1000 * number), 1)
template = getattr(question_cfg, "prompt_template", QUESTION_GENERATOR_PROMPT)
template = template if (template is not None and template.strip() != "") else QUESTION_GENERATOR_PROMPT
prompt = (
template
.replace("{text}", chunk_text)
.replace("{number}", str(number))
.replace("{textLength}", str(len(chunk_text)))
)
async with self.question_semaphore:
loop = asyncio.get_running_loop()
raw_answer = await loop.run_in_executor(
None,
chat,
question_chat,
prompt,
)
# 解析为问题列表
questions = self._parse_questions_from_answer(
raw_answer,
)
return questions
async def _generate_answers_for_one_chunk(
self,
file_task: DataSynthesisFileInstance,
chunk: DataSynthesisChunkInstance,
questions: list[str],
answer_cfg: SyntheConfig,
answer_chat: BaseChatModel,
) -> bool:
"""为一个 chunk 的所有问题生成答案并写入 SynthesisData。
返回:是否至少成功写入一条 QA。
"""
if not questions:
return False
chunk_text = chunk.chunk_content or ""
template = getattr(answer_cfg, "prompt_template", ANSWER_GENERATOR_PROMPT)
template = template if (template is not None and template.strip() != "") else ANSWER_GENERATOR_PROMPT
extra_vars = getattr(answer_cfg, "extra_prompt_vars", {}) or {}
success_flags: list[bool] = []
async def process_single_question(question: str):
prompt = template.replace("{text}", chunk_text).replace("{question}", question)
for k, v in extra_vars.items():
prompt.replace(f"{{{{{k}}}}}", str(v))
else:
prompt_local = prompt
async with self.answer_semaphore:
loop = asyncio.get_running_loop()
answer = await loop.run_in_executor(
None,
chat,
answer_chat,
prompt_local,
)
# 默认结构:与 ANSWER_GENERATOR_PROMPT 一致,并补充 instruction 字段
base_obj: dict[str, object] = {
"input": chunk_text,
"output": answer,
}
# 如果模型已经按照 ANSWER_GENERATOR_PROMPT 返回了 JSON,则尝试解析并在其上增加 instruction
parsed_obj: dict[str, object] | None = None
if isinstance(answer, str):
cleaned = extract_json_substring(answer)
try:
parsed = json.loads(cleaned)
if isinstance(parsed, dict):
parsed_obj = parsed
except Exception:
parsed_obj = None
if parsed_obj is not None:
parsed_obj["instruction"] = question
data_obj = parsed_obj
else:
base_obj["instruction"] = question
data_obj = base_obj
record = SynthesisData(
id=str(uuid.uuid4()),
data=data_obj,
synthesis_file_instance_id=file_task.id,
chunk_instance_id=chunk.id,
)
self.db.add(record)
success_flags.append(True)
tasks = [process_single_question(q) for q in questions]
await asyncio.gather(*tasks, return_exceptions=True)
if success_flags:
await self.db.commit()
return True
return False
@staticmethod
def _parse_questions_from_answer(
raw_answer: str,
) -> list[str]:
"""从大模型返回中解析问题数组。"""
if not raw_answer:
return []
cleaned = extract_json_substring(raw_answer)
try:
data = json.loads(cleaned)
except Exception as e:
logger.error(
f"Failed to parse question list JSON for task: {e}. "
)
return []
if isinstance(data, list):
return [str(q) for q in data if isinstance(q, str) and q.strip()]
# 容错:如果是单个字符串
if isinstance(data, str) and data.strip():
return [data.strip()]
return []
# ==================== 原有辅助方法(文件路径/切片/持久化等) ====================
async def _resolve_file_path(self, file_id: str) -> str | None:
"""根据文件ID查询 t_dm_dataset_files 并返回 file_path(仅 ACTIVE 文件)。"""
result = await self.db.execute(
select(DatasetFiles).where(DatasetFiles.id == file_id)
)
file_obj = result.scalar_one_or_none()
if not file_obj:
return None
return file_obj.file_path
@staticmethod
def _load_and_split(file_path: str, chunk_size: int, chunk_overlap: int):
"""使用 LangChain 加载文本并进行切片,直接返回 Document 列表。
Args:
file_path: 待切片的文件路径
chunk_size: 切片大小
chunk_overlap: 切片重叠大小
"""
try:
docs = load_documents(file_path)
split_docs = DocumentSplitter.auto_split(docs, chunk_size, chunk_overlap)
return split_docs
except Exception as e:
logger.error(f"Error loading or splitting file {file_path}: {e}")
raise
async def _persist_chunks( async def _persist_chunks(
self, self,
synthesis_task: DataSynthesisInstance, synthesis_task: DataSynthInstance,
file_task: DataSynthesisFileInstance, file_task: DataSynthesisFileInstance,
file_id: str, file_id: str,
chunks, chunks,
@@ -164,201 +503,10 @@ class GenerationService:
await self.db.commit() await self.db.commit()
await self.db.refresh(file_task) await self.db.refresh(file_task)
async def _invoke_llm_for_chunks(
self,
synthesis_task: DataSynthesisInstance,
file_id: str,
chunks,
synthesis_cfg: dict,
chat_client,
semaphore: asyncio.Semaphore,
) -> None:
"""针对每个分片并发调用大模型生成数据。"""
# 需要将 answer 和对应 chunk 建立关系,因此这里保留 chunk_index
tasks = [
self._call_llm(doc, file_id, idx, synthesis_task, synthesis_cfg, chat_client, semaphore)
for idx, doc in enumerate(chunks, start=1)
]
await asyncio.gather(*tasks, return_exceptions=True)
async def _call_llm(
self,
doc,
file_id: str,
idx: int,
synthesis_task,
synthesis_cfg: dict,
chat_client,
semaphore: asyncio.Semaphore,
):
"""单次大模型调用逻辑,带并发控制。
说明:
- 使用信号量限制全局并发量(当前为 10)。
- 使用线程池执行同步的 chat 调用,避免阻塞事件循环。
- 在拿到 LLM 返回后,解析为 JSON 并批量写入 SynthesisData,
同时更新文件级 processed_chunks / 进度等信息。
"""
async with semaphore:
prompt = self._build_qa_prompt(doc.page_content, synthesis_cfg)
try:
loop = asyncio.get_running_loop()
answer = await loop.run_in_executor(None, chat, chat_client, prompt)
logger.debug(
f"Generated QA for task={synthesis_task.id}, file={file_id}, chunk={idx}"
)
await self._handle_llm_answer(
synthesis_task_id=str(synthesis_task.id),
file_id=file_id,
chunk_index=idx,
raw_answer=answer,
)
return answer
except Exception as e:
logger.error(
f"LLM generation failed for task={synthesis_task.id}, file={file_id}, chunk={idx}: {e}"
)
return None
async def _resolve_file_path(self, file_id: str) -> str | None:
"""根据文件ID查询 t_dm_dataset_files 并返回 file_path(仅 ACTIVE 文件)。"""
result = await self.db.execute(
select(DatasetFiles).where(DatasetFiles.id == file_id)
)
file_obj = result.scalar_one_or_none()
if not file_obj:
return None
return file_obj.file_path
def _load_and_split(self, file_path: str, chunk_size: int, chunk_overlap: int):
"""使用 LangChain 加载文本并进行切片,直接返回 Document 列表。
当前实现:
- 使用 TextLoader 加载纯文本/Markdown/JSON 等文本文件
- 使用 RecursiveCharacterTextSplitter 做基于字符的递归切片
保留每个 Document 的 metadata,方便后续追加例如文件ID、chunk序号等信息。
"""
docs = load_documents(file_path)
splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
# 尝试按这些分隔符优先切分,再退化到字符级
separators=["\n\n", "\n", "", "", "", "!", "?", "\n", "\t", " "]
)
split_docs = splitter.split_documents(docs)
return split_docs
@staticmethod
def _build_qa_prompt(chunk: str, synthesis_cfg: dict) -> str:
"""构造 QA 数据合成的提示词。
要求:
- synthesis_cfg["prompt_template"] 是一个字符串,其中包含 {document} 占位符;
- 将当前切片内容替换到 {document}
如果未提供或模板非法,则使用内置默认模板。
"""
template = None
if isinstance(synthesis_cfg, dict):
template = synthesis_cfg.get("prompt_template")
synthesis_count = synthesis_cfg["synthesis_count"] if ("synthesis_count" in synthesis_cfg and synthesis_cfg["synthesis_count"]) else 5
try:
prompt = template.format(document=chunk, synthesis_count=synthesis_count)
except Exception:
# 防御性处理:如果 format 出现异常,则退回到简单拼接
prompt = f"{template}\n\n文档内容:{chunk}\n\n请根据文档内容生成 {synthesis_count} 条符合要求的问答数据。"
return prompt
async def _handle_llm_answer(
self,
synthesis_task_id: str,
file_id: str,
chunk_index: int,
raw_answer: str,
) -> None:
"""解析 LLM 返回内容为 JSON,批量保存到 SynthesisData,并更新文件任务进度。
约定:
- LLM 返回的 raw_answer 是 JSON 字符串,可以是:
1)单个对象:{"question": ..., "answer": ...}
2)对象数组:[{}, {}, ...]
- 我们将其规范化为列表,每个元素作为一条 SynthesisData.data 写入。
- 根据 synthesis_task_id + file_id + chunk_index 找到对应的 DataSynthesisChunkInstance,
以便设置 chunk_instance_id 和 synthesis_file_instance_id。
- 每处理完一个 chunk,递增对应 DataSynthesisFileInstance.processed_chunks,并按比例更新进度。
"""
if not raw_answer:
return
# 1. 预处理原始回答:尝试从中截取出最可能的 JSON 片段
cleaned = _extract_json_substring(raw_answer)
# 2. 解析 JSON,统一成列表结构
try:
parsed = json.loads(cleaned)
except Exception as e:
logger.error(
f"Failed to parse LLM answer as JSON for task={synthesis_task_id}, file={file_id}, chunk={chunk_index}: {e}. Raw answer: {raw_answer!r}"
)
return
if isinstance(parsed, dict):
items = [parsed]
elif isinstance(parsed, list):
items = [p for p in parsed if isinstance(p, dict)]
else:
logger.error(f"Unexpected JSON structure from LLM answer for task={synthesis_task_id}, file={file_id}, chunk={chunk_index}: {type(parsed)}")
return
if not items:
return
# 3. 找到对应的 chunk 记录(一个 chunk_index 对应一条记录)
chunk_result = await self.db.execute(
select(DataSynthesisChunkInstance, DataSynthesisFileInstance)
.join(
DataSynthesisFileInstance,
DataSynthesisFileInstance.id == DataSynthesisChunkInstance.synthesis_file_instance_id,
)
.where(
DataSynthesisFileInstance.synthesis_instance_id == synthesis_task_id,
DataSynthesisFileInstance.source_file_id == file_id,
DataSynthesisChunkInstance.chunk_index == chunk_index,
)
)
row = chunk_result.first()
if not row:
logger.error(
f"Chunk record not found for task={synthesis_task_id}, file={file_id}, chunk_index={chunk_index}, skip saving SynthesisData."
)
return
chunk_instance, file_instance = row
# 4. 批量写入 SynthesisData
for data_obj in items:
record = SynthesisData(
id=str(uuid.uuid4()),
data=data_obj,
synthesis_file_instance_id=file_instance.id,
chunk_instance_id=chunk_instance.id,
)
self.db.add(record)
# 5. 更新文件级 processed_chunks / 进度
file_instance.processed_chunks = (file_instance.processed_chunks or 0) + 1
await self.db.commit()
await self.db.refresh(file_instance)
async def _get_or_create_file_instance( async def _get_or_create_file_instance(
self, self,
synthesis_task_id: str, synthesis_task_id: str,
source_file_id: str, source_file_id: str,
file_path: str,
) -> DataSynthesisFileInstance: ) -> DataSynthesisFileInstance:
"""根据任务ID和原始文件ID,查找或创建对应的 DataSynthesisFileInstance 记录。 """根据任务ID和原始文件ID,查找或创建对应的 DataSynthesisFileInstance 记录。
@@ -374,32 +522,9 @@ class GenerationService:
) )
) )
file_task = result.scalar_one_or_none() file_task = result.scalar_one_or_none()
if file_task is not None:
return file_task return file_task
# 查询任务以获取 result_data_location async def _mark_file_failed(self, synth_task_id: str, file_id: str, reason: str | None = None) -> None:
task = await self.db.get(DataSynthesisInstance, synthesis_task_id)
target_location = task.result_data_location if task else ""
# 创建新的文件任务记录,初始状态为 processing
file_task = DataSynthesisFileInstance(
id=str(uuid.uuid4()),
synthesis_instance_id=synthesis_task_id,
file_name=Path(file_path).name,
source_file_id=source_file_id,
target_file_location=target_location or "",
status="processing",
total_chunks=0,
processed_chunks=0,
created_by="system",
updated_by="system",
)
self.db.add(file_task)
await self.db.commit()
await self.db.refresh(file_task)
return file_task
async def _mark_file_failed(self, synthesis_task_id: str, file_id: str, reason: str | None = None) -> None:
"""将指定任务下的单个文件任务标记为失败状态,兜底错误处理。 """将指定任务下的单个文件任务标记为失败状态,兜底错误处理。
- 如果找到对应的 DataSynthesisFileInstance,则更新其 status="failed" - 如果找到对应的 DataSynthesisFileInstance,则更新其 status="failed"
@@ -409,14 +534,14 @@ class GenerationService:
try: try:
result = await self.db.execute( result = await self.db.execute(
select(DataSynthesisFileInstance).where( select(DataSynthesisFileInstance).where(
DataSynthesisFileInstance.synthesis_instance_id == synthesis_task_id, DataSynthesisFileInstance.synthesis_instance_id == synth_task_id,
DataSynthesisFileInstance.source_file_id == file_id, DataSynthesisFileInstance.source_file_id == file_id,
) )
) )
file_task = result.scalar_one_or_none() file_task = result.scalar_one_or_none()
if not file_task: if not file_task:
logger.warning( logger.warning(
f"Failed to mark file as failed: no DataSynthesisFileInstance found for task={synthesis_task_id}, file_id={file_id}, reason={reason}" f"Failed to mark file as failed: no DataSynthesisFileInstance found for task={synth_task_id}, file_id={file_id}, reason={reason}"
) )
return return
@@ -424,10 +549,72 @@ class GenerationService:
await self.db.commit() await self.db.commit()
await self.db.refresh(file_task) await self.db.refresh(file_task)
logger.info( logger.info(
f"Marked file task as failed for task={synthesis_task_id}, file_id={file_id}, reason={reason}" f"Marked file task as failed for task={synth_task_id}, file_id={file_id}, reason={reason}"
) )
except Exception as e: except Exception as e:
# 兜底日志,避免异常向外传播影响其它文件处理 # 兜底日志,避免异常向外传播影响其它文件处理
logger.exception( logger.exception(
f"Unexpected error when marking file failed for task={synthesis_task_id}, file_id={file_id}, original_reason={reason}, error={e}" f"Unexpected error when marking file failed for task={synth_task_id}, file_id={file_id}, original_reason={reason}, error={e}"
) )
async def _get_file_ids_for_task(self, synth_task_id: str):
"""根据任务ID查询关联的文件原始ID列表"""
result = await self.db.execute(
select(DataSynthesisFileInstance.source_file_id)
.where(DataSynthesisFileInstance.synthesis_instance_id == synth_task_id)
)
file_ids = result.scalars().all()
return file_ids
# ========== 新增:chunk 计数与批量加载、processed_chunks 安全更新辅助方法 ==========
async def _count_chunks_for_file(self, synth_file_instance_id: str) -> int:
"""统计指定任务与文件下的 chunk 总数。"""
from sqlalchemy import func
result = await self.db.execute(
select(func.count(DataSynthesisChunkInstance.id)).where(
DataSynthesisChunkInstance.synthesis_file_instance_id == synth_file_instance_id
)
)
return int(result.scalar() or 0)
async def _load_chunk_batch(
self,
file_task_id: str,
start_index: int,
end_index: int,
) -> list[DataSynthesisChunkInstance]:
"""按索引范围加载指定文件任务下的一批 chunk 记录(含边界)。"""
result = await self.db.execute(
select(DataSynthesisChunkInstance)
.where(
DataSynthesisChunkInstance.synthesis_file_instance_id == file_task_id,
DataSynthesisChunkInstance.chunk_index >= start_index,
DataSynthesisChunkInstance.chunk_index <= end_index,
)
.order_by(DataSynthesisChunkInstance.chunk_index.asc())
)
return list(result.scalars().all())
async def _increment_processed_chunks(self, file_task_id: str, delta: int) -> None:
result = await self.db.execute(
select(DataSynthesisFileInstance).where(
DataSynthesisFileInstance.id == file_task_id,
)
)
file_task = result.scalar_one_or_none()
if not file_task:
logger.error(f"Failed to increment processed_chunks: file_task {file_task_id} not found")
return
# 原始自增
new_value = (file_task.processed_chunks or 0) + int(delta)
# 如果存在 total_chunks,上限为 total_chunks,避免超过
total = file_task.total_chunks
if isinstance(total, int) and total >= 0:
new_value = min(new_value, total)
file_task.processed_chunks = new_value
await self.db.commit()
await self.db.refresh(file_task)

View File

@@ -1,71 +1,138 @@
from app.module.generation.schema.generation import SynthesisType from app.module.generation.schema.generation import SynthesisType
QA_PROMPT="""# 角色 QUESTION_GENERATOR_PROMPT=f"""# Role: 文本问题生成专家
你是一位专业的AI助手,擅长从给定的文本中提取关键信息并创建用于教学和测试的问答对。 ## Profile:
- Description: 你是一名专业的文本分析与问题设计专家,能够从复杂文本中提炼关键信息并产出可用于模型微调的高质量问题集合。
- Input Length: {{textLength}}
- Output Goal: 生成不少于 {{number}} 个高质量问题,用于构建问答训练数据集。
# 任务 ## Skills:
请根据用户提供的原始文档,生成一系列高质量、多样化的问答对 1. 能够全面理解原文内容,识别核心概念、事实与逻辑结构
2. 擅长设计具有明确答案指向性的问题,覆盖文本多个侧面。
3. 善于控制问题难度与类型,保证多样性与代表性。
4. 严格遵守格式规范,确保输出可直接用于程序化处理。
# 输入文档 ## Workflow:
{document} 1. **文本解析**:通读全文,分段识别关键实体、事件、数值与结论。
2. **问题设计**:基于信息密度和重要性选择最佳提问切入点。
3. **质量检查**:逐条校验问题,确保:
- 问题答案可在原文中直接找到依据。
- 问题之间主题不重复、角度不雷同。
- 语言表述准确、无歧义且符合常规问句形式。
# 要求与指令 ## Constraints:
1. **问题类型**:生成 {synthesis_count} 个左右的问答对。问题类型应多样化,包括但不限于: 1. 所有问题必须严格依据原文内容,不得添加外部信息或假设情境。
* **事实性**:基于文本中明确提到的事实 2. 问题需覆盖文本的不同主题、层级或视角,避免集中于单一片段
* **理解性**:需要理解上下文和概念 3. 禁止输出与材料元信息相关的问题(如作者、章节、目录等)
* **归纳性**:需要总结或归纳多个信息点 4. 提问时请假设没有相应的文章可供参考,因此不要在问题中使用"这个""这些"等指示代词,也不得包含“报告/文章/文献/表格中提到”等表述
2. **答案来源**:所有答案必须严格基于提供的文档内容,不得编造原文不存在的信息 5. 输出不少于 {{number}} 个问题,问题语言与原文主要语言保持一致
3. **语言**:请根据输入文档的主要语言进行提问和回答。
4. **问题质量**:问题应清晰、无歧义,并且是读完文档后自然会产生的问题。
5. **答案质量**:答案应准确、简洁、完整。
# 输出格式 ## Output Format:
请严格按照以下JSON格式输出,保持字段顺序,确保没有额外的解释或标记: - 使用合法的 JSON 数组,仅包含字符串元素。
[ - 字段必须使用英文双引号。
{{"instruction": "问题1","input": "参考内容1","output": "答案1"}}, - 严格遵循以下结构:
{{"instruction": "问题2","input": "参考内容1","output": "答案2"}}, ```
... ["问题1", "问题2", "..."]
] ```
## Output Example:
```
["人工智能伦理框架应包含哪些核心要素", "民法典对个人数据保护有哪些新规定?"]
```
## 参考原文:
{{text}}
""" """
ANSWER_GENERATOR_PROMPT=f"""# Role: 微调数据生成专家
## Profile:
- Description: 你是一名微调数据生成专家,擅长基于给定内容生成准确对应的问题答案,确保答案的准确性、相关性和完整性,能够直接输出符合模型训练要求的结构化数据。
COT_PROMPT="""# 角色 ## Skills:
你是一位专业的数据合成专家,擅长基于给定的原始文档和 COT(Chain of Thought,思维链)逻辑,生成高质量、符合实际应用场景的 COT 数据。COT 数据需包含清晰的问题、逐步推理过程和最终结论,能完整还原解决问题的思考路径。 1. 严格基于给定内容生成答案,不添加任何外部信息
2. 答案需准确无误、逻辑通顺,与问题高度相关
3. 能够精准提取给定内容中的关键信息,并整合为自然流畅的完整答案
4. 输出结果必须符合指定的结构化格式要求
# 任务 ## Workflow:
请根据用户提供的原始文档,生成一系列高质量、多样化的 COT 数据。每个 COT 数据需围绕文档中的关键信息、核心问题或逻辑关联点展开,确保推理过程贴合文档内容,结论准确可靠。 1. 分析给定的参考内容,梳理核心信息和逻辑框架
2. 结合提出的具体问题,从参考内容中提取与之匹配的关键依据
3. 基于提取的依据,生成准确、详尽且符合逻辑的答案
4. 将依据内容和答案分别填入指定字段,形成结构化输出
5. 校验输出内容,确保格式正确、信息完整、无引用性表述
# 输入文档 ## Output Format:
{document} 输出格式为固定字典结构:
```json
{{
"input": "此处填入回答问题所依据的完整参考内容",
"output": "此处填入基于参考内容生成的准确答案"
}}
```
# 要求与指令 ## Constrains:
1. **数量要求**:生成 {synthesis_count} 条左右的 COT 数据。 1. `input`字段必须根据给定的参考内容填入回答问题的依据,不得更改原文含义
2. **内容要求**: 2. `output`字段的答案必须完全基于`input`中的内容,严禁编造、添加外部信息
* 每条 COT 数据需包含 “问题”“思维链推理”“最终结论” 三部分,逻辑闭环,推理步骤清晰、连贯,不跳跃关键环节。 3. 答案需充分详细,包含回答问题的所有必要信息,满足大模型微调训练的数据要求
* 问题需基于文档中的事实信息、概念关联或逻辑疑问,是读完文档后自然产生的有价值问题(避免无意义或过于简单的问题)。 4. 答案中不得出现「参考」「依据」「文献中提到」等任何引用性表述,仅呈现最终结论
* 思维链推理需严格依据文档内容,逐步推导,每一步推理都能对应文档中的具体信息,不编造原文不存在的内容,不主观臆断。 5. 必须严格遵守指定的字典输出格式,不得额外添加其他内容
* 最终结论需简洁、准确,是思维链推理的合理结果,与文档核心信息一致。
3. **多样化要求**:
* 问题类型多样化,包括但不限于事实查询类、逻辑分析类、原因推导类、方案对比类、结论归纳类。
* 推理角度多样化,可从不同角色(如项目参与者、需求方、测试人员)或不同维度(如功能实现、进度推进、问题解决)展开推理。
4. **语言要求**:
* 语言通顺、表达清晰,无歧义,推理过程口语化但不随意,符合正常思考逻辑,最终结论简洁规范。
* 请根据输入文档的主要语言进行提问和回答。
# 输出格式 ## Reference Content
请严格按照以下 JSON 格式输出,保持字段顺序,确保没有额外的解释或标记,每条 COT 数据独立成项: ------ 参考内容 Start ------
[ {{text}}
{{"question": "具体问题","chain_of_thought": "步骤 1:明确问题核心,定位文档中相关信息范围;步骤 2:提取文档中与问题相关的关键信息 1;步骤 3:结合关键信息 1 推导中间结论 1;步骤 4:提取文档中与问题相关的关键信息 2;步骤 5:结合中间结论 1 和关键信息 2 推导中间结论 2;...(逐步推进);步骤 N:汇总所有中间结论,得出最终结论","conclusion": "简洁准确的最终结论"}}, ------ 参考内容 End ------
{{"question": "具体问题","chain_of_thought": "步骤 1:明确问题核心,定位文档中相关信息范围;步骤 2:提取文档中与问题相关的关键信息 1;步骤 3:结合关键信息 1 推导中间结论 1;步骤 4:提取文档中与问题相关的关键信息 2;步骤 5:结合中间结论 1 和关键信息 2 推导中间结论 2;...(逐步推进);步骤 N:汇总所有中间结论,得出最终结论","conclusion": "简洁准确的最终结论"}}, ## Question
... {{question}}
] """
COT_GENERATOR_PROMPT=f"""# Role: 微调数据生成专家
## Profile:
- Description: 你是一名微调数据生成专家,擅长基于给定参考内容,通过**思维链(COT)逐步推理**生成准确、完整的答案,输出符合大模型微调训练要求的结构化COT数据,还原从信息提取到结论推导的全思考路径。
## Skills:
1. 严格基于给定参考内容开展推理,不引入任何外部信息
2. 能够拆解问题逻辑,按步骤提取关键信息并推导,确保推理过程连贯、无跳跃
3. 生成的答案精准对应问题,逻辑通顺,与参考内容高度一致
4. 输出结果严格符合指定的结构化COT格式要求
## Workflow:
1. 分析给定参考内容,梳理核心信息、概念及逻辑关联
2. 结合具体问题,明确推理起点与目标,划定参考内容中的相关信息范围
3. 分步推导:提取关键信息→推导中间结论→结合更多信息完善逻辑→形成最终结论
4. 将完整推理过程、最终答案填入指定字段,生成结构化COT数据
5. 校验:确保推理每一步均对应参考内容,无编造信息,格式合规,无引用性表述
## Output Format:
输出固定JSON结构,包含思维链推理、最终答案两部分:
```json
{{
"chain_of_thought": "基于参考内容逐步推理的完整思维链,详述每一步提取的信息和推导的逻辑过程",
"output": "此处填入基于思维链推理得出的准确、详细的最终结论"
}}
```
## Constrains:
2. `chain_of_thought`字段需还原完整推理路径,每一步推导均需对应`Reference Content`中的具体内容,严禁主观臆断
3. `output`字段的答案必须完全来源于`Reference Content`和`chain_of_thought`的推导,不添加任何外部信息,满足大模型微调对数据质量的要求
4. 整个输出中不得出现「参考」「依据」「文献中提到」等引用性表述,仅呈现推理过程与结论
5. 必须严格遵守指定JSON格式,字段顺序固定,无额外解释或标记内容
## Reference Content
------ 参考内容 Start ------
{{text}}
------ 参考内容 End ------
## Question
{{question}}
""" """
def get_prompt(synth_type: SynthesisType): def get_prompt(synth_type: SynthesisType):
if synth_type == SynthesisType.QA: if synth_type == SynthesisType.QA:
return QA_PROMPT return ANSWER_GENERATOR_PROMPT
elif synth_type == SynthesisType.COT: elif synth_type == SynthesisType.COT:
return COT_PROMPT return COT_GENERATOR_PROMPT
elif synth_type == SynthesisType.QUESTION:
return QUESTION_GENERATOR_PROMPT
else: else:
raise ValueError(f"Unsupported synthesis type: {synth_type}") raise ValueError(f"Unsupported synthesis type: {synth_type}")

View File

@@ -0,0 +1,169 @@
import os
from typing import List, Optional, Tuple
from langchain_core.documents import Document
from langchain_text_splitters import (
RecursiveCharacterTextSplitter,
MarkdownHeaderTextSplitter
)
class DocumentSplitter:
"""
文档分割器类 - 增强版,优先通过元数据识别文档类型
核心特性:
1. 优先从metadata的source字段(文件扩展名)识别Markdown
2. 元数据缺失时,通过内容特征降级检测
3. 支持CJK(中日韩)语言优化
"""
def __init__(
self,
chunk_size: int = 2000,
chunk_overlap: int = 200,
is_cjk_language: bool = True,
markdown_headers: Optional[List[Tuple[str, str]]] = None
):
"""
初始化文档分割器
Args:
chunk_size: 每个文本块的最大长度(默认2000字符)
chunk_overlap: 文本块之间的重叠长度(默认200字符)
is_cjk_language: 是否处理中日韩等无词边界语言(默认True)
markdown_headers: Markdown标题分割规则(默认:#/##/###/####)
"""
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.is_cjk_language = is_cjk_language
# 默认Markdown标题分割规则
self.markdown_headers = markdown_headers or [
("#", "header_1"),
("##", "header_2"),
("###", "header_3"),
("####", "header_4"),
]
# 初始化基础文本分割器
self.text_splitter = self._create_text_splitter()
def _create_text_splitter(self) -> RecursiveCharacterTextSplitter:
"""创建递归字符分割器(内部方法)"""
# 优化后的CJK分隔符列表(修复语法错误,调整优先级)
if self.is_cjk_language:
separators = [
"\n\n", "\n", # 段落/换行(最高优先级)
"", ".", # 句号(中文/英文)
"", "!", # 感叹号(中文/英文)
"", "?", # 问号(中文/英文)
"", ";", # 分号(中文/英文)
"", ",", # 逗号(中文/英文)
"", # 顿号(中文)
"", ":", # 冒号(中文/英文)
" ", # 空格
"\u200b", "", # 零宽空格/兜底
]
else:
separators = ["\n\n", "\n", " ", ".", "!", "?", ";", ":", ",", ""]
return RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
separators=separators,
length_function=len,
is_separator_regex=False
)
@staticmethod
def _is_markdown(doc: Document) -> bool:
"""
优先从元数据判断是否为Markdown
规则:检查metadata中的source字段扩展名是否为.md/.markdown/.mdx等
"""
# 获取source字段(忽略大小写)
source = doc.metadata.get("source", "").lower()
if not source:
return False
# 获取文件扩展名
ext = os.path.splitext(source)[-1].lower()
# Markdown常见扩展名列表
md_ext = [".md", ".markdown", ".mdx", ".mkd", ".mkdown"]
return ext in md_ext
def split(self, documents: List[Document], is_markdown: bool = False) -> List[Document]:
"""
核心分割方法
Args:
documents: 待分割的Document列表
is_markdown: 是否为Markdown文档(默认False)
Returns:
分割后的Document列表
"""
if not documents:
return []
# Markdown文档处理:先按标题分割,再按字符分割
if is_markdown:
# 初始化Markdown标题分割器
md_splitter = MarkdownHeaderTextSplitter(
headers_to_split_on=self.markdown_headers,
strip_headers=True,
return_each_line=False
)
# 按标题分割并继承元数据
md_chunks = []
for doc in documents:
chunks = md_splitter.split_text(doc.page_content)
for chunk in chunks:
chunk.metadata.update(doc.metadata)
md_chunks.extend(chunks)
# 对标题分割后的内容进行字符分割
final_chunks = self.text_splitter.split_documents(md_chunks)
# 普通文本直接分割
else:
final_chunks = self.text_splitter.split_documents(documents)
return final_chunks
# 核心自动分割方法(元数据优先)
@classmethod
def auto_split(
cls,
documents: List[Document],
chunk_size: int = 2000,
chunk_overlap: int = 200
) -> List[Document]:
"""
极简快捷方法:自动识别文档类型并分割(元数据优先)
仅需传入3个参数,无需初始化类实例
Args:
documents: 待分割的Document列表
chunk_size: 每个文本块的最大长度(默认2000字符)
chunk_overlap: 文本块之间的重叠长度(默认200字符)
Returns:
分割后的Document列表
"""
if not documents:
return []
# 初始化分割器实例(使用CJK默认优化)
splitter = cls(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
is_cjk_language=True
)
# 自动检测文档类型(元数据优先)
is_md = splitter._is_markdown(documents[0])
# 根据检测结果选择分割方式
return splitter.split(documents, is_markdown=is_md)

View File

@@ -14,7 +14,8 @@ def call_openai_style_model(base_url, api_key, model_name, prompt, **kwargs):
) )
return response.choices[0].message.content return response.choices[0].message.content
def _extract_json_substring(raw: str) -> str:
def extract_json_substring(raw: str) -> str:
"""从 LLM 的原始回答中提取最可能的 JSON 字符串片段。 """从 LLM 的原始回答中提取最可能的 JSON 字符串片段。
处理思路: 处理思路:
@@ -22,11 +23,21 @@ def _extract_json_substring(raw: str) -> str:
- 优先在文本中查找第一个 '{''[' 作为 JSON 起始; - 优先在文本中查找第一个 '{''[' 作为 JSON 起始;
- 再从后向前找最后一个 '}'']' 作为结束; - 再从后向前找最后一个 '}'']' 作为结束;
- 如果找不到合适的边界,就退回原始字符串。 - 如果找不到合适的边界,就退回原始字符串。
- 部分模型可能会在回复中加入 `<think>...</think>` 内部思考内容,应在解析前先去除。
该方法不会保证截取的一定是合法 JSON,但能显著提高 json.loads 的成功率。 该方法不会保证截取的一定是合法 JSON,但能显著提高 json.loads 的成功率。
""" """
if not raw: if not raw:
return raw return raw
# 先移除所有 <think>...</think> 段落(包括跨多行的情况)
try:
import re
raw = re.sub(r"<think>[\s\S]*?</think>", "", raw, flags=re.IGNORECASE)
except Exception:
# 正则异常时不影响后续逻辑,继续使用原始文本
pass
start = None start = None
end = None end = None

View File

@@ -2,24 +2,20 @@ USE datamate;
-- =============================== -- ===============================
-- t_data_synthesis_instances (数据合成任务表) -- t_data_synthesis_instances (数据合成任务表)
create table if not exists t_data_synthesis_instances create table if not exists t_data_synth_instances
( (
id VARCHAR(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci PRIMARY KEY COMMENT 'UUID', id VARCHAR(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci PRIMARY KEY COMMENT 'UUID',
name VARCHAR(255) NOT NULL COMMENT '任务名称', name VARCHAR(255) NOT NULL COMMENT '任务名称',
description TEXT COMMENT '任务描述', description TEXT COMMENT '任务描述',
status VARCHAR(20) COMMENT '任务状态', status VARCHAR(20) COMMENT '任务状态',
synthesis_type VARCHAR(20) NOT NULL COMMENT '合成类型', synth_type VARCHAR(20) NOT NULL COMMENT '合成类型',
model_id VARCHAR(255) NOT NULL COMMENT '模型ID',
progress INT DEFAULT 0 COMMENT '任务进度(百分比)', progress INT DEFAULT 0 COMMENT '任务进度(百分比)',
result_data_location VARCHAR(1000) COMMENT '结果数据存储位', synth_config JSON NOT NULL COMMENT '合成配',
text_split_config JSON NOT NULL COMMENT '文本切片配置',
synthesis_config JSON NOT NULL COMMENT '合成配置',
source_file_id JSON NOT NULL COMMENT '原始文件ID列表',
total_files INT DEFAULT 0 COMMENT '总文件数', total_files INT DEFAULT 0 COMMENT '总文件数',
processed_files INT DEFAULT 0 COMMENT '已处理文件数', processed_files INT DEFAULT 0 COMMENT '已处理文件数',
total_chunks INT DEFAULT 0 COMMENT '总文本块数', total_chunks INT DEFAULT 0 COMMENT '总文本块数',
processed_chunks INT DEFAULT 0 COMMENT '已处理文本块数', processed_chunks INT DEFAULT 0 COMMENT '已处理文本块数',
total_synthesis_data INT DEFAULT 0 COMMENT '总合成数据量', total_synth_data INT DEFAULT 0 COMMENT '总合成数据量',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间', updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
created_by VARCHAR(255) COMMENT '创建者', created_by VARCHAR(255) COMMENT '创建者',
@@ -34,7 +30,7 @@ create table if not exists t_data_synthesis_file_instances
synthesis_instance_id VARCHAR(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci COMMENT '数据合成任务ID', synthesis_instance_id VARCHAR(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci COMMENT '数据合成任务ID',
file_name VARCHAR(255) NOT NULL COMMENT '文件名', file_name VARCHAR(255) NOT NULL COMMENT '文件名',
source_file_id VARCHAR(255) NOT NULL COMMENT '原始文件ID', source_file_id VARCHAR(255) NOT NULL COMMENT '原始文件ID',
target_file_location VARCHAR(1000) NOT NULL COMMENT '目标文件存储位置', target_file_location VARCHAR(1000) NULL COMMENT '目标文件存储位置',
status VARCHAR(20) COMMENT '任务状态', status VARCHAR(20) COMMENT '任务状态',
total_chunks INT DEFAULT 0 COMMENT '总文本块数', total_chunks INT DEFAULT 0 COMMENT '总文本块数',
processed_chunks INT DEFAULT 0 COMMENT '已处理文本块数', processed_chunks INT DEFAULT 0 COMMENT '已处理文本块数',