feat:问题生成过程优化及COT数据生成优化 (#169)

* fix(chart): update Helm chart helpers and values for improved configuration

* feat(SynthesisTaskTab): enhance task table with tooltip support and improved column widths

* feat(CreateTask, SynthFileTask): improve task creation and detail view with enhanced payload handling and UI updates

* feat(SynthFileTask): enhance file display with progress tracking and delete action

* feat(SynthFileTask): enhance file display with progress tracking and delete action

* feat(SynthDataDetail): add delete action for chunks with confirmation prompt

* feat(SynthDataDetail): update edit and delete buttons to icon-only format

* feat(SynthDataDetail): add confirmation modals for chunk and synthesis data deletion

* feat(DocumentSplitter): add enhanced document splitting functionality with CJK support and metadata detection

* feat(DataSynthesis): refactor data synthesis models and update task handling logic

* feat(DataSynthesis): streamline synthesis task handling and enhance chunk processing logic

* feat(DataSynthesis): refactor data synthesis models and update task handling logic

* fix(generation_service): ensure processed chunks are incremented regardless of question generation success

* feat(CreateTask): enhance task creation with new synthesis templates and improved configuration options

* feat(CreateTask): enhance task creation with new synthesis templates and improved configuration options

* feat(CreateTask): enhance task creation with new synthesis templates and improved configuration options

* feat(CreateTask): enhance task creation with new synthesis templates and improved configuration options
This commit is contained in:
Dallas98
2025-12-18 16:51:18 +08:00
committed by GitHub
parent 761f7f6a51
commit e0e9b1d94d
14 changed files with 1362 additions and 571 deletions

View File

@@ -1,66 +1,65 @@
import uuid
from xml.etree.ElementTree import tostring
from sqlalchemy import Column, String, Text, Integer, JSON, TIMESTAMP, ForeignKey, func
from sqlalchemy.orm import relationship
from sqlalchemy import Column, String, Text, Integer, JSON, TIMESTAMP, func
from app.db.session import Base
from app.module.generation.schema.generation import CreateSynthesisTaskRequest
async def save_synthesis_task(db_session, synthesis_task: CreateSynthesisTaskRequest):
"""保存数据合成任务。"""
# 转换为模型实例
"""保存数据合成任务。
注意:当前 MySQL 表 `t_data_synth_instances` 结构中只包含 synth_type / synth_config 等字段,
没有 model_id、text_split_config、source_file_id、result_data_location 等列,因此这里只保存
与表结构一致的字段,其他信息由上层逻辑或其它表负责管理。
"""
gid = str(uuid.uuid4())
synthesis_task_instance = DataSynthesisInstance(
# 兼容旧请求结构:从请求对象中提取必要字段,
# - 合成类型:synthesis_type -> synth_type
# - 合成配置:text_split_config + synthesis_config 合并后写入 synth_config
synth_task_instance = DataSynthInstance(
id=gid,
name=synthesis_task.name,
description=synthesis_task.description,
status="pending",
model_id=synthesis_task.model_id,
synthesis_type=synthesis_task.synthesis_type.value,
synth_type=synthesis_task.synthesis_type.value,
progress=0,
result_data_location=f"/dataset/synthesis_results/{gid}/",
text_split_config=synthesis_task.text_split_config.model_dump(),
synthesis_config=synthesis_task.synthesis_config.model_dump(),
source_file_id=synthesis_task.source_file_id,
total_files=len(synthesis_task.source_file_id),
synth_config=synthesis_task.synth_config.model_dump(),
total_files=len(synthesis_task.source_file_id or []),
processed_files=0,
total_chunks=0,
processed_chunks=0,
total_synthesis_data=0,
total_synth_data=0,
created_at=func.now(),
updated_at=func.now(),
created_by="system",
updated_by="system"
updated_by="system",
)
db_session.add(synthesis_task_instance)
db_session.add(synth_task_instance)
await db_session.commit()
await db_session.refresh(synthesis_task_instance)
return synthesis_task_instance
await db_session.refresh(synth_task_instance)
return synth_task_instance
class DataSynthesisInstance(Base):
"""数据合成任务表,对应表 t_data_synthesis_instances
class DataSynthInstance(Base):
"""数据合成任务表,对应表 t_data_synth_instances
create table if not exists t_data_synthesis_instances
create table if not exists t_data_synth_instances
(
id VARCHAR(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci PRIMARY KEY COMMENT 'UUID',
name VARCHAR(255) NOT NULL COMMENT '任务名称',
description TEXT COMMENT '任务描述',
status VARCHAR(20) COMMENT '任务状态',
synthesis_type VARCHAR(20) NOT NULL COMMENT '合成类型',
model_id VARCHAR(255) NOT NULL COMMENT '模型ID',
synth_type VARCHAR(20) NOT NULL COMMENT '合成类型',
progress INT DEFAULT 0 COMMENT '任务进度(百分比)',
result_data_location VARCHAR(1000) COMMENT '结果数据存储位',
text_split_config JSON NOT NULL COMMENT '文本切片配置',
synthesis_config JSON NOT NULL COMMENT '合成配置',
source_file_id JSON NOT NULL COMMENT '原始文件ID列表',
synth_config JSON NOT NULL COMMENT '合成配',
total_files INT DEFAULT 0 COMMENT '总文件数',
processed_files INT DEFAULT 0 COMMENT '已处理文件数',
total_chunks INT DEFAULT 0 COMMENT '总文本块数',
processed_chunks INT DEFAULT 0 COMMENT '已处理文本块数',
total_synthesis_data INT DEFAULT 0 COMMENT '总合成数据量',
total_synth_data INT DEFAULT 0 COMMENT '总合成数据量',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
created_by VARCHAR(255) COMMENT '创建者',
@@ -68,27 +67,29 @@ class DataSynthesisInstance(Base):
) COMMENT='数据合成任务表(UUID 主键)';
"""
__tablename__ = "t_data_synthesis_instances"
__tablename__ = "t_data_synth_instances"
id = Column(String(36), primary_key=True, index=True, comment="UUID")
name = Column(String(255), nullable=False, comment="任务名称")
description = Column(Text, nullable=True, comment="任务描述")
status = Column(String(20), nullable=True, comment="任务状态")
synthesis_type = Column(String(20), nullable=False, comment="合成类型")
model_id = Column(String(255), nullable=False, comment="模型ID")
# 与数据库字段保持一致:synth_type / synth_config
synth_type = Column(String(20), nullable=False, comment="合成类型")
progress = Column(Integer, nullable=False, default=0, comment="任务进度(百分比)")
result_data_location = Column(String(1000), nullable=True, comment="结果数据存储位")
text_split_config = Column(JSON, nullable=False, comment="文本切片配置")
synthesis_config = Column(JSON, nullable=False, comment="合成配置")
source_file_id = Column(JSON, nullable=False, comment="原始文件ID列表")
synth_config = Column(JSON, nullable=False, comment="合成配")
total_files = Column(Integer, nullable=False, default=0, comment="总文件数")
processed_files = Column(Integer, nullable=False, default=0, comment="已处理文件数")
total_chunks = Column(Integer, nullable=False, default=0, comment="总文本块数")
processed_chunks = Column(Integer, nullable=False, default=0, comment="已处理文本块数")
total_synthesis_data = Column(Integer, nullable=False, default=0, comment="总合成数据量")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), nullable=True, comment="创建时间")
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), nullable=True, comment="更新时间")
total_synth_data = Column(Integer, nullable=False, default=0, comment="总合成数据量")
created_at = Column(TIMESTAMP, nullable=False, default=func.now(), comment="创建时间")
updated_at = Column(
TIMESTAMP,
nullable=False,
default=func.now(),
onupdate=func.now(),
comment="更新时间",
)
created_by = Column(String(255), nullable=True, comment="创建者")
updated_by = Column(String(255), nullable=True, comment="更新者")
@@ -123,7 +124,7 @@ class DataSynthesisFileInstance(Base):
)
file_name = Column(String(255), nullable=False, comment="文件名")
source_file_id = Column(String(255), nullable=False, comment="原始文件ID")
target_file_location = Column(String(1000), nullable=False, comment="目标文件存储位置")
target_file_location = Column(String(1000), nullable=True, comment="目标文件存储位置")
status = Column(String(20), nullable=True, comment="任务状态")
total_chunks = Column(Integer, nullable=False, default=0, comment="总文本块数")
processed_chunks = Column(Integer, nullable=False, default=0, comment="已处理文本块数")

View File

@@ -13,7 +13,7 @@ from app.db.models.data_synthesis import DataSynthesisFileInstance, SynthesisDat
from app.db.session import AsyncSessionLocal
from app.module.evaluation.schema.evaluation import SourceType
from app.module.shared.schema import TaskStatus
from app.module.shared.util.model_chat import call_openai_style_model, _extract_json_substring
from app.module.shared.util.model_chat import call_openai_style_model, extract_json_substring
from app.module.evaluation.schema.prompt import get_prompt
from app.module.shared.util.structured_file import StructuredFileHandlerFactory
from app.module.system.service.common_service import get_model_by_id
@@ -36,8 +36,8 @@ class EvaluationExecutor:
.replace("{question}", eval_content.get("instruction")))
.replace("{answer}", eval_content.get("output")))
if self.task.task_type == "COT":
prompt_text = ((prompt_text.replace("{question}", eval_content.get("question"))
.replace("{conclusion}", eval_content.get("conclusion")))
prompt_text = ((prompt_text.replace("{question}", eval_content.get("instruction"))
.replace("{conclusion}", eval_content.get("output")))
.replace("{chain_of_thought}", eval_content.get("chain_of_thought")))
return prompt_text
@@ -73,7 +73,7 @@ class EvaluationExecutor:
call_openai_style_model, model_config.base_url, model_config.api_key, model_config.model_name,
prompt_text,
)
resp_text = _extract_json_substring(resp_text)
resp_text = extract_json_substring(resp_text)
try:
json.loads(resp_text)
except Exception as e:

View File

@@ -1,4 +1,5 @@
import uuid
from typing import cast
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
from sqlalchemy import select, func, delete
@@ -7,13 +8,12 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.models.data_synthesis import (
save_synthesis_task,
DataSynthesisInstance,
DataSynthInstance,
DataSynthesisFileInstance,
DataSynthesisChunkInstance,
SynthesisData,
)
from app.db.models.dataset_management import DatasetFiles
from app.db.models.model_config import get_model_by_id
from app.db.session import get_db
from app.module.generation.schema.generation import (
CreateSynthesisTaskRequest,
@@ -28,9 +28,9 @@ from app.module.generation.schema.generation import (
SynthesisDataUpdateRequest,
BatchDeleteSynthesisDataRequest,
)
from app.module.generation.service.export_service import SynthesisDatasetExporter, SynthesisExportError
from app.module.generation.service.generation_service import GenerationService
from app.module.generation.service.prompt import get_prompt
from app.module.generation.service.export_service import SynthesisDatasetExporter, SynthesisExportError
from app.module.shared.schema import StandardResponse
router = APIRouter(
@@ -47,10 +47,6 @@ async def create_synthesis_task(
db: AsyncSession = Depends(get_db),
):
"""创建数据合成任务"""
result = await get_model_by_id(db, request.model_id)
if not result:
raise HTTPException(status_code=404, detail="Model not found")
# 先根据 source_file_id 在 DatasetFiles 中查出已有文件信息
file_ids = request.source_file_id or []
dataset_files = []
@@ -65,32 +61,48 @@ async def create_synthesis_task(
synthesis_task = await save_synthesis_task(db, request)
# 将已有的 DatasetFiles 记录保存到 t_data_synthesis_file_instances
synth_files = []
for f in dataset_files:
file_instance = DataSynthesisFileInstance(
id=str(uuid.uuid4()), # 使用新的 UUID 作为文件任务记录的主键,避免与 DatasetFiles 主键冲突
synthesis_instance_id=synthesis_task.id,
file_name=f.file_name,
source_file_id=str(f.id),
target_file_location=synthesis_task.result_data_location or "",
status="pending",
total_chunks=0,
processed_chunks=0,
created_by="system",
updated_by="system",
)
db.add(file_instance)
synth_files.append(file_instance)
if dataset_files:
db.add_all(synth_files)
await db.commit()
generation_service = GenerationService(db)
# 异步处理任务:只传任务 ID,后台任务中使用新的 DB 会话重新加载任务对象
background_tasks.add_task(generation_service.process_task, synthesis_task.id)
# 将 ORM 对象包装成 DataSynthesisTaskItem,兼容新字段从 synth_config 还原
task_item = DataSynthesisTaskItem(
id=synthesis_task.id,
name=synthesis_task.name,
description=synthesis_task.description,
status=synthesis_task.status,
synthesis_type=synthesis_task.synth_type,
total_files=synthesis_task.total_files,
created_at=synthesis_task.created_at,
updated_at=synthesis_task.updated_at,
created_by=synthesis_task.created_by,
updated_by=synthesis_task.updated_by,
)
return StandardResponse(
code=200,
message="success",
data=synthesis_task,
data=task_item,
)
@@ -100,14 +112,26 @@ async def get_synthesis_task(
db: AsyncSession = Depends(get_db)
):
"""获取数据合成任务详情"""
result = await db.get(DataSynthesisInstance, task_id)
if not result:
synthesis_task = await db.get(DataSynthInstance, task_id)
if not synthesis_task:
raise HTTPException(status_code=404, detail="Synthesis task not found")
task_item = DataSynthesisTaskItem(
id=synthesis_task.id,
name=synthesis_task.name,
description=synthesis_task.description,
status=synthesis_task.status,
synthesis_type=synthesis_task.synth_type,
total_files=synthesis_task.total_files,
created_at=synthesis_task.created_at,
updated_at=synthesis_task.updated_at,
created_by=synthesis_task.created_by,
updated_by=synthesis_task.updated_by,
)
return StandardResponse(
code=200,
message="success",
data=result,
data=task_item,
)
@@ -121,16 +145,16 @@ async def list_synthesis_tasks(
db: AsyncSession = Depends(get_db)
):
"""分页列出所有数据合成任务,默认按创建时间倒序"""
query = select(DataSynthesisInstance)
query = select(DataSynthInstance)
if synthesis_type:
query = query.filter(DataSynthesisInstance.synthesis_type == synthesis_type)
query = query.filter(DataSynthInstance.synth_type == synthesis_type)
if status:
query = query.filter(DataSynthesisInstance.status == status)
query = query.filter(DataSynthInstance.status == status)
if name:
query = query.filter(DataSynthesisInstance.name.like(f"%{name}%"))
query = query.filter(DataSynthInstance.name.like(f"%{name}%"))
# 默认按创建时间倒序排列
query = query.order_by(DataSynthesisInstance.created_at.desc())
query = query.order_by(DataSynthInstance.created_at.desc())
count_q = select(func.count()).select_from(query.subquery())
total = (await db.execute(count_q)).scalar_one()
@@ -143,31 +167,39 @@ async def list_synthesis_tasks(
result = await db.execute(query.offset((page - 1) * page_size).limit(page_size))
rows = result.scalars().all()
task_items = [
DataSynthesisTaskItem(
id=row.id,
name=row.name,
description=row.description,
status=row.status,
synthesis_type=row.synthesis_type,
model_id=row.model_id,
progress=row.progress,
result_data_location=row.result_data_location,
text_split_config=row.text_split_config,
synthesis_config=row.synthesis_config,
source_file_id=row.source_file_id,
total_files=row.total_files,
processed_files=row.processed_files,
total_chunks=row.total_chunks,
processed_chunks=row.processed_chunks,
total_synthesis_data=row.total_synthesis_data,
created_at=row.created_at,
updated_at=row.updated_at,
created_by=row.created_by,
updated_by=row.updated_by,
task_items: list[DataSynthesisTaskItem] = []
for row in rows:
synth_cfg = getattr(row, "synth_config", {}) or {}
text_split_cfg = synth_cfg.get("text_split_config") or {}
synthesis_cfg = synth_cfg.get("synthesis_config") or {}
source_file_ids = synth_cfg.get("source_file_id") or []
model_id = synth_cfg.get("model_id")
result_location = synth_cfg.get("result_data_location")
task_items.append(
DataSynthesisTaskItem(
id=str(row.id),
name=str(row.name),
description=cast(str | None, row.description),
status=cast(str | None, row.status),
synthesis_type=str(row.synth_type),
model_id=model_id or "",
progress=int(cast(int, row.progress)),
result_data_location=result_location,
text_split_config=text_split_cfg,
synthesis_config=synthesis_cfg,
source_file_id=list(source_file_ids),
total_files=int(cast(int, row.total_files)),
processed_files=int(cast(int, row.processed_files)),
total_chunks=int(cast(int, row.total_chunks)),
processed_chunks=int(cast(int, row.processed_chunks)),
total_synthesis_data=int(cast(int, row.total_synth_data)),
created_at=row.created_at,
updated_at=row.updated_at,
created_by=row.created_by,
updated_by=row.updated_by,
)
)
for row in rows
]
paged = PagedDataSynthesisTaskResponse(
content=task_items,
@@ -190,7 +222,7 @@ async def delete_synthesis_task(
db: AsyncSession = Depends(get_db)
):
"""删除数据合成任务"""
task = await db.get(DataSynthesisInstance, task_id)
task = await db.get(DataSynthInstance, task_id)
if not task:
raise HTTPException(status_code=404, detail="Synthesis task not found")
@@ -241,7 +273,7 @@ async def delete_synthesis_file_task(
):
"""删除数据合成任务中的文件任务,同时刷新任务表中的文件/切片数量"""
# 先获取任务和文件任务记录
task = await db.get(DataSynthesisInstance, task_id)
task = await db.get(DataSynthInstance, task_id)
if not task:
raise HTTPException(status_code=404, detail="Synthesis task not found")
@@ -306,7 +338,7 @@ async def list_synthesis_file_tasks(
):
"""分页获取某个数据合成任务下的文件任务列表"""
# 先校验任务是否存在
task = await db.get(DataSynthesisInstance, task_id)
task = await db.get(DataSynthInstance, task_id)
if not task:
raise HTTPException(status_code=404, detail="Synthesis task not found")
@@ -333,7 +365,6 @@ async def list_synthesis_file_tasks(
synthesis_instance_id=row.synthesis_instance_id,
file_name=row.file_name,
source_file_id=row.source_file_id,
target_file_location=row.target_file_location,
status=row.status,
total_chunks=row.total_chunks,
processed_chunks=row.processed_chunks,
@@ -523,7 +554,7 @@ async def delete_synthesis_data_by_chunk(
result = await db.execute(
delete(SynthesisData).where(SynthesisData.chunk_instance_id == chunk_id)
)
deleted = result.rowcount or 0
deleted = int(getattr(result, "rowcount", 0) or 0)
await db.commit()
@@ -542,7 +573,7 @@ async def batch_delete_synthesis_data(
result = await db.execute(
delete(SynthesisData).where(SynthesisData.id.in_(request.ids))
)
deleted = result.rowcount or 0
deleted = int(getattr(result, "rowcount", 0) or 0)
await db.commit()
return StandardResponse(code=200, message="success", data=deleted)

View File

@@ -11,33 +11,45 @@ class TextSplitConfig(BaseModel):
chunk_overlap: int = Field(..., description="重叠令牌数")
class SynthesisConfig(BaseModel):
class SyntheConfig(BaseModel):
"""合成配置"""
prompt_template: str = Field(..., description="合成提示模板")
synthesis_count: int = Field(None, description="单个chunk合成的数据数量")
model_id: str = Field(..., description="模型ID")
prompt_template: str = Field(None, description="合成提示模板")
number: Optional[int] = Field(None, description="单个chunk合成的数据数量")
temperature: Optional[float] = Field(None, description="温度参数")
class Config(BaseModel):
"""配置"""
text_split_config: TextSplitConfig = Field(None, description="文本切片配置")
question_synth_config: SyntheConfig = Field(None, description="问题合成配置")
answer_synth_config: SyntheConfig = Field(None, description="答案合成配置")
# 新增:整个任务允许生成的 QA 总上限(问题/答案对数量)
max_qa_pairs: Optional[int] = Field(
default=None,
description="整个任务允许生成的 QA 对总量上限;为 None 或 <=0 表示不限制",
)
class SynthesisType(Enum):
"""合成类型"""
QA = "QA"
COT = "COT"
QUESTION = "QUESTION"
class CreateSynthesisTaskRequest(BaseModel):
"""创建数据合成任务请求"""
name: str = Field(..., description="合成任务名称")
description: Optional[str] = Field(None, description="合成任务描述")
model_id: str = Field(..., description="模型ID")
source_file_id: list[str] = Field(..., description="原始文件ID列表")
text_split_config: TextSplitConfig = Field(None, description="文本切片配置")
synthesis_config: SynthesisConfig = Field(..., description="合成配置")
synthesis_type: SynthesisType = Field(..., description="合成类型")
source_file_id: list[str] = Field(..., description="原始文件ID列表")
synth_config: Config = Field(..., description="合成配置")
@field_validator("description")
@classmethod
def empty_string_to_none(cls, v: Optional[str]) -> Optional[str]:
"""前端如果传入空字符串,将其统一转为 None,避免存库时看起来像有描述但实际上为空。"""
"""前端如果传入空字符串,将其统一转为 None,避免存库时看起来像有描述但实际上为空。"""
if isinstance(v, str) and v.strip() == "":
return None
return v
@@ -50,17 +62,7 @@ class DataSynthesisTaskItem(BaseModel):
description: Optional[str] = None
status: Optional[str] = None
synthesis_type: str
model_id: str
progress: int
result_data_location: Optional[str] = None
text_split_config: Dict[str, Any]
synthesis_config: Dict[str, Any]
source_file_id: list[str]
total_files: int
processed_files: int
total_chunks: int
processed_chunks: int
total_synthesis_data: int
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
created_by: Optional[str] = None
@@ -85,7 +87,6 @@ class DataSynthesisFileTaskItem(BaseModel):
synthesis_instance_id: str
file_name: str
source_file_id: str
target_file_location: str
status: Optional[str] = None
total_chunks: int
processed_chunks: int
@@ -108,7 +109,7 @@ class PagedDataSynthesisFileTaskResponse(BaseModel):
class DataSynthesisChunkItem(BaseModel):
"""数据合成文件下的 chunk 记录"""
"""数据合成任务下的 chunk 记录"""
id: str
synthesis_file_instance_id: str
chunk_index: Optional[int] = None

View File

@@ -9,7 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.models.data_synthesis import (
DataSynthesisInstance,
DataSynthInstance,
DataSynthesisFileInstance,
SynthesisData,
)
@@ -43,7 +43,7 @@ class SynthesisDatasetExporter:
Optimized to process one file at a time to reduce memory usage.
"""
task = await self._db.get(DataSynthesisInstance, task_id)
task = await self._db.get(DataSynthInstance, task_id)
if not task:
raise SynthesisExportError(f"Synthesis task {task_id} not found")

View File

@@ -1,138 +1,477 @@
import asyncio
import json
import uuid
from pathlib import Path
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.language_models import BaseChatModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models.data_synthesis import (
DataSynthesisInstance,
DataSynthInstance,
DataSynthesisFileInstance,
DataSynthesisChunkInstance,
SynthesisData,
)
from app.db.models.dataset_management import DatasetFiles
from app.db.models.model_config import get_model_by_id
from app.db.session import logger
from app.module.shared.util.model_chat import _extract_json_substring
from app.module.system.service.common_service import get_chat_client, chat
from app.common.document_loaders import load_documents
from app.module.generation.schema.generation import Config, SyntheConfig
from app.module.generation.service.prompt import (
QUESTION_GENERATOR_PROMPT,
ANSWER_GENERATOR_PROMPT,
)
from app.module.shared.common.document_loaders import load_documents
from app.module.shared.common.text_split import DocumentSplitter
from app.module.shared.util.model_chat import extract_json_substring
from app.module.system.service.common_service import chat, get_model_by_id, get_chat_client
class GenerationService:
def __init__(self, db: AsyncSession):
self.db = db
# 全局并发信号量:保证任意时刻最多 10 次模型调用
self.question_semaphore = asyncio.Semaphore(10)
self.answer_semaphore = asyncio.Semaphore(100)
async def process_task(self, task_id: str):
"""处理数据合成任务入口:根据任务ID加载任务并逐个处理源文件。"""
synthesis_task: DataSynthesisInstance | None = await self.db.get(DataSynthesisInstance, task_id)
if not synthesis_task:
synth_task: DataSynthInstance | None = await self.db.get(DataSynthInstance, task_id)
if not synth_task:
logger.error(f"Synthesis task {task_id} not found, abort processing")
return
logger.info(f"Processing synthesis task {task_id}")
file_ids = synthesis_task.source_file_id or []
logger.info(f"Start processing synthe task {task_id}")
# 获取模型客户端
model_result = await get_model_by_id(self.db, str(synthesis_task.model_id))
if not model_result:
logger.error(
f"Model config not found for id={synthesis_task.model_id}, abort task {synthesis_task.id}"
)
# 从 synth_config 中读取 max_qa_pairs,全局控制 QA 总量上限;<=0 或异常则视为不限制
try:
cfg = Config(**(synth_task.synth_config or {}))
max_qa_pairs = cfg.max_qa_pairs if (cfg and cfg.max_qa_pairs and cfg.max_qa_pairs > 0) else None
except Exception:
max_qa_pairs = None
# 获取任务关联的文件原始ID列表
file_ids = await self._get_file_ids_for_task(task_id)
if not file_ids:
logger.warning(f"No files associated with task {task_id}, abort processing")
return
chat_client = get_chat_client(model_result)
# 控制并发度的信号量(限制全任务范围内最多 10 个并发调用)
semaphore = asyncio.Semaphore(10)
# 逐个文件处理
for file_id in file_ids:
try:
success = await self._process_single_file(
synthesis_task=synthesis_task,
file_id=file_id,
chat_client=chat_client,
semaphore=semaphore,
)
success = await self._process_single_file(synth_task, file_id, max_qa_pairs=max_qa_pairs)
except Exception as e:
logger.exception(f"Unexpected error when processing file {file_id} for task {task_id}: {e}")
# 确保对应文件任务状态标记为失败
await self._mark_file_failed(str(synthesis_task.id), file_id, str(e))
await self._mark_file_failed(str(synth_task.id), file_id, str(e))
success = False
if success:
# 每处理完一个文件,简单增加 processed_files 计数
synthesis_task.processed_files = (synthesis_task.processed_files or 0) + 1
synth_task.processed_files = (synth_task.processed_files or 0) + 1
await self.db.commit()
await self.db.refresh(synthesis_task)
await self.db.refresh(synth_task)
logger.info(f"Finished processing synthesis task {synthesis_task.id}")
logger.info(f"Finished processing synthesis task {synth_task.id}")
# ==================== 高层文件处理流程 ====================
async def _process_single_file(
self,
synthesis_task: DataSynthesisInstance,
synth_task: DataSynthInstance,
file_id: str,
chat_client,
semaphore: asyncio.Semaphore,
max_qa_pairs: int | None = None,
) -> bool:
"""处理单个源文件:解析路径、切片、保存分块并触发 LLM 调用。"""
"""按 chunk 批量流式处理单个源文件。
流程:
1. 切片并将所有 chunk 持久化到 DB 后释放内存;
2. 从 DB 按 chunk_index 升序批量读取 chunk;
3. 对批次中的每个 chunk:先生成指定数量的问题,再基于这些问题生成答案;
4. 每成功处理完一个 chunk(即该 chunk 至少生成一条 QA)就更新一次 processed_chunks;
5. 全部完成后将文件实例标记为 completed。
"""
# 解析文件路径与配置
file_path = await self._resolve_file_path(file_id)
if not file_path:
logger.warning(f"File path not found for file_id={file_id}, skip")
await self._mark_file_failed(str(synthesis_task.id), file_id, "file_path_not_found")
await self._mark_file_failed(str(synth_task.id), file_id, "file_path_not_found")
return False
logger.info(f"Processing file_id={file_id}, path={file_path}")
split_cfg = synthesis_task.text_split_config or {}
synthesis_cfg = synthesis_task.synthesis_config or {}
chunk_size = int(split_cfg.get("chunk_size", 800))
chunk_overlap = int(split_cfg.get("chunk_overlap", 50))
# 加载并切片
try:
chunks = self._load_and_split(file_path, chunk_size, chunk_overlap)
config = Config(**(synth_task.synth_config or {}))
except Exception as e:
logger.error(f"Failed to load/split file {file_path}: {e}")
await self._mark_file_failed(str(synthesis_task.id), file_id, f"load_split_error: {e}")
logger.error(f"Invalid synth_config for task={synth_task.id}: {e}")
await self._mark_file_failed(str(synth_task.id), file_id, "invalid_synth_config")
return False
# 1. 加载并切片(仅在此处占用内存)
chunks = self._load_and_split(
file_path,
config.text_split_config.chunk_size,
config.text_split_config.chunk_overlap,
)
if not chunks:
logger.warning(f"No chunks generated for file_id={file_id}")
await self._mark_file_failed(str(synthesis_task.id), file_id, "no_chunks_generated")
await self._mark_file_failed(str(synth_task.id), file_id, "no_chunks_generated")
return False
logger.info(f"File {file_id} split into {len(chunks)} chunks by LangChain")
# 保存文件任务记录 + 分块记录
# 2. 获取文件实例并持久化 chunk 记录
file_task = await self._get_or_create_file_instance(
synthesis_task_id=str(synthesis_task.id),
synthesis_task_id=str(synth_task.id),
source_file_id=file_id,
file_path=file_path,
)
await self._persist_chunks(synthesis_task, file_task, file_id, chunks)
if not file_task:
logger.error(
f"DataSynthesisFileInstance not found for task={synth_task.id}, file_id={file_id}"
)
await self._mark_file_failed(str(synth_task.id), file_id, "file_instance_not_found")
return False
# 针对每个切片并发调用大模型
await self._invoke_llm_for_chunks(
synthesis_task=synthesis_task,
file_id=file_id,
chunks=chunks,
synthesis_cfg=synthesis_cfg,
chat_client=chat_client,
semaphore=semaphore,
await self._persist_chunks(synth_task, file_task, file_id, chunks)
total_chunks = len(chunks)
# 释放内存中的切片
del chunks
# 3. 读取问答配置
question_cfg: SyntheConfig | None = config.question_synth_config
answer_cfg: SyntheConfig | None = config.answer_synth_config
if not question_cfg or not answer_cfg:
logger.error(
f"Question/Answer synth config missing for task={synth_task.id}, file={file_id}"
)
await self._mark_file_failed(str(synth_task.id), file_id, "qa_config_missing")
return False
logger.info(
f"Start QA generation for task={synth_task.id}, file={file_id}, total_chunks={total_chunks}"
)
# 如果执行到此处,说明该文件的切片与 LLM 调用流程均未抛出异常,标记为完成
# 为本文件构建模型 client
question_model = await get_model_by_id(self.db, question_cfg.model_id)
answer_model = await get_model_by_id(self.db, answer_cfg.model_id)
question_chat = get_chat_client(question_model)
answer_chat = get_chat_client(answer_model)
# 分批次从 DB 读取并处理 chunk
batch_size = 20
current_index = 1
while current_index <= total_chunks:
end_index = min(current_index + batch_size - 1, total_chunks)
chunk_batch = await self._load_chunk_batch(
file_task_id=file_task.id,
start_index=current_index,
end_index=end_index,
)
if not chunk_batch:
logger.warning(
f"Empty chunk batch loaded for file={file_id}, range=[{current_index}, {end_index}]"
)
current_index = end_index + 1
continue
# 对本批中的每个 chunk 并发处理(内部受 semaphore 限流)
async def process_one(chunk: DataSynthesisChunkInstance) -> bool:
return await self._process_single_chunk_qa(
file_task=file_task,
chunk=chunk,
question_cfg=question_cfg,
answer_cfg=answer_cfg,
question_chat=question_chat,
answer_chat=answer_chat,
synth_task_id=str(synth_task.id),
max_qa_pairs=max_qa_pairs,
)
tasks = [process_one(chunk) for chunk in chunk_batch]
await asyncio.gather(*tasks, return_exceptions=True)
current_index = end_index + 1
# 全部完成
file_task.status = "completed"
await self.db.commit()
await self.db.refresh(file_task)
return True
async def _process_single_chunk_qa(
self,
file_task: DataSynthesisFileInstance,
chunk: DataSynthesisChunkInstance,
question_cfg: SyntheConfig,
answer_cfg: SyntheConfig,
question_chat: BaseChatModel,
answer_chat: BaseChatModel,
synth_task_id: str,
max_qa_pairs: int | None = None,
) -> bool:
"""处理单个 chunk:生成问题列表,然后为每个问题生成答案并落库。
为了全局控制 QA 总量:在本方法开始处,根据 synth_task_id 查询当前已落盘的
SynthesisData 条数,如果 >= max_qa_pairs,则不再对当前 chunk 做任何 QA 生成,
并将当前文件任务标记为 completed,processed_chunks = total_chunks。
已经进入后续流程的任务(例如其它协程正在生成答案)允许自然执行完。
"""
# 如果没有全局上限配置,维持原有行为
if max_qa_pairs is not None and max_qa_pairs > 0:
from sqlalchemy import func
# 统计当前整个任务下已生成的 QA 总数
result = await self.db.execute(
select(func.count(SynthesisData.id)).where(
SynthesisData.synthesis_file_instance_id.in_(
select(DataSynthesisFileInstance.id).where(
DataSynthesisFileInstance.synthesis_instance_id == synth_task_id
)
)
)
)
current_qa_count = int(result.scalar() or 0)
if current_qa_count >= max_qa_pairs:
logger.info(
"max_qa_pairs reached: current=%s, max=%s, task_id=%s, file_task_id=%s, skip new QA generation for this chunk.",
current_qa_count,
max_qa_pairs,
synth_task_id,
file_task.id,
)
# 将文件任务标记为已完成,并认为所有 chunk 均已处理
file_task.status = "completed"
if file_task.total_chunks is not None:
file_task.processed_chunks = file_task.total_chunks
await self.db.commit()
await self.db.refresh(file_task)
return False
# ---- 下面保持原有逻辑不变 ----
chunk_index = chunk.chunk_index
chunk_text = chunk.chunk_content or ""
if not chunk_text.strip():
logger.warning(
f"Empty chunk text for file_task={file_task.id}, chunk_index={chunk_index}"
)
# 无论成功或失败,均视为该 chunk 已处理完成
try:
await self._increment_processed_chunks(file_task.id, 1)
except Exception as e:
logger.exception(
f"Failed to increment processed_chunks for file_task={file_task.id}, chunk_index={chunk_index}: {e}"
)
return False
success_any = False
# 1. 生成问题
try:
questions = await self._generate_questions_for_one_chunk(
chunk_text=chunk_text,
question_cfg=question_cfg,
question_chat=question_chat,
)
except Exception as e:
logger.error(
f"Generate questions failed for file_task={file_task.id}, chunk_index={chunk_index}: {e}"
)
questions = []
if not questions:
logger.info(
f"No questions generated for file_task={file_task.id}, chunk_index={chunk_index}"
)
else:
# 2. 针对每个问题生成答案并入库
qa_success = await self._generate_answers_for_one_chunk(
file_task=file_task,
chunk=chunk,
questions=questions,
answer_cfg=answer_cfg,
answer_chat=answer_chat,
)
success_any = bool(qa_success)
# 无论本 chunk 处理是否成功,都增加 processed_chunks 计数,避免任务长时间卡住
try:
await self._increment_processed_chunks(file_task.id, 1)
except Exception as e:
logger.exception(
f"Failed to increment processed_chunks for file_task={file_task.id}, chunk_index={chunk_index}: {e}"
)
return success_any
async def _generate_questions_for_one_chunk(
self,
chunk_text: str,
question_cfg: SyntheConfig,
question_chat: BaseChatModel,
) -> list[str]:
"""针对单个 chunk 文本,调用 question_chat 生成问题列表。"""
number = question_cfg.number or 5
number = number if number is not None else 5
number = max(int(len(chunk_text) / 1000 * number), 1)
template = getattr(question_cfg, "prompt_template", QUESTION_GENERATOR_PROMPT)
template = template if (template is not None and template.strip() != "") else QUESTION_GENERATOR_PROMPT
prompt = (
template
.replace("{text}", chunk_text)
.replace("{number}", str(number))
.replace("{textLength}", str(len(chunk_text)))
)
async with self.question_semaphore:
loop = asyncio.get_running_loop()
raw_answer = await loop.run_in_executor(
None,
chat,
question_chat,
prompt,
)
# 解析为问题列表
questions = self._parse_questions_from_answer(
raw_answer,
)
return questions
async def _generate_answers_for_one_chunk(
self,
file_task: DataSynthesisFileInstance,
chunk: DataSynthesisChunkInstance,
questions: list[str],
answer_cfg: SyntheConfig,
answer_chat: BaseChatModel,
) -> bool:
"""为一个 chunk 的所有问题生成答案并写入 SynthesisData。
返回:是否至少成功写入一条 QA。
"""
if not questions:
return False
chunk_text = chunk.chunk_content or ""
template = getattr(answer_cfg, "prompt_template", ANSWER_GENERATOR_PROMPT)
template = template if (template is not None and template.strip() != "") else ANSWER_GENERATOR_PROMPT
extra_vars = getattr(answer_cfg, "extra_prompt_vars", {}) or {}
success_flags: list[bool] = []
async def process_single_question(question: str):
prompt = template.replace("{text}", chunk_text).replace("{question}", question)
for k, v in extra_vars.items():
prompt.replace(f"{{{{{k}}}}}", str(v))
else:
prompt_local = prompt
async with self.answer_semaphore:
loop = asyncio.get_running_loop()
answer = await loop.run_in_executor(
None,
chat,
answer_chat,
prompt_local,
)
# 默认结构:与 ANSWER_GENERATOR_PROMPT 一致,并补充 instruction 字段
base_obj: dict[str, object] = {
"input": chunk_text,
"output": answer,
}
# 如果模型已经按照 ANSWER_GENERATOR_PROMPT 返回了 JSON,则尝试解析并在其上增加 instruction
parsed_obj: dict[str, object] | None = None
if isinstance(answer, str):
cleaned = extract_json_substring(answer)
try:
parsed = json.loads(cleaned)
if isinstance(parsed, dict):
parsed_obj = parsed
except Exception:
parsed_obj = None
if parsed_obj is not None:
parsed_obj["instruction"] = question
data_obj = parsed_obj
else:
base_obj["instruction"] = question
data_obj = base_obj
record = SynthesisData(
id=str(uuid.uuid4()),
data=data_obj,
synthesis_file_instance_id=file_task.id,
chunk_instance_id=chunk.id,
)
self.db.add(record)
success_flags.append(True)
tasks = [process_single_question(q) for q in questions]
await asyncio.gather(*tasks, return_exceptions=True)
if success_flags:
await self.db.commit()
return True
return False
@staticmethod
def _parse_questions_from_answer(
raw_answer: str,
) -> list[str]:
"""从大模型返回中解析问题数组。"""
if not raw_answer:
return []
cleaned = extract_json_substring(raw_answer)
try:
data = json.loads(cleaned)
except Exception as e:
logger.error(
f"Failed to parse question list JSON for task: {e}. "
)
return []
if isinstance(data, list):
return [str(q) for q in data if isinstance(q, str) and q.strip()]
# 容错:如果是单个字符串
if isinstance(data, str) and data.strip():
return [data.strip()]
return []
# ==================== 原有辅助方法(文件路径/切片/持久化等) ====================
async def _resolve_file_path(self, file_id: str) -> str | None:
"""根据文件ID查询 t_dm_dataset_files 并返回 file_path(仅 ACTIVE 文件)。"""
result = await self.db.execute(
select(DatasetFiles).where(DatasetFiles.id == file_id)
)
file_obj = result.scalar_one_or_none()
if not file_obj:
return None
return file_obj.file_path
@staticmethod
def _load_and_split(file_path: str, chunk_size: int, chunk_overlap: int):
"""使用 LangChain 加载文本并进行切片,直接返回 Document 列表。
Args:
file_path: 待切片的文件路径
chunk_size: 切片大小
chunk_overlap: 切片重叠大小
"""
try:
docs = load_documents(file_path)
split_docs = DocumentSplitter.auto_split(docs, chunk_size, chunk_overlap)
return split_docs
except Exception as e:
logger.error(f"Error loading or splitting file {file_path}: {e}")
raise
async def _persist_chunks(
self,
synthesis_task: DataSynthesisInstance,
synthesis_task: DataSynthInstance,
file_task: DataSynthesisFileInstance,
file_id: str,
chunks,
@@ -164,201 +503,10 @@ class GenerationService:
await self.db.commit()
await self.db.refresh(file_task)
async def _invoke_llm_for_chunks(
self,
synthesis_task: DataSynthesisInstance,
file_id: str,
chunks,
synthesis_cfg: dict,
chat_client,
semaphore: asyncio.Semaphore,
) -> None:
"""针对每个分片并发调用大模型生成数据。"""
# 需要将 answer 和对应 chunk 建立关系,因此这里保留 chunk_index
tasks = [
self._call_llm(doc, file_id, idx, synthesis_task, synthesis_cfg, chat_client, semaphore)
for idx, doc in enumerate(chunks, start=1)
]
await asyncio.gather(*tasks, return_exceptions=True)
async def _call_llm(
self,
doc,
file_id: str,
idx: int,
synthesis_task,
synthesis_cfg: dict,
chat_client,
semaphore: asyncio.Semaphore,
):
"""单次大模型调用逻辑,带并发控制。
说明:
- 使用信号量限制全局并发量(当前为 10)。
- 使用线程池执行同步的 chat 调用,避免阻塞事件循环。
- 在拿到 LLM 返回后,解析为 JSON 并批量写入 SynthesisData,
同时更新文件级 processed_chunks / 进度等信息。
"""
async with semaphore:
prompt = self._build_qa_prompt(doc.page_content, synthesis_cfg)
try:
loop = asyncio.get_running_loop()
answer = await loop.run_in_executor(None, chat, chat_client, prompt)
logger.debug(
f"Generated QA for task={synthesis_task.id}, file={file_id}, chunk={idx}"
)
await self._handle_llm_answer(
synthesis_task_id=str(synthesis_task.id),
file_id=file_id,
chunk_index=idx,
raw_answer=answer,
)
return answer
except Exception as e:
logger.error(
f"LLM generation failed for task={synthesis_task.id}, file={file_id}, chunk={idx}: {e}"
)
return None
async def _resolve_file_path(self, file_id: str) -> str | None:
"""根据文件ID查询 t_dm_dataset_files 并返回 file_path(仅 ACTIVE 文件)。"""
result = await self.db.execute(
select(DatasetFiles).where(DatasetFiles.id == file_id)
)
file_obj = result.scalar_one_or_none()
if not file_obj:
return None
return file_obj.file_path
def _load_and_split(self, file_path: str, chunk_size: int, chunk_overlap: int):
"""使用 LangChain 加载文本并进行切片,直接返回 Document 列表。
当前实现:
- 使用 TextLoader 加载纯文本/Markdown/JSON 等文本文件
- 使用 RecursiveCharacterTextSplitter 做基于字符的递归切片
保留每个 Document 的 metadata,方便后续追加例如文件ID、chunk序号等信息。
"""
docs = load_documents(file_path)
splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
# 尝试按这些分隔符优先切分,再退化到字符级
separators=["\n\n", "\n", "", "", "", "!", "?", "\n", "\t", " "]
)
split_docs = splitter.split_documents(docs)
return split_docs
@staticmethod
def _build_qa_prompt(chunk: str, synthesis_cfg: dict) -> str:
"""构造 QA 数据合成的提示词。
要求:
- synthesis_cfg["prompt_template"] 是一个字符串,其中包含 {document} 占位符;
- 将当前切片内容替换到 {document}
如果未提供或模板非法,则使用内置默认模板。
"""
template = None
if isinstance(synthesis_cfg, dict):
template = synthesis_cfg.get("prompt_template")
synthesis_count = synthesis_cfg["synthesis_count"] if ("synthesis_count" in synthesis_cfg and synthesis_cfg["synthesis_count"]) else 5
try:
prompt = template.format(document=chunk, synthesis_count=synthesis_count)
except Exception:
# 防御性处理:如果 format 出现异常,则退回到简单拼接
prompt = f"{template}\n\n文档内容:{chunk}\n\n请根据文档内容生成 {synthesis_count} 条符合要求的问答数据。"
return prompt
async def _handle_llm_answer(
self,
synthesis_task_id: str,
file_id: str,
chunk_index: int,
raw_answer: str,
) -> None:
"""解析 LLM 返回内容为 JSON,批量保存到 SynthesisData,并更新文件任务进度。
约定:
- LLM 返回的 raw_answer 是 JSON 字符串,可以是:
1)单个对象:{"question": ..., "answer": ...}
2)对象数组:[{}, {}, ...]
- 我们将其规范化为列表,每个元素作为一条 SynthesisData.data 写入。
- 根据 synthesis_task_id + file_id + chunk_index 找到对应的 DataSynthesisChunkInstance,
以便设置 chunk_instance_id 和 synthesis_file_instance_id。
- 每处理完一个 chunk,递增对应 DataSynthesisFileInstance.processed_chunks,并按比例更新进度。
"""
if not raw_answer:
return
# 1. 预处理原始回答:尝试从中截取出最可能的 JSON 片段
cleaned = _extract_json_substring(raw_answer)
# 2. 解析 JSON,统一成列表结构
try:
parsed = json.loads(cleaned)
except Exception as e:
logger.error(
f"Failed to parse LLM answer as JSON for task={synthesis_task_id}, file={file_id}, chunk={chunk_index}: {e}. Raw answer: {raw_answer!r}"
)
return
if isinstance(parsed, dict):
items = [parsed]
elif isinstance(parsed, list):
items = [p for p in parsed if isinstance(p, dict)]
else:
logger.error(f"Unexpected JSON structure from LLM answer for task={synthesis_task_id}, file={file_id}, chunk={chunk_index}: {type(parsed)}")
return
if not items:
return
# 3. 找到对应的 chunk 记录(一个 chunk_index 对应一条记录)
chunk_result = await self.db.execute(
select(DataSynthesisChunkInstance, DataSynthesisFileInstance)
.join(
DataSynthesisFileInstance,
DataSynthesisFileInstance.id == DataSynthesisChunkInstance.synthesis_file_instance_id,
)
.where(
DataSynthesisFileInstance.synthesis_instance_id == synthesis_task_id,
DataSynthesisFileInstance.source_file_id == file_id,
DataSynthesisChunkInstance.chunk_index == chunk_index,
)
)
row = chunk_result.first()
if not row:
logger.error(
f"Chunk record not found for task={synthesis_task_id}, file={file_id}, chunk_index={chunk_index}, skip saving SynthesisData."
)
return
chunk_instance, file_instance = row
# 4. 批量写入 SynthesisData
for data_obj in items:
record = SynthesisData(
id=str(uuid.uuid4()),
data=data_obj,
synthesis_file_instance_id=file_instance.id,
chunk_instance_id=chunk_instance.id,
)
self.db.add(record)
# 5. 更新文件级 processed_chunks / 进度
file_instance.processed_chunks = (file_instance.processed_chunks or 0) + 1
await self.db.commit()
await self.db.refresh(file_instance)
async def _get_or_create_file_instance(
self,
synthesis_task_id: str,
source_file_id: str,
file_path: str,
) -> DataSynthesisFileInstance:
"""根据任务ID和原始文件ID,查找或创建对应的 DataSynthesisFileInstance 记录。
@@ -374,32 +522,9 @@ class GenerationService:
)
)
file_task = result.scalar_one_or_none()
if file_task is not None:
return file_task
# 查询任务以获取 result_data_location
task = await self.db.get(DataSynthesisInstance, synthesis_task_id)
target_location = task.result_data_location if task else ""
# 创建新的文件任务记录,初始状态为 processing
file_task = DataSynthesisFileInstance(
id=str(uuid.uuid4()),
synthesis_instance_id=synthesis_task_id,
file_name=Path(file_path).name,
source_file_id=source_file_id,
target_file_location=target_location or "",
status="processing",
total_chunks=0,
processed_chunks=0,
created_by="system",
updated_by="system",
)
self.db.add(file_task)
await self.db.commit()
await self.db.refresh(file_task)
return file_task
async def _mark_file_failed(self, synthesis_task_id: str, file_id: str, reason: str | None = None) -> None:
async def _mark_file_failed(self, synth_task_id: str, file_id: str, reason: str | None = None) -> None:
"""将指定任务下的单个文件任务标记为失败状态,兜底错误处理。
- 如果找到对应的 DataSynthesisFileInstance,则更新其 status="failed"
@@ -409,14 +534,14 @@ class GenerationService:
try:
result = await self.db.execute(
select(DataSynthesisFileInstance).where(
DataSynthesisFileInstance.synthesis_instance_id == synthesis_task_id,
DataSynthesisFileInstance.synthesis_instance_id == synth_task_id,
DataSynthesisFileInstance.source_file_id == file_id,
)
)
file_task = result.scalar_one_or_none()
if not file_task:
logger.warning(
f"Failed to mark file as failed: no DataSynthesisFileInstance found for task={synthesis_task_id}, file_id={file_id}, reason={reason}"
f"Failed to mark file as failed: no DataSynthesisFileInstance found for task={synth_task_id}, file_id={file_id}, reason={reason}"
)
return
@@ -424,10 +549,72 @@ class GenerationService:
await self.db.commit()
await self.db.refresh(file_task)
logger.info(
f"Marked file task as failed for task={synthesis_task_id}, file_id={file_id}, reason={reason}"
f"Marked file task as failed for task={synth_task_id}, file_id={file_id}, reason={reason}"
)
except Exception as e:
# 兜底日志,避免异常向外传播影响其它文件处理
logger.exception(
f"Unexpected error when marking file failed for task={synthesis_task_id}, file_id={file_id}, original_reason={reason}, error={e}"
f"Unexpected error when marking file failed for task={synth_task_id}, file_id={file_id}, original_reason={reason}, error={e}"
)
async def _get_file_ids_for_task(self, synth_task_id: str):
"""根据任务ID查询关联的文件原始ID列表"""
result = await self.db.execute(
select(DataSynthesisFileInstance.source_file_id)
.where(DataSynthesisFileInstance.synthesis_instance_id == synth_task_id)
)
file_ids = result.scalars().all()
return file_ids
# ========== 新增:chunk 计数与批量加载、processed_chunks 安全更新辅助方法 ==========
async def _count_chunks_for_file(self, synth_file_instance_id: str) -> int:
"""统计指定任务与文件下的 chunk 总数。"""
from sqlalchemy import func
result = await self.db.execute(
select(func.count(DataSynthesisChunkInstance.id)).where(
DataSynthesisChunkInstance.synthesis_file_instance_id == synth_file_instance_id
)
)
return int(result.scalar() or 0)
async def _load_chunk_batch(
self,
file_task_id: str,
start_index: int,
end_index: int,
) -> list[DataSynthesisChunkInstance]:
"""按索引范围加载指定文件任务下的一批 chunk 记录(含边界)。"""
result = await self.db.execute(
select(DataSynthesisChunkInstance)
.where(
DataSynthesisChunkInstance.synthesis_file_instance_id == file_task_id,
DataSynthesisChunkInstance.chunk_index >= start_index,
DataSynthesisChunkInstance.chunk_index <= end_index,
)
.order_by(DataSynthesisChunkInstance.chunk_index.asc())
)
return list(result.scalars().all())
async def _increment_processed_chunks(self, file_task_id: str, delta: int) -> None:
result = await self.db.execute(
select(DataSynthesisFileInstance).where(
DataSynthesisFileInstance.id == file_task_id,
)
)
file_task = result.scalar_one_or_none()
if not file_task:
logger.error(f"Failed to increment processed_chunks: file_task {file_task_id} not found")
return
# 原始自增
new_value = (file_task.processed_chunks or 0) + int(delta)
# 如果存在 total_chunks,上限为 total_chunks,避免超过
total = file_task.total_chunks
if isinstance(total, int) and total >= 0:
new_value = min(new_value, total)
file_task.processed_chunks = new_value
await self.db.commit()
await self.db.refresh(file_task)

View File

@@ -1,71 +1,138 @@
from app.module.generation.schema.generation import SynthesisType
QA_PROMPT="""# 角色
你是一位专业的AI助手,擅长从给定的文本中提取关键信息并创建用于教学和测试的问答对。
QUESTION_GENERATOR_PROMPT=f"""# Role: 文本问题生成专家
## Profile:
- Description: 你是一名专业的文本分析与问题设计专家,能够从复杂文本中提炼关键信息并产出可用于模型微调的高质量问题集合。
- Input Length: {{textLength}}
- Output Goal: 生成不少于 {{number}} 个高质量问题,用于构建问答训练数据集。
# 任务
请根据用户提供的原始文档,生成一系列高质量、多样化的问答对
## Skills:
1. 能够全面理解原文内容,识别核心概念、事实与逻辑结构
2. 擅长设计具有明确答案指向性的问题,覆盖文本多个侧面。
3. 善于控制问题难度与类型,保证多样性与代表性。
4. 严格遵守格式规范,确保输出可直接用于程序化处理。
# 输入文档
{document}
## Workflow:
1. **文本解析**:通读全文,分段识别关键实体、事件、数值与结论。
2. **问题设计**:基于信息密度和重要性选择最佳提问切入点。
3. **质量检查**:逐条校验问题,确保:
- 问题答案可在原文中直接找到依据。
- 问题之间主题不重复、角度不雷同。
- 语言表述准确、无歧义且符合常规问句形式。
# 要求与指令
1. **问题类型**:生成 {synthesis_count} 个左右的问答对。问题类型应多样化,包括但不限于:
* **事实性**:基于文本中明确提到的事实
* **理解性**:需要理解上下文和概念
* **归纳性**:需要总结或归纳多个信息点
2. **答案来源**:所有答案必须严格基于提供的文档内容,不得编造原文不存在的信息
3. **语言**:请根据输入文档的主要语言进行提问和回答。
4. **问题质量**:问题应清晰、无歧义,并且是读完文档后自然会产生的问题。
5. **答案质量**:答案应准确、简洁、完整。
## Constraints:
1. 所有问题必须严格依据原文内容,不得添加外部信息或假设情境。
2. 问题需覆盖文本的不同主题、层级或视角,避免集中于单一片段
3. 禁止输出与材料元信息相关的问题(如作者、章节、目录等)
4. 提问时请假设没有相应的文章可供参考,因此不要在问题中使用"这个""这些"等指示代词,也不得包含“报告/文章/文献/表格中提到”等表述
5. 输出不少于 {{number}} 个问题,问题语言与原文主要语言保持一致
# 输出格式
请严格按照以下JSON格式输出,保持字段顺序,确保没有额外的解释或标记:
[
{{"instruction": "问题1","input": "参考内容1","output": "答案1"}},
{{"instruction": "问题2","input": "参考内容1","output": "答案2"}},
...
]
## Output Format:
- 使用合法的 JSON 数组,仅包含字符串元素。
- 字段必须使用英文双引号。
- 严格遵循以下结构:
```
["问题1", "问题2", "..."]
```
## Output Example:
```
["人工智能伦理框架应包含哪些核心要素", "民法典对个人数据保护有哪些新规定?"]
```
## 参考原文:
{{text}}
"""
ANSWER_GENERATOR_PROMPT=f"""# Role: 微调数据生成专家
## Profile:
- Description: 你是一名微调数据生成专家,擅长基于给定内容生成准确对应的问题答案,确保答案的准确性、相关性和完整性,能够直接输出符合模型训练要求的结构化数据。
COT_PROMPT="""# 角色
你是一位专业的数据合成专家,擅长基于给定的原始文档和 COT(Chain of Thought,思维链)逻辑,生成高质量、符合实际应用场景的 COT 数据。COT 数据需包含清晰的问题、逐步推理过程和最终结论,能完整还原解决问题的思考路径。
## Skills:
1. 严格基于给定内容生成答案,不添加任何外部信息
2. 答案需准确无误、逻辑通顺,与问题高度相关
3. 能够精准提取给定内容中的关键信息,并整合为自然流畅的完整答案
4. 输出结果必须符合指定的结构化格式要求
# 任务
请根据用户提供的原始文档,生成一系列高质量、多样化的 COT 数据。每个 COT 数据需围绕文档中的关键信息、核心问题或逻辑关联点展开,确保推理过程贴合文档内容,结论准确可靠。
## Workflow:
1. 分析给定的参考内容,梳理核心信息和逻辑框架
2. 结合提出的具体问题,从参考内容中提取与之匹配的关键依据
3. 基于提取的依据,生成准确、详尽且符合逻辑的答案
4. 将依据内容和答案分别填入指定字段,形成结构化输出
5. 校验输出内容,确保格式正确、信息完整、无引用性表述
# 输入文档
{document}
## Output Format:
输出格式为固定字典结构:
```json
{{
"input": "此处填入回答问题所依据的完整参考内容",
"output": "此处填入基于参考内容生成的准确答案"
}}
```
# 要求与指令
1. **数量要求**:生成 {synthesis_count} 条左右的 COT 数据。
2. **内容要求**:
* 每条 COT 数据需包含 “问题”“思维链推理”“最终结论” 三部分,逻辑闭环,推理步骤清晰、连贯,不跳跃关键环节。
* 问题需基于文档中的事实信息、概念关联或逻辑疑问,是读完文档后自然产生的有价值问题(避免无意义或过于简单的问题)。
* 思维链推理需严格依据文档内容,逐步推导,每一步推理都能对应文档中的具体信息,不编造原文不存在的内容,不主观臆断。
* 最终结论需简洁、准确,是思维链推理的合理结果,与文档核心信息一致。
3. **多样化要求**:
* 问题类型多样化,包括但不限于事实查询类、逻辑分析类、原因推导类、方案对比类、结论归纳类。
* 推理角度多样化,可从不同角色(如项目参与者、需求方、测试人员)或不同维度(如功能实现、进度推进、问题解决)展开推理。
4. **语言要求**:
* 语言通顺、表达清晰,无歧义,推理过程口语化但不随意,符合正常思考逻辑,最终结论简洁规范。
* 请根据输入文档的主要语言进行提问和回答。
## Constrains:
1. `input`字段必须根据给定的参考内容填入回答问题的依据,不得更改原文含义
2. `output`字段的答案必须完全基于`input`中的内容,严禁编造、添加外部信息
3. 答案需充分详细,包含回答问题的所有必要信息,满足大模型微调训练的数据要求
4. 答案中不得出现「参考」「依据」「文献中提到」等任何引用性表述,仅呈现最终结论
5. 必须严格遵守指定的字典输出格式,不得额外添加其他内容
# 输出格式
请严格按照以下 JSON 格式输出,保持字段顺序,确保没有额外的解释或标记,每条 COT 数据独立成项:
[
{{"question": "具体问题","chain_of_thought": "步骤 1:明确问题核心,定位文档中相关信息范围;步骤 2:提取文档中与问题相关的关键信息 1;步骤 3:结合关键信息 1 推导中间结论 1;步骤 4:提取文档中与问题相关的关键信息 2;步骤 5:结合中间结论 1 和关键信息 2 推导中间结论 2;...(逐步推进);步骤 N:汇总所有中间结论,得出最终结论","conclusion": "简洁准确的最终结论"}},
## Reference Content
------ 参考内容 Start ------
{{text}}
------ 参考内容 End ------
{{"question": "具体问题","chain_of_thought": "步骤 1:明确问题核心,定位文档中相关信息范围;步骤 2:提取文档中与问题相关的关键信息 1;步骤 3:结合关键信息 1 推导中间结论 1;步骤 4:提取文档中与问题相关的关键信息 2;步骤 5:结合中间结论 1 和关键信息 2 推导中间结论 2;...(逐步推进);步骤 N:汇总所有中间结论,得出最终结论","conclusion": "简洁准确的最终结论"}},
...
]
## Question
{{question}}
"""
COT_GENERATOR_PROMPT=f"""# Role: 微调数据生成专家
## Profile:
- Description: 你是一名微调数据生成专家,擅长基于给定参考内容,通过**思维链(COT)逐步推理**生成准确、完整的答案,输出符合大模型微调训练要求的结构化COT数据,还原从信息提取到结论推导的全思考路径。
## Skills:
1. 严格基于给定参考内容开展推理,不引入任何外部信息
2. 能够拆解问题逻辑,按步骤提取关键信息并推导,确保推理过程连贯、无跳跃
3. 生成的答案精准对应问题,逻辑通顺,与参考内容高度一致
4. 输出结果严格符合指定的结构化COT格式要求
## Workflow:
1. 分析给定参考内容,梳理核心信息、概念及逻辑关联
2. 结合具体问题,明确推理起点与目标,划定参考内容中的相关信息范围
3. 分步推导:提取关键信息→推导中间结论→结合更多信息完善逻辑→形成最终结论
4. 将完整推理过程、最终答案填入指定字段,生成结构化COT数据
5. 校验:确保推理每一步均对应参考内容,无编造信息,格式合规,无引用性表述
## Output Format:
输出固定JSON结构,包含思维链推理、最终答案两部分:
```json
{{
"chain_of_thought": "基于参考内容逐步推理的完整思维链,详述每一步提取的信息和推导的逻辑过程",
"output": "此处填入基于思维链推理得出的准确、详细的最终结论"
}}
```
## Constrains:
2. `chain_of_thought`字段需还原完整推理路径,每一步推导均需对应`Reference Content`中的具体内容,严禁主观臆断
3. `output`字段的答案必须完全来源于`Reference Content`和`chain_of_thought`的推导,不添加任何外部信息,满足大模型微调对数据质量的要求
4. 整个输出中不得出现「参考」「依据」「文献中提到」等引用性表述,仅呈现推理过程与结论
5. 必须严格遵守指定JSON格式,字段顺序固定,无额外解释或标记内容
## Reference Content
------ 参考内容 Start ------
{{text}}
------ 参考内容 End ------
## Question
{{question}}
"""
def get_prompt(synth_type: SynthesisType):
if synth_type == SynthesisType.QA:
return QA_PROMPT
return ANSWER_GENERATOR_PROMPT
elif synth_type == SynthesisType.COT:
return COT_PROMPT
return COT_GENERATOR_PROMPT
elif synth_type == SynthesisType.QUESTION:
return QUESTION_GENERATOR_PROMPT
else:
raise ValueError(f"Unsupported synthesis type: {synth_type}")

View File

@@ -0,0 +1,169 @@
import os
from typing import List, Optional, Tuple
from langchain_core.documents import Document
from langchain_text_splitters import (
RecursiveCharacterTextSplitter,
MarkdownHeaderTextSplitter
)
class DocumentSplitter:
"""
文档分割器类 - 增强版,优先通过元数据识别文档类型
核心特性:
1. 优先从metadata的source字段(文件扩展名)识别Markdown
2. 元数据缺失时,通过内容特征降级检测
3. 支持CJK(中日韩)语言优化
"""
def __init__(
self,
chunk_size: int = 2000,
chunk_overlap: int = 200,
is_cjk_language: bool = True,
markdown_headers: Optional[List[Tuple[str, str]]] = None
):
"""
初始化文档分割器
Args:
chunk_size: 每个文本块的最大长度(默认2000字符)
chunk_overlap: 文本块之间的重叠长度(默认200字符)
is_cjk_language: 是否处理中日韩等无词边界语言(默认True)
markdown_headers: Markdown标题分割规则(默认:#/##/###/####)
"""
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.is_cjk_language = is_cjk_language
# 默认Markdown标题分割规则
self.markdown_headers = markdown_headers or [
("#", "header_1"),
("##", "header_2"),
("###", "header_3"),
("####", "header_4"),
]
# 初始化基础文本分割器
self.text_splitter = self._create_text_splitter()
def _create_text_splitter(self) -> RecursiveCharacterTextSplitter:
"""创建递归字符分割器(内部方法)"""
# 优化后的CJK分隔符列表(修复语法错误,调整优先级)
if self.is_cjk_language:
separators = [
"\n\n", "\n", # 段落/换行(最高优先级)
"", ".", # 句号(中文/英文)
"", "!", # 感叹号(中文/英文)
"", "?", # 问号(中文/英文)
"", ";", # 分号(中文/英文)
"", ",", # 逗号(中文/英文)
"", # 顿号(中文)
"", ":", # 冒号(中文/英文)
" ", # 空格
"\u200b", "", # 零宽空格/兜底
]
else:
separators = ["\n\n", "\n", " ", ".", "!", "?", ";", ":", ",", ""]
return RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
separators=separators,
length_function=len,
is_separator_regex=False
)
@staticmethod
def _is_markdown(doc: Document) -> bool:
"""
优先从元数据判断是否为Markdown
规则:检查metadata中的source字段扩展名是否为.md/.markdown/.mdx等
"""
# 获取source字段(忽略大小写)
source = doc.metadata.get("source", "").lower()
if not source:
return False
# 获取文件扩展名
ext = os.path.splitext(source)[-1].lower()
# Markdown常见扩展名列表
md_ext = [".md", ".markdown", ".mdx", ".mkd", ".mkdown"]
return ext in md_ext
def split(self, documents: List[Document], is_markdown: bool = False) -> List[Document]:
"""
核心分割方法
Args:
documents: 待分割的Document列表
is_markdown: 是否为Markdown文档(默认False)
Returns:
分割后的Document列表
"""
if not documents:
return []
# Markdown文档处理:先按标题分割,再按字符分割
if is_markdown:
# 初始化Markdown标题分割器
md_splitter = MarkdownHeaderTextSplitter(
headers_to_split_on=self.markdown_headers,
strip_headers=True,
return_each_line=False
)
# 按标题分割并继承元数据
md_chunks = []
for doc in documents:
chunks = md_splitter.split_text(doc.page_content)
for chunk in chunks:
chunk.metadata.update(doc.metadata)
md_chunks.extend(chunks)
# 对标题分割后的内容进行字符分割
final_chunks = self.text_splitter.split_documents(md_chunks)
# 普通文本直接分割
else:
final_chunks = self.text_splitter.split_documents(documents)
return final_chunks
# 核心自动分割方法(元数据优先)
@classmethod
def auto_split(
cls,
documents: List[Document],
chunk_size: int = 2000,
chunk_overlap: int = 200
) -> List[Document]:
"""
极简快捷方法:自动识别文档类型并分割(元数据优先)
仅需传入3个参数,无需初始化类实例
Args:
documents: 待分割的Document列表
chunk_size: 每个文本块的最大长度(默认2000字符)
chunk_overlap: 文本块之间的重叠长度(默认200字符)
Returns:
分割后的Document列表
"""
if not documents:
return []
# 初始化分割器实例(使用CJK默认优化)
splitter = cls(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
is_cjk_language=True
)
# 自动检测文档类型(元数据优先)
is_md = splitter._is_markdown(documents[0])
# 根据检测结果选择分割方式
return splitter.split(documents, is_markdown=is_md)

View File

@@ -14,7 +14,8 @@ def call_openai_style_model(base_url, api_key, model_name, prompt, **kwargs):
)
return response.choices[0].message.content
def _extract_json_substring(raw: str) -> str:
def extract_json_substring(raw: str) -> str:
"""从 LLM 的原始回答中提取最可能的 JSON 字符串片段。
处理思路:
@@ -22,11 +23,21 @@ def _extract_json_substring(raw: str) -> str:
- 优先在文本中查找第一个 '{''[' 作为 JSON 起始;
- 再从后向前找最后一个 '}'']' 作为结束;
- 如果找不到合适的边界,就退回原始字符串。
- 部分模型可能会在回复中加入 `<think>...</think>` 内部思考内容,应在解析前先去除。
该方法不会保证截取的一定是合法 JSON,但能显著提高 json.loads 的成功率。
"""
if not raw:
return raw
# 先移除所有 <think>...</think> 段落(包括跨多行的情况)
try:
import re
raw = re.sub(r"<think>[\s\S]*?</think>", "", raw, flags=re.IGNORECASE)
except Exception:
# 正则异常时不影响后续逻辑,继续使用原始文本
pass
start = None
end = None