You've already forked DataMate
feat: Implement data synthesis task management with database models and API endpoints (#122)
This commit is contained in:
@@ -19,21 +19,26 @@ python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
3. 安装依赖:
|
||||
3. 安装依赖
|
||||
由于项目使用poetry管理依赖,你可以使用以下命令安装:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
pip install poetry
|
||||
poetry install
|
||||
```
|
||||
或者直接使用pip安装(如果poetry不可用):
|
||||
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
4. 准备环境变量(示例)
|
||||
4. 配置环境变量
|
||||
复制环境变量示例文件并配置:
|
||||
|
||||
创建 `.env` 并设置必要的变量,例如:
|
||||
|
||||
- DATABASE_URL(或根据项目配置使用具体变量)
|
||||
- LABEL_STUDIO_BASE_URL
|
||||
- LABEL_STUDIO_USER_TOKEN
|
||||
|
||||
(具体变量请参考 `.env.example`)
|
||||
```bash
|
||||
cp .env.example .env
|
||||
```
|
||||
编辑.env文件,设置必要的环境变量,如数据库连接、Label Studio配置等。
|
||||
|
||||
5. 数据库迁移(开发环境):
|
||||
|
||||
|
||||
0
runtime/datamate-python/app/__init__.py
Normal file
0
runtime/datamate-python/app/__init__.py
Normal file
@@ -17,7 +17,7 @@ class Settings(BaseSettings):
|
||||
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 18000
|
||||
|
||||
|
||||
# CORS
|
||||
# allowed_origins: List[str] = ["*"]
|
||||
# allowed_methods: List[str] = ["*"]
|
||||
@@ -36,7 +36,7 @@ class Settings(BaseSettings):
|
||||
mysql_database: str = "datamate"
|
||||
|
||||
database_url: str = "" # Will be overridden by build_database_url() if not provided
|
||||
|
||||
|
||||
@model_validator(mode='after')
|
||||
def build_database_url(self):
|
||||
"""如果没有提供 database_url,则根据 MySQL 配置构建"""
|
||||
|
||||
197
runtime/datamate-python/app/db/models/data_synthesis.py
Normal file
197
runtime/datamate-python/app/db/models/data_synthesis.py
Normal file
@@ -0,0 +1,197 @@
|
||||
import uuid
|
||||
from xml.etree.ElementTree import tostring
|
||||
|
||||
from sqlalchemy import Column, String, Text, Integer, JSON, TIMESTAMP, ForeignKey, func
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.db.session import Base
|
||||
from app.module.generation.schema.generation import CreateSynthesisTaskRequest
|
||||
|
||||
|
||||
async def save_synthesis_task(db_session, synthesis_task: CreateSynthesisTaskRequest):
|
||||
"""保存数据合成任务。"""
|
||||
# 转换为模型实例
|
||||
gid = str(uuid.uuid4())
|
||||
synthesis_task_instance = DataSynthesisInstance(
|
||||
id=gid,
|
||||
name=synthesis_task.name,
|
||||
description=synthesis_task.description,
|
||||
status="pending",
|
||||
model_id=synthesis_task.model_id,
|
||||
synthesis_type=synthesis_task.synthesis_type.value,
|
||||
progress=0,
|
||||
result_data_location=f"/dataset/synthesis_results/{gid}/",
|
||||
text_split_config=synthesis_task.text_split_config.model_dump(),
|
||||
synthesis_config=synthesis_task.synthesis_config.model_dump(),
|
||||
source_file_id=synthesis_task.source_file_id,
|
||||
total_files=len(synthesis_task.source_file_id),
|
||||
processed_files=0,
|
||||
total_chunks=0,
|
||||
processed_chunks=0,
|
||||
total_synthesis_data=0,
|
||||
created_at=func.now(),
|
||||
updated_at=func.now(),
|
||||
created_by="system",
|
||||
updated_by="system"
|
||||
)
|
||||
db_session.add(synthesis_task_instance)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(synthesis_task_instance)
|
||||
return synthesis_task_instance
|
||||
|
||||
|
||||
class DataSynthesisInstance(Base):
|
||||
"""数据合成任务表,对应表 t_data_synthesis_instances
|
||||
|
||||
create table if not exists t_data_synthesis_instances
|
||||
(
|
||||
id VARCHAR(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci PRIMARY KEY COMMENT 'UUID',
|
||||
name VARCHAR(255) NOT NULL COMMENT '任务名称',
|
||||
description TEXT COMMENT '任务描述',
|
||||
status VARCHAR(20) COMMENT '任务状态',
|
||||
synthesis_type VARCHAR(20) NOT NULL COMMENT '合成类型',
|
||||
model_id VARCHAR(255) NOT NULL COMMENT '模型ID',
|
||||
progress INT DEFAULT 0 COMMENT '任务进度(百分比)',
|
||||
result_data_location VARCHAR(1000) COMMENT '结果数据存储位置',
|
||||
text_split_config JSON NOT NULL COMMENT '文本切片配置',
|
||||
synthesis_config JSON NOT NULL COMMENT '合成配置',
|
||||
source_file_id JSON NOT NULL COMMENT '原始文件ID列表',
|
||||
total_files INT DEFAULT 0 COMMENT '总文件数',
|
||||
processed_files INT DEFAULT 0 COMMENT '已处理文件数',
|
||||
total_chunks INT DEFAULT 0 COMMENT '总文本块数',
|
||||
processed_chunks INT DEFAULT 0 COMMENT '已处理文本块数',
|
||||
total_synthesis_data INT DEFAULT 0 COMMENT '总合成数据量',
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
|
||||
created_by VARCHAR(255) COMMENT '创建者',
|
||||
updated_by VARCHAR(255) COMMENT '更新者'
|
||||
) COMMENT='数据合成任务表(UUID 主键)';
|
||||
"""
|
||||
|
||||
__tablename__ = "t_data_synthesis_instances"
|
||||
|
||||
id = Column(String(36), primary_key=True, index=True, comment="UUID")
|
||||
name = Column(String(255), nullable=False, comment="任务名称")
|
||||
description = Column(Text, nullable=True, comment="任务描述")
|
||||
status = Column(String(20), nullable=True, comment="任务状态")
|
||||
synthesis_type = Column(String(20), nullable=False, comment="合成类型")
|
||||
model_id = Column(String(255), nullable=False, comment="模型ID")
|
||||
progress = Column(Integer, nullable=False, default=0, comment="任务进度(百分比)")
|
||||
result_data_location = Column(String(1000), nullable=True, comment="结果数据存储位置")
|
||||
text_split_config = Column(JSON, nullable=False, comment="文本切片配置")
|
||||
synthesis_config = Column(JSON, nullable=False, comment="合成配置")
|
||||
source_file_id = Column(JSON, nullable=False, comment="原始文件ID列表")
|
||||
total_files = Column(Integer, nullable=False, default=0, comment="总文件数")
|
||||
processed_files = Column(Integer, nullable=False, default=0, comment="已处理文件数")
|
||||
total_chunks = Column(Integer, nullable=False, default=0, comment="总文本块数")
|
||||
processed_chunks = Column(Integer, nullable=False, default=0, comment="已处理文本块数")
|
||||
total_synthesis_data = Column(Integer, nullable=False, default=0, comment="总合成数据量")
|
||||
|
||||
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), nullable=True, comment="创建时间")
|
||||
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), nullable=True, comment="更新时间")
|
||||
created_by = Column(String(255), nullable=True, comment="创建者")
|
||||
updated_by = Column(String(255), nullable=True, comment="更新者")
|
||||
|
||||
|
||||
class DataSynthesisFileInstance(Base):
|
||||
"""数据合成文件任务表,对应表 t_data_synthesis_file_instances
|
||||
|
||||
create table if not exists t_data_synthesis_file_instances (
|
||||
id VARCHAR(36) PRIMARY KEY COMMENT 'UUID',
|
||||
synthesis_instance_id VARCHAR(36) COMMENT '数据合成任务ID',
|
||||
file_name VARCHAR(255) NOT NULL COMMENT '文件名',
|
||||
source_file_id VARCHAR(255) NOT NULL COMMENT '原始文件ID',
|
||||
target_file_location VARCHAR(1000) NOT NULL COMMENT '目标文件存储位置',
|
||||
status VARCHAR(20) COMMENT '任务状态',
|
||||
total_chunks INT DEFAULT 0 COMMENT '总文本块数',
|
||||
processed_chunks INT DEFAULT 0 COMMENT '已处理文本块数',
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
|
||||
created_by VARCHAR(255) COMMENT '创建者',
|
||||
updated_by VARCHAR(255) COMMENT '更新者'
|
||||
) COMMENT='数据合成文件任务表(UUID 主键)';
|
||||
"""
|
||||
|
||||
__tablename__ = "t_data_synthesis_file_instances"
|
||||
|
||||
id = Column(String(36), primary_key=True, index=True, comment="UUID")
|
||||
synthesis_instance_id = Column(
|
||||
String(36),
|
||||
nullable=False,
|
||||
comment="数据合成任务ID",
|
||||
index=True,
|
||||
)
|
||||
file_name = Column(String(255), nullable=False, comment="文件名")
|
||||
source_file_id = Column(String(255), nullable=False, comment="原始文件ID")
|
||||
target_file_location = Column(String(1000), nullable=False, comment="目标文件存储位置")
|
||||
status = Column(String(20), nullable=True, comment="任务状态")
|
||||
total_chunks = Column(Integer, nullable=False, default=0, comment="总文本块数")
|
||||
processed_chunks = Column(Integer, nullable=False, default=0, comment="已处理文本块数")
|
||||
|
||||
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), nullable=True, comment="创建时间")
|
||||
updated_at = Column(
|
||||
TIMESTAMP,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
nullable=True,
|
||||
comment="更新时间",
|
||||
)
|
||||
created_by = Column(String(255), nullable=True, comment="创建者")
|
||||
updated_by = Column(String(255), nullable=True, comment="更新者")
|
||||
|
||||
|
||||
class DataSynthesisChunkInstance(Base):
|
||||
"""数据合成分块任务表,对应表 t_data_synthesis_chunk_instances
|
||||
|
||||
create table if not exists t_data_synthesis_chunk_instances (
|
||||
id VARCHAR(36) PRIMARY KEY COMMENT 'UUID',
|
||||
synthesis_file_instance_id VARCHAR(36) COMMENT '数据合成文件任务ID',
|
||||
chunk_index INT COMMENT '分块索引',
|
||||
chunk_content TEXT COMMENT '分块内容',
|
||||
metadata JSON COMMENT '分块元数据'
|
||||
) COMMENT='数据合成分块任务表(UUID 主键)';
|
||||
"""
|
||||
|
||||
__tablename__ = "t_data_synthesis_chunk_instances"
|
||||
|
||||
id = Column(String(36), primary_key=True, index=True, comment="UUID")
|
||||
synthesis_file_instance_id = Column(
|
||||
String(36),
|
||||
nullable=False,
|
||||
comment="数据合成文件任务ID",
|
||||
index=True,
|
||||
)
|
||||
chunk_index = Column(Integer, nullable=True, comment="分块索引")
|
||||
chunk_content = Column(Text, nullable=True, comment="分块内容")
|
||||
# SQLAlchemy Declarative 保留了属性名 'metadata',这里使用 chunk_metadata 作为属性名,
|
||||
# 底层列名仍为 'metadata' 以保持与表结构兼容。
|
||||
chunk_metadata = Column("metadata", JSON, nullable=True, comment="分块元数据")
|
||||
|
||||
|
||||
class SynthesisData(Base):
|
||||
"""数据合成结果表,对应表 t_synthesis_data
|
||||
|
||||
create table if not exists t_synthesis_data (
|
||||
id VARCHAR(36) PRIMARY KEY COMMENT 'UUID',
|
||||
data json COMMENT '合成的数据',
|
||||
synthesis_file_instance_id VARCHAR(36) COMMENT '数据合成文件任务ID',
|
||||
chunk_instance_id VARCHAR(36) COMMENT '分块任务ID'
|
||||
) COMMENT='数据合成任务队列表(UUID 主键)';
|
||||
"""
|
||||
|
||||
__tablename__ = "t_data_synthesis_data"
|
||||
|
||||
id = Column(String(36), primary_key=True, index=True, comment="UUID")
|
||||
data = Column(JSON, nullable=True, comment="合成的数据")
|
||||
synthesis_file_instance_id = Column(
|
||||
String(36),
|
||||
nullable=False,
|
||||
comment="数据合成文件任务ID",
|
||||
index=True,
|
||||
)
|
||||
chunk_instance_id = Column(
|
||||
String(36),
|
||||
nullable=False,
|
||||
comment="分块任务ID",
|
||||
index=True,
|
||||
)
|
||||
57
runtime/datamate-python/app/db/models/model_config.py
Normal file
57
runtime/datamate-python/app/db/models/model_config.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from sqlalchemy import Column, String, Integer, TIMESTAMP, select
|
||||
|
||||
from app.db.session import Base
|
||||
|
||||
|
||||
async def get_model_by_id(db_session, model_id: str):
|
||||
"""根据 ID 获取单个模型配置。"""
|
||||
result =await db_session.execute(select(ModelConfig).where(ModelConfig.id == model_id))
|
||||
model_config = result.scalar_one_or_none()
|
||||
return model_config
|
||||
|
||||
class ModelConfig(Base):
|
||||
"""模型配置表,对应表 t_model_config
|
||||
|
||||
CREATE TABLE IF NOT EXISTS t_model_config (
|
||||
id VARCHAR(36) PRIMARY KEY COMMENT '主键ID',
|
||||
model_name VARCHAR(100) NOT NULL COMMENT '模型名称(如 qwen2)',
|
||||
provider VARCHAR(50) NOT NULL COMMENT '模型提供商(如 Ollama、OpenAI、DeepSeek)',
|
||||
base_url VARCHAR(255) NOT NULL COMMENT 'API 基础地址',
|
||||
api_key VARCHAR(512) DEFAULT '' COMMENT 'API 密钥(无密钥则为空)',
|
||||
type VARCHAR(50) NOT NULL COMMENT '模型类型(如 chat、embedding)',
|
||||
is_enabled TINYINT DEFAULT 1 COMMENT '是否启用:1-启用,0-禁用',
|
||||
is_default TINYINT DEFAULT 0 COMMENT '是否默认:1-默认,0-非默认',
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
|
||||
created_by VARCHAR(255) COMMENT '创建者',
|
||||
updated_by VARCHAR(255) COMMENT '更新者',
|
||||
UNIQUE KEY uk_model_provider (model_name, provider)
|
||||
) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4 COMMENT ='模型配置表';
|
||||
"""
|
||||
|
||||
__tablename__ = "t_model_config"
|
||||
|
||||
id = Column(String(36), primary_key=True, index=True, comment="主键ID")
|
||||
model_name = Column(String(100), nullable=False, comment="模型名称(如 qwen2)")
|
||||
provider = Column(String(50), nullable=False, comment="模型提供商(如 Ollama、OpenAI、DeepSeek)")
|
||||
base_url = Column(String(255), nullable=False, comment="API 基础地址")
|
||||
api_key = Column(String(512), nullable=False, default="", comment="API 密钥(无密钥则为空)")
|
||||
type = Column(String(50), nullable=False, comment="模型类型(如 chat、embedding)")
|
||||
|
||||
# 使用 Integer 存储 TINYINT,后续可在业务层将 0/1 转为 bool
|
||||
is_enabled = Column(Integer, nullable=False, default=1, comment="是否启用:1-启用,0-禁用")
|
||||
is_default = Column(Integer, nullable=False, default=0, comment="是否默认:1-默认,0-非默认")
|
||||
|
||||
created_at = Column(TIMESTAMP, nullable=True, comment="创建时间")
|
||||
updated_at = Column(TIMESTAMP, nullable=True, comment="更新时间")
|
||||
created_by = Column(String(255), nullable=True, comment="创建者")
|
||||
updated_by = Column(String(255), nullable=True, comment="更新者")
|
||||
|
||||
__table_args__ = (
|
||||
# 与 DDL 中的 uk_model_provider 保持一致
|
||||
{
|
||||
"mysql_engine": "InnoDB",
|
||||
"mysql_charset": "utf8mb4",
|
||||
"comment": "模型配置表",
|
||||
},
|
||||
)
|
||||
@@ -15,8 +15,8 @@ engine = create_async_engine(
|
||||
|
||||
# 创建会话工厂
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False
|
||||
)
|
||||
|
||||
@@ -29,4 +29,3 @@ async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
@@ -24,7 +24,7 @@ logger = get_logger(__name__)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
|
||||
|
||||
# @startup
|
||||
logger.info("DataMate Python Backend starting...")
|
||||
|
||||
@@ -43,7 +43,7 @@ async def lifespan(app: FastAPI):
|
||||
logger.info(f"Label Studio: {settings.label_studio_base_url}")
|
||||
|
||||
yield
|
||||
|
||||
|
||||
# @shutdown
|
||||
logger.info("DataMate Python Backend shutting down ...\n\n")
|
||||
|
||||
@@ -105,11 +105,11 @@ async def root():
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host=settings.host,
|
||||
port=settings.port,
|
||||
reload=settings.debug,
|
||||
log_level=settings.log_level.lower()
|
||||
)
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ from fastapi import APIRouter
|
||||
from .system.interface import router as system_router
|
||||
from .annotation.interface import router as annotation_router
|
||||
from .synthesis.interface import router as ratio_router
|
||||
from .generation.interface import router as generation_router
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api"
|
||||
@@ -11,5 +12,6 @@ router = APIRouter(
|
||||
router.include_router(system_router)
|
||||
router.include_router(annotation_router)
|
||||
router.include_router(ratio_router)
|
||||
router.include_router(generation_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/synth",
|
||||
tags = ["synth"]
|
||||
)
|
||||
|
||||
# Include sub-routers
|
||||
from .generation_api import router as generation_router_router
|
||||
|
||||
router.include_router(generation_router_router)
|
||||
@@ -0,0 +1,260 @@
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
|
||||
from sqlalchemy import select, func, delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.db.models.data_synthesis import (
|
||||
save_synthesis_task,
|
||||
DataSynthesisInstance,
|
||||
DataSynthesisFileInstance,
|
||||
DataSynthesisChunkInstance,
|
||||
SynthesisData,
|
||||
)
|
||||
from app.db.models.dataset_management import DatasetFiles
|
||||
from app.db.models.model_config import get_model_by_id
|
||||
from app.db.session import get_db
|
||||
from app.module.generation.schema.generation import (
|
||||
CreateSynthesisTaskRequest,
|
||||
DataSynthesisTaskItem,
|
||||
PagedDataSynthesisTaskResponse, SynthesisType)
|
||||
from app.module.generation.service.generation_service import GenerationService
|
||||
from app.module.generation.service.prompt import get_prompt
|
||||
from app.module.shared.schema import StandardResponse
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/gen",
|
||||
tags=["gen"]
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@router.post("/task", response_model=StandardResponse[DataSynthesisTaskItem])
|
||||
async def create_synthesis_task(
|
||||
request: CreateSynthesisTaskRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""创建数据合成任务"""
|
||||
result = await get_model_by_id(db, request.model_id)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
# 先根据 source_file_id 在 DatasetFiles 中查出已有文件信息
|
||||
file_ids = request.source_file_id or []
|
||||
dataset_files = []
|
||||
if file_ids:
|
||||
ds_result = await db.execute(
|
||||
select(DatasetFiles).where(DatasetFiles.id.in_(file_ids))
|
||||
)
|
||||
dataset_files = ds_result.scalars().all()
|
||||
|
||||
# 保存任务到数据库
|
||||
request.source_file_id = [str(f.id) for f in dataset_files]
|
||||
synthesis_task = await save_synthesis_task(db, request)
|
||||
|
||||
# 将已有的 DatasetFiles 记录保存到 t_data_synthesis_file_instances
|
||||
for f in dataset_files:
|
||||
file_instance = DataSynthesisFileInstance(
|
||||
id=str(uuid.uuid4()), # 使用新的 UUID 作为文件任务记录的主键,避免与 DatasetFiles 主键冲突
|
||||
synthesis_instance_id=synthesis_task.id,
|
||||
file_name=f.file_name,
|
||||
source_file_id=str(f.id),
|
||||
target_file_location=synthesis_task.result_data_location or "",
|
||||
status="pending",
|
||||
total_chunks=0,
|
||||
processed_chunks=0,
|
||||
created_by="system",
|
||||
updated_by="system",
|
||||
)
|
||||
db.add(file_instance)
|
||||
|
||||
if dataset_files:
|
||||
await db.commit()
|
||||
|
||||
generation_service = GenerationService(db)
|
||||
# 异步处理任务:只传任务 ID,后台任务中使用新的 DB 会话重新加载任务对象
|
||||
background_tasks.add_task(generation_service.process_task, synthesis_task.id)
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=synthesis_task,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/task/{task_id}", response_model=StandardResponse[DataSynthesisTaskItem])
|
||||
async def get_synthesis_task(
|
||||
task_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取数据合成任务详情"""
|
||||
result = await db.get(DataSynthesisInstance, task_id)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Synthesis task not found")
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=result,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tasks", response_model=StandardResponse[PagedDataSynthesisTaskResponse], status_code=200)
|
||||
async def list_synthesis_tasks(
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
synthesis_type: str | None = None,
|
||||
status: str | None = None,
|
||||
name: str | None = None,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""分页列出所有数据合成任务"""
|
||||
query = select(DataSynthesisInstance)
|
||||
if synthesis_type:
|
||||
query = query.filter(DataSynthesisInstance.synthesis_type == synthesis_type)
|
||||
if status:
|
||||
query = query.filter(DataSynthesisInstance.status == status)
|
||||
if name:
|
||||
query = query.filter(DataSynthesisInstance.name.like(f"%{name}%"))
|
||||
|
||||
count_q = select(func.count()).select_from(query.subquery())
|
||||
total = (await db.execute(count_q)).scalar_one()
|
||||
|
||||
if page < 1:
|
||||
page = 1
|
||||
if page_size < 1:
|
||||
page_size = 10
|
||||
|
||||
result = await db.execute(query.offset((page - 1) * page_size).limit(page_size))
|
||||
rows = result.scalars().all()
|
||||
|
||||
task_items = [
|
||||
DataSynthesisTaskItem(
|
||||
id=row.id,
|
||||
name=row.name,
|
||||
description=row.description,
|
||||
status=row.status,
|
||||
synthesis_type=row.synthesis_type,
|
||||
model_id=row.model_id,
|
||||
progress=row.progress,
|
||||
result_data_location=row.result_data_location,
|
||||
text_split_config=row.text_split_config,
|
||||
synthesis_config=row.synthesis_config,
|
||||
source_file_id=row.source_file_id,
|
||||
total_files=row.total_files,
|
||||
processed_files=row.processed_files,
|
||||
total_chunks=row.total_chunks,
|
||||
processed_chunks=row.processed_chunks,
|
||||
total_synthesis_data=row.total_synthesis_data,
|
||||
created_at=row.created_at,
|
||||
updated_at=row.updated_at,
|
||||
created_by=row.created_by,
|
||||
updated_by=row.updated_by,
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
paged = PagedDataSynthesisTaskResponse(
|
||||
content=task_items,
|
||||
totalElements=total,
|
||||
totalPages=(total + page_size - 1) // page_size,
|
||||
page=page,
|
||||
size=page_size,
|
||||
)
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="Success",
|
||||
data=paged,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/task/{task_id}", response_model=StandardResponse[None])
|
||||
async def delete_synthesis_task(
|
||||
task_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""删除数据合成任务"""
|
||||
task = await db.get(DataSynthesisInstance, task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Synthesis task not found")
|
||||
|
||||
# 1. 删除与该任务相关的 SynthesisData、Chunk、File 记录
|
||||
# 先查出所有文件任务 ID
|
||||
file_result = await db.execute(
|
||||
select(DataSynthesisFileInstance.id).where(
|
||||
DataSynthesisFileInstance.synthesis_instance_id == task_id
|
||||
)
|
||||
)
|
||||
file_ids = [row[0] for row in file_result.all()]
|
||||
|
||||
if file_ids:
|
||||
# 删除 SynthesisData(根据文件任务ID)
|
||||
await db.execute(delete(SynthesisData).where(
|
||||
SynthesisData.synthesis_file_instance_id.in_(file_ids)
|
||||
)
|
||||
)
|
||||
|
||||
# 删除 Chunk 记录
|
||||
await db.execute(delete(DataSynthesisChunkInstance).where(
|
||||
DataSynthesisChunkInstance.synthesis_file_instance_id.in_(file_ids)
|
||||
)
|
||||
)
|
||||
|
||||
# 删除文件任务记录
|
||||
await db.execute(delete(DataSynthesisFileInstance).where(
|
||||
DataSynthesisFileInstance.id.in_(file_ids)
|
||||
)
|
||||
)
|
||||
|
||||
# 2. 删除任务本身
|
||||
await db.delete(task)
|
||||
await db.commit()
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=None,
|
||||
)
|
||||
|
||||
@router.delete("/task/{task_id}/{file_id}", response_model=StandardResponse[None])
|
||||
async def delete_synthesis_file_task(
|
||||
task_id: str,
|
||||
file_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""删除数据合成任务中的文件任务"""
|
||||
file_task = await db.get(DataSynthesisFileInstance, file_id)
|
||||
if not file_task:
|
||||
raise HTTPException(status_code=404, detail="Synthesis file task not found")
|
||||
|
||||
# 删除 SynthesisData(根据文件任务ID)
|
||||
await db.execute(delete(SynthesisData).where(
|
||||
SynthesisData.synthesis_file_instance_id == file_id
|
||||
)
|
||||
)
|
||||
|
||||
# 删除 Chunk 记录
|
||||
await db.execute(delete(DataSynthesisChunkInstance).where(
|
||||
DataSynthesisChunkInstance.synthesis_file_instance_id == file_id
|
||||
)
|
||||
)
|
||||
|
||||
# 删除文件任务记录
|
||||
await db.execute(delete(DataSynthesisFileInstance).where(
|
||||
DataSynthesisFileInstance.id == file_id
|
||||
)
|
||||
)
|
||||
|
||||
@router.get("/prompt", response_model=StandardResponse[str])
|
||||
async def get_prompt_by_type(
|
||||
synth_type: SynthesisType,
|
||||
):
|
||||
prompt = get_prompt(synth_type)
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="Success",
|
||||
data=prompt,
|
||||
)
|
||||
@@ -0,0 +1,76 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TextSplitConfig(BaseModel):
|
||||
"""文本切片配置"""
|
||||
chunk_size: int = Field(..., description="最大令牌数")
|
||||
chunk_overlap: int = Field(..., description="重叠令牌数")
|
||||
|
||||
|
||||
class SynthesisConfig(BaseModel):
|
||||
"""合成配置"""
|
||||
prompt_template: str = Field(..., description="合成提示模板")
|
||||
synthesis_count: int = Field(None, description="单个chunk合成的数据数量")
|
||||
temperature: Optional[float] = Field(None, description="温度参数")
|
||||
|
||||
|
||||
class SynthesisType(Enum):
|
||||
"""合成类型"""
|
||||
QA = "QA"
|
||||
COT = "COT"
|
||||
|
||||
|
||||
class CreateSynthesisTaskRequest(BaseModel):
|
||||
"""创建数据合成任务请求"""
|
||||
name: str = Field(..., description="合成任务名称")
|
||||
description: str = Field(None, description="合成任务描述")
|
||||
model_id: str = Field(..., description="模型ID")
|
||||
source_file_id: list[str] = Field(..., description="原始文件ID列表")
|
||||
text_split_config: TextSplitConfig = Field(None, description="文本切片配置")
|
||||
synthesis_config: SynthesisConfig = Field(..., description="合成配置")
|
||||
synthesis_type: SynthesisType = Field(..., description="合成类型")
|
||||
|
||||
|
||||
class DataSynthesisTaskItem(BaseModel):
|
||||
"""数据合成任务列表/详情项"""
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
synthesis_type: str
|
||||
model_id: str
|
||||
progress: int
|
||||
result_data_location: Optional[str] = None
|
||||
text_split_config: Dict[str, Any]
|
||||
synthesis_config: Dict[str, Any]
|
||||
source_file_id: list[str]
|
||||
total_files: int
|
||||
processed_files: int
|
||||
total_chunks: int
|
||||
processed_chunks: int
|
||||
total_synthesis_data: int
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
created_by: Optional[str] = None
|
||||
updated_by: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
|
||||
class PagedDataSynthesisTaskResponse(BaseModel):
|
||||
"""分页数据合成任务响应"""
|
||||
content: List[DataSynthesisTaskItem]
|
||||
totalElements: int
|
||||
totalPages: int
|
||||
page: int
|
||||
size: int
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
"""聊天请求参数"""
|
||||
model_id: str
|
||||
prompt: str
|
||||
@@ -0,0 +1,544 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from langchain_community.document_loaders import (
|
||||
TextLoader,
|
||||
CSVLoader,
|
||||
JSONLoader,
|
||||
UnstructuredMarkdownLoader,
|
||||
UnstructuredHTMLLoader,
|
||||
UnstructuredFileLoader,
|
||||
PyPDFLoader,
|
||||
UnstructuredWordDocumentLoader,
|
||||
UnstructuredPowerPointLoader,
|
||||
UnstructuredExcelLoader,
|
||||
)
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models.data_synthesis import (
|
||||
DataSynthesisInstance,
|
||||
DataSynthesisFileInstance,
|
||||
DataSynthesisChunkInstance,
|
||||
SynthesisData,
|
||||
)
|
||||
from app.db.models.dataset_management import DatasetFiles
|
||||
from app.db.models.model_config import get_model_by_id
|
||||
from app.db.session import logger
|
||||
from app.module.system.service.common_service import get_chat_client, chat
|
||||
|
||||
|
||||
class GenerationService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def process_task(self, task_id: str):
|
||||
"""处理数据合成任务入口:根据任务ID加载任务并逐个处理源文件。"""
|
||||
synthesis_task: DataSynthesisInstance | None = await self.db.get(DataSynthesisInstance, task_id)
|
||||
if not synthesis_task:
|
||||
logger.error(f"Synthesis task {task_id} not found, abort processing")
|
||||
return
|
||||
|
||||
logger.info(f"Processing synthesis task {task_id}")
|
||||
file_ids = synthesis_task.source_file_id or []
|
||||
|
||||
# 获取模型客户端
|
||||
model_result = await get_model_by_id(self.db, str(synthesis_task.model_id))
|
||||
if not model_result:
|
||||
logger.error(
|
||||
f"Model config not found for id={synthesis_task.model_id}, abort task {synthesis_task.id}"
|
||||
)
|
||||
return
|
||||
chat_client = get_chat_client(model_result)
|
||||
|
||||
# 控制并发度的信号量(限制全任务范围内最多 10 个并发调用)
|
||||
semaphore = asyncio.Semaphore(10)
|
||||
|
||||
# 逐个文件处理
|
||||
for file_id in file_ids:
|
||||
try:
|
||||
success = await self._process_single_file(
|
||||
synthesis_task=synthesis_task,
|
||||
file_id=file_id,
|
||||
chat_client=chat_client,
|
||||
semaphore=semaphore,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error when processing file {file_id} for task {task_id}: {e}")
|
||||
# 确保对应文件任务状态标记为失败
|
||||
await self._mark_file_failed(str(synthesis_task.id), file_id, str(e))
|
||||
success = False
|
||||
|
||||
if success:
|
||||
# 每处理完一个文件,简单增加 processed_files 计数
|
||||
synthesis_task.processed_files = (synthesis_task.processed_files or 0) + 1
|
||||
await self.db.commit()
|
||||
await self.db.refresh(synthesis_task)
|
||||
|
||||
logger.info(f"Finished processing synthesis task {synthesis_task.id}")
|
||||
|
||||
async def _process_single_file(
|
||||
self,
|
||||
synthesis_task: DataSynthesisInstance,
|
||||
file_id: str,
|
||||
chat_client,
|
||||
semaphore: asyncio.Semaphore,
|
||||
) -> bool:
|
||||
"""处理单个源文件:解析路径、切片、保存分块并触发 LLM 调用。"""
|
||||
file_path = await self._resolve_file_path(file_id)
|
||||
if not file_path:
|
||||
logger.warning(f"File path not found for file_id={file_id}, skip")
|
||||
await self._mark_file_failed(str(synthesis_task.id), file_id, "file_path_not_found")
|
||||
return False
|
||||
|
||||
logger.info(f"Processing file_id={file_id}, path={file_path}")
|
||||
|
||||
split_cfg = synthesis_task.text_split_config or {}
|
||||
synthesis_cfg = synthesis_task.synthesis_config or {}
|
||||
chunk_size = int(split_cfg.get("chunk_size", 800))
|
||||
chunk_overlap = int(split_cfg.get("chunk_overlap", 50))
|
||||
# 加载并切片
|
||||
try:
|
||||
chunks = self._load_and_split(file_path, chunk_size, chunk_overlap)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load/split file {file_path}: {e}")
|
||||
await self._mark_file_failed(str(synthesis_task.id), file_id, f"load_split_error: {e}")
|
||||
return False
|
||||
|
||||
if not chunks:
|
||||
logger.warning(f"No chunks generated for file_id={file_id}")
|
||||
await self._mark_file_failed(str(synthesis_task.id), file_id, "no_chunks_generated")
|
||||
return False
|
||||
|
||||
logger.info(f"File {file_id} split into {len(chunks)} chunks by LangChain")
|
||||
|
||||
# 保存文件任务记录 + 分块记录
|
||||
file_task = await self._get_or_create_file_instance(
|
||||
synthesis_task_id=str(synthesis_task.id),
|
||||
source_file_id=file_id,
|
||||
file_path=file_path,
|
||||
)
|
||||
await self._persist_chunks(synthesis_task, file_task, file_id, chunks)
|
||||
|
||||
# 针对每个切片并发调用大模型
|
||||
await self._invoke_llm_for_chunks(
|
||||
synthesis_task=synthesis_task,
|
||||
file_id=file_id,
|
||||
chunks=chunks,
|
||||
synthesis_cfg=synthesis_cfg,
|
||||
chat_client=chat_client,
|
||||
semaphore=semaphore,
|
||||
)
|
||||
|
||||
# 如果执行到此处,说明该文件的切片与 LLM 调用流程均未抛出异常,标记为完成
|
||||
file_task.status = "completed"
|
||||
await self.db.commit()
|
||||
await self.db.refresh(file_task)
|
||||
|
||||
return True
|
||||
|
||||
async def _persist_chunks(
|
||||
self,
|
||||
synthesis_task: DataSynthesisInstance,
|
||||
file_task: DataSynthesisFileInstance,
|
||||
file_id: str,
|
||||
chunks,
|
||||
) -> None:
|
||||
"""将切片结果保存到 t_data_synthesis_chunk_instances,并更新文件级分块计数。"""
|
||||
for idx, doc in enumerate(chunks, start=1):
|
||||
# 先复制原始 Document.metadata,再在其上追加任务相关字段,避免覆盖原有元数据
|
||||
base_metadata = dict(getattr(doc, "metadata", {}) or {})
|
||||
base_metadata.update(
|
||||
{
|
||||
"task_id": str(synthesis_task.id),
|
||||
"file_id": file_id
|
||||
}
|
||||
)
|
||||
|
||||
chunk_record = DataSynthesisChunkInstance(
|
||||
id=str(uuid.uuid4()),
|
||||
synthesis_file_instance_id=file_task.id,
|
||||
chunk_index=idx,
|
||||
chunk_content=doc.page_content,
|
||||
chunk_metadata=base_metadata,
|
||||
)
|
||||
self.db.add(chunk_record)
|
||||
|
||||
# 更新文件任务的分块数量
|
||||
file_task.chunk_count = len(chunks)
|
||||
file_task.status = "processing"
|
||||
|
||||
await self.db.refresh(file_task)
|
||||
await self.db.commit()
|
||||
|
||||
async def _invoke_llm_for_chunks(
|
||||
self,
|
||||
synthesis_task: DataSynthesisInstance,
|
||||
file_id: str,
|
||||
chunks,
|
||||
synthesis_cfg: dict,
|
||||
chat_client,
|
||||
semaphore: asyncio.Semaphore,
|
||||
) -> None:
|
||||
"""针对每个分片并发调用大模型生成数据。"""
|
||||
# 需要将 answer 和对应 chunk 建立关系,因此这里保留 chunk_index
|
||||
tasks = [
|
||||
self._call_llm(doc, file_id, idx, synthesis_task, synthesis_cfg, chat_client, semaphore)
|
||||
for idx, doc in enumerate(chunks, start=1)
|
||||
]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
async def _call_llm(
|
||||
self,
|
||||
doc,
|
||||
file_id: str,
|
||||
idx: int,
|
||||
synthesis_task,
|
||||
synthesis_cfg: dict,
|
||||
chat_client,
|
||||
semaphore: asyncio.Semaphore,
|
||||
):
|
||||
"""单次大模型调用逻辑,带并发控制。
|
||||
|
||||
说明:
|
||||
- 使用信号量限制全局并发量(当前为 10)。
|
||||
- 使用线程池执行同步的 chat 调用,避免阻塞事件循环。
|
||||
- 在拿到 LLM 返回后,解析为 JSON 并批量写入 SynthesisData,
|
||||
同时更新文件级 processed_chunks / 进度等信息。
|
||||
"""
|
||||
async with semaphore:
|
||||
prompt = self._build_qa_prompt(doc.page_content, synthesis_cfg)
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
answer = await loop.run_in_executor(None, chat, chat_client, prompt)
|
||||
logger.debug(
|
||||
f"Generated QA for task={synthesis_task.id}, file={file_id}, chunk={idx}"
|
||||
)
|
||||
await self._handle_llm_answer(
|
||||
synthesis_task_id=str(synthesis_task.id),
|
||||
file_id=file_id,
|
||||
chunk_index=idx,
|
||||
raw_answer=answer,
|
||||
)
|
||||
return answer
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"LLM generation failed for task={synthesis_task.id}, file={file_id}, chunk={idx}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
async def _resolve_file_path(self, file_id: str) -> str | None:
|
||||
"""根据文件ID查询 t_dm_dataset_files 并返回 file_path(仅 ACTIVE 文件)。"""
|
||||
result = await self.db.execute(
|
||||
select(DatasetFiles).where(DatasetFiles.id == file_id)
|
||||
)
|
||||
file_obj = result.scalar_one_or_none()
|
||||
if not file_obj:
|
||||
return None
|
||||
return file_obj.file_path
|
||||
|
||||
def _load_and_split(self, file_path: str, chunk_size: int, chunk_overlap: int):
|
||||
"""使用 LangChain 加载文本并进行切片,直接返回 Document 列表。
|
||||
|
||||
当前实现:
|
||||
- 使用 TextLoader 加载纯文本/Markdown/JSON 等文本文件
|
||||
- 使用 RecursiveCharacterTextSplitter 做基于字符的递归切片
|
||||
|
||||
保留每个 Document 的 metadata,方便后续追加例如文件ID、chunk序号等信息。
|
||||
"""
|
||||
loader = self._build_loader(file_path)
|
||||
docs = loader.load()
|
||||
|
||||
splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
# 尝试按这些分隔符优先切分,再退化到字符级
|
||||
separators=["\n\n", "\n", "。", "!", "?", "!", "?", "。\n", "\t", " "]
|
||||
)
|
||||
split_docs = splitter.split_documents(docs)
|
||||
return split_docs
|
||||
|
||||
@staticmethod
|
||||
def _build_loader(file_path: str):
|
||||
"""根据文件扩展名选择合适的 LangChain 文本加载器,尽量覆盖常见泛文本格式。
|
||||
|
||||
优先按格式选择专门的 Loader,找不到匹配时退回到 TextLoader。
|
||||
"""
|
||||
path = Path(file_path)
|
||||
suffix = path.suffix.lower()
|
||||
path_str = str(path)
|
||||
|
||||
# 1. 纯文本类
|
||||
if suffix in {".txt", "", ".log"}: # "" 兼容无扩展名
|
||||
return TextLoader(path_str, encoding="utf-8")
|
||||
|
||||
# 2. Markdown
|
||||
if suffix in {".md", ".markdown"}:
|
||||
# UnstructuredMarkdownLoader 会保留更多结构信息
|
||||
return UnstructuredMarkdownLoader(path_str)
|
||||
|
||||
# 3. HTML / HTM
|
||||
if suffix in {".html", ".htm"}:
|
||||
return UnstructuredHTMLLoader(path_str)
|
||||
|
||||
# 4. JSON
|
||||
if suffix == ".json":
|
||||
# 使用 JSONLoader 将 JSON 中的内容展开成文档
|
||||
# 这里使用默认 jq_schema,后续需要更精细地提取可以在此调整
|
||||
return JSONLoader(file_path=path_str, jq_schema=".")
|
||||
|
||||
# 5. CSV / TSV
|
||||
if suffix in {".csv", ".tsv"}:
|
||||
# CSVLoader 默认将每一行作为一条 Document
|
||||
return CSVLoader(file_path=path_str)
|
||||
|
||||
# 6. YAML
|
||||
if suffix in {".yaml", ".yml"}:
|
||||
# 暂时按纯文本加载
|
||||
return TextLoader(path_str, encoding="utf-8")
|
||||
|
||||
# 7. PDF
|
||||
if suffix == ".pdf":
|
||||
return PyPDFLoader(path_str)
|
||||
|
||||
# 8. Word 文档
|
||||
if suffix in {".docx", ".doc"}:
|
||||
# UnstructuredWordDocumentLoader 支持 .docx/.doc 文本抽取
|
||||
return UnstructuredWordDocumentLoader(path_str)
|
||||
|
||||
# 9. PowerPoint
|
||||
if suffix in {".ppt", ".pptx"}:
|
||||
return UnstructuredPowerPointLoader(path_str)
|
||||
|
||||
# 10. Excel
|
||||
if suffix in {".xls", ".xlsx"}:
|
||||
return UnstructuredExcelLoader(path_str)
|
||||
|
||||
# 11. 兜底:使用 UnstructuredFileLoader 或 TextLoader 作为纯文本
|
||||
try:
|
||||
return UnstructuredFileLoader(path_str)
|
||||
except Exception:
|
||||
return TextLoader(path_str, encoding="utf-8")
|
||||
|
||||
@staticmethod
|
||||
def _build_qa_prompt(chunk: str, synthesis_cfg: dict) -> str:
|
||||
"""构造 QA 数据合成的提示词。
|
||||
|
||||
要求:
|
||||
- synthesis_cfg["prompt_template"] 是一个字符串,其中包含 {document} 占位符;
|
||||
- 将当前切片内容替换到 {document}。
|
||||
如果未提供或模板非法,则使用内置默认模板。
|
||||
"""
|
||||
template = None
|
||||
if isinstance(synthesis_cfg, dict):
|
||||
template = synthesis_cfg.get("prompt_template")
|
||||
synthesis_count = synthesis_cfg["synthesis_count"] if ("synthesis_count" in synthesis_cfg and synthesis_cfg["synthesis_count"]) else 5
|
||||
try:
|
||||
prompt = template.format(document=chunk, synthesis_count=synthesis_count)
|
||||
except Exception:
|
||||
# 防御性处理:如果 format 出现异常,则退回到简单拼接
|
||||
prompt = f"{template}\n\n文档内容:{chunk}\n\n请根据文档内容生成 {synthesis_count} 条符合要求的问答数据。"
|
||||
return prompt
|
||||
|
||||
async def _handle_llm_answer(
|
||||
self,
|
||||
synthesis_task_id: str,
|
||||
file_id: str,
|
||||
chunk_index: int,
|
||||
raw_answer: str,
|
||||
) -> None:
|
||||
"""解析 LLM 返回内容为 JSON,批量保存到 SynthesisData,并更新文件任务进度。
|
||||
|
||||
约定:
|
||||
- LLM 返回的 raw_answer 是 JSON 字符串,可以是:
|
||||
1)单个对象:{"question": ..., "answer": ...}
|
||||
2)对象数组:[{}, {}, ...]
|
||||
- 我们将其规范化为列表,每个元素作为一条 SynthesisData.data 写入。
|
||||
- 根据 synthesis_task_id + file_id + chunk_index 找到对应的 DataSynthesisChunkInstance,
|
||||
以便设置 chunk_instance_id 和 synthesis_file_instance_id。
|
||||
- 每处理完一个 chunk,递增对应 DataSynthesisFileInstance.processed_chunks,并按比例更新进度。
|
||||
"""
|
||||
if not raw_answer:
|
||||
return
|
||||
|
||||
# 1. 预处理原始回答:尝试从中截取出最可能的 JSON 片段
|
||||
cleaned = self._extract_json_substring(raw_answer)
|
||||
|
||||
# 2. 解析 JSON,统一成列表结构
|
||||
try:
|
||||
parsed = json.loads(cleaned)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to parse LLM answer as JSON for task={synthesis_task_id}, file={file_id}, chunk={chunk_index}: {e}. Raw answer: {raw_answer!r}"
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(parsed, dict):
|
||||
items = [parsed]
|
||||
elif isinstance(parsed, list):
|
||||
items = [p for p in parsed if isinstance(p, dict)]
|
||||
else:
|
||||
logger.error(f"Unexpected JSON structure from LLM answer for task={synthesis_task_id}, file={file_id}, chunk={chunk_index}: {type(parsed)}")
|
||||
return
|
||||
|
||||
if not items:
|
||||
return
|
||||
|
||||
# 3. 找到对应的 chunk 记录(一个 chunk_index 对应一条记录)
|
||||
chunk_result = await self.db.execute(
|
||||
select(DataSynthesisChunkInstance, DataSynthesisFileInstance)
|
||||
.join(
|
||||
DataSynthesisFileInstance,
|
||||
DataSynthesisFileInstance.id == DataSynthesisChunkInstance.synthesis_file_instance_id,
|
||||
)
|
||||
.where(
|
||||
DataSynthesisFileInstance.synthesis_instance_id == synthesis_task_id,
|
||||
DataSynthesisFileInstance.source_file_id == file_id,
|
||||
DataSynthesisChunkInstance.chunk_index == chunk_index,
|
||||
)
|
||||
)
|
||||
row = chunk_result.first()
|
||||
if not row:
|
||||
logger.error(
|
||||
f"Chunk record not found for task={synthesis_task_id}, file={file_id}, chunk_index={chunk_index}, skip saving SynthesisData."
|
||||
)
|
||||
return
|
||||
|
||||
chunk_instance, file_instance = row
|
||||
|
||||
# 4. 批量写入 SynthesisData
|
||||
for data_obj in items:
|
||||
record = SynthesisData(
|
||||
id=str(uuid.uuid4()),
|
||||
data=data_obj,
|
||||
synthesis_file_instance_id=file_instance.id,
|
||||
chunk_instance_id=chunk_instance.id,
|
||||
)
|
||||
self.db.add(record)
|
||||
|
||||
# 5. 更新文件级 processed_chunks / 进度
|
||||
file_instance.processed_chunks = (file_instance.processed_chunks or 0) + 1
|
||||
|
||||
|
||||
await self.db.commit()
|
||||
await self.db.refresh(file_instance)
|
||||
|
||||
@staticmethod
|
||||
def _extract_json_substring(raw: str) -> str:
|
||||
"""从 LLM 的原始回答中提取最可能的 JSON 字符串片段。
|
||||
|
||||
处理思路:
|
||||
- 原始回答可能是:说明文字 + JSON + 说明文字,甚至带有 Markdown 代码块。
|
||||
- 优先在文本中查找第一个 '{' 或 '[' 作为 JSON 起始;
|
||||
- 再从后向前找最后一个 '}' 或 ']' 作为结束;
|
||||
- 如果找不到合适的边界,就退回原始字符串。
|
||||
该方法不会保证截取的一定是合法 JSON,但能显著提高 json.loads 的成功率。
|
||||
"""
|
||||
if not raw:
|
||||
return raw
|
||||
|
||||
start = None
|
||||
end = None
|
||||
|
||||
# 查找第一个 JSON 起始符号
|
||||
for i, ch in enumerate(raw):
|
||||
if ch in "[{":
|
||||
start = i
|
||||
break
|
||||
|
||||
# 查找最后一个 JSON 结束符号
|
||||
for i in range(len(raw) - 1, -1, -1):
|
||||
if raw[i] in "]}":
|
||||
end = i + 1 # 切片是左闭右开
|
||||
break
|
||||
|
||||
if start is not None and end is not None and start < end:
|
||||
return raw[start:end].strip()
|
||||
|
||||
# 兜底:去掉常见 Markdown 包裹(```json ... ```)
|
||||
stripped = raw.strip()
|
||||
if stripped.startswith("```"):
|
||||
# 去掉首尾 ``` 标记
|
||||
stripped = stripped.strip("`")
|
||||
return stripped
|
||||
|
||||
async def _get_or_create_file_instance(
|
||||
self,
|
||||
synthesis_task_id: str,
|
||||
source_file_id: str,
|
||||
file_path: str,
|
||||
) -> DataSynthesisFileInstance:
|
||||
"""根据任务ID和原始文件ID,查找或创建对应的 DataSynthesisFileInstance 记录。
|
||||
|
||||
- 如果已存在(同一任务 + 同一 source_file_id),直接返回;
|
||||
- 如果不存在,则创建一条新的文件任务记录,file_name 来自文件路径,
|
||||
target_file_location 先复用任务的 result_data_location。
|
||||
"""
|
||||
# 尝试查询已有文件任务记录
|
||||
result = await self.db.execute(
|
||||
select(DataSynthesisFileInstance).where(
|
||||
DataSynthesisFileInstance.synthesis_instance_id == synthesis_task_id,
|
||||
DataSynthesisFileInstance.source_file_id == source_file_id,
|
||||
)
|
||||
)
|
||||
file_task = result.scalar_one_or_none()
|
||||
if file_task is not None:
|
||||
return file_task
|
||||
|
||||
# 查询任务以获取 result_data_location
|
||||
task = await self.db.get(DataSynthesisInstance, synthesis_task_id)
|
||||
target_location = task.result_data_location if task else ""
|
||||
|
||||
# 创建新的文件任务记录,初始状态为 processing
|
||||
file_task = DataSynthesisFileInstance(
|
||||
id=str(uuid.uuid4()),
|
||||
synthesis_instance_id=synthesis_task_id,
|
||||
file_name=Path(file_path).name,
|
||||
source_file_id=source_file_id,
|
||||
target_file_location=target_location or "",
|
||||
status="processing",
|
||||
total_chunks=0,
|
||||
processed_chunks=0,
|
||||
created_by="system",
|
||||
updated_by="system",
|
||||
)
|
||||
self.db.add(file_task)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(file_task)
|
||||
return file_task
|
||||
|
||||
async def _mark_file_failed(self, synthesis_task_id: str, file_id: str, reason: str | None = None) -> None:
|
||||
"""将指定任务下的单个文件任务标记为失败状态,兜底错误处理。
|
||||
|
||||
- 如果找到对应的 DataSynthesisFileInstance,则更新其 status="failed"。
|
||||
- 如果未找到,则静默返回,仅记录日志。
|
||||
- reason 参数仅用于日志记录,方便排查。
|
||||
"""
|
||||
try:
|
||||
result = await self.db.execute(
|
||||
select(DataSynthesisFileInstance).where(
|
||||
DataSynthesisFileInstance.synthesis_instance_id == synthesis_task_id,
|
||||
DataSynthesisFileInstance.source_file_id == file_id,
|
||||
)
|
||||
)
|
||||
file_task = result.scalar_one_or_none()
|
||||
if not file_task:
|
||||
logger.warning(
|
||||
f"Failed to mark file as failed: no DataSynthesisFileInstance found for task={synthesis_task_id}, file_id={file_id}, reason={reason}"
|
||||
)
|
||||
return
|
||||
|
||||
file_task.status = "failed"
|
||||
await self.db.commit()
|
||||
await self.db.refresh(file_task)
|
||||
logger.info(
|
||||
f"Marked file task as failed for task={synthesis_task_id}, file_id={file_id}, reason={reason}"
|
||||
)
|
||||
except Exception as e:
|
||||
# 兜底日志,避免异常向外传播影响其它文件处理
|
||||
logger.exception(
|
||||
f"Unexpected error when marking file failed for task={synthesis_task_id}, file_id={file_id}, original_reason={reason}, error={e}"
|
||||
)
|
||||
@@ -0,0 +1,73 @@
|
||||
from app.module.generation.schema.generation import SynthesisType
|
||||
|
||||
QA_PROMPT="""
|
||||
# 角色
|
||||
你是一位专业的AI助手,擅长从给定的文本中提取关键信息并创建用于教学和测试的问答对。
|
||||
|
||||
# 任务
|
||||
请根据用户提供的原始文档,生成一系列高质量、多样化的问答对。
|
||||
|
||||
# 输入文档
|
||||
{document}
|
||||
|
||||
# 要求与指令
|
||||
1. **问题类型**:生成{synthesis_count - 1}-{synthesis_count + 1}个问答对。问题类型应多样化,包括但不限于:
|
||||
* **事实性**:基于文本中明确提到的事实。
|
||||
* **理解性**:需要理解上下文和概念。
|
||||
* **归纳性**:需要总结或归纳多个信息点。
|
||||
2. **答案来源**:所有答案必须严格基于提供的文档内容,不得编造原文不存在的信息。
|
||||
3. **语言**:请根据输入文档的主要语言进行提问和回答。
|
||||
4. **问题质量**:问题应清晰、无歧义,并且是读完文档后自然会产生的问题。
|
||||
5. **答案质量**:答案应准确、简洁、完整。
|
||||
|
||||
# 输出格式
|
||||
请严格按照以下JSON格式输出,确保没有额外的解释或标记:
|
||||
[
|
||||
{{"instruction": "问题1","input": "参考内容1","output": "答案1"}},
|
||||
{{"instruction": "问题2","input": "参考内容1","output": "答案2"}},
|
||||
...
|
||||
]
|
||||
"""
|
||||
|
||||
|
||||
COT_PROMPT="""
|
||||
# 角色
|
||||
你是一位专业的数据合成专家,擅长基于给定的原始文档和 COT(Chain of Thought,思维链)逻辑,生成高质量、符合实际应用场景的 COT 数据。COT 数据需包含清晰的问题、逐步推理过程和最终结论,能完整还原解决问题的思考路径。
|
||||
|
||||
# 任务
|
||||
请根据用户提供的原始文档,生成一系列高质量、多样化的 COT 数据。每个 COT 数据需围绕文档中的关键信息、核心问题或逻辑关联点展开,确保推理过程贴合文档内容,结论准确可靠。
|
||||
|
||||
# 输入文档
|
||||
{document}
|
||||
|
||||
# 要求与指令
|
||||
1. **数量要求**:生成 {min\_count}-{max\_count} 条 COT 数据(min\_count={synthesis\_count-1},max\_count={synthesis\_count+1})。
|
||||
2. **内容要求**:
|
||||
* 每条 COT 数据需包含 “问题”“思维链推理”“最终结论” 三部分,逻辑闭环,推理步骤清晰、连贯,不跳跃关键环节。
|
||||
* 问题需基于文档中的事实信息、概念关联或逻辑疑问,是读完文档后自然产生的有价值问题(避免无意义或过于简单的问题)。
|
||||
* 思维链推理需严格依据文档内容,逐步推导,每一步推理都能对应文档中的具体信息,不编造原文不存在的内容,不主观臆断。
|
||||
* 最终结论需简洁、准确,是思维链推理的合理结果,与文档核心信息一致。
|
||||
3. **多样化要求**:
|
||||
* 问题类型多样化,包括但不限于事实查询类、逻辑分析类、原因推导类、方案对比类、结论归纳类。
|
||||
* 推理角度多样化,可从不同角色(如项目参与者、需求方、测试人员)或不同维度(如功能实现、进度推进、问题解决)展开推理。
|
||||
4. **语言要求**:
|
||||
* 语言通顺、表达清晰,无歧义,推理过程口语化但不随意,符合正常思考逻辑,最终结论简洁规范。
|
||||
* 请根据输入文档的主要语言进行提问和回答。
|
||||
|
||||
# 输出格式
|
||||
请严格按照以下 JSON 格式输出,确保没有额外的解释或标记,每条 COT 数据独立成项:
|
||||
[
|
||||
{{"question": "具体问题","chain_of_thought": "步骤 1:明确问题核心,定位文档中相关信息范围;步骤 2:提取文档中与问题相关的关键信息 1;步骤 3:结合关键信息 1 推导中间结论 1;步骤 4:提取文档中与问题相关的关键信息 2;步骤 5:结合中间结论 1 和关键信息 2 推导中间结论 2;...(逐步推进);步骤 N:汇总所有中间结论,得出最终结论","conclusion": "简洁准确的最终结论"}},
|
||||
|
||||
{{"question": "具体问题","chain_of_thought": "步骤 1:明确问题核心,定位文档中相关信息范围;步骤 2:提取文档中与问题相关的关键信息 1;步骤 3:结合关键信息 1 推导中间结论 1;步骤 4:提取文档中与问题相关的关键信息 2;步骤 5:结合中间结论 1 和关键信息 2 推导中间结论 2;...(逐步推进);步骤 N:汇总所有中间结论,得出最终结论","conclusion": "简洁准确的最终结论"}},
|
||||
...
|
||||
]
|
||||
"""
|
||||
|
||||
def get_prompt(synth_type: SynthesisType):
|
||||
if synth_type == SynthesisType.QA:
|
||||
return QA_PROMPT
|
||||
elif synth_type == SynthesisType.COT:
|
||||
return COT_PROMPT
|
||||
else:
|
||||
raise ValueError(f"Unsupported synthesis type: {synth_type}")
|
||||
@@ -0,0 +1,29 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic import SecretStr
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models.model_config import ModelConfig
|
||||
|
||||
|
||||
async def get_model_by_id(db: AsyncSession, model_id: str) -> Optional[ModelConfig]:
|
||||
"""根据模型ID获取 ModelConfig 记录。"""
|
||||
result = await db.execute(select(ModelConfig).where(ModelConfig.id == model_id))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
def get_chat_client(model: ModelConfig) -> BaseChatModel:
|
||||
return ChatOpenAI(
|
||||
model=model.model_name,
|
||||
base_url=model.base_url,
|
||||
api_key=SecretStr(model.api_key),
|
||||
)
|
||||
|
||||
|
||||
def chat(model: BaseChatModel, prompt: str) -> str:
|
||||
"""使用指定模型进行聊天"""
|
||||
response = model.invoke(prompt)
|
||||
return response.content
|
||||
38
runtime/datamate-python/poetry.lock
generated
38
runtime/datamate-python/poetry.lock
generated
@@ -152,7 +152,6 @@ description = "Lightweight in-process concurrent programming"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\""
|
||||
files = [
|
||||
{file = "greenlet-3.2.4-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:8c68325b0d0acf8d91dde4e6f930967dd52a5302cd4062932a6b2e7c2969f47c"},
|
||||
{file = "greenlet-3.2.4-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:94385f101946790ae13da500603491f04a76b6e4c059dab271b3ce2e283b2590"},
|
||||
@@ -353,6 +352,25 @@ files = [
|
||||
[package.extras]
|
||||
all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"]
|
||||
|
||||
[[package]]
|
||||
name = "loguru"
|
||||
version = "0.7.2"
|
||||
description = "Python logging made (stupidly) simple"
|
||||
optional = false
|
||||
python-versions = ">=3.5"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "loguru-0.7.2-py3-none-any.whl", hash = "sha256:003d71e3d3ed35f0f8984898359d65b79e5b21943f78af86aa5491210429b8eb"},
|
||||
{file = "loguru-0.7.2.tar.gz", hash = "sha256:e671a53522515f34fd406340ee968cb9ecafbc4b36c679da03c18fd8d0bd51ac"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""}
|
||||
win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""}
|
||||
|
||||
[package.extras]
|
||||
dev = ["Sphinx (==7.2.5) ; python_version >= \"3.9\"", "colorama (==0.4.5) ; python_version < \"3.8\"", "colorama (==0.4.6) ; python_version >= \"3.8\"", "exceptiongroup (==1.1.3) ; python_version >= \"3.7\" and python_version < \"3.11\"", "freezegun (==1.1.0) ; python_version < \"3.8\"", "freezegun (==1.2.2) ; python_version >= \"3.8\"", "mypy (==v0.910) ; python_version < \"3.6\"", "mypy (==v0.971) ; python_version == \"3.6\"", "mypy (==v1.4.1) ; python_version == \"3.7\"", "mypy (==v1.5.1) ; python_version >= \"3.8\"", "pre-commit (==3.4.0) ; python_version >= \"3.8\"", "pytest (==6.1.2) ; python_version < \"3.8\"", "pytest (==7.4.0) ; python_version >= \"3.8\"", "pytest-cov (==2.12.1) ; python_version < \"3.8\"", "pytest-cov (==4.1.0) ; python_version >= \"3.8\"", "pytest-mypy-plugins (==1.9.3) ; python_version >= \"3.6\" and python_version < \"3.8\"", "pytest-mypy-plugins (==3.0.0) ; python_version >= \"3.8\"", "sphinx-autobuild (==2021.3.14) ; python_version >= \"3.9\"", "sphinx-rtd-theme (==1.3.0) ; python_version >= \"3.9\"", "tox (==3.27.1) ; python_version < \"3.8\"", "tox (==4.11.0) ; python_version >= \"3.8\""]
|
||||
|
||||
[[package]]
|
||||
name = "pydantic"
|
||||
version = "2.12.4"
|
||||
@@ -1132,7 +1150,23 @@ files = [
|
||||
{file = "websockets-15.0.1.tar.gz", hash = "sha256:82544de02076bafba038ce055ee6412d68da13ab47f0c60cab827346de828dee"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "win32-setctime"
|
||||
version = "1.2.0"
|
||||
description = "A small Python utility to set file creation time on Windows"
|
||||
optional = false
|
||||
python-versions = ">=3.5"
|
||||
groups = ["main"]
|
||||
markers = "sys_platform == \"win32\""
|
||||
files = [
|
||||
{file = "win32_setctime-1.2.0-py3-none-any.whl", hash = "sha256:95d644c4e708aba81dc3704a116d8cbc974d70b3bdb8be1d150e36be6e9d1390"},
|
||||
{file = "win32_setctime-1.2.0.tar.gz", hash = "sha256:ae1fdf948f5640aae05c511ade119313fb6a30d7eabe25fef9764dca5873c4c0"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
dev = ["black (>=19.3b0) ; python_version >= \"3.6\"", "pytest (>=4.6.2)"]
|
||||
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.12"
|
||||
content-hash = "36f9e7212af4fa5832884a2d39a2e7dfbf668e79949d7e90fc12a9e8c96195a7"
|
||||
content-hash = "a47a488ea25f1fa4db5439b36e3f00e797788c48ff2b9623e9bae72a61154df8"
|
||||
|
||||
@@ -21,7 +21,13 @@ dependencies = [
|
||||
"python-multipart (>=0.0.20,<0.0.21)",
|
||||
"python-dotenv (>=1.2.1,<2.0.0)",
|
||||
"python-dateutil (>=2.9.0.post0,<3.0.0)",
|
||||
"pyyaml (>=6.0.3,<7.0.0)"
|
||||
"pyyaml (>=6.0.3,<7.0.0)",
|
||||
"greenlet (>=3.2.4,<4.0.0)",
|
||||
"loguru (>=0.7.2,<0.7.3)",
|
||||
"langchain (>=1.0.0)",
|
||||
"langchain-community (>0.4,<0.4.1)",
|
||||
"unstructured[all]",
|
||||
"markdown"
|
||||
]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user