import asyncio import json import uuid from pathlib import Path from langchain_text_splitters import RecursiveCharacterTextSplitter from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.db.models.data_synthesis import ( DataSynthesisInstance, DataSynthesisFileInstance, DataSynthesisChunkInstance, SynthesisData, ) from app.db.models.dataset_management import DatasetFiles from app.db.models.model_config import get_model_by_id from app.db.session import logger from app.module.shared.util.model_chat import _extract_json_substring from app.module.system.service.common_service import get_chat_client, chat from app.common.document_loaders import load_documents class GenerationService: def __init__(self, db: AsyncSession): self.db = db async def process_task(self, task_id: str): """处理数据合成任务入口:根据任务ID加载任务并逐个处理源文件。""" synthesis_task: DataSynthesisInstance | None = await self.db.get(DataSynthesisInstance, task_id) if not synthesis_task: logger.error(f"Synthesis task {task_id} not found, abort processing") return logger.info(f"Processing synthesis task {task_id}") file_ids = synthesis_task.source_file_id or [] # 获取模型客户端 model_result = await get_model_by_id(self.db, str(synthesis_task.model_id)) if not model_result: logger.error( f"Model config not found for id={synthesis_task.model_id}, abort task {synthesis_task.id}" ) return chat_client = get_chat_client(model_result) # 控制并发度的信号量(限制全任务范围内最多 10 个并发调用) semaphore = asyncio.Semaphore(10) # 逐个文件处理 for file_id in file_ids: try: success = await self._process_single_file( synthesis_task=synthesis_task, file_id=file_id, chat_client=chat_client, semaphore=semaphore, ) except Exception as e: logger.exception(f"Unexpected error when processing file {file_id} for task {task_id}: {e}") # 确保对应文件任务状态标记为失败 await self._mark_file_failed(str(synthesis_task.id), file_id, str(e)) success = False if success: # 每处理完一个文件,简单增加 processed_files 计数 synthesis_task.processed_files = (synthesis_task.processed_files or 0) + 1 await self.db.commit() await self.db.refresh(synthesis_task) logger.info(f"Finished processing synthesis task {synthesis_task.id}") async def _process_single_file( self, synthesis_task: DataSynthesisInstance, file_id: str, chat_client, semaphore: asyncio.Semaphore, ) -> bool: """处理单个源文件:解析路径、切片、保存分块并触发 LLM 调用。""" file_path = await self._resolve_file_path(file_id) if not file_path: logger.warning(f"File path not found for file_id={file_id}, skip") await self._mark_file_failed(str(synthesis_task.id), file_id, "file_path_not_found") return False logger.info(f"Processing file_id={file_id}, path={file_path}") split_cfg = synthesis_task.text_split_config or {} synthesis_cfg = synthesis_task.synthesis_config or {} chunk_size = int(split_cfg.get("chunk_size", 800)) chunk_overlap = int(split_cfg.get("chunk_overlap", 50)) # 加载并切片 try: chunks = self._load_and_split(file_path, chunk_size, chunk_overlap) except Exception as e: logger.error(f"Failed to load/split file {file_path}: {e}") await self._mark_file_failed(str(synthesis_task.id), file_id, f"load_split_error: {e}") return False if not chunks: logger.warning(f"No chunks generated for file_id={file_id}") await self._mark_file_failed(str(synthesis_task.id), file_id, "no_chunks_generated") return False logger.info(f"File {file_id} split into {len(chunks)} chunks by LangChain") # 保存文件任务记录 + 分块记录 file_task = await self._get_or_create_file_instance( synthesis_task_id=str(synthesis_task.id), source_file_id=file_id, file_path=file_path, ) await self._persist_chunks(synthesis_task, file_task, file_id, chunks) # 针对每个切片并发调用大模型 await self._invoke_llm_for_chunks( synthesis_task=synthesis_task, file_id=file_id, chunks=chunks, synthesis_cfg=synthesis_cfg, chat_client=chat_client, semaphore=semaphore, ) # 如果执行到此处,说明该文件的切片与 LLM 调用流程均未抛出异常,标记为完成 file_task.status = "completed" await self.db.commit() await self.db.refresh(file_task) return True async def _persist_chunks( self, synthesis_task: DataSynthesisInstance, file_task: DataSynthesisFileInstance, file_id: str, chunks, ) -> None: """将切片结果保存到 t_data_synthesis_chunk_instances,并更新文件级分块计数。""" for idx, doc in enumerate(chunks, start=1): # 先复制原始 Document.metadata,再在其上追加任务相关字段,避免覆盖原有元数据 base_metadata = dict(getattr(doc, "metadata", {}) or {}) base_metadata.update( { "task_id": str(synthesis_task.id), "file_id": file_id } ) chunk_record = DataSynthesisChunkInstance( id=str(uuid.uuid4()), synthesis_file_instance_id=file_task.id, chunk_index=idx, chunk_content=doc.page_content, chunk_metadata=base_metadata, ) self.db.add(chunk_record) # 更新文件任务的分块数量 file_task.total_chunks = len(chunks) file_task.status = "processing" await self.db.commit() await self.db.refresh(file_task) async def _invoke_llm_for_chunks( self, synthesis_task: DataSynthesisInstance, file_id: str, chunks, synthesis_cfg: dict, chat_client, semaphore: asyncio.Semaphore, ) -> None: """针对每个分片并发调用大模型生成数据。""" # 需要将 answer 和对应 chunk 建立关系,因此这里保留 chunk_index tasks = [ self._call_llm(doc, file_id, idx, synthesis_task, synthesis_cfg, chat_client, semaphore) for idx, doc in enumerate(chunks, start=1) ] await asyncio.gather(*tasks, return_exceptions=True) async def _call_llm( self, doc, file_id: str, idx: int, synthesis_task, synthesis_cfg: dict, chat_client, semaphore: asyncio.Semaphore, ): """单次大模型调用逻辑,带并发控制。 说明: - 使用信号量限制全局并发量(当前为 10)。 - 使用线程池执行同步的 chat 调用,避免阻塞事件循环。 - 在拿到 LLM 返回后,解析为 JSON 并批量写入 SynthesisData, 同时更新文件级 processed_chunks / 进度等信息。 """ async with semaphore: prompt = self._build_qa_prompt(doc.page_content, synthesis_cfg) try: loop = asyncio.get_running_loop() answer = await loop.run_in_executor(None, chat, chat_client, prompt) logger.debug( f"Generated QA for task={synthesis_task.id}, file={file_id}, chunk={idx}" ) await self._handle_llm_answer( synthesis_task_id=str(synthesis_task.id), file_id=file_id, chunk_index=idx, raw_answer=answer, ) return answer except Exception as e: logger.error( f"LLM generation failed for task={synthesis_task.id}, file={file_id}, chunk={idx}: {e}" ) return None async def _resolve_file_path(self, file_id: str) -> str | None: """根据文件ID查询 t_dm_dataset_files 并返回 file_path(仅 ACTIVE 文件)。""" result = await self.db.execute( select(DatasetFiles).where(DatasetFiles.id == file_id) ) file_obj = result.scalar_one_or_none() if not file_obj: return None return file_obj.file_path def _load_and_split(self, file_path: str, chunk_size: int, chunk_overlap: int): """使用 LangChain 加载文本并进行切片,直接返回 Document 列表。 当前实现: - 使用 TextLoader 加载纯文本/Markdown/JSON 等文本文件 - 使用 RecursiveCharacterTextSplitter 做基于字符的递归切片 保留每个 Document 的 metadata,方便后续追加例如文件ID、chunk序号等信息。 """ docs = load_documents(file_path) splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap, # 尝试按这些分隔符优先切分,再退化到字符级 separators=["\n\n", "\n", "。", "!", "?", "!", "?", "。\n", "\t", " "] ) split_docs = splitter.split_documents(docs) return split_docs @staticmethod def _build_qa_prompt(chunk: str, synthesis_cfg: dict) -> str: """构造 QA 数据合成的提示词。 要求: - synthesis_cfg["prompt_template"] 是一个字符串,其中包含 {document} 占位符; - 将当前切片内容替换到 {document}。 如果未提供或模板非法,则使用内置默认模板。 """ template = None if isinstance(synthesis_cfg, dict): template = synthesis_cfg.get("prompt_template") synthesis_count = synthesis_cfg["synthesis_count"] if ("synthesis_count" in synthesis_cfg and synthesis_cfg["synthesis_count"]) else 5 try: prompt = template.format(document=chunk, synthesis_count=synthesis_count) except Exception: # 防御性处理:如果 format 出现异常,则退回到简单拼接 prompt = f"{template}\n\n文档内容:{chunk}\n\n请根据文档内容生成 {synthesis_count} 条符合要求的问答数据。" return prompt async def _handle_llm_answer( self, synthesis_task_id: str, file_id: str, chunk_index: int, raw_answer: str, ) -> None: """解析 LLM 返回内容为 JSON,批量保存到 SynthesisData,并更新文件任务进度。 约定: - LLM 返回的 raw_answer 是 JSON 字符串,可以是: 1)单个对象:{"question": ..., "answer": ...} 2)对象数组:[{}, {}, ...] - 我们将其规范化为列表,每个元素作为一条 SynthesisData.data 写入。 - 根据 synthesis_task_id + file_id + chunk_index 找到对应的 DataSynthesisChunkInstance, 以便设置 chunk_instance_id 和 synthesis_file_instance_id。 - 每处理完一个 chunk,递增对应 DataSynthesisFileInstance.processed_chunks,并按比例更新进度。 """ if not raw_answer: return # 1. 预处理原始回答:尝试从中截取出最可能的 JSON 片段 cleaned = _extract_json_substring(raw_answer) # 2. 解析 JSON,统一成列表结构 try: parsed = json.loads(cleaned) except Exception as e: logger.error( f"Failed to parse LLM answer as JSON for task={synthesis_task_id}, file={file_id}, chunk={chunk_index}: {e}. Raw answer: {raw_answer!r}" ) return if isinstance(parsed, dict): items = [parsed] elif isinstance(parsed, list): items = [p for p in parsed if isinstance(p, dict)] else: logger.error(f"Unexpected JSON structure from LLM answer for task={synthesis_task_id}, file={file_id}, chunk={chunk_index}: {type(parsed)}") return if not items: return # 3. 找到对应的 chunk 记录(一个 chunk_index 对应一条记录) chunk_result = await self.db.execute( select(DataSynthesisChunkInstance, DataSynthesisFileInstance) .join( DataSynthesisFileInstance, DataSynthesisFileInstance.id == DataSynthesisChunkInstance.synthesis_file_instance_id, ) .where( DataSynthesisFileInstance.synthesis_instance_id == synthesis_task_id, DataSynthesisFileInstance.source_file_id == file_id, DataSynthesisChunkInstance.chunk_index == chunk_index, ) ) row = chunk_result.first() if not row: logger.error( f"Chunk record not found for task={synthesis_task_id}, file={file_id}, chunk_index={chunk_index}, skip saving SynthesisData." ) return chunk_instance, file_instance = row # 4. 批量写入 SynthesisData for data_obj in items: record = SynthesisData( id=str(uuid.uuid4()), data=data_obj, synthesis_file_instance_id=file_instance.id, chunk_instance_id=chunk_instance.id, ) self.db.add(record) # 5. 更新文件级 processed_chunks / 进度 file_instance.processed_chunks = (file_instance.processed_chunks or 0) + 1 await self.db.commit() await self.db.refresh(file_instance) async def _get_or_create_file_instance( self, synthesis_task_id: str, source_file_id: str, file_path: str, ) -> DataSynthesisFileInstance: """根据任务ID和原始文件ID,查找或创建对应的 DataSynthesisFileInstance 记录。 - 如果已存在(同一任务 + 同一 source_file_id),直接返回; - 如果不存在,则创建一条新的文件任务记录,file_name 来自文件路径, target_file_location 先复用任务的 result_data_location。 """ # 尝试查询已有文件任务记录 result = await self.db.execute( select(DataSynthesisFileInstance).where( DataSynthesisFileInstance.synthesis_instance_id == synthesis_task_id, DataSynthesisFileInstance.source_file_id == source_file_id, ) ) file_task = result.scalar_one_or_none() if file_task is not None: return file_task # 查询任务以获取 result_data_location task = await self.db.get(DataSynthesisInstance, synthesis_task_id) target_location = task.result_data_location if task else "" # 创建新的文件任务记录,初始状态为 processing file_task = DataSynthesisFileInstance( id=str(uuid.uuid4()), synthesis_instance_id=synthesis_task_id, file_name=Path(file_path).name, source_file_id=source_file_id, target_file_location=target_location or "", status="processing", total_chunks=0, processed_chunks=0, created_by="system", updated_by="system", ) self.db.add(file_task) await self.db.commit() await self.db.refresh(file_task) return file_task async def _mark_file_failed(self, synthesis_task_id: str, file_id: str, reason: str | None = None) -> None: """将指定任务下的单个文件任务标记为失败状态,兜底错误处理。 - 如果找到对应的 DataSynthesisFileInstance,则更新其 status="failed"。 - 如果未找到,则静默返回,仅记录日志。 - reason 参数仅用于日志记录,方便排查。 """ try: result = await self.db.execute( select(DataSynthesisFileInstance).where( DataSynthesisFileInstance.synthesis_instance_id == synthesis_task_id, DataSynthesisFileInstance.source_file_id == file_id, ) ) file_task = result.scalar_one_or_none() if not file_task: logger.warning( f"Failed to mark file as failed: no DataSynthesisFileInstance found for task={synthesis_task_id}, file_id={file_id}, reason={reason}" ) return file_task.status = "failed" await self.db.commit() await self.db.refresh(file_task) logger.info( f"Marked file task as failed for task={synthesis_task_id}, file_id={file_id}, reason={reason}" ) except Exception as e: # 兜底日志,避免异常向外传播影响其它文件处理 logger.exception( f"Unexpected error when marking file failed for task={synthesis_task_id}, file_id={file_id}, original_reason={reason}, error={e}" )