You've already forked DataMate
* feature: add cot data evaluation function * fix: added verification to evaluation results * fix: fix the prompt for evaluating * fix: 修复当评估结果为空导致读取失败的问题
507 lines
20 KiB
Python
507 lines
20 KiB
Python
import asyncio
|
|
import uuid
|
|
import json
|
|
from pathlib import Path
|
|
|
|
from langchain_community.document_loaders import (
|
|
TextLoader,
|
|
CSVLoader,
|
|
JSONLoader,
|
|
UnstructuredMarkdownLoader,
|
|
UnstructuredHTMLLoader,
|
|
UnstructuredFileLoader,
|
|
PyPDFLoader,
|
|
UnstructuredWordDocumentLoader,
|
|
UnstructuredPowerPointLoader,
|
|
UnstructuredExcelLoader,
|
|
)
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.db.models.data_synthesis import (
|
|
DataSynthesisInstance,
|
|
DataSynthesisFileInstance,
|
|
DataSynthesisChunkInstance,
|
|
SynthesisData,
|
|
)
|
|
from app.db.models.dataset_management import DatasetFiles
|
|
from app.db.models.model_config import get_model_by_id
|
|
from app.db.session import logger
|
|
from app.module.shared.util.model_chat import _extract_json_substring
|
|
from app.module.system.service.common_service import get_chat_client, chat
|
|
|
|
|
|
class GenerationService:
|
|
def __init__(self, db: AsyncSession):
|
|
self.db = db
|
|
|
|
async def process_task(self, task_id: str):
|
|
"""处理数据合成任务入口:根据任务ID加载任务并逐个处理源文件。"""
|
|
synthesis_task: DataSynthesisInstance | None = await self.db.get(DataSynthesisInstance, task_id)
|
|
if not synthesis_task:
|
|
logger.error(f"Synthesis task {task_id} not found, abort processing")
|
|
return
|
|
|
|
logger.info(f"Processing synthesis task {task_id}")
|
|
file_ids = synthesis_task.source_file_id or []
|
|
|
|
# 获取模型客户端
|
|
model_result = await get_model_by_id(self.db, str(synthesis_task.model_id))
|
|
if not model_result:
|
|
logger.error(
|
|
f"Model config not found for id={synthesis_task.model_id}, abort task {synthesis_task.id}"
|
|
)
|
|
return
|
|
chat_client = get_chat_client(model_result)
|
|
|
|
# 控制并发度的信号量(限制全任务范围内最多 10 个并发调用)
|
|
semaphore = asyncio.Semaphore(10)
|
|
|
|
# 逐个文件处理
|
|
for file_id in file_ids:
|
|
try:
|
|
success = await self._process_single_file(
|
|
synthesis_task=synthesis_task,
|
|
file_id=file_id,
|
|
chat_client=chat_client,
|
|
semaphore=semaphore,
|
|
)
|
|
except Exception as e:
|
|
logger.exception(f"Unexpected error when processing file {file_id} for task {task_id}: {e}")
|
|
# 确保对应文件任务状态标记为失败
|
|
await self._mark_file_failed(str(synthesis_task.id), file_id, str(e))
|
|
success = False
|
|
|
|
if success:
|
|
# 每处理完一个文件,简单增加 processed_files 计数
|
|
synthesis_task.processed_files = (synthesis_task.processed_files or 0) + 1
|
|
await self.db.commit()
|
|
await self.db.refresh(synthesis_task)
|
|
|
|
logger.info(f"Finished processing synthesis task {synthesis_task.id}")
|
|
|
|
async def _process_single_file(
|
|
self,
|
|
synthesis_task: DataSynthesisInstance,
|
|
file_id: str,
|
|
chat_client,
|
|
semaphore: asyncio.Semaphore,
|
|
) -> bool:
|
|
"""处理单个源文件:解析路径、切片、保存分块并触发 LLM 调用。"""
|
|
file_path = await self._resolve_file_path(file_id)
|
|
if not file_path:
|
|
logger.warning(f"File path not found for file_id={file_id}, skip")
|
|
await self._mark_file_failed(str(synthesis_task.id), file_id, "file_path_not_found")
|
|
return False
|
|
|
|
logger.info(f"Processing file_id={file_id}, path={file_path}")
|
|
|
|
split_cfg = synthesis_task.text_split_config or {}
|
|
synthesis_cfg = synthesis_task.synthesis_config or {}
|
|
chunk_size = int(split_cfg.get("chunk_size", 800))
|
|
chunk_overlap = int(split_cfg.get("chunk_overlap", 50))
|
|
# 加载并切片
|
|
try:
|
|
chunks = self._load_and_split(file_path, chunk_size, chunk_overlap)
|
|
except Exception as e:
|
|
logger.error(f"Failed to load/split file {file_path}: {e}")
|
|
await self._mark_file_failed(str(synthesis_task.id), file_id, f"load_split_error: {e}")
|
|
return False
|
|
|
|
if not chunks:
|
|
logger.warning(f"No chunks generated for file_id={file_id}")
|
|
await self._mark_file_failed(str(synthesis_task.id), file_id, "no_chunks_generated")
|
|
return False
|
|
|
|
logger.info(f"File {file_id} split into {len(chunks)} chunks by LangChain")
|
|
|
|
# 保存文件任务记录 + 分块记录
|
|
file_task = await self._get_or_create_file_instance(
|
|
synthesis_task_id=str(synthesis_task.id),
|
|
source_file_id=file_id,
|
|
file_path=file_path,
|
|
)
|
|
await self._persist_chunks(synthesis_task, file_task, file_id, chunks)
|
|
|
|
# 针对每个切片并发调用大模型
|
|
await self._invoke_llm_for_chunks(
|
|
synthesis_task=synthesis_task,
|
|
file_id=file_id,
|
|
chunks=chunks,
|
|
synthesis_cfg=synthesis_cfg,
|
|
chat_client=chat_client,
|
|
semaphore=semaphore,
|
|
)
|
|
|
|
# 如果执行到此处,说明该文件的切片与 LLM 调用流程均未抛出异常,标记为完成
|
|
file_task.status = "completed"
|
|
await self.db.commit()
|
|
await self.db.refresh(file_task)
|
|
|
|
return True
|
|
|
|
async def _persist_chunks(
|
|
self,
|
|
synthesis_task: DataSynthesisInstance,
|
|
file_task: DataSynthesisFileInstance,
|
|
file_id: str,
|
|
chunks,
|
|
) -> None:
|
|
"""将切片结果保存到 t_data_synthesis_chunk_instances,并更新文件级分块计数。"""
|
|
for idx, doc in enumerate(chunks, start=1):
|
|
# 先复制原始 Document.metadata,再在其上追加任务相关字段,避免覆盖原有元数据
|
|
base_metadata = dict(getattr(doc, "metadata", {}) or {})
|
|
base_metadata.update(
|
|
{
|
|
"task_id": str(synthesis_task.id),
|
|
"file_id": file_id
|
|
}
|
|
)
|
|
|
|
chunk_record = DataSynthesisChunkInstance(
|
|
id=str(uuid.uuid4()),
|
|
synthesis_file_instance_id=file_task.id,
|
|
chunk_index=idx,
|
|
chunk_content=doc.page_content,
|
|
chunk_metadata=base_metadata,
|
|
)
|
|
self.db.add(chunk_record)
|
|
|
|
# 更新文件任务的分块数量
|
|
file_task.total_chunks = len(chunks)
|
|
file_task.status = "processing"
|
|
|
|
await self.db.commit()
|
|
await self.db.refresh(file_task)
|
|
|
|
async def _invoke_llm_for_chunks(
|
|
self,
|
|
synthesis_task: DataSynthesisInstance,
|
|
file_id: str,
|
|
chunks,
|
|
synthesis_cfg: dict,
|
|
chat_client,
|
|
semaphore: asyncio.Semaphore,
|
|
) -> None:
|
|
"""针对每个分片并发调用大模型生成数据。"""
|
|
# 需要将 answer 和对应 chunk 建立关系,因此这里保留 chunk_index
|
|
tasks = [
|
|
self._call_llm(doc, file_id, idx, synthesis_task, synthesis_cfg, chat_client, semaphore)
|
|
for idx, doc in enumerate(chunks, start=1)
|
|
]
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
async def _call_llm(
|
|
self,
|
|
doc,
|
|
file_id: str,
|
|
idx: int,
|
|
synthesis_task,
|
|
synthesis_cfg: dict,
|
|
chat_client,
|
|
semaphore: asyncio.Semaphore,
|
|
):
|
|
"""单次大模型调用逻辑,带并发控制。
|
|
|
|
说明:
|
|
- 使用信号量限制全局并发量(当前为 10)。
|
|
- 使用线程池执行同步的 chat 调用,避免阻塞事件循环。
|
|
- 在拿到 LLM 返回后,解析为 JSON 并批量写入 SynthesisData,
|
|
同时更新文件级 processed_chunks / 进度等信息。
|
|
"""
|
|
async with semaphore:
|
|
prompt = self._build_qa_prompt(doc.page_content, synthesis_cfg)
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
answer = await loop.run_in_executor(None, chat, chat_client, prompt)
|
|
logger.debug(
|
|
f"Generated QA for task={synthesis_task.id}, file={file_id}, chunk={idx}"
|
|
)
|
|
await self._handle_llm_answer(
|
|
synthesis_task_id=str(synthesis_task.id),
|
|
file_id=file_id,
|
|
chunk_index=idx,
|
|
raw_answer=answer,
|
|
)
|
|
return answer
|
|
except Exception as e:
|
|
logger.error(
|
|
f"LLM generation failed for task={synthesis_task.id}, file={file_id}, chunk={idx}: {e}"
|
|
)
|
|
return None
|
|
|
|
async def _resolve_file_path(self, file_id: str) -> str | None:
|
|
"""根据文件ID查询 t_dm_dataset_files 并返回 file_path(仅 ACTIVE 文件)。"""
|
|
result = await self.db.execute(
|
|
select(DatasetFiles).where(DatasetFiles.id == file_id)
|
|
)
|
|
file_obj = result.scalar_one_or_none()
|
|
if not file_obj:
|
|
return None
|
|
return file_obj.file_path
|
|
|
|
def _load_and_split(self, file_path: str, chunk_size: int, chunk_overlap: int):
|
|
"""使用 LangChain 加载文本并进行切片,直接返回 Document 列表。
|
|
|
|
当前实现:
|
|
- 使用 TextLoader 加载纯文本/Markdown/JSON 等文本文件
|
|
- 使用 RecursiveCharacterTextSplitter 做基于字符的递归切片
|
|
|
|
保留每个 Document 的 metadata,方便后续追加例如文件ID、chunk序号等信息。
|
|
"""
|
|
loader = self._build_loader(file_path)
|
|
docs = loader.load()
|
|
|
|
splitter = RecursiveCharacterTextSplitter(
|
|
chunk_size=chunk_size,
|
|
chunk_overlap=chunk_overlap,
|
|
# 尝试按这些分隔符优先切分,再退化到字符级
|
|
separators=["\n\n", "\n", "。", "!", "?", "!", "?", "。\n", "\t", " "]
|
|
)
|
|
split_docs = splitter.split_documents(docs)
|
|
return split_docs
|
|
|
|
@staticmethod
|
|
def _build_loader(file_path: str):
|
|
"""根据文件扩展名选择合适的 LangChain 文本加载器,尽量覆盖常见泛文本格式。
|
|
|
|
优先按格式选择专门的 Loader,找不到匹配时退回到 TextLoader。
|
|
"""
|
|
path = Path(file_path)
|
|
suffix = path.suffix.lower()
|
|
path_str = str(path)
|
|
|
|
# 1. 纯文本类
|
|
if suffix in {".txt", "", ".log"}: # "" 兼容无扩展名
|
|
return TextLoader(path_str, encoding="utf-8")
|
|
|
|
# 2. Markdown
|
|
if suffix in {".md", ".markdown"}:
|
|
# UnstructuredMarkdownLoader 会保留更多结构信息
|
|
return UnstructuredMarkdownLoader(path_str)
|
|
|
|
# 3. HTML / HTM
|
|
if suffix in {".html", ".htm"}:
|
|
return UnstructuredHTMLLoader(path_str)
|
|
|
|
# 4. JSON
|
|
if suffix == ".json":
|
|
# 使用 JSONLoader 将 JSON 中的内容展开成文档
|
|
# 这里使用默认 jq_schema,后续需要更精细地提取可以在此调整
|
|
return JSONLoader(file_path=path_str, jq_schema=".")
|
|
|
|
# 5. CSV / TSV
|
|
if suffix in {".csv", ".tsv"}:
|
|
# CSVLoader 默认将每一行作为一条 Document
|
|
return CSVLoader(file_path=path_str)
|
|
|
|
# 6. YAML
|
|
if suffix in {".yaml", ".yml"}:
|
|
# 暂时按纯文本加载
|
|
return TextLoader(path_str, encoding="utf-8")
|
|
|
|
# 7. PDF
|
|
if suffix == ".pdf":
|
|
return PyPDFLoader(path_str)
|
|
|
|
# 8. Word 文档
|
|
if suffix in {".docx", ".doc"}:
|
|
# UnstructuredWordDocumentLoader 支持 .docx/.doc 文本抽取
|
|
return UnstructuredWordDocumentLoader(path_str)
|
|
|
|
# 9. PowerPoint
|
|
if suffix in {".ppt", ".pptx"}:
|
|
return UnstructuredPowerPointLoader(path_str)
|
|
|
|
# 10. Excel
|
|
if suffix in {".xls", ".xlsx"}:
|
|
return UnstructuredExcelLoader(path_str)
|
|
|
|
# 11. 兜底:使用 UnstructuredFileLoader 或 TextLoader 作为纯文本
|
|
try:
|
|
return UnstructuredFileLoader(path_str)
|
|
except Exception:
|
|
return TextLoader(path_str, encoding="utf-8")
|
|
|
|
@staticmethod
|
|
def _build_qa_prompt(chunk: str, synthesis_cfg: dict) -> str:
|
|
"""构造 QA 数据合成的提示词。
|
|
|
|
要求:
|
|
- synthesis_cfg["prompt_template"] 是一个字符串,其中包含 {document} 占位符;
|
|
- 将当前切片内容替换到 {document}。
|
|
如果未提供或模板非法,则使用内置默认模板。
|
|
"""
|
|
template = None
|
|
if isinstance(synthesis_cfg, dict):
|
|
template = synthesis_cfg.get("prompt_template")
|
|
synthesis_count = synthesis_cfg["synthesis_count"] if ("synthesis_count" in synthesis_cfg and synthesis_cfg["synthesis_count"]) else 5
|
|
try:
|
|
prompt = template.format(document=chunk, synthesis_count=synthesis_count)
|
|
except Exception:
|
|
# 防御性处理:如果 format 出现异常,则退回到简单拼接
|
|
prompt = f"{template}\n\n文档内容:{chunk}\n\n请根据文档内容生成 {synthesis_count} 条符合要求的问答数据。"
|
|
return prompt
|
|
|
|
async def _handle_llm_answer(
|
|
self,
|
|
synthesis_task_id: str,
|
|
file_id: str,
|
|
chunk_index: int,
|
|
raw_answer: str,
|
|
) -> None:
|
|
"""解析 LLM 返回内容为 JSON,批量保存到 SynthesisData,并更新文件任务进度。
|
|
|
|
约定:
|
|
- LLM 返回的 raw_answer 是 JSON 字符串,可以是:
|
|
1)单个对象:{"question": ..., "answer": ...}
|
|
2)对象数组:[{}, {}, ...]
|
|
- 我们将其规范化为列表,每个元素作为一条 SynthesisData.data 写入。
|
|
- 根据 synthesis_task_id + file_id + chunk_index 找到对应的 DataSynthesisChunkInstance,
|
|
以便设置 chunk_instance_id 和 synthesis_file_instance_id。
|
|
- 每处理完一个 chunk,递增对应 DataSynthesisFileInstance.processed_chunks,并按比例更新进度。
|
|
"""
|
|
if not raw_answer:
|
|
return
|
|
|
|
# 1. 预处理原始回答:尝试从中截取出最可能的 JSON 片段
|
|
cleaned = _extract_json_substring(raw_answer)
|
|
|
|
# 2. 解析 JSON,统一成列表结构
|
|
try:
|
|
parsed = json.loads(cleaned)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Failed to parse LLM answer as JSON for task={synthesis_task_id}, file={file_id}, chunk={chunk_index}: {e}. Raw answer: {raw_answer!r}"
|
|
)
|
|
return
|
|
|
|
if isinstance(parsed, dict):
|
|
items = [parsed]
|
|
elif isinstance(parsed, list):
|
|
items = [p for p in parsed if isinstance(p, dict)]
|
|
else:
|
|
logger.error(f"Unexpected JSON structure from LLM answer for task={synthesis_task_id}, file={file_id}, chunk={chunk_index}: {type(parsed)}")
|
|
return
|
|
|
|
if not items:
|
|
return
|
|
|
|
# 3. 找到对应的 chunk 记录(一个 chunk_index 对应一条记录)
|
|
chunk_result = await self.db.execute(
|
|
select(DataSynthesisChunkInstance, DataSynthesisFileInstance)
|
|
.join(
|
|
DataSynthesisFileInstance,
|
|
DataSynthesisFileInstance.id == DataSynthesisChunkInstance.synthesis_file_instance_id,
|
|
)
|
|
.where(
|
|
DataSynthesisFileInstance.synthesis_instance_id == synthesis_task_id,
|
|
DataSynthesisFileInstance.source_file_id == file_id,
|
|
DataSynthesisChunkInstance.chunk_index == chunk_index,
|
|
)
|
|
)
|
|
row = chunk_result.first()
|
|
if not row:
|
|
logger.error(
|
|
f"Chunk record not found for task={synthesis_task_id}, file={file_id}, chunk_index={chunk_index}, skip saving SynthesisData."
|
|
)
|
|
return
|
|
|
|
chunk_instance, file_instance = row
|
|
|
|
# 4. 批量写入 SynthesisData
|
|
for data_obj in items:
|
|
record = SynthesisData(
|
|
id=str(uuid.uuid4()),
|
|
data=data_obj,
|
|
synthesis_file_instance_id=file_instance.id,
|
|
chunk_instance_id=chunk_instance.id,
|
|
)
|
|
self.db.add(record)
|
|
|
|
# 5. 更新文件级 processed_chunks / 进度
|
|
file_instance.processed_chunks = (file_instance.processed_chunks or 0) + 1
|
|
|
|
|
|
await self.db.commit()
|
|
await self.db.refresh(file_instance)
|
|
|
|
async def _get_or_create_file_instance(
|
|
self,
|
|
synthesis_task_id: str,
|
|
source_file_id: str,
|
|
file_path: str,
|
|
) -> DataSynthesisFileInstance:
|
|
"""根据任务ID和原始文件ID,查找或创建对应的 DataSynthesisFileInstance 记录。
|
|
|
|
- 如果已存在(同一任务 + 同一 source_file_id),直接返回;
|
|
- 如果不存在,则创建一条新的文件任务记录,file_name 来自文件路径,
|
|
target_file_location 先复用任务的 result_data_location。
|
|
"""
|
|
# 尝试查询已有文件任务记录
|
|
result = await self.db.execute(
|
|
select(DataSynthesisFileInstance).where(
|
|
DataSynthesisFileInstance.synthesis_instance_id == synthesis_task_id,
|
|
DataSynthesisFileInstance.source_file_id == source_file_id,
|
|
)
|
|
)
|
|
file_task = result.scalar_one_or_none()
|
|
if file_task is not None:
|
|
return file_task
|
|
|
|
# 查询任务以获取 result_data_location
|
|
task = await self.db.get(DataSynthesisInstance, synthesis_task_id)
|
|
target_location = task.result_data_location if task else ""
|
|
|
|
# 创建新的文件任务记录,初始状态为 processing
|
|
file_task = DataSynthesisFileInstance(
|
|
id=str(uuid.uuid4()),
|
|
synthesis_instance_id=synthesis_task_id,
|
|
file_name=Path(file_path).name,
|
|
source_file_id=source_file_id,
|
|
target_file_location=target_location or "",
|
|
status="processing",
|
|
total_chunks=0,
|
|
processed_chunks=0,
|
|
created_by="system",
|
|
updated_by="system",
|
|
)
|
|
self.db.add(file_task)
|
|
await self.db.commit()
|
|
await self.db.refresh(file_task)
|
|
return file_task
|
|
|
|
async def _mark_file_failed(self, synthesis_task_id: str, file_id: str, reason: str | None = None) -> None:
|
|
"""将指定任务下的单个文件任务标记为失败状态,兜底错误处理。
|
|
|
|
- 如果找到对应的 DataSynthesisFileInstance,则更新其 status="failed"。
|
|
- 如果未找到,则静默返回,仅记录日志。
|
|
- reason 参数仅用于日志记录,方便排查。
|
|
"""
|
|
try:
|
|
result = await self.db.execute(
|
|
select(DataSynthesisFileInstance).where(
|
|
DataSynthesisFileInstance.synthesis_instance_id == synthesis_task_id,
|
|
DataSynthesisFileInstance.source_file_id == file_id,
|
|
)
|
|
)
|
|
file_task = result.scalar_one_or_none()
|
|
if not file_task:
|
|
logger.warning(
|
|
f"Failed to mark file as failed: no DataSynthesisFileInstance found for task={synthesis_task_id}, file_id={file_id}, reason={reason}"
|
|
)
|
|
return
|
|
|
|
file_task.status = "failed"
|
|
await self.db.commit()
|
|
await self.db.refresh(file_task)
|
|
logger.info(
|
|
f"Marked file task as failed for task={synthesis_task_id}, file_id={file_id}, reason={reason}"
|
|
)
|
|
except Exception as e:
|
|
# 兜底日志,避免异常向外传播影响其它文件处理
|
|
logger.exception(
|
|
f"Unexpected error when marking file failed for task={synthesis_task_id}, file_id={file_id}, original_reason={reason}, error={e}"
|
|
)
|