diff --git a/runtime/datamate-python/README.md b/runtime/datamate-python/README.md index 41349a3..d806bbe 100644 --- a/runtime/datamate-python/README.md +++ b/runtime/datamate-python/README.md @@ -19,21 +19,26 @@ python -m venv .venv source .venv/bin/activate ``` -3. 安装依赖: +3. 安装依赖 +由于项目使用poetry管理依赖,你可以使用以下命令安装: ```bash -pip install -r requirements.txt +pip install poetry +poetry install +``` +或者直接使用pip安装(如果poetry不可用): + +```bash +pip install -e . ``` -4. 准备环境变量(示例) +4. 配置环境变量 +复制环境变量示例文件并配置: -创建 `.env` 并设置必要的变量,例如: - -- DATABASE_URL(或根据项目配置使用具体变量) -- LABEL_STUDIO_BASE_URL -- LABEL_STUDIO_USER_TOKEN - -(具体变量请参考 `.env.example`) +```bash +cp .env.example .env +``` +编辑.env文件,设置必要的环境变量,如数据库连接、Label Studio配置等。 5. 数据库迁移(开发环境): diff --git a/runtime/datamate-python/app/__init__.py b/runtime/datamate-python/app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/runtime/datamate-python/app/core/config.py b/runtime/datamate-python/app/core/config.py index b0ec44d..580987c 100644 --- a/runtime/datamate-python/app/core/config.py +++ b/runtime/datamate-python/app/core/config.py @@ -17,7 +17,7 @@ class Settings(BaseSettings): host: str = "0.0.0.0" port: int = 18000 - + # CORS # allowed_origins: List[str] = ["*"] # allowed_methods: List[str] = ["*"] @@ -36,7 +36,7 @@ class Settings(BaseSettings): mysql_database: str = "datamate" database_url: str = "" # Will be overridden by build_database_url() if not provided - + @model_validator(mode='after') def build_database_url(self): """如果没有提供 database_url,则根据 MySQL 配置构建""" diff --git a/runtime/datamate-python/app/db/models/data_synthesis.py b/runtime/datamate-python/app/db/models/data_synthesis.py new file mode 100644 index 0000000..809cd25 --- /dev/null +++ b/runtime/datamate-python/app/db/models/data_synthesis.py @@ -0,0 +1,197 @@ +import uuid +from xml.etree.ElementTree import tostring + +from sqlalchemy import Column, String, Text, Integer, JSON, TIMESTAMP, ForeignKey, func +from sqlalchemy.orm import relationship + +from app.db.session import Base +from app.module.generation.schema.generation import CreateSynthesisTaskRequest + + +async def save_synthesis_task(db_session, synthesis_task: CreateSynthesisTaskRequest): + """保存数据合成任务。""" + # 转换为模型实例 + gid = str(uuid.uuid4()) + synthesis_task_instance = DataSynthesisInstance( + id=gid, + name=synthesis_task.name, + description=synthesis_task.description, + status="pending", + model_id=synthesis_task.model_id, + synthesis_type=synthesis_task.synthesis_type.value, + progress=0, + result_data_location=f"/dataset/synthesis_results/{gid}/", + text_split_config=synthesis_task.text_split_config.model_dump(), + synthesis_config=synthesis_task.synthesis_config.model_dump(), + source_file_id=synthesis_task.source_file_id, + total_files=len(synthesis_task.source_file_id), + processed_files=0, + total_chunks=0, + processed_chunks=0, + total_synthesis_data=0, + created_at=func.now(), + updated_at=func.now(), + created_by="system", + updated_by="system" + ) + db_session.add(synthesis_task_instance) + await db_session.commit() + await db_session.refresh(synthesis_task_instance) + return synthesis_task_instance + + +class DataSynthesisInstance(Base): + """数据合成任务表,对应表 t_data_synthesis_instances + + create table if not exists t_data_synthesis_instances + ( + id VARCHAR(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci PRIMARY KEY COMMENT 'UUID', + name VARCHAR(255) NOT NULL COMMENT '任务名称', + description TEXT COMMENT '任务描述', + status VARCHAR(20) COMMENT '任务状态', + synthesis_type VARCHAR(20) NOT NULL COMMENT '合成类型', + model_id VARCHAR(255) NOT NULL COMMENT '模型ID', + progress INT DEFAULT 0 COMMENT '任务进度(百分比)', + result_data_location VARCHAR(1000) COMMENT '结果数据存储位置', + text_split_config JSON NOT NULL COMMENT '文本切片配置', + synthesis_config JSON NOT NULL COMMENT '合成配置', + source_file_id JSON NOT NULL COMMENT '原始文件ID列表', + total_files INT DEFAULT 0 COMMENT '总文件数', + processed_files INT DEFAULT 0 COMMENT '已处理文件数', + total_chunks INT DEFAULT 0 COMMENT '总文本块数', + processed_chunks INT DEFAULT 0 COMMENT '已处理文本块数', + total_synthesis_data INT DEFAULT 0 COMMENT '总合成数据量', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间', + created_by VARCHAR(255) COMMENT '创建者', + updated_by VARCHAR(255) COMMENT '更新者' + ) COMMENT='数据合成任务表(UUID 主键)'; + """ + + __tablename__ = "t_data_synthesis_instances" + + id = Column(String(36), primary_key=True, index=True, comment="UUID") + name = Column(String(255), nullable=False, comment="任务名称") + description = Column(Text, nullable=True, comment="任务描述") + status = Column(String(20), nullable=True, comment="任务状态") + synthesis_type = Column(String(20), nullable=False, comment="合成类型") + model_id = Column(String(255), nullable=False, comment="模型ID") + progress = Column(Integer, nullable=False, default=0, comment="任务进度(百分比)") + result_data_location = Column(String(1000), nullable=True, comment="结果数据存储位置") + text_split_config = Column(JSON, nullable=False, comment="文本切片配置") + synthesis_config = Column(JSON, nullable=False, comment="合成配置") + source_file_id = Column(JSON, nullable=False, comment="原始文件ID列表") + total_files = Column(Integer, nullable=False, default=0, comment="总文件数") + processed_files = Column(Integer, nullable=False, default=0, comment="已处理文件数") + total_chunks = Column(Integer, nullable=False, default=0, comment="总文本块数") + processed_chunks = Column(Integer, nullable=False, default=0, comment="已处理文本块数") + total_synthesis_data = Column(Integer, nullable=False, default=0, comment="总合成数据量") + + created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), nullable=True, comment="创建时间") + updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), nullable=True, comment="更新时间") + created_by = Column(String(255), nullable=True, comment="创建者") + updated_by = Column(String(255), nullable=True, comment="更新者") + + +class DataSynthesisFileInstance(Base): + """数据合成文件任务表,对应表 t_data_synthesis_file_instances + + create table if not exists t_data_synthesis_file_instances ( + id VARCHAR(36) PRIMARY KEY COMMENT 'UUID', + synthesis_instance_id VARCHAR(36) COMMENT '数据合成任务ID', + file_name VARCHAR(255) NOT NULL COMMENT '文件名', + source_file_id VARCHAR(255) NOT NULL COMMENT '原始文件ID', + target_file_location VARCHAR(1000) NOT NULL COMMENT '目标文件存储位置', + status VARCHAR(20) COMMENT '任务状态', + total_chunks INT DEFAULT 0 COMMENT '总文本块数', + processed_chunks INT DEFAULT 0 COMMENT '已处理文本块数', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间', + created_by VARCHAR(255) COMMENT '创建者', + updated_by VARCHAR(255) COMMENT '更新者' + ) COMMENT='数据合成文件任务表(UUID 主键)'; + """ + + __tablename__ = "t_data_synthesis_file_instances" + + id = Column(String(36), primary_key=True, index=True, comment="UUID") + synthesis_instance_id = Column( + String(36), + nullable=False, + comment="数据合成任务ID", + index=True, + ) + file_name = Column(String(255), nullable=False, comment="文件名") + source_file_id = Column(String(255), nullable=False, comment="原始文件ID") + target_file_location = Column(String(1000), nullable=False, comment="目标文件存储位置") + status = Column(String(20), nullable=True, comment="任务状态") + total_chunks = Column(Integer, nullable=False, default=0, comment="总文本块数") + processed_chunks = Column(Integer, nullable=False, default=0, comment="已处理文本块数") + + created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), nullable=True, comment="创建时间") + updated_at = Column( + TIMESTAMP, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + nullable=True, + comment="更新时间", + ) + created_by = Column(String(255), nullable=True, comment="创建者") + updated_by = Column(String(255), nullable=True, comment="更新者") + + +class DataSynthesisChunkInstance(Base): + """数据合成分块任务表,对应表 t_data_synthesis_chunk_instances + + create table if not exists t_data_synthesis_chunk_instances ( + id VARCHAR(36) PRIMARY KEY COMMENT 'UUID', + synthesis_file_instance_id VARCHAR(36) COMMENT '数据合成文件任务ID', + chunk_index INT COMMENT '分块索引', + chunk_content TEXT COMMENT '分块内容', + metadata JSON COMMENT '分块元数据' + ) COMMENT='数据合成分块任务表(UUID 主键)'; + """ + + __tablename__ = "t_data_synthesis_chunk_instances" + + id = Column(String(36), primary_key=True, index=True, comment="UUID") + synthesis_file_instance_id = Column( + String(36), + nullable=False, + comment="数据合成文件任务ID", + index=True, + ) + chunk_index = Column(Integer, nullable=True, comment="分块索引") + chunk_content = Column(Text, nullable=True, comment="分块内容") + # SQLAlchemy Declarative 保留了属性名 'metadata',这里使用 chunk_metadata 作为属性名, + # 底层列名仍为 'metadata' 以保持与表结构兼容。 + chunk_metadata = Column("metadata", JSON, nullable=True, comment="分块元数据") + + +class SynthesisData(Base): + """数据合成结果表,对应表 t_synthesis_data + + create table if not exists t_synthesis_data ( + id VARCHAR(36) PRIMARY KEY COMMENT 'UUID', + data json COMMENT '合成的数据', + synthesis_file_instance_id VARCHAR(36) COMMENT '数据合成文件任务ID', + chunk_instance_id VARCHAR(36) COMMENT '分块任务ID' + ) COMMENT='数据合成任务队列表(UUID 主键)'; + """ + + __tablename__ = "t_data_synthesis_data" + + id = Column(String(36), primary_key=True, index=True, comment="UUID") + data = Column(JSON, nullable=True, comment="合成的数据") + synthesis_file_instance_id = Column( + String(36), + nullable=False, + comment="数据合成文件任务ID", + index=True, + ) + chunk_instance_id = Column( + String(36), + nullable=False, + comment="分块任务ID", + index=True, + ) diff --git a/runtime/datamate-python/app/db/models/model_config.py b/runtime/datamate-python/app/db/models/model_config.py new file mode 100644 index 0000000..be75043 --- /dev/null +++ b/runtime/datamate-python/app/db/models/model_config.py @@ -0,0 +1,57 @@ +from sqlalchemy import Column, String, Integer, TIMESTAMP, select + +from app.db.session import Base + + +async def get_model_by_id(db_session, model_id: str): + """根据 ID 获取单个模型配置。""" + result =await db_session.execute(select(ModelConfig).where(ModelConfig.id == model_id)) + model_config = result.scalar_one_or_none() + return model_config + +class ModelConfig(Base): + """模型配置表,对应表 t_model_config + + CREATE TABLE IF NOT EXISTS t_model_config ( + id VARCHAR(36) PRIMARY KEY COMMENT '主键ID', + model_name VARCHAR(100) NOT NULL COMMENT '模型名称(如 qwen2)', + provider VARCHAR(50) NOT NULL COMMENT '模型提供商(如 Ollama、OpenAI、DeepSeek)', + base_url VARCHAR(255) NOT NULL COMMENT 'API 基础地址', + api_key VARCHAR(512) DEFAULT '' COMMENT 'API 密钥(无密钥则为空)', + type VARCHAR(50) NOT NULL COMMENT '模型类型(如 chat、embedding)', + is_enabled TINYINT DEFAULT 1 COMMENT '是否启用:1-启用,0-禁用', + is_default TINYINT DEFAULT 0 COMMENT '是否默认:1-默认,0-非默认', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间', + created_by VARCHAR(255) COMMENT '创建者', + updated_by VARCHAR(255) COMMENT '更新者', + UNIQUE KEY uk_model_provider (model_name, provider) + ) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4 COMMENT ='模型配置表'; + """ + + __tablename__ = "t_model_config" + + id = Column(String(36), primary_key=True, index=True, comment="主键ID") + model_name = Column(String(100), nullable=False, comment="模型名称(如 qwen2)") + provider = Column(String(50), nullable=False, comment="模型提供商(如 Ollama、OpenAI、DeepSeek)") + base_url = Column(String(255), nullable=False, comment="API 基础地址") + api_key = Column(String(512), nullable=False, default="", comment="API 密钥(无密钥则为空)") + type = Column(String(50), nullable=False, comment="模型类型(如 chat、embedding)") + + # 使用 Integer 存储 TINYINT,后续可在业务层将 0/1 转为 bool + is_enabled = Column(Integer, nullable=False, default=1, comment="是否启用:1-启用,0-禁用") + is_default = Column(Integer, nullable=False, default=0, comment="是否默认:1-默认,0-非默认") + + created_at = Column(TIMESTAMP, nullable=True, comment="创建时间") + updated_at = Column(TIMESTAMP, nullable=True, comment="更新时间") + created_by = Column(String(255), nullable=True, comment="创建者") + updated_by = Column(String(255), nullable=True, comment="更新者") + + __table_args__ = ( + # 与 DDL 中的 uk_model_provider 保持一致 + { + "mysql_engine": "InnoDB", + "mysql_charset": "utf8mb4", + "comment": "模型配置表", + }, + ) diff --git a/runtime/datamate-python/app/db/session.py b/runtime/datamate-python/app/db/session.py index 8ec2db8..c70b196 100644 --- a/runtime/datamate-python/app/db/session.py +++ b/runtime/datamate-python/app/db/session.py @@ -15,8 +15,8 @@ engine = create_async_engine( # 创建会话工厂 AsyncSessionLocal = async_sessionmaker( - engine, - class_=AsyncSession, + engine, + class_=AsyncSession, expire_on_commit=False ) @@ -29,4 +29,3 @@ async def get_db() -> AsyncGenerator[AsyncSession, None]: yield session finally: await session.close() - \ No newline at end of file diff --git a/runtime/datamate-python/app/main.py b/runtime/datamate-python/app/main.py index 0e7f69b..5411ece 100644 --- a/runtime/datamate-python/app/main.py +++ b/runtime/datamate-python/app/main.py @@ -24,7 +24,7 @@ logger = get_logger(__name__) @asynccontextmanager async def lifespan(app: FastAPI): - + # @startup logger.info("DataMate Python Backend starting...") @@ -43,7 +43,7 @@ async def lifespan(app: FastAPI): logger.info(f"Label Studio: {settings.label_studio_base_url}") yield - + # @shutdown logger.info("DataMate Python Backend shutting down ...\n\n") @@ -105,11 +105,11 @@ async def root(): if __name__ == "__main__": import uvicorn - + uvicorn.run( "app.main:app", host=settings.host, port=settings.port, reload=settings.debug, log_level=settings.log_level.lower() - ) \ No newline at end of file + ) diff --git a/runtime/datamate-python/app/module/__init__.py b/runtime/datamate-python/app/module/__init__.py index b0e1a52..0295370 100644 --- a/runtime/datamate-python/app/module/__init__.py +++ b/runtime/datamate-python/app/module/__init__.py @@ -3,6 +3,7 @@ from fastapi import APIRouter from .system.interface import router as system_router from .annotation.interface import router as annotation_router from .synthesis.interface import router as ratio_router +from .generation.interface import router as generation_router router = APIRouter( prefix="/api" @@ -11,5 +12,6 @@ router = APIRouter( router.include_router(system_router) router.include_router(annotation_router) router.include_router(ratio_router) +router.include_router(generation_router) __all__ = ["router"] diff --git a/runtime/datamate-python/app/module/generation/__init__.py b/runtime/datamate-python/app/module/generation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/runtime/datamate-python/app/module/generation/interface/__init__.py b/runtime/datamate-python/app/module/generation/interface/__init__.py new file mode 100644 index 0000000..4437a1b --- /dev/null +++ b/runtime/datamate-python/app/module/generation/interface/__init__.py @@ -0,0 +1,11 @@ +from fastapi import APIRouter + +router = APIRouter( + prefix="/synth", + tags = ["synth"] +) + +# Include sub-routers +from .generation_api import router as generation_router_router + +router.include_router(generation_router_router) diff --git a/runtime/datamate-python/app/module/generation/interface/generation_api.py b/runtime/datamate-python/app/module/generation/interface/generation_api.py new file mode 100644 index 0000000..ae8d0d2 --- /dev/null +++ b/runtime/datamate-python/app/module/generation/interface/generation_api.py @@ -0,0 +1,260 @@ +import uuid + +from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks +from sqlalchemy import select, func, delete +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.logging import get_logger +from app.db.models.data_synthesis import ( + save_synthesis_task, + DataSynthesisInstance, + DataSynthesisFileInstance, + DataSynthesisChunkInstance, + SynthesisData, +) +from app.db.models.dataset_management import DatasetFiles +from app.db.models.model_config import get_model_by_id +from app.db.session import get_db +from app.module.generation.schema.generation import ( + CreateSynthesisTaskRequest, + DataSynthesisTaskItem, + PagedDataSynthesisTaskResponse, SynthesisType) +from app.module.generation.service.generation_service import GenerationService +from app.module.generation.service.prompt import get_prompt +from app.module.shared.schema import StandardResponse + +router = APIRouter( + prefix="/gen", + tags=["gen"] +) + +logger = get_logger(__name__) + +@router.post("/task", response_model=StandardResponse[DataSynthesisTaskItem]) +async def create_synthesis_task( + request: CreateSynthesisTaskRequest, + background_tasks: BackgroundTasks, + db: AsyncSession = Depends(get_db), +): + """创建数据合成任务""" + result = await get_model_by_id(db, request.model_id) + if not result: + raise HTTPException(status_code=404, detail="Model not found") + + # 先根据 source_file_id 在 DatasetFiles 中查出已有文件信息 + file_ids = request.source_file_id or [] + dataset_files = [] + if file_ids: + ds_result = await db.execute( + select(DatasetFiles).where(DatasetFiles.id.in_(file_ids)) + ) + dataset_files = ds_result.scalars().all() + + # 保存任务到数据库 + request.source_file_id = [str(f.id) for f in dataset_files] + synthesis_task = await save_synthesis_task(db, request) + + # 将已有的 DatasetFiles 记录保存到 t_data_synthesis_file_instances + for f in dataset_files: + file_instance = DataSynthesisFileInstance( + id=str(uuid.uuid4()), # 使用新的 UUID 作为文件任务记录的主键,避免与 DatasetFiles 主键冲突 + synthesis_instance_id=synthesis_task.id, + file_name=f.file_name, + source_file_id=str(f.id), + target_file_location=synthesis_task.result_data_location or "", + status="pending", + total_chunks=0, + processed_chunks=0, + created_by="system", + updated_by="system", + ) + db.add(file_instance) + + if dataset_files: + await db.commit() + + generation_service = GenerationService(db) + # 异步处理任务:只传任务 ID,后台任务中使用新的 DB 会话重新加载任务对象 + background_tasks.add_task(generation_service.process_task, synthesis_task.id) + + return StandardResponse( + code=200, + message="success", + data=synthesis_task, + ) + + +@router.get("/task/{task_id}", response_model=StandardResponse[DataSynthesisTaskItem]) +async def get_synthesis_task( + task_id: str, + db: AsyncSession = Depends(get_db) +): + """获取数据合成任务详情""" + result = await db.get(DataSynthesisInstance, task_id) + if not result: + raise HTTPException(status_code=404, detail="Synthesis task not found") + + return StandardResponse( + code=200, + message="success", + data=result, + ) + + +@router.get("/tasks", response_model=StandardResponse[PagedDataSynthesisTaskResponse], status_code=200) +async def list_synthesis_tasks( + page: int = 1, + page_size: int = 10, + synthesis_type: str | None = None, + status: str | None = None, + name: str | None = None, + db: AsyncSession = Depends(get_db) +): + """分页列出所有数据合成任务""" + query = select(DataSynthesisInstance) + if synthesis_type: + query = query.filter(DataSynthesisInstance.synthesis_type == synthesis_type) + if status: + query = query.filter(DataSynthesisInstance.status == status) + if name: + query = query.filter(DataSynthesisInstance.name.like(f"%{name}%")) + + count_q = select(func.count()).select_from(query.subquery()) + total = (await db.execute(count_q)).scalar_one() + + if page < 1: + page = 1 + if page_size < 1: + page_size = 10 + + result = await db.execute(query.offset((page - 1) * page_size).limit(page_size)) + rows = result.scalars().all() + + task_items = [ + DataSynthesisTaskItem( + id=row.id, + name=row.name, + description=row.description, + status=row.status, + synthesis_type=row.synthesis_type, + model_id=row.model_id, + progress=row.progress, + result_data_location=row.result_data_location, + text_split_config=row.text_split_config, + synthesis_config=row.synthesis_config, + source_file_id=row.source_file_id, + total_files=row.total_files, + processed_files=row.processed_files, + total_chunks=row.total_chunks, + processed_chunks=row.processed_chunks, + total_synthesis_data=row.total_synthesis_data, + created_at=row.created_at, + updated_at=row.updated_at, + created_by=row.created_by, + updated_by=row.updated_by, + ) + for row in rows + ] + + paged = PagedDataSynthesisTaskResponse( + content=task_items, + totalElements=total, + totalPages=(total + page_size - 1) // page_size, + page=page, + size=page_size, + ) + + return StandardResponse( + code=200, + message="Success", + data=paged, + ) + + +@router.delete("/task/{task_id}", response_model=StandardResponse[None]) +async def delete_synthesis_task( + task_id: str, + db: AsyncSession = Depends(get_db) +): + """删除数据合成任务""" + task = await db.get(DataSynthesisInstance, task_id) + if not task: + raise HTTPException(status_code=404, detail="Synthesis task not found") + + # 1. 删除与该任务相关的 SynthesisData、Chunk、File 记录 + # 先查出所有文件任务 ID + file_result = await db.execute( + select(DataSynthesisFileInstance.id).where( + DataSynthesisFileInstance.synthesis_instance_id == task_id + ) + ) + file_ids = [row[0] for row in file_result.all()] + + if file_ids: + # 删除 SynthesisData(根据文件任务ID) + await db.execute(delete(SynthesisData).where( + SynthesisData.synthesis_file_instance_id.in_(file_ids) + ) + ) + + # 删除 Chunk 记录 + await db.execute(delete(DataSynthesisChunkInstance).where( + DataSynthesisChunkInstance.synthesis_file_instance_id.in_(file_ids) + ) + ) + + # 删除文件任务记录 + await db.execute(delete(DataSynthesisFileInstance).where( + DataSynthesisFileInstance.id.in_(file_ids) + ) + ) + + # 2. 删除任务本身 + await db.delete(task) + await db.commit() + + return StandardResponse( + code=200, + message="success", + data=None, + ) + +@router.delete("/task/{task_id}/{file_id}", response_model=StandardResponse[None]) +async def delete_synthesis_file_task( + task_id: str, + file_id: str, + db: AsyncSession = Depends(get_db) +): + """删除数据合成任务中的文件任务""" + file_task = await db.get(DataSynthesisFileInstance, file_id) + if not file_task: + raise HTTPException(status_code=404, detail="Synthesis file task not found") + + # 删除 SynthesisData(根据文件任务ID) + await db.execute(delete(SynthesisData).where( + SynthesisData.synthesis_file_instance_id == file_id + ) + ) + + # 删除 Chunk 记录 + await db.execute(delete(DataSynthesisChunkInstance).where( + DataSynthesisChunkInstance.synthesis_file_instance_id == file_id + ) + ) + + # 删除文件任务记录 + await db.execute(delete(DataSynthesisFileInstance).where( + DataSynthesisFileInstance.id == file_id + ) + ) + +@router.get("/prompt", response_model=StandardResponse[str]) +async def get_prompt_by_type( + synth_type: SynthesisType, +): + prompt = get_prompt(synth_type) + return StandardResponse( + code=200, + message="Success", + data=prompt, + ) diff --git a/runtime/datamate-python/app/module/generation/schema/__init__.py b/runtime/datamate-python/app/module/generation/schema/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/runtime/datamate-python/app/module/generation/schema/generation.py b/runtime/datamate-python/app/module/generation/schema/generation.py new file mode 100644 index 0000000..7317c02 --- /dev/null +++ b/runtime/datamate-python/app/module/generation/schema/generation.py @@ -0,0 +1,76 @@ +from datetime import datetime +from enum import Enum +from typing import List, Optional, Dict, Any + +from pydantic import BaseModel, Field + + +class TextSplitConfig(BaseModel): + """文本切片配置""" + chunk_size: int = Field(..., description="最大令牌数") + chunk_overlap: int = Field(..., description="重叠令牌数") + + +class SynthesisConfig(BaseModel): + """合成配置""" + prompt_template: str = Field(..., description="合成提示模板") + synthesis_count: int = Field(None, description="单个chunk合成的数据数量") + temperature: Optional[float] = Field(None, description="温度参数") + + +class SynthesisType(Enum): + """合成类型""" + QA = "QA" + COT = "COT" + + +class CreateSynthesisTaskRequest(BaseModel): + """创建数据合成任务请求""" + name: str = Field(..., description="合成任务名称") + description: str = Field(None, description="合成任务描述") + model_id: str = Field(..., description="模型ID") + source_file_id: list[str] = Field(..., description="原始文件ID列表") + text_split_config: TextSplitConfig = Field(None, description="文本切片配置") + synthesis_config: SynthesisConfig = Field(..., description="合成配置") + synthesis_type: SynthesisType = Field(..., description="合成类型") + + +class DataSynthesisTaskItem(BaseModel): + """数据合成任务列表/详情项""" + id: str + name: str + description: Optional[str] = None + status: Optional[str] = None + synthesis_type: str + model_id: str + progress: int + result_data_location: Optional[str] = None + text_split_config: Dict[str, Any] + synthesis_config: Dict[str, Any] + source_file_id: list[str] + total_files: int + processed_files: int + total_chunks: int + processed_chunks: int + total_synthesis_data: int + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + created_by: Optional[str] = None + updated_by: Optional[str] = None + + class Config: + orm_mode = True + + +class PagedDataSynthesisTaskResponse(BaseModel): + """分页数据合成任务响应""" + content: List[DataSynthesisTaskItem] + totalElements: int + totalPages: int + page: int + size: int + +class ChatRequest(BaseModel): + """聊天请求参数""" + model_id: str + prompt: str diff --git a/runtime/datamate-python/app/module/generation/service/__init__.py b/runtime/datamate-python/app/module/generation/service/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/runtime/datamate-python/app/module/generation/service/generation_service.py b/runtime/datamate-python/app/module/generation/service/generation_service.py new file mode 100644 index 0000000..e158366 --- /dev/null +++ b/runtime/datamate-python/app/module/generation/service/generation_service.py @@ -0,0 +1,544 @@ +import asyncio +import uuid +import json +from pathlib import Path + +from langchain_community.document_loaders import ( + TextLoader, + CSVLoader, + JSONLoader, + UnstructuredMarkdownLoader, + UnstructuredHTMLLoader, + UnstructuredFileLoader, + PyPDFLoader, + UnstructuredWordDocumentLoader, + UnstructuredPowerPointLoader, + UnstructuredExcelLoader, +) +from langchain_text_splitters import RecursiveCharacterTextSplitter +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db.models.data_synthesis import ( + DataSynthesisInstance, + DataSynthesisFileInstance, + DataSynthesisChunkInstance, + SynthesisData, +) +from app.db.models.dataset_management import DatasetFiles +from app.db.models.model_config import get_model_by_id +from app.db.session import logger +from app.module.system.service.common_service import get_chat_client, chat + + +class GenerationService: + def __init__(self, db: AsyncSession): + self.db = db + + async def process_task(self, task_id: str): + """处理数据合成任务入口:根据任务ID加载任务并逐个处理源文件。""" + synthesis_task: DataSynthesisInstance | None = await self.db.get(DataSynthesisInstance, task_id) + if not synthesis_task: + logger.error(f"Synthesis task {task_id} not found, abort processing") + return + + logger.info(f"Processing synthesis task {task_id}") + file_ids = synthesis_task.source_file_id or [] + + # 获取模型客户端 + model_result = await get_model_by_id(self.db, str(synthesis_task.model_id)) + if not model_result: + logger.error( + f"Model config not found for id={synthesis_task.model_id}, abort task {synthesis_task.id}" + ) + return + chat_client = get_chat_client(model_result) + + # 控制并发度的信号量(限制全任务范围内最多 10 个并发调用) + semaphore = asyncio.Semaphore(10) + + # 逐个文件处理 + for file_id in file_ids: + try: + success = await self._process_single_file( + synthesis_task=synthesis_task, + file_id=file_id, + chat_client=chat_client, + semaphore=semaphore, + ) + except Exception as e: + logger.exception(f"Unexpected error when processing file {file_id} for task {task_id}: {e}") + # 确保对应文件任务状态标记为失败 + await self._mark_file_failed(str(synthesis_task.id), file_id, str(e)) + success = False + + if success: + # 每处理完一个文件,简单增加 processed_files 计数 + synthesis_task.processed_files = (synthesis_task.processed_files or 0) + 1 + await self.db.commit() + await self.db.refresh(synthesis_task) + + logger.info(f"Finished processing synthesis task {synthesis_task.id}") + + async def _process_single_file( + self, + synthesis_task: DataSynthesisInstance, + file_id: str, + chat_client, + semaphore: asyncio.Semaphore, + ) -> bool: + """处理单个源文件:解析路径、切片、保存分块并触发 LLM 调用。""" + file_path = await self._resolve_file_path(file_id) + if not file_path: + logger.warning(f"File path not found for file_id={file_id}, skip") + await self._mark_file_failed(str(synthesis_task.id), file_id, "file_path_not_found") + return False + + logger.info(f"Processing file_id={file_id}, path={file_path}") + + split_cfg = synthesis_task.text_split_config or {} + synthesis_cfg = synthesis_task.synthesis_config or {} + chunk_size = int(split_cfg.get("chunk_size", 800)) + chunk_overlap = int(split_cfg.get("chunk_overlap", 50)) + # 加载并切片 + try: + chunks = self._load_and_split(file_path, chunk_size, chunk_overlap) + except Exception as e: + logger.error(f"Failed to load/split file {file_path}: {e}") + await self._mark_file_failed(str(synthesis_task.id), file_id, f"load_split_error: {e}") + return False + + if not chunks: + logger.warning(f"No chunks generated for file_id={file_id}") + await self._mark_file_failed(str(synthesis_task.id), file_id, "no_chunks_generated") + return False + + logger.info(f"File {file_id} split into {len(chunks)} chunks by LangChain") + + # 保存文件任务记录 + 分块记录 + file_task = await self._get_or_create_file_instance( + synthesis_task_id=str(synthesis_task.id), + source_file_id=file_id, + file_path=file_path, + ) + await self._persist_chunks(synthesis_task, file_task, file_id, chunks) + + # 针对每个切片并发调用大模型 + await self._invoke_llm_for_chunks( + synthesis_task=synthesis_task, + file_id=file_id, + chunks=chunks, + synthesis_cfg=synthesis_cfg, + chat_client=chat_client, + semaphore=semaphore, + ) + + # 如果执行到此处,说明该文件的切片与 LLM 调用流程均未抛出异常,标记为完成 + file_task.status = "completed" + await self.db.commit() + await self.db.refresh(file_task) + + return True + + async def _persist_chunks( + self, + synthesis_task: DataSynthesisInstance, + file_task: DataSynthesisFileInstance, + file_id: str, + chunks, + ) -> None: + """将切片结果保存到 t_data_synthesis_chunk_instances,并更新文件级分块计数。""" + for idx, doc in enumerate(chunks, start=1): + # 先复制原始 Document.metadata,再在其上追加任务相关字段,避免覆盖原有元数据 + base_metadata = dict(getattr(doc, "metadata", {}) or {}) + base_metadata.update( + { + "task_id": str(synthesis_task.id), + "file_id": file_id + } + ) + + chunk_record = DataSynthesisChunkInstance( + id=str(uuid.uuid4()), + synthesis_file_instance_id=file_task.id, + chunk_index=idx, + chunk_content=doc.page_content, + chunk_metadata=base_metadata, + ) + self.db.add(chunk_record) + + # 更新文件任务的分块数量 + file_task.chunk_count = len(chunks) + file_task.status = "processing" + + await self.db.refresh(file_task) + await self.db.commit() + + async def _invoke_llm_for_chunks( + self, + synthesis_task: DataSynthesisInstance, + file_id: str, + chunks, + synthesis_cfg: dict, + chat_client, + semaphore: asyncio.Semaphore, + ) -> None: + """针对每个分片并发调用大模型生成数据。""" + # 需要将 answer 和对应 chunk 建立关系,因此这里保留 chunk_index + tasks = [ + self._call_llm(doc, file_id, idx, synthesis_task, synthesis_cfg, chat_client, semaphore) + for idx, doc in enumerate(chunks, start=1) + ] + await asyncio.gather(*tasks, return_exceptions=True) + + async def _call_llm( + self, + doc, + file_id: str, + idx: int, + synthesis_task, + synthesis_cfg: dict, + chat_client, + semaphore: asyncio.Semaphore, + ): + """单次大模型调用逻辑,带并发控制。 + + 说明: + - 使用信号量限制全局并发量(当前为 10)。 + - 使用线程池执行同步的 chat 调用,避免阻塞事件循环。 + - 在拿到 LLM 返回后,解析为 JSON 并批量写入 SynthesisData, + 同时更新文件级 processed_chunks / 进度等信息。 + """ + async with semaphore: + prompt = self._build_qa_prompt(doc.page_content, synthesis_cfg) + try: + loop = asyncio.get_running_loop() + answer = await loop.run_in_executor(None, chat, chat_client, prompt) + logger.debug( + f"Generated QA for task={synthesis_task.id}, file={file_id}, chunk={idx}" + ) + await self._handle_llm_answer( + synthesis_task_id=str(synthesis_task.id), + file_id=file_id, + chunk_index=idx, + raw_answer=answer, + ) + return answer + except Exception as e: + logger.error( + f"LLM generation failed for task={synthesis_task.id}, file={file_id}, chunk={idx}: {e}" + ) + return None + + async def _resolve_file_path(self, file_id: str) -> str | None: + """根据文件ID查询 t_dm_dataset_files 并返回 file_path(仅 ACTIVE 文件)。""" + result = await self.db.execute( + select(DatasetFiles).where(DatasetFiles.id == file_id) + ) + file_obj = result.scalar_one_or_none() + if not file_obj: + return None + return file_obj.file_path + + def _load_and_split(self, file_path: str, chunk_size: int, chunk_overlap: int): + """使用 LangChain 加载文本并进行切片,直接返回 Document 列表。 + + 当前实现: + - 使用 TextLoader 加载纯文本/Markdown/JSON 等文本文件 + - 使用 RecursiveCharacterTextSplitter 做基于字符的递归切片 + + 保留每个 Document 的 metadata,方便后续追加例如文件ID、chunk序号等信息。 + """ + loader = self._build_loader(file_path) + docs = loader.load() + + splitter = RecursiveCharacterTextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + # 尝试按这些分隔符优先切分,再退化到字符级 + separators=["\n\n", "\n", "。", "!", "?", "!", "?", "。\n", "\t", " "] + ) + split_docs = splitter.split_documents(docs) + return split_docs + + @staticmethod + def _build_loader(file_path: str): + """根据文件扩展名选择合适的 LangChain 文本加载器,尽量覆盖常见泛文本格式。 + + 优先按格式选择专门的 Loader,找不到匹配时退回到 TextLoader。 + """ + path = Path(file_path) + suffix = path.suffix.lower() + path_str = str(path) + + # 1. 纯文本类 + if suffix in {".txt", "", ".log"}: # "" 兼容无扩展名 + return TextLoader(path_str, encoding="utf-8") + + # 2. Markdown + if suffix in {".md", ".markdown"}: + # UnstructuredMarkdownLoader 会保留更多结构信息 + return UnstructuredMarkdownLoader(path_str) + + # 3. HTML / HTM + if suffix in {".html", ".htm"}: + return UnstructuredHTMLLoader(path_str) + + # 4. JSON + if suffix == ".json": + # 使用 JSONLoader 将 JSON 中的内容展开成文档 + # 这里使用默认 jq_schema,后续需要更精细地提取可以在此调整 + return JSONLoader(file_path=path_str, jq_schema=".") + + # 5. CSV / TSV + if suffix in {".csv", ".tsv"}: + # CSVLoader 默认将每一行作为一条 Document + return CSVLoader(file_path=path_str) + + # 6. YAML + if suffix in {".yaml", ".yml"}: + # 暂时按纯文本加载 + return TextLoader(path_str, encoding="utf-8") + + # 7. PDF + if suffix == ".pdf": + return PyPDFLoader(path_str) + + # 8. Word 文档 + if suffix in {".docx", ".doc"}: + # UnstructuredWordDocumentLoader 支持 .docx/.doc 文本抽取 + return UnstructuredWordDocumentLoader(path_str) + + # 9. PowerPoint + if suffix in {".ppt", ".pptx"}: + return UnstructuredPowerPointLoader(path_str) + + # 10. Excel + if suffix in {".xls", ".xlsx"}: + return UnstructuredExcelLoader(path_str) + + # 11. 兜底:使用 UnstructuredFileLoader 或 TextLoader 作为纯文本 + try: + return UnstructuredFileLoader(path_str) + except Exception: + return TextLoader(path_str, encoding="utf-8") + + @staticmethod + def _build_qa_prompt(chunk: str, synthesis_cfg: dict) -> str: + """构造 QA 数据合成的提示词。 + + 要求: + - synthesis_cfg["prompt_template"] 是一个字符串,其中包含 {document} 占位符; + - 将当前切片内容替换到 {document}。 + 如果未提供或模板非法,则使用内置默认模板。 + """ + template = None + if isinstance(synthesis_cfg, dict): + template = synthesis_cfg.get("prompt_template") + synthesis_count = synthesis_cfg["synthesis_count"] if ("synthesis_count" in synthesis_cfg and synthesis_cfg["synthesis_count"]) else 5 + try: + prompt = template.format(document=chunk, synthesis_count=synthesis_count) + except Exception: + # 防御性处理:如果 format 出现异常,则退回到简单拼接 + prompt = f"{template}\n\n文档内容:{chunk}\n\n请根据文档内容生成 {synthesis_count} 条符合要求的问答数据。" + return prompt + + async def _handle_llm_answer( + self, + synthesis_task_id: str, + file_id: str, + chunk_index: int, + raw_answer: str, + ) -> None: + """解析 LLM 返回内容为 JSON,批量保存到 SynthesisData,并更新文件任务进度。 + + 约定: + - LLM 返回的 raw_answer 是 JSON 字符串,可以是: + 1)单个对象:{"question": ..., "answer": ...} + 2)对象数组:[{}, {}, ...] + - 我们将其规范化为列表,每个元素作为一条 SynthesisData.data 写入。 + - 根据 synthesis_task_id + file_id + chunk_index 找到对应的 DataSynthesisChunkInstance, + 以便设置 chunk_instance_id 和 synthesis_file_instance_id。 + - 每处理完一个 chunk,递增对应 DataSynthesisFileInstance.processed_chunks,并按比例更新进度。 + """ + if not raw_answer: + return + + # 1. 预处理原始回答:尝试从中截取出最可能的 JSON 片段 + cleaned = self._extract_json_substring(raw_answer) + + # 2. 解析 JSON,统一成列表结构 + try: + parsed = json.loads(cleaned) + except Exception as e: + logger.error( + f"Failed to parse LLM answer as JSON for task={synthesis_task_id}, file={file_id}, chunk={chunk_index}: {e}. Raw answer: {raw_answer!r}" + ) + return + + if isinstance(parsed, dict): + items = [parsed] + elif isinstance(parsed, list): + items = [p for p in parsed if isinstance(p, dict)] + else: + logger.error(f"Unexpected JSON structure from LLM answer for task={synthesis_task_id}, file={file_id}, chunk={chunk_index}: {type(parsed)}") + return + + if not items: + return + + # 3. 找到对应的 chunk 记录(一个 chunk_index 对应一条记录) + chunk_result = await self.db.execute( + select(DataSynthesisChunkInstance, DataSynthesisFileInstance) + .join( + DataSynthesisFileInstance, + DataSynthesisFileInstance.id == DataSynthesisChunkInstance.synthesis_file_instance_id, + ) + .where( + DataSynthesisFileInstance.synthesis_instance_id == synthesis_task_id, + DataSynthesisFileInstance.source_file_id == file_id, + DataSynthesisChunkInstance.chunk_index == chunk_index, + ) + ) + row = chunk_result.first() + if not row: + logger.error( + f"Chunk record not found for task={synthesis_task_id}, file={file_id}, chunk_index={chunk_index}, skip saving SynthesisData." + ) + return + + chunk_instance, file_instance = row + + # 4. 批量写入 SynthesisData + for data_obj in items: + record = SynthesisData( + id=str(uuid.uuid4()), + data=data_obj, + synthesis_file_instance_id=file_instance.id, + chunk_instance_id=chunk_instance.id, + ) + self.db.add(record) + + # 5. 更新文件级 processed_chunks / 进度 + file_instance.processed_chunks = (file_instance.processed_chunks or 0) + 1 + + + await self.db.commit() + await self.db.refresh(file_instance) + + @staticmethod + def _extract_json_substring(raw: str) -> str: + """从 LLM 的原始回答中提取最可能的 JSON 字符串片段。 + + 处理思路: + - 原始回答可能是:说明文字 + JSON + 说明文字,甚至带有 Markdown 代码块。 + - 优先在文本中查找第一个 '{' 或 '[' 作为 JSON 起始; + - 再从后向前找最后一个 '}' 或 ']' 作为结束; + - 如果找不到合适的边界,就退回原始字符串。 + 该方法不会保证截取的一定是合法 JSON,但能显著提高 json.loads 的成功率。 + """ + if not raw: + return raw + + start = None + end = None + + # 查找第一个 JSON 起始符号 + for i, ch in enumerate(raw): + if ch in "[{": + start = i + break + + # 查找最后一个 JSON 结束符号 + for i in range(len(raw) - 1, -1, -1): + if raw[i] in "]}": + end = i + 1 # 切片是左闭右开 + break + + if start is not None and end is not None and start < end: + return raw[start:end].strip() + + # 兜底:去掉常见 Markdown 包裹(```json ... ```) + stripped = raw.strip() + if stripped.startswith("```"): + # 去掉首尾 ``` 标记 + stripped = stripped.strip("`") + return stripped + + async def _get_or_create_file_instance( + self, + synthesis_task_id: str, + source_file_id: str, + file_path: str, + ) -> DataSynthesisFileInstance: + """根据任务ID和原始文件ID,查找或创建对应的 DataSynthesisFileInstance 记录。 + + - 如果已存在(同一任务 + 同一 source_file_id),直接返回; + - 如果不存在,则创建一条新的文件任务记录,file_name 来自文件路径, + target_file_location 先复用任务的 result_data_location。 + """ + # 尝试查询已有文件任务记录 + result = await self.db.execute( + select(DataSynthesisFileInstance).where( + DataSynthesisFileInstance.synthesis_instance_id == synthesis_task_id, + DataSynthesisFileInstance.source_file_id == source_file_id, + ) + ) + file_task = result.scalar_one_or_none() + if file_task is not None: + return file_task + + # 查询任务以获取 result_data_location + task = await self.db.get(DataSynthesisInstance, synthesis_task_id) + target_location = task.result_data_location if task else "" + + # 创建新的文件任务记录,初始状态为 processing + file_task = DataSynthesisFileInstance( + id=str(uuid.uuid4()), + synthesis_instance_id=synthesis_task_id, + file_name=Path(file_path).name, + source_file_id=source_file_id, + target_file_location=target_location or "", + status="processing", + total_chunks=0, + processed_chunks=0, + created_by="system", + updated_by="system", + ) + self.db.add(file_task) + await self.db.commit() + await self.db.refresh(file_task) + return file_task + + async def _mark_file_failed(self, synthesis_task_id: str, file_id: str, reason: str | None = None) -> None: + """将指定任务下的单个文件任务标记为失败状态,兜底错误处理。 + + - 如果找到对应的 DataSynthesisFileInstance,则更新其 status="failed"。 + - 如果未找到,则静默返回,仅记录日志。 + - reason 参数仅用于日志记录,方便排查。 + """ + try: + result = await self.db.execute( + select(DataSynthesisFileInstance).where( + DataSynthesisFileInstance.synthesis_instance_id == synthesis_task_id, + DataSynthesisFileInstance.source_file_id == file_id, + ) + ) + file_task = result.scalar_one_or_none() + if not file_task: + logger.warning( + f"Failed to mark file as failed: no DataSynthesisFileInstance found for task={synthesis_task_id}, file_id={file_id}, reason={reason}" + ) + return + + file_task.status = "failed" + await self.db.commit() + await self.db.refresh(file_task) + logger.info( + f"Marked file task as failed for task={synthesis_task_id}, file_id={file_id}, reason={reason}" + ) + except Exception as e: + # 兜底日志,避免异常向外传播影响其它文件处理 + logger.exception( + f"Unexpected error when marking file failed for task={synthesis_task_id}, file_id={file_id}, original_reason={reason}, error={e}" + ) diff --git a/runtime/datamate-python/app/module/generation/service/prompt.py b/runtime/datamate-python/app/module/generation/service/prompt.py new file mode 100644 index 0000000..0da60e5 --- /dev/null +++ b/runtime/datamate-python/app/module/generation/service/prompt.py @@ -0,0 +1,73 @@ +from app.module.generation.schema.generation import SynthesisType + +QA_PROMPT=""" +# 角色 +你是一位专业的AI助手,擅长从给定的文本中提取关键信息并创建用于教学和测试的问答对。 + +# 任务 +请根据用户提供的原始文档,生成一系列高质量、多样化的问答对。 + +# 输入文档 +{document} + +# 要求与指令 +1. **问题类型**:生成{synthesis_count - 1}-{synthesis_count + 1}个问答对。问题类型应多样化,包括但不限于: + * **事实性**:基于文本中明确提到的事实。 + * **理解性**:需要理解上下文和概念。 + * **归纳性**:需要总结或归纳多个信息点。 +2. **答案来源**:所有答案必须严格基于提供的文档内容,不得编造原文不存在的信息。 +3. **语言**:请根据输入文档的主要语言进行提问和回答。 +4. **问题质量**:问题应清晰、无歧义,并且是读完文档后自然会产生的问题。 +5. **答案质量**:答案应准确、简洁、完整。 + +# 输出格式 +请严格按照以下JSON格式输出,确保没有额外的解释或标记: +[ +{{"instruction": "问题1","input": "参考内容1","output": "答案1"}}, +{{"instruction": "问题2","input": "参考内容1","output": "答案2"}}, + ... +] +""" + + +COT_PROMPT=""" +# 角色 +你是一位专业的数据合成专家,擅长基于给定的原始文档和 COT(Chain of Thought,思维链)逻辑,生成高质量、符合实际应用场景的 COT 数据。COT 数据需包含清晰的问题、逐步推理过程和最终结论,能完整还原解决问题的思考路径。 + +# 任务 +请根据用户提供的原始文档,生成一系列高质量、多样化的 COT 数据。每个 COT 数据需围绕文档中的关键信息、核心问题或逻辑关联点展开,确保推理过程贴合文档内容,结论准确可靠。 + +# 输入文档 +{document} + +# 要求与指令 +1. **数量要求**:生成 {min\_count}-{max\_count} 条 COT 数据(min\_count={synthesis\_count-1},max\_count={synthesis\_count+1})。 +2. **内容要求**: + * 每条 COT 数据需包含 “问题”“思维链推理”“最终结论” 三部分,逻辑闭环,推理步骤清晰、连贯,不跳跃关键环节。 + * 问题需基于文档中的事实信息、概念关联或逻辑疑问,是读完文档后自然产生的有价值问题(避免无意义或过于简单的问题)。 + * 思维链推理需严格依据文档内容,逐步推导,每一步推理都能对应文档中的具体信息,不编造原文不存在的内容,不主观臆断。 + * 最终结论需简洁、准确,是思维链推理的合理结果,与文档核心信息一致。 +3. **多样化要求**: + * 问题类型多样化,包括但不限于事实查询类、逻辑分析类、原因推导类、方案对比类、结论归纳类。 + * 推理角度多样化,可从不同角色(如项目参与者、需求方、测试人员)或不同维度(如功能实现、进度推进、问题解决)展开推理。 +4. **语言要求**: + * 语言通顺、表达清晰,无歧义,推理过程口语化但不随意,符合正常思考逻辑,最终结论简洁规范。 + * 请根据输入文档的主要语言进行提问和回答。 + +# 输出格式 +请严格按照以下 JSON 格式输出,确保没有额外的解释或标记,每条 COT 数据独立成项: +[ +{{"question": "具体问题","chain_of_thought": "步骤 1:明确问题核心,定位文档中相关信息范围;步骤 2:提取文档中与问题相关的关键信息 1;步骤 3:结合关键信息 1 推导中间结论 1;步骤 4:提取文档中与问题相关的关键信息 2;步骤 5:结合中间结论 1 和关键信息 2 推导中间结论 2;...(逐步推进);步骤 N:汇总所有中间结论,得出最终结论","conclusion": "简洁准确的最终结论"}}, + +{{"question": "具体问题","chain_of_thought": "步骤 1:明确问题核心,定位文档中相关信息范围;步骤 2:提取文档中与问题相关的关键信息 1;步骤 3:结合关键信息 1 推导中间结论 1;步骤 4:提取文档中与问题相关的关键信息 2;步骤 5:结合中间结论 1 和关键信息 2 推导中间结论 2;...(逐步推进);步骤 N:汇总所有中间结论,得出最终结论","conclusion": "简洁准确的最终结论"}}, +... +] +""" + +def get_prompt(synth_type: SynthesisType): + if synth_type == SynthesisType.QA: + return QA_PROMPT + elif synth_type == SynthesisType.COT: + return COT_PROMPT + else: + raise ValueError(f"Unsupported synthesis type: {synth_type}") diff --git a/runtime/datamate-python/app/module/system/service/__init__.py b/runtime/datamate-python/app/module/system/service/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/runtime/datamate-python/app/module/system/service/common_service.py b/runtime/datamate-python/app/module/system/service/common_service.py new file mode 100644 index 0000000..22ed6ec --- /dev/null +++ b/runtime/datamate-python/app/module/system/service/common_service.py @@ -0,0 +1,29 @@ +from typing import Optional + +from langchain_core.language_models import BaseChatModel +from langchain_openai import ChatOpenAI +from pydantic import SecretStr +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db.models.model_config import ModelConfig + + +async def get_model_by_id(db: AsyncSession, model_id: str) -> Optional[ModelConfig]: + """根据模型ID获取 ModelConfig 记录。""" + result = await db.execute(select(ModelConfig).where(ModelConfig.id == model_id)) + return result.scalar_one_or_none() + + +def get_chat_client(model: ModelConfig) -> BaseChatModel: + return ChatOpenAI( + model=model.model_name, + base_url=model.base_url, + api_key=SecretStr(model.api_key), + ) + + +def chat(model: BaseChatModel, prompt: str) -> str: + """使用指定模型进行聊天""" + response = model.invoke(prompt) + return response.content diff --git a/runtime/datamate-python/poetry.lock b/runtime/datamate-python/poetry.lock index 599b16b..1e393c0 100644 --- a/runtime/datamate-python/poetry.lock +++ b/runtime/datamate-python/poetry.lock @@ -152,7 +152,6 @@ description = "Lightweight in-process concurrent programming" optional = false python-versions = ">=3.9" groups = ["main"] -markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\"" files = [ {file = "greenlet-3.2.4-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:8c68325b0d0acf8d91dde4e6f930967dd52a5302cd4062932a6b2e7c2969f47c"}, {file = "greenlet-3.2.4-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:94385f101946790ae13da500603491f04a76b6e4c059dab271b3ce2e283b2590"}, @@ -353,6 +352,25 @@ files = [ [package.extras] all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"] +[[package]] +name = "loguru" +version = "0.7.2" +description = "Python logging made (stupidly) simple" +optional = false +python-versions = ">=3.5" +groups = ["main"] +files = [ + {file = "loguru-0.7.2-py3-none-any.whl", hash = "sha256:003d71e3d3ed35f0f8984898359d65b79e5b21943f78af86aa5491210429b8eb"}, + {file = "loguru-0.7.2.tar.gz", hash = "sha256:e671a53522515f34fd406340ee968cb9ecafbc4b36c679da03c18fd8d0bd51ac"}, +] + +[package.dependencies] +colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""} +win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""} + +[package.extras] +dev = ["Sphinx (==7.2.5) ; python_version >= \"3.9\"", "colorama (==0.4.5) ; python_version < \"3.8\"", "colorama (==0.4.6) ; python_version >= \"3.8\"", "exceptiongroup (==1.1.3) ; python_version >= \"3.7\" and python_version < \"3.11\"", "freezegun (==1.1.0) ; python_version < \"3.8\"", "freezegun (==1.2.2) ; python_version >= \"3.8\"", "mypy (==v0.910) ; python_version < \"3.6\"", "mypy (==v0.971) ; python_version == \"3.6\"", "mypy (==v1.4.1) ; python_version == \"3.7\"", "mypy (==v1.5.1) ; python_version >= \"3.8\"", "pre-commit (==3.4.0) ; python_version >= \"3.8\"", "pytest (==6.1.2) ; python_version < \"3.8\"", "pytest (==7.4.0) ; python_version >= \"3.8\"", "pytest-cov (==2.12.1) ; python_version < \"3.8\"", "pytest-cov (==4.1.0) ; python_version >= \"3.8\"", "pytest-mypy-plugins (==1.9.3) ; python_version >= \"3.6\" and python_version < \"3.8\"", "pytest-mypy-plugins (==3.0.0) ; python_version >= \"3.8\"", "sphinx-autobuild (==2021.3.14) ; python_version >= \"3.9\"", "sphinx-rtd-theme (==1.3.0) ; python_version >= \"3.9\"", "tox (==3.27.1) ; python_version < \"3.8\"", "tox (==4.11.0) ; python_version >= \"3.8\""] + [[package]] name = "pydantic" version = "2.12.4" @@ -1132,7 +1150,23 @@ files = [ {file = "websockets-15.0.1.tar.gz", hash = "sha256:82544de02076bafba038ce055ee6412d68da13ab47f0c60cab827346de828dee"}, ] +[[package]] +name = "win32-setctime" +version = "1.2.0" +description = "A small Python utility to set file creation time on Windows" +optional = false +python-versions = ">=3.5" +groups = ["main"] +markers = "sys_platform == \"win32\"" +files = [ + {file = "win32_setctime-1.2.0-py3-none-any.whl", hash = "sha256:95d644c4e708aba81dc3704a116d8cbc974d70b3bdb8be1d150e36be6e9d1390"}, + {file = "win32_setctime-1.2.0.tar.gz", hash = "sha256:ae1fdf948f5640aae05c511ade119313fb6a30d7eabe25fef9764dca5873c4c0"}, +] + +[package.extras] +dev = ["black (>=19.3b0) ; python_version >= \"3.6\"", "pytest (>=4.6.2)"] + [metadata] lock-version = "2.1" python-versions = ">=3.12" -content-hash = "36f9e7212af4fa5832884a2d39a2e7dfbf668e79949d7e90fc12a9e8c96195a7" +content-hash = "a47a488ea25f1fa4db5439b36e3f00e797788c48ff2b9623e9bae72a61154df8" diff --git a/runtime/datamate-python/pyproject.toml b/runtime/datamate-python/pyproject.toml index 296b50a..01c38b7 100644 --- a/runtime/datamate-python/pyproject.toml +++ b/runtime/datamate-python/pyproject.toml @@ -21,7 +21,13 @@ dependencies = [ "python-multipart (>=0.0.20,<0.0.21)", "python-dotenv (>=1.2.1,<2.0.0)", "python-dateutil (>=2.9.0.post0,<3.0.0)", - "pyyaml (>=6.0.3,<7.0.0)" + "pyyaml (>=6.0.3,<7.0.0)", + "greenlet (>=3.2.4,<4.0.0)", + "loguru (>=0.7.2,<0.7.3)", + "langchain (>=1.0.0)", + "langchain-community (>0.4,<0.4.1)", + "unstructured[all]", + "markdown" ] diff --git a/scripts/db/data-synthesis-init.sql b/scripts/db/data-synthesis-init.sql new file mode 100644 index 0000000..0909d6b --- /dev/null +++ b/scripts/db/data-synthesis-init.sql @@ -0,0 +1,64 @@ +USE datamate; + +-- =============================== +-- t_data_synthesis_instances (数据合成任务表) +create table if not exists t_data_synthesis_instances +( + id VARCHAR(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci PRIMARY KEY COMMENT 'UUID', + name VARCHAR(255) NOT NULL COMMENT '任务名称', + description TEXT COMMENT '任务描述', + status VARCHAR(20) COMMENT '任务状态', + synthesis_type VARCHAR(20) NOT NULL COMMENT '合成类型', + model_id VARCHAR(255) NOT NULL COMMENT '模型ID', + progress INT DEFAULT 0 COMMENT '任务进度(百分比)', + result_data_location VARCHAR(1000) COMMENT '结果数据存储位置', + text_split_config JSON NOT NULL COMMENT '文本切片配置', + synthesis_config JSON NOT NULL COMMENT '合成配置', + source_file_id JSON NOT NULL COMMENT '原始文件ID列表', + total_files INT DEFAULT 0 COMMENT '总文件数', + processed_files INT DEFAULT 0 COMMENT '已处理文件数', + total_chunks INT DEFAULT 0 COMMENT '总文本块数', + processed_chunks INT DEFAULT 0 COMMENT '已处理文本块数', + total_synthesis_data INT DEFAULT 0 COMMENT '总合成数据量', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间', + created_by VARCHAR(255) COMMENT '创建者', + updated_by VARCHAR(255) COMMENT '更新者' +) COMMENT='数据合成任务表(UUID 主键)'; + +-- =============================== +-- t_data_synthesis_file_instances (数据合成文件任务表) +create table if not exists t_data_synthesis_file_instances +( + id VARCHAR(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci PRIMARY KEY COMMENT 'UUID', + synthesis_instance_id VARCHAR(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci COMMENT '数据合成任务ID', + file_name VARCHAR(255) NOT NULL COMMENT '文件名', + source_file_id VARCHAR(255) NOT NULL COMMENT '原始文件ID', + target_file_location VARCHAR(1000) NOT NULL COMMENT '目标文件存储位置', + status VARCHAR(20) COMMENT '任务状态', + total_chunks INT DEFAULT 0 COMMENT '总文本块数', + processed_chunks INT DEFAULT 0 COMMENT '已处理文本块数', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间', + created_by VARCHAR(255) COMMENT '创建者', + updated_by VARCHAR(255) COMMENT '更新者' +) COMMENT='数据合成文件任务表(UUID 主键)'; + + +create table if not exists t_data_synthesis_chunk_instances +( + id VARCHAR(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci PRIMARY KEY COMMENT 'UUID', + synthesis_file_instance_id VARCHAR(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci COMMENT '数据合成文件任务ID', + chunk_index INT COMMENT '分块索引', + chunk_content TEXT COMMENT '分块内容', + metadata JSON COMMENT '分块元数据' +) COMMENT='数据合成分块任务表(UUID 主键)'; + + +create table if not exists t_data_synthesis_data +( + id VARCHAR(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci PRIMARY KEY COMMENT 'UUID', + data json COMMENT '合成的数据', + synthesis_file_instance_id VARCHAR(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci COMMENT '数据合成文件任务ID', + chunk_instance_id VARCHAR(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci COMMENT '分块任务ID' +) COMMENT='数据合成任务队列表(UUID 主键)'; \ No newline at end of file