You've already forked DataMate
feat:问题生成过程优化及COT数据生成优化 (#169)
* fix(chart): update Helm chart helpers and values for improved configuration * feat(SynthesisTaskTab): enhance task table with tooltip support and improved column widths * feat(CreateTask, SynthFileTask): improve task creation and detail view with enhanced payload handling and UI updates * feat(SynthFileTask): enhance file display with progress tracking and delete action * feat(SynthFileTask): enhance file display with progress tracking and delete action * feat(SynthDataDetail): add delete action for chunks with confirmation prompt * feat(SynthDataDetail): update edit and delete buttons to icon-only format * feat(SynthDataDetail): add confirmation modals for chunk and synthesis data deletion * feat(DocumentSplitter): add enhanced document splitting functionality with CJK support and metadata detection * feat(DataSynthesis): refactor data synthesis models and update task handling logic * feat(DataSynthesis): streamline synthesis task handling and enhance chunk processing logic * feat(DataSynthesis): refactor data synthesis models and update task handling logic * fix(generation_service): ensure processed chunks are incremented regardless of question generation success * feat(CreateTask): enhance task creation with new synthesis templates and improved configuration options * feat(CreateTask): enhance task creation with new synthesis templates and improved configuration options * feat(CreateTask): enhance task creation with new synthesis templates and improved configuration options * feat(CreateTask): enhance task creation with new synthesis templates and improved configuration options
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import uuid
|
||||
from typing import cast
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
|
||||
from sqlalchemy import select, func, delete
|
||||
@@ -7,13 +8,12 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.core.logging import get_logger
|
||||
from app.db.models.data_synthesis import (
|
||||
save_synthesis_task,
|
||||
DataSynthesisInstance,
|
||||
DataSynthInstance,
|
||||
DataSynthesisFileInstance,
|
||||
DataSynthesisChunkInstance,
|
||||
SynthesisData,
|
||||
)
|
||||
from app.db.models.dataset_management import DatasetFiles
|
||||
from app.db.models.model_config import get_model_by_id
|
||||
from app.db.session import get_db
|
||||
from app.module.generation.schema.generation import (
|
||||
CreateSynthesisTaskRequest,
|
||||
@@ -28,9 +28,9 @@ from app.module.generation.schema.generation import (
|
||||
SynthesisDataUpdateRequest,
|
||||
BatchDeleteSynthesisDataRequest,
|
||||
)
|
||||
from app.module.generation.service.export_service import SynthesisDatasetExporter, SynthesisExportError
|
||||
from app.module.generation.service.generation_service import GenerationService
|
||||
from app.module.generation.service.prompt import get_prompt
|
||||
from app.module.generation.service.export_service import SynthesisDatasetExporter, SynthesisExportError
|
||||
from app.module.shared.schema import StandardResponse
|
||||
|
||||
router = APIRouter(
|
||||
@@ -47,10 +47,6 @@ async def create_synthesis_task(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""创建数据合成任务"""
|
||||
result = await get_model_by_id(db, request.model_id)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
# 先根据 source_file_id 在 DatasetFiles 中查出已有文件信息
|
||||
file_ids = request.source_file_id or []
|
||||
dataset_files = []
|
||||
@@ -65,32 +61,48 @@ async def create_synthesis_task(
|
||||
synthesis_task = await save_synthesis_task(db, request)
|
||||
|
||||
# 将已有的 DatasetFiles 记录保存到 t_data_synthesis_file_instances
|
||||
synth_files = []
|
||||
for f in dataset_files:
|
||||
file_instance = DataSynthesisFileInstance(
|
||||
id=str(uuid.uuid4()), # 使用新的 UUID 作为文件任务记录的主键,避免与 DatasetFiles 主键冲突
|
||||
synthesis_instance_id=synthesis_task.id,
|
||||
file_name=f.file_name,
|
||||
source_file_id=str(f.id),
|
||||
target_file_location=synthesis_task.result_data_location or "",
|
||||
status="pending",
|
||||
total_chunks=0,
|
||||
processed_chunks=0,
|
||||
created_by="system",
|
||||
updated_by="system",
|
||||
)
|
||||
db.add(file_instance)
|
||||
synth_files.append(file_instance)
|
||||
|
||||
if dataset_files:
|
||||
db.add_all(synth_files)
|
||||
await db.commit()
|
||||
|
||||
generation_service = GenerationService(db)
|
||||
# 异步处理任务:只传任务 ID,后台任务中使用新的 DB 会话重新加载任务对象
|
||||
background_tasks.add_task(generation_service.process_task, synthesis_task.id)
|
||||
|
||||
# 将 ORM 对象包装成 DataSynthesisTaskItem,兼容新字段从 synth_config 还原
|
||||
|
||||
task_item = DataSynthesisTaskItem(
|
||||
id=synthesis_task.id,
|
||||
name=synthesis_task.name,
|
||||
description=synthesis_task.description,
|
||||
status=synthesis_task.status,
|
||||
synthesis_type=synthesis_task.synth_type,
|
||||
total_files=synthesis_task.total_files,
|
||||
created_at=synthesis_task.created_at,
|
||||
updated_at=synthesis_task.updated_at,
|
||||
created_by=synthesis_task.created_by,
|
||||
updated_by=synthesis_task.updated_by,
|
||||
)
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=synthesis_task,
|
||||
data=task_item,
|
||||
)
|
||||
|
||||
|
||||
@@ -100,14 +112,26 @@ async def get_synthesis_task(
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取数据合成任务详情"""
|
||||
result = await db.get(DataSynthesisInstance, task_id)
|
||||
if not result:
|
||||
synthesis_task = await db.get(DataSynthInstance, task_id)
|
||||
if not synthesis_task:
|
||||
raise HTTPException(status_code=404, detail="Synthesis task not found")
|
||||
|
||||
task_item = DataSynthesisTaskItem(
|
||||
id=synthesis_task.id,
|
||||
name=synthesis_task.name,
|
||||
description=synthesis_task.description,
|
||||
status=synthesis_task.status,
|
||||
synthesis_type=synthesis_task.synth_type,
|
||||
total_files=synthesis_task.total_files,
|
||||
created_at=synthesis_task.created_at,
|
||||
updated_at=synthesis_task.updated_at,
|
||||
created_by=synthesis_task.created_by,
|
||||
updated_by=synthesis_task.updated_by,
|
||||
)
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=result,
|
||||
data=task_item,
|
||||
)
|
||||
|
||||
|
||||
@@ -121,16 +145,16 @@ async def list_synthesis_tasks(
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""分页列出所有数据合成任务,默认按创建时间倒序"""
|
||||
query = select(DataSynthesisInstance)
|
||||
query = select(DataSynthInstance)
|
||||
if synthesis_type:
|
||||
query = query.filter(DataSynthesisInstance.synthesis_type == synthesis_type)
|
||||
query = query.filter(DataSynthInstance.synth_type == synthesis_type)
|
||||
if status:
|
||||
query = query.filter(DataSynthesisInstance.status == status)
|
||||
query = query.filter(DataSynthInstance.status == status)
|
||||
if name:
|
||||
query = query.filter(DataSynthesisInstance.name.like(f"%{name}%"))
|
||||
query = query.filter(DataSynthInstance.name.like(f"%{name}%"))
|
||||
|
||||
# 默认按创建时间倒序排列
|
||||
query = query.order_by(DataSynthesisInstance.created_at.desc())
|
||||
query = query.order_by(DataSynthInstance.created_at.desc())
|
||||
|
||||
count_q = select(func.count()).select_from(query.subquery())
|
||||
total = (await db.execute(count_q)).scalar_one()
|
||||
@@ -143,31 +167,39 @@ async def list_synthesis_tasks(
|
||||
result = await db.execute(query.offset((page - 1) * page_size).limit(page_size))
|
||||
rows = result.scalars().all()
|
||||
|
||||
task_items = [
|
||||
DataSynthesisTaskItem(
|
||||
id=row.id,
|
||||
name=row.name,
|
||||
description=row.description,
|
||||
status=row.status,
|
||||
synthesis_type=row.synthesis_type,
|
||||
model_id=row.model_id,
|
||||
progress=row.progress,
|
||||
result_data_location=row.result_data_location,
|
||||
text_split_config=row.text_split_config,
|
||||
synthesis_config=row.synthesis_config,
|
||||
source_file_id=row.source_file_id,
|
||||
total_files=row.total_files,
|
||||
processed_files=row.processed_files,
|
||||
total_chunks=row.total_chunks,
|
||||
processed_chunks=row.processed_chunks,
|
||||
total_synthesis_data=row.total_synthesis_data,
|
||||
created_at=row.created_at,
|
||||
updated_at=row.updated_at,
|
||||
created_by=row.created_by,
|
||||
updated_by=row.updated_by,
|
||||
task_items: list[DataSynthesisTaskItem] = []
|
||||
for row in rows:
|
||||
synth_cfg = getattr(row, "synth_config", {}) or {}
|
||||
text_split_cfg = synth_cfg.get("text_split_config") or {}
|
||||
synthesis_cfg = synth_cfg.get("synthesis_config") or {}
|
||||
source_file_ids = synth_cfg.get("source_file_id") or []
|
||||
model_id = synth_cfg.get("model_id")
|
||||
result_location = synth_cfg.get("result_data_location")
|
||||
|
||||
task_items.append(
|
||||
DataSynthesisTaskItem(
|
||||
id=str(row.id),
|
||||
name=str(row.name),
|
||||
description=cast(str | None, row.description),
|
||||
status=cast(str | None, row.status),
|
||||
synthesis_type=str(row.synth_type),
|
||||
model_id=model_id or "",
|
||||
progress=int(cast(int, row.progress)),
|
||||
result_data_location=result_location,
|
||||
text_split_config=text_split_cfg,
|
||||
synthesis_config=synthesis_cfg,
|
||||
source_file_id=list(source_file_ids),
|
||||
total_files=int(cast(int, row.total_files)),
|
||||
processed_files=int(cast(int, row.processed_files)),
|
||||
total_chunks=int(cast(int, row.total_chunks)),
|
||||
processed_chunks=int(cast(int, row.processed_chunks)),
|
||||
total_synthesis_data=int(cast(int, row.total_synth_data)),
|
||||
created_at=row.created_at,
|
||||
updated_at=row.updated_at,
|
||||
created_by=row.created_by,
|
||||
updated_by=row.updated_by,
|
||||
)
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
paged = PagedDataSynthesisTaskResponse(
|
||||
content=task_items,
|
||||
@@ -190,7 +222,7 @@ async def delete_synthesis_task(
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""删除数据合成任务"""
|
||||
task = await db.get(DataSynthesisInstance, task_id)
|
||||
task = await db.get(DataSynthInstance, task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Synthesis task not found")
|
||||
|
||||
@@ -241,7 +273,7 @@ async def delete_synthesis_file_task(
|
||||
):
|
||||
"""删除数据合成任务中的文件任务,同时刷新任务表中的文件/切片数量"""
|
||||
# 先获取任务和文件任务记录
|
||||
task = await db.get(DataSynthesisInstance, task_id)
|
||||
task = await db.get(DataSynthInstance, task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Synthesis task not found")
|
||||
|
||||
@@ -306,7 +338,7 @@ async def list_synthesis_file_tasks(
|
||||
):
|
||||
"""分页获取某个数据合成任务下的文件任务列表"""
|
||||
# 先校验任务是否存在
|
||||
task = await db.get(DataSynthesisInstance, task_id)
|
||||
task = await db.get(DataSynthInstance, task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Synthesis task not found")
|
||||
|
||||
@@ -333,7 +365,6 @@ async def list_synthesis_file_tasks(
|
||||
synthesis_instance_id=row.synthesis_instance_id,
|
||||
file_name=row.file_name,
|
||||
source_file_id=row.source_file_id,
|
||||
target_file_location=row.target_file_location,
|
||||
status=row.status,
|
||||
total_chunks=row.total_chunks,
|
||||
processed_chunks=row.processed_chunks,
|
||||
@@ -523,7 +554,7 @@ async def delete_synthesis_data_by_chunk(
|
||||
result = await db.execute(
|
||||
delete(SynthesisData).where(SynthesisData.chunk_instance_id == chunk_id)
|
||||
)
|
||||
deleted = result.rowcount or 0
|
||||
deleted = int(getattr(result, "rowcount", 0) or 0)
|
||||
|
||||
await db.commit()
|
||||
|
||||
@@ -542,7 +573,7 @@ async def batch_delete_synthesis_data(
|
||||
result = await db.execute(
|
||||
delete(SynthesisData).where(SynthesisData.id.in_(request.ids))
|
||||
)
|
||||
deleted = result.rowcount or 0
|
||||
deleted = int(getattr(result, "rowcount", 0) or 0)
|
||||
await db.commit()
|
||||
|
||||
return StandardResponse(code=200, message="success", data=deleted)
|
||||
|
||||
Reference in New Issue
Block a user