feat:问题生成过程优化及COT数据生成优化 (#169)

* fix(chart): update Helm chart helpers and values for improved configuration

* feat(SynthesisTaskTab): enhance task table with tooltip support and improved column widths

* feat(CreateTask, SynthFileTask): improve task creation and detail view with enhanced payload handling and UI updates

* feat(SynthFileTask): enhance file display with progress tracking and delete action

* feat(SynthFileTask): enhance file display with progress tracking and delete action

* feat(SynthDataDetail): add delete action for chunks with confirmation prompt

* feat(SynthDataDetail): update edit and delete buttons to icon-only format

* feat(SynthDataDetail): add confirmation modals for chunk and synthesis data deletion

* feat(DocumentSplitter): add enhanced document splitting functionality with CJK support and metadata detection

* feat(DataSynthesis): refactor data synthesis models and update task handling logic

* feat(DataSynthesis): streamline synthesis task handling and enhance chunk processing logic

* feat(DataSynthesis): refactor data synthesis models and update task handling logic

* fix(generation_service): ensure processed chunks are incremented regardless of question generation success

* feat(CreateTask): enhance task creation with new synthesis templates and improved configuration options

* feat(CreateTask): enhance task creation with new synthesis templates and improved configuration options

* feat(CreateTask): enhance task creation with new synthesis templates and improved configuration options

* feat(CreateTask): enhance task creation with new synthesis templates and improved configuration options
This commit is contained in:
Dallas98
2025-12-18 16:51:18 +08:00
committed by GitHub
parent 761f7f6a51
commit e0e9b1d94d
14 changed files with 1362 additions and 571 deletions

View File

@@ -1,138 +1,477 @@
import asyncio
import json
import uuid
from pathlib import Path
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.language_models import BaseChatModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models.data_synthesis import (
DataSynthesisInstance,
DataSynthInstance,
DataSynthesisFileInstance,
DataSynthesisChunkInstance,
SynthesisData,
)
from app.db.models.dataset_management import DatasetFiles
from app.db.models.model_config import get_model_by_id
from app.db.session import logger
from app.module.shared.util.model_chat import _extract_json_substring
from app.module.system.service.common_service import get_chat_client, chat
from app.common.document_loaders import load_documents
from app.module.generation.schema.generation import Config, SyntheConfig
from app.module.generation.service.prompt import (
QUESTION_GENERATOR_PROMPT,
ANSWER_GENERATOR_PROMPT,
)
from app.module.shared.common.document_loaders import load_documents
from app.module.shared.common.text_split import DocumentSplitter
from app.module.shared.util.model_chat import extract_json_substring
from app.module.system.service.common_service import chat, get_model_by_id, get_chat_client
class GenerationService:
def __init__(self, db: AsyncSession):
self.db = db
# 全局并发信号量:保证任意时刻最多 10 次模型调用
self.question_semaphore = asyncio.Semaphore(10)
self.answer_semaphore = asyncio.Semaphore(100)
async def process_task(self, task_id: str):
"""处理数据合成任务入口:根据任务ID加载任务并逐个处理源文件。"""
synthesis_task: DataSynthesisInstance | None = await self.db.get(DataSynthesisInstance, task_id)
if not synthesis_task:
synth_task: DataSynthInstance | None = await self.db.get(DataSynthInstance, task_id)
if not synth_task:
logger.error(f"Synthesis task {task_id} not found, abort processing")
return
logger.info(f"Processing synthesis task {task_id}")
file_ids = synthesis_task.source_file_id or []
logger.info(f"Start processing synthe task {task_id}")
# 获取模型客户端
model_result = await get_model_by_id(self.db, str(synthesis_task.model_id))
if not model_result:
logger.error(
f"Model config not found for id={synthesis_task.model_id}, abort task {synthesis_task.id}"
)
# 从 synth_config 中读取 max_qa_pairs,全局控制 QA 总量上限;<=0 或异常则视为不限制
try:
cfg = Config(**(synth_task.synth_config or {}))
max_qa_pairs = cfg.max_qa_pairs if (cfg and cfg.max_qa_pairs and cfg.max_qa_pairs > 0) else None
except Exception:
max_qa_pairs = None
# 获取任务关联的文件原始ID列表
file_ids = await self._get_file_ids_for_task(task_id)
if not file_ids:
logger.warning(f"No files associated with task {task_id}, abort processing")
return
chat_client = get_chat_client(model_result)
# 控制并发度的信号量(限制全任务范围内最多 10 个并发调用)
semaphore = asyncio.Semaphore(10)
# 逐个文件处理
for file_id in file_ids:
try:
success = await self._process_single_file(
synthesis_task=synthesis_task,
file_id=file_id,
chat_client=chat_client,
semaphore=semaphore,
)
success = await self._process_single_file(synth_task, file_id, max_qa_pairs=max_qa_pairs)
except Exception as e:
logger.exception(f"Unexpected error when processing file {file_id} for task {task_id}: {e}")
# 确保对应文件任务状态标记为失败
await self._mark_file_failed(str(synthesis_task.id), file_id, str(e))
await self._mark_file_failed(str(synth_task.id), file_id, str(e))
success = False
if success:
# 每处理完一个文件,简单增加 processed_files 计数
synthesis_task.processed_files = (synthesis_task.processed_files or 0) + 1
synth_task.processed_files = (synth_task.processed_files or 0) + 1
await self.db.commit()
await self.db.refresh(synthesis_task)
await self.db.refresh(synth_task)
logger.info(f"Finished processing synthesis task {synthesis_task.id}")
logger.info(f"Finished processing synthesis task {synth_task.id}")
# ==================== 高层文件处理流程 ====================
async def _process_single_file(
self,
synthesis_task: DataSynthesisInstance,
synth_task: DataSynthInstance,
file_id: str,
chat_client,
semaphore: asyncio.Semaphore,
max_qa_pairs: int | None = None,
) -> bool:
"""处理单个源文件:解析路径、切片、保存分块并触发 LLM 调用。"""
"""按 chunk 批量流式处理单个源文件。
流程:
1. 切片并将所有 chunk 持久化到 DB 后释放内存;
2. 从 DB 按 chunk_index 升序批量读取 chunk;
3. 对批次中的每个 chunk:先生成指定数量的问题,再基于这些问题生成答案;
4. 每成功处理完一个 chunk(即该 chunk 至少生成一条 QA)就更新一次 processed_chunks;
5. 全部完成后将文件实例标记为 completed。
"""
# 解析文件路径与配置
file_path = await self._resolve_file_path(file_id)
if not file_path:
logger.warning(f"File path not found for file_id={file_id}, skip")
await self._mark_file_failed(str(synthesis_task.id), file_id, "file_path_not_found")
await self._mark_file_failed(str(synth_task.id), file_id, "file_path_not_found")
return False
logger.info(f"Processing file_id={file_id}, path={file_path}")
split_cfg = synthesis_task.text_split_config or {}
synthesis_cfg = synthesis_task.synthesis_config or {}
chunk_size = int(split_cfg.get("chunk_size", 800))
chunk_overlap = int(split_cfg.get("chunk_overlap", 50))
# 加载并切片
try:
chunks = self._load_and_split(file_path, chunk_size, chunk_overlap)
config = Config(**(synth_task.synth_config or {}))
except Exception as e:
logger.error(f"Failed to load/split file {file_path}: {e}")
await self._mark_file_failed(str(synthesis_task.id), file_id, f"load_split_error: {e}")
logger.error(f"Invalid synth_config for task={synth_task.id}: {e}")
await self._mark_file_failed(str(synth_task.id), file_id, "invalid_synth_config")
return False
# 1. 加载并切片(仅在此处占用内存)
chunks = self._load_and_split(
file_path,
config.text_split_config.chunk_size,
config.text_split_config.chunk_overlap,
)
if not chunks:
logger.warning(f"No chunks generated for file_id={file_id}")
await self._mark_file_failed(str(synthesis_task.id), file_id, "no_chunks_generated")
await self._mark_file_failed(str(synth_task.id), file_id, "no_chunks_generated")
return False
logger.info(f"File {file_id} split into {len(chunks)} chunks by LangChain")
# 保存文件任务记录 + 分块记录
# 2. 获取文件实例并持久化 chunk 记录
file_task = await self._get_or_create_file_instance(
synthesis_task_id=str(synthesis_task.id),
synthesis_task_id=str(synth_task.id),
source_file_id=file_id,
file_path=file_path,
)
await self._persist_chunks(synthesis_task, file_task, file_id, chunks)
if not file_task:
logger.error(
f"DataSynthesisFileInstance not found for task={synth_task.id}, file_id={file_id}"
)
await self._mark_file_failed(str(synth_task.id), file_id, "file_instance_not_found")
return False
# 针对每个切片并发调用大模型
await self._invoke_llm_for_chunks(
synthesis_task=synthesis_task,
file_id=file_id,
chunks=chunks,
synthesis_cfg=synthesis_cfg,
chat_client=chat_client,
semaphore=semaphore,
await self._persist_chunks(synth_task, file_task, file_id, chunks)
total_chunks = len(chunks)
# 释放内存中的切片
del chunks
# 3. 读取问答配置
question_cfg: SyntheConfig | None = config.question_synth_config
answer_cfg: SyntheConfig | None = config.answer_synth_config
if not question_cfg or not answer_cfg:
logger.error(
f"Question/Answer synth config missing for task={synth_task.id}, file={file_id}"
)
await self._mark_file_failed(str(synth_task.id), file_id, "qa_config_missing")
return False
logger.info(
f"Start QA generation for task={synth_task.id}, file={file_id}, total_chunks={total_chunks}"
)
# 如果执行到此处,说明该文件的切片与 LLM 调用流程均未抛出异常,标记为完成
# 为本文件构建模型 client
question_model = await get_model_by_id(self.db, question_cfg.model_id)
answer_model = await get_model_by_id(self.db, answer_cfg.model_id)
question_chat = get_chat_client(question_model)
answer_chat = get_chat_client(answer_model)
# 分批次从 DB 读取并处理 chunk
batch_size = 20
current_index = 1
while current_index <= total_chunks:
end_index = min(current_index + batch_size - 1, total_chunks)
chunk_batch = await self._load_chunk_batch(
file_task_id=file_task.id,
start_index=current_index,
end_index=end_index,
)
if not chunk_batch:
logger.warning(
f"Empty chunk batch loaded for file={file_id}, range=[{current_index}, {end_index}]"
)
current_index = end_index + 1
continue
# 对本批中的每个 chunk 并发处理(内部受 semaphore 限流)
async def process_one(chunk: DataSynthesisChunkInstance) -> bool:
return await self._process_single_chunk_qa(
file_task=file_task,
chunk=chunk,
question_cfg=question_cfg,
answer_cfg=answer_cfg,
question_chat=question_chat,
answer_chat=answer_chat,
synth_task_id=str(synth_task.id),
max_qa_pairs=max_qa_pairs,
)
tasks = [process_one(chunk) for chunk in chunk_batch]
await asyncio.gather(*tasks, return_exceptions=True)
current_index = end_index + 1
# 全部完成
file_task.status = "completed"
await self.db.commit()
await self.db.refresh(file_task)
return True
async def _process_single_chunk_qa(
self,
file_task: DataSynthesisFileInstance,
chunk: DataSynthesisChunkInstance,
question_cfg: SyntheConfig,
answer_cfg: SyntheConfig,
question_chat: BaseChatModel,
answer_chat: BaseChatModel,
synth_task_id: str,
max_qa_pairs: int | None = None,
) -> bool:
"""处理单个 chunk:生成问题列表,然后为每个问题生成答案并落库。
为了全局控制 QA 总量:在本方法开始处,根据 synth_task_id 查询当前已落盘的
SynthesisData 条数,如果 >= max_qa_pairs,则不再对当前 chunk 做任何 QA 生成,
并将当前文件任务标记为 completed,processed_chunks = total_chunks。
已经进入后续流程的任务(例如其它协程正在生成答案)允许自然执行完。
"""
# 如果没有全局上限配置,维持原有行为
if max_qa_pairs is not None and max_qa_pairs > 0:
from sqlalchemy import func
# 统计当前整个任务下已生成的 QA 总数
result = await self.db.execute(
select(func.count(SynthesisData.id)).where(
SynthesisData.synthesis_file_instance_id.in_(
select(DataSynthesisFileInstance.id).where(
DataSynthesisFileInstance.synthesis_instance_id == synth_task_id
)
)
)
)
current_qa_count = int(result.scalar() or 0)
if current_qa_count >= max_qa_pairs:
logger.info(
"max_qa_pairs reached: current=%s, max=%s, task_id=%s, file_task_id=%s, skip new QA generation for this chunk.",
current_qa_count,
max_qa_pairs,
synth_task_id,
file_task.id,
)
# 将文件任务标记为已完成,并认为所有 chunk 均已处理
file_task.status = "completed"
if file_task.total_chunks is not None:
file_task.processed_chunks = file_task.total_chunks
await self.db.commit()
await self.db.refresh(file_task)
return False
# ---- 下面保持原有逻辑不变 ----
chunk_index = chunk.chunk_index
chunk_text = chunk.chunk_content or ""
if not chunk_text.strip():
logger.warning(
f"Empty chunk text for file_task={file_task.id}, chunk_index={chunk_index}"
)
# 无论成功或失败,均视为该 chunk 已处理完成
try:
await self._increment_processed_chunks(file_task.id, 1)
except Exception as e:
logger.exception(
f"Failed to increment processed_chunks for file_task={file_task.id}, chunk_index={chunk_index}: {e}"
)
return False
success_any = False
# 1. 生成问题
try:
questions = await self._generate_questions_for_one_chunk(
chunk_text=chunk_text,
question_cfg=question_cfg,
question_chat=question_chat,
)
except Exception as e:
logger.error(
f"Generate questions failed for file_task={file_task.id}, chunk_index={chunk_index}: {e}"
)
questions = []
if not questions:
logger.info(
f"No questions generated for file_task={file_task.id}, chunk_index={chunk_index}"
)
else:
# 2. 针对每个问题生成答案并入库
qa_success = await self._generate_answers_for_one_chunk(
file_task=file_task,
chunk=chunk,
questions=questions,
answer_cfg=answer_cfg,
answer_chat=answer_chat,
)
success_any = bool(qa_success)
# 无论本 chunk 处理是否成功,都增加 processed_chunks 计数,避免任务长时间卡住
try:
await self._increment_processed_chunks(file_task.id, 1)
except Exception as e:
logger.exception(
f"Failed to increment processed_chunks for file_task={file_task.id}, chunk_index={chunk_index}: {e}"
)
return success_any
async def _generate_questions_for_one_chunk(
self,
chunk_text: str,
question_cfg: SyntheConfig,
question_chat: BaseChatModel,
) -> list[str]:
"""针对单个 chunk 文本,调用 question_chat 生成问题列表。"""
number = question_cfg.number or 5
number = number if number is not None else 5
number = max(int(len(chunk_text) / 1000 * number), 1)
template = getattr(question_cfg, "prompt_template", QUESTION_GENERATOR_PROMPT)
template = template if (template is not None and template.strip() != "") else QUESTION_GENERATOR_PROMPT
prompt = (
template
.replace("{text}", chunk_text)
.replace("{number}", str(number))
.replace("{textLength}", str(len(chunk_text)))
)
async with self.question_semaphore:
loop = asyncio.get_running_loop()
raw_answer = await loop.run_in_executor(
None,
chat,
question_chat,
prompt,
)
# 解析为问题列表
questions = self._parse_questions_from_answer(
raw_answer,
)
return questions
async def _generate_answers_for_one_chunk(
self,
file_task: DataSynthesisFileInstance,
chunk: DataSynthesisChunkInstance,
questions: list[str],
answer_cfg: SyntheConfig,
answer_chat: BaseChatModel,
) -> bool:
"""为一个 chunk 的所有问题生成答案并写入 SynthesisData。
返回:是否至少成功写入一条 QA。
"""
if not questions:
return False
chunk_text = chunk.chunk_content or ""
template = getattr(answer_cfg, "prompt_template", ANSWER_GENERATOR_PROMPT)
template = template if (template is not None and template.strip() != "") else ANSWER_GENERATOR_PROMPT
extra_vars = getattr(answer_cfg, "extra_prompt_vars", {}) or {}
success_flags: list[bool] = []
async def process_single_question(question: str):
prompt = template.replace("{text}", chunk_text).replace("{question}", question)
for k, v in extra_vars.items():
prompt.replace(f"{{{{{k}}}}}", str(v))
else:
prompt_local = prompt
async with self.answer_semaphore:
loop = asyncio.get_running_loop()
answer = await loop.run_in_executor(
None,
chat,
answer_chat,
prompt_local,
)
# 默认结构:与 ANSWER_GENERATOR_PROMPT 一致,并补充 instruction 字段
base_obj: dict[str, object] = {
"input": chunk_text,
"output": answer,
}
# 如果模型已经按照 ANSWER_GENERATOR_PROMPT 返回了 JSON,则尝试解析并在其上增加 instruction
parsed_obj: dict[str, object] | None = None
if isinstance(answer, str):
cleaned = extract_json_substring(answer)
try:
parsed = json.loads(cleaned)
if isinstance(parsed, dict):
parsed_obj = parsed
except Exception:
parsed_obj = None
if parsed_obj is not None:
parsed_obj["instruction"] = question
data_obj = parsed_obj
else:
base_obj["instruction"] = question
data_obj = base_obj
record = SynthesisData(
id=str(uuid.uuid4()),
data=data_obj,
synthesis_file_instance_id=file_task.id,
chunk_instance_id=chunk.id,
)
self.db.add(record)
success_flags.append(True)
tasks = [process_single_question(q) for q in questions]
await asyncio.gather(*tasks, return_exceptions=True)
if success_flags:
await self.db.commit()
return True
return False
@staticmethod
def _parse_questions_from_answer(
raw_answer: str,
) -> list[str]:
"""从大模型返回中解析问题数组。"""
if not raw_answer:
return []
cleaned = extract_json_substring(raw_answer)
try:
data = json.loads(cleaned)
except Exception as e:
logger.error(
f"Failed to parse question list JSON for task: {e}. "
)
return []
if isinstance(data, list):
return [str(q) for q in data if isinstance(q, str) and q.strip()]
# 容错:如果是单个字符串
if isinstance(data, str) and data.strip():
return [data.strip()]
return []
# ==================== 原有辅助方法(文件路径/切片/持久化等) ====================
async def _resolve_file_path(self, file_id: str) -> str | None:
"""根据文件ID查询 t_dm_dataset_files 并返回 file_path(仅 ACTIVE 文件)。"""
result = await self.db.execute(
select(DatasetFiles).where(DatasetFiles.id == file_id)
)
file_obj = result.scalar_one_or_none()
if not file_obj:
return None
return file_obj.file_path
@staticmethod
def _load_and_split(file_path: str, chunk_size: int, chunk_overlap: int):
"""使用 LangChain 加载文本并进行切片,直接返回 Document 列表。
Args:
file_path: 待切片的文件路径
chunk_size: 切片大小
chunk_overlap: 切片重叠大小
"""
try:
docs = load_documents(file_path)
split_docs = DocumentSplitter.auto_split(docs, chunk_size, chunk_overlap)
return split_docs
except Exception as e:
logger.error(f"Error loading or splitting file {file_path}: {e}")
raise
async def _persist_chunks(
self,
synthesis_task: DataSynthesisInstance,
synthesis_task: DataSynthInstance,
file_task: DataSynthesisFileInstance,
file_id: str,
chunks,
@@ -164,201 +503,10 @@ class GenerationService:
await self.db.commit()
await self.db.refresh(file_task)
async def _invoke_llm_for_chunks(
self,
synthesis_task: DataSynthesisInstance,
file_id: str,
chunks,
synthesis_cfg: dict,
chat_client,
semaphore: asyncio.Semaphore,
) -> None:
"""针对每个分片并发调用大模型生成数据。"""
# 需要将 answer 和对应 chunk 建立关系,因此这里保留 chunk_index
tasks = [
self._call_llm(doc, file_id, idx, synthesis_task, synthesis_cfg, chat_client, semaphore)
for idx, doc in enumerate(chunks, start=1)
]
await asyncio.gather(*tasks, return_exceptions=True)
async def _call_llm(
self,
doc,
file_id: str,
idx: int,
synthesis_task,
synthesis_cfg: dict,
chat_client,
semaphore: asyncio.Semaphore,
):
"""单次大模型调用逻辑,带并发控制。
说明:
- 使用信号量限制全局并发量(当前为 10)。
- 使用线程池执行同步的 chat 调用,避免阻塞事件循环。
- 在拿到 LLM 返回后,解析为 JSON 并批量写入 SynthesisData,
同时更新文件级 processed_chunks / 进度等信息。
"""
async with semaphore:
prompt = self._build_qa_prompt(doc.page_content, synthesis_cfg)
try:
loop = asyncio.get_running_loop()
answer = await loop.run_in_executor(None, chat, chat_client, prompt)
logger.debug(
f"Generated QA for task={synthesis_task.id}, file={file_id}, chunk={idx}"
)
await self._handle_llm_answer(
synthesis_task_id=str(synthesis_task.id),
file_id=file_id,
chunk_index=idx,
raw_answer=answer,
)
return answer
except Exception as e:
logger.error(
f"LLM generation failed for task={synthesis_task.id}, file={file_id}, chunk={idx}: {e}"
)
return None
async def _resolve_file_path(self, file_id: str) -> str | None:
"""根据文件ID查询 t_dm_dataset_files 并返回 file_path(仅 ACTIVE 文件)。"""
result = await self.db.execute(
select(DatasetFiles).where(DatasetFiles.id == file_id)
)
file_obj = result.scalar_one_or_none()
if not file_obj:
return None
return file_obj.file_path
def _load_and_split(self, file_path: str, chunk_size: int, chunk_overlap: int):
"""使用 LangChain 加载文本并进行切片,直接返回 Document 列表。
当前实现:
- 使用 TextLoader 加载纯文本/Markdown/JSON 等文本文件
- 使用 RecursiveCharacterTextSplitter 做基于字符的递归切片
保留每个 Document 的 metadata,方便后续追加例如文件ID、chunk序号等信息。
"""
docs = load_documents(file_path)
splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
# 尝试按这些分隔符优先切分,再退化到字符级
separators=["\n\n", "\n", "", "", "", "!", "?", "\n", "\t", " "]
)
split_docs = splitter.split_documents(docs)
return split_docs
@staticmethod
def _build_qa_prompt(chunk: str, synthesis_cfg: dict) -> str:
"""构造 QA 数据合成的提示词。
要求:
- synthesis_cfg["prompt_template"] 是一个字符串,其中包含 {document} 占位符;
- 将当前切片内容替换到 {document}
如果未提供或模板非法,则使用内置默认模板。
"""
template = None
if isinstance(synthesis_cfg, dict):
template = synthesis_cfg.get("prompt_template")
synthesis_count = synthesis_cfg["synthesis_count"] if ("synthesis_count" in synthesis_cfg and synthesis_cfg["synthesis_count"]) else 5
try:
prompt = template.format(document=chunk, synthesis_count=synthesis_count)
except Exception:
# 防御性处理:如果 format 出现异常,则退回到简单拼接
prompt = f"{template}\n\n文档内容:{chunk}\n\n请根据文档内容生成 {synthesis_count} 条符合要求的问答数据。"
return prompt
async def _handle_llm_answer(
self,
synthesis_task_id: str,
file_id: str,
chunk_index: int,
raw_answer: str,
) -> None:
"""解析 LLM 返回内容为 JSON,批量保存到 SynthesisData,并更新文件任务进度。
约定:
- LLM 返回的 raw_answer 是 JSON 字符串,可以是:
1)单个对象:{"question": ..., "answer": ...}
2)对象数组:[{}, {}, ...]
- 我们将其规范化为列表,每个元素作为一条 SynthesisData.data 写入。
- 根据 synthesis_task_id + file_id + chunk_index 找到对应的 DataSynthesisChunkInstance,
以便设置 chunk_instance_id 和 synthesis_file_instance_id。
- 每处理完一个 chunk,递增对应 DataSynthesisFileInstance.processed_chunks,并按比例更新进度。
"""
if not raw_answer:
return
# 1. 预处理原始回答:尝试从中截取出最可能的 JSON 片段
cleaned = _extract_json_substring(raw_answer)
# 2. 解析 JSON,统一成列表结构
try:
parsed = json.loads(cleaned)
except Exception as e:
logger.error(
f"Failed to parse LLM answer as JSON for task={synthesis_task_id}, file={file_id}, chunk={chunk_index}: {e}. Raw answer: {raw_answer!r}"
)
return
if isinstance(parsed, dict):
items = [parsed]
elif isinstance(parsed, list):
items = [p for p in parsed if isinstance(p, dict)]
else:
logger.error(f"Unexpected JSON structure from LLM answer for task={synthesis_task_id}, file={file_id}, chunk={chunk_index}: {type(parsed)}")
return
if not items:
return
# 3. 找到对应的 chunk 记录(一个 chunk_index 对应一条记录)
chunk_result = await self.db.execute(
select(DataSynthesisChunkInstance, DataSynthesisFileInstance)
.join(
DataSynthesisFileInstance,
DataSynthesisFileInstance.id == DataSynthesisChunkInstance.synthesis_file_instance_id,
)
.where(
DataSynthesisFileInstance.synthesis_instance_id == synthesis_task_id,
DataSynthesisFileInstance.source_file_id == file_id,
DataSynthesisChunkInstance.chunk_index == chunk_index,
)
)
row = chunk_result.first()
if not row:
logger.error(
f"Chunk record not found for task={synthesis_task_id}, file={file_id}, chunk_index={chunk_index}, skip saving SynthesisData."
)
return
chunk_instance, file_instance = row
# 4. 批量写入 SynthesisData
for data_obj in items:
record = SynthesisData(
id=str(uuid.uuid4()),
data=data_obj,
synthesis_file_instance_id=file_instance.id,
chunk_instance_id=chunk_instance.id,
)
self.db.add(record)
# 5. 更新文件级 processed_chunks / 进度
file_instance.processed_chunks = (file_instance.processed_chunks or 0) + 1
await self.db.commit()
await self.db.refresh(file_instance)
async def _get_or_create_file_instance(
self,
synthesis_task_id: str,
source_file_id: str,
file_path: str,
) -> DataSynthesisFileInstance:
"""根据任务ID和原始文件ID,查找或创建对应的 DataSynthesisFileInstance 记录。
@@ -374,32 +522,9 @@ class GenerationService:
)
)
file_task = result.scalar_one_or_none()
if file_task is not None:
return file_task
# 查询任务以获取 result_data_location
task = await self.db.get(DataSynthesisInstance, synthesis_task_id)
target_location = task.result_data_location if task else ""
# 创建新的文件任务记录,初始状态为 processing
file_task = DataSynthesisFileInstance(
id=str(uuid.uuid4()),
synthesis_instance_id=synthesis_task_id,
file_name=Path(file_path).name,
source_file_id=source_file_id,
target_file_location=target_location or "",
status="processing",
total_chunks=0,
processed_chunks=0,
created_by="system",
updated_by="system",
)
self.db.add(file_task)
await self.db.commit()
await self.db.refresh(file_task)
return file_task
async def _mark_file_failed(self, synthesis_task_id: str, file_id: str, reason: str | None = None) -> None:
async def _mark_file_failed(self, synth_task_id: str, file_id: str, reason: str | None = None) -> None:
"""将指定任务下的单个文件任务标记为失败状态,兜底错误处理。
- 如果找到对应的 DataSynthesisFileInstance,则更新其 status="failed"
@@ -409,14 +534,14 @@ class GenerationService:
try:
result = await self.db.execute(
select(DataSynthesisFileInstance).where(
DataSynthesisFileInstance.synthesis_instance_id == synthesis_task_id,
DataSynthesisFileInstance.synthesis_instance_id == synth_task_id,
DataSynthesisFileInstance.source_file_id == file_id,
)
)
file_task = result.scalar_one_or_none()
if not file_task:
logger.warning(
f"Failed to mark file as failed: no DataSynthesisFileInstance found for task={synthesis_task_id}, file_id={file_id}, reason={reason}"
f"Failed to mark file as failed: no DataSynthesisFileInstance found for task={synth_task_id}, file_id={file_id}, reason={reason}"
)
return
@@ -424,10 +549,72 @@ class GenerationService:
await self.db.commit()
await self.db.refresh(file_task)
logger.info(
f"Marked file task as failed for task={synthesis_task_id}, file_id={file_id}, reason={reason}"
f"Marked file task as failed for task={synth_task_id}, file_id={file_id}, reason={reason}"
)
except Exception as e:
# 兜底日志,避免异常向外传播影响其它文件处理
logger.exception(
f"Unexpected error when marking file failed for task={synthesis_task_id}, file_id={file_id}, original_reason={reason}, error={e}"
f"Unexpected error when marking file failed for task={synth_task_id}, file_id={file_id}, original_reason={reason}, error={e}"
)
async def _get_file_ids_for_task(self, synth_task_id: str):
"""根据任务ID查询关联的文件原始ID列表"""
result = await self.db.execute(
select(DataSynthesisFileInstance.source_file_id)
.where(DataSynthesisFileInstance.synthesis_instance_id == synth_task_id)
)
file_ids = result.scalars().all()
return file_ids
# ========== 新增:chunk 计数与批量加载、processed_chunks 安全更新辅助方法 ==========
async def _count_chunks_for_file(self, synth_file_instance_id: str) -> int:
"""统计指定任务与文件下的 chunk 总数。"""
from sqlalchemy import func
result = await self.db.execute(
select(func.count(DataSynthesisChunkInstance.id)).where(
DataSynthesisChunkInstance.synthesis_file_instance_id == synth_file_instance_id
)
)
return int(result.scalar() or 0)
async def _load_chunk_batch(
self,
file_task_id: str,
start_index: int,
end_index: int,
) -> list[DataSynthesisChunkInstance]:
"""按索引范围加载指定文件任务下的一批 chunk 记录(含边界)。"""
result = await self.db.execute(
select(DataSynthesisChunkInstance)
.where(
DataSynthesisChunkInstance.synthesis_file_instance_id == file_task_id,
DataSynthesisChunkInstance.chunk_index >= start_index,
DataSynthesisChunkInstance.chunk_index <= end_index,
)
.order_by(DataSynthesisChunkInstance.chunk_index.asc())
)
return list(result.scalars().all())
async def _increment_processed_chunks(self, file_task_id: str, delta: int) -> None:
result = await self.db.execute(
select(DataSynthesisFileInstance).where(
DataSynthesisFileInstance.id == file_task_id,
)
)
file_task = result.scalar_one_or_none()
if not file_task:
logger.error(f"Failed to increment processed_chunks: file_task {file_task_id} not found")
return
# 原始自增
new_value = (file_task.processed_chunks or 0) + int(delta)
# 如果存在 total_chunks,上限为 total_chunks,避免超过
total = file_task.total_chunks
if isinstance(total, int) and total >= 0:
new_value = min(new_value, total)
file_task.processed_chunks = new_value
await self.db.commit()
await self.db.refresh(file_task)