You've already forked DataMate
refactor: modify data collection to python implementation (#214)
* feature: LabelStudio jumps without login * refactor: modify data collection to python implementation * refactor: modify data collection to python implementation * refactor: modify data collection to python implementation * refactor: modify data collection to python implementation * refactor: modify data collection to python implementation * refactor: modify data collection to python implementation * fix: remove terrabase dependency * feature: add the collection task executions page and the collection template page * fix: fix the collection task creation * fix: fix the collection task creation
This commit is contained in:
66
runtime/datamate-python/app/db/models/data_collection.py
Normal file
66
runtime/datamate-python/app/db/models/data_collection.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import uuid
|
||||
from sqlalchemy import Column, String, Text, TIMESTAMP, Integer, BigInteger, Numeric, JSON, Boolean
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.db.session import Base
|
||||
|
||||
class CollectionTemplate(Base):
|
||||
"""归集模板表(UUID 主键) -> t_dc_collection_templates"""
|
||||
|
||||
__tablename__ = "t_dc_collection_templates"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="模板ID(UUID)")
|
||||
name = Column(String(255), nullable=False, comment="模板名称")
|
||||
description = Column(Text, nullable=True, comment="模板描述")
|
||||
source_type = Column(String(64), nullable=False, comment="源数据源类型")
|
||||
source_name = Column(String(64), nullable=False, comment="源数据源名称")
|
||||
target_type = Column(String(64), nullable=False, comment="目标数据源类型")
|
||||
target_name = Column(String(64), nullable=False, comment="目标数据源名称")
|
||||
template_content = Column(JSON, nullable=False, comment="模板内容")
|
||||
built_in = Column(Boolean, default=False, comment="是否系统内置模板")
|
||||
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
|
||||
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
|
||||
created_by = Column(String(255), nullable=True, comment="创建者")
|
||||
updated_by = Column(String(255), nullable=True, comment="更新者")
|
||||
|
||||
class CollectionTask(Base):
|
||||
"""归集任务表(UUID 主键) -> t_dc_collection_tasks"""
|
||||
|
||||
__tablename__ = "t_dc_collection_tasks"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID")
|
||||
name = Column(String(255), nullable=False, comment="任务名称")
|
||||
description = Column(Text, nullable=True, comment="任务描述")
|
||||
sync_mode = Column(String(20), nullable=False, server_default="ONCE", comment="同步模式:ONCE/SCHEDULED")
|
||||
template_id = Column(String(36), nullable=False, comment="归集模板ID")
|
||||
template_name = Column(String(255), nullable=False, comment="归集模板名称")
|
||||
target_path = Column(String(1000), nullable=True, server_default="", comment="目标存储路径")
|
||||
config = Column(JSON, nullable=False, comment="归集配置(DataX配置),包含源端和目标端配置信息")
|
||||
schedule_expression = Column(String(255), nullable=True, comment="Cron调度表达式")
|
||||
status = Column(String(20), nullable=True, server_default="DRAFT", comment="任务状态:DRAFT/READY/RUNNING/SUCCESS/FAILED/STOPPED")
|
||||
retry_count = Column(Integer, nullable=True, server_default="3", comment="重试次数")
|
||||
timeout_seconds = Column(Integer, nullable=True, server_default="3600", comment="超时时间(秒)")
|
||||
last_execution_id = Column(String(36), nullable=True, comment="最后执行ID(UUID)")
|
||||
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
|
||||
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
|
||||
created_by = Column(String(255), nullable=True, comment="创建者")
|
||||
updated_by = Column(String(255), nullable=True, comment="更新者")
|
||||
|
||||
class TaskExecution(Base):
|
||||
"""任务执行记录表(UUID 主键) -> t_dc_task_executions"""
|
||||
|
||||
__tablename__ = "t_dc_task_executions"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="执行记录ID(UUID)")
|
||||
task_id = Column(String(36), nullable=False, comment="任务ID")
|
||||
task_name = Column(String(255), nullable=False, comment="任务名称")
|
||||
status = Column(String(20), nullable=True, server_default="RUNNING", comment="执行状态:RUNNING/SUCCESS/FAILED/STOPPED")
|
||||
log_path = Column(String(1000), nullable=True, server_default="", comment="日志文件路径")
|
||||
started_at = Column(TIMESTAMP, nullable=True, comment="开始时间")
|
||||
completed_at = Column(TIMESTAMP, nullable=True, comment="完成时间")
|
||||
duration_seconds = Column(Integer, nullable=True, server_default="0", comment="执行时长(秒)")
|
||||
error_message = Column(Text, nullable=True, comment="错误信息")
|
||||
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
|
||||
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
|
||||
created_by = Column(String(255), nullable=True, comment="创建者")
|
||||
updated_by = Column(String(255), nullable=True, comment="更新者")
|
||||
@@ -5,6 +5,7 @@ from .annotation.interface import router as annotation_router
|
||||
from .ratio.interface import router as ratio_router
|
||||
from .generation.interface import router as generation_router
|
||||
from .evaluation.interface import router as evaluation_router
|
||||
from .collection.interface import router as collection_route
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api"
|
||||
@@ -15,5 +16,6 @@ router.include_router(annotation_router)
|
||||
router.include_router(ratio_router)
|
||||
router.include_router(generation_router)
|
||||
router.include_router(evaluation_router)
|
||||
router.include_router(collection_route)
|
||||
|
||||
__all__ = ["router"]
|
||||
|
||||
@@ -0,0 +1,200 @@
|
||||
import json
|
||||
import threading
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.db.models.data_collection import CollectionTask, TaskExecution, CollectionTemplate
|
||||
from app.module.collection.schema.collection import CollectionConfig, SyncMode
|
||||
from app.module.shared.schema import TaskStatus
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class DataxClient:
|
||||
def __init__(self, task: CollectionTask, execution: TaskExecution):
|
||||
self.execution = execution
|
||||
self.task = task
|
||||
self.config_file_path = f"/flow/data-collection/{task.id}/config.json"
|
||||
self.python_path = "python"
|
||||
self.datax_main = "/opt/datax/bin/datax.py"
|
||||
Path(self.config_file_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def validate_json_string(self) -> Dict[str, Any]:
|
||||
"""
|
||||
验证 JSON 字符串
|
||||
|
||||
Returns:
|
||||
解析后的配置字典
|
||||
"""
|
||||
try:
|
||||
config = json.loads(self.task.config)
|
||||
|
||||
# 基本验证
|
||||
if 'job' not in config:
|
||||
raise ValueError("JSON 必须包含 'job' 字段")
|
||||
|
||||
if 'content' not in config.get('job', {}):
|
||||
raise ValueError("job 必须包含 'content' 字段")
|
||||
|
||||
logger.info("JSON 配置验证通过")
|
||||
return config
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"JSON 格式错误: {e}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"配置验证失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
def generate_datx_config(task_config: CollectionConfig, template: CollectionTemplate, target_path: str):
|
||||
# 校验参数
|
||||
reader_parameter = {
|
||||
**(task_config.parameter if task_config.parameter else {}),
|
||||
**(task_config.reader if task_config.reader else {})
|
||||
}
|
||||
writer_parameter = {
|
||||
**(task_config.parameter if task_config.parameter else {}),
|
||||
**(task_config.writer if task_config.writer else {}),
|
||||
"destPath": target_path
|
||||
}
|
||||
# 生成任务运行配置
|
||||
job_config = {
|
||||
"content": [
|
||||
{
|
||||
"reader": {
|
||||
"name": template.source_type,
|
||||
"parameter": reader_parameter
|
||||
},
|
||||
"writer": {
|
||||
"name": template.target_type,
|
||||
"parameter": writer_parameter
|
||||
}
|
||||
}
|
||||
],
|
||||
"setting": {
|
||||
"speed": {
|
||||
"channel": 2
|
||||
}
|
||||
}
|
||||
}
|
||||
task_config.job = job_config
|
||||
|
||||
def create__config_file(self) -> str:
|
||||
"""
|
||||
创建配置文件
|
||||
|
||||
Returns:
|
||||
临时文件路径
|
||||
"""
|
||||
# 验证 JSON
|
||||
config = self.validate_json_string()
|
||||
|
||||
# 写入临时文件
|
||||
with open(self.config_file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(config, f, indent=2, ensure_ascii=False)
|
||||
|
||||
logger.debug(f"创建配置文件: {self.config_file_path}")
|
||||
return self.config_file_path
|
||||
|
||||
def run_datax_job(self):
|
||||
"""
|
||||
启动 DataX 任务
|
||||
|
||||
Returns:
|
||||
执行结果字典
|
||||
"""
|
||||
# 创建配置文件
|
||||
self.create__config_file()
|
||||
try:
|
||||
# 构建命令
|
||||
cmd = [self.python_path, str(self.datax_main), str(self.config_file_path)]
|
||||
cmd_str = ' '.join(cmd)
|
||||
logger.info(f"执行命令: {cmd_str}")
|
||||
if not self.execution.started_at:
|
||||
self.execution.started_at = datetime.now()
|
||||
# 执行命令并写入日志
|
||||
with open(self.execution.log_path, 'w', encoding='utf-8') as log_f:
|
||||
# 写入头信息
|
||||
self.write_header_log(cmd_str, log_f)
|
||||
# 启动datax进程
|
||||
exit_code = self._run_process(cmd, log_f)
|
||||
# 记录结束时间
|
||||
self.execution.completed_at = datetime.now()
|
||||
self.execution.duration_seconds = (self.execution.completed_at - self.execution.started_at).total_seconds()
|
||||
# 写入结束信息
|
||||
self.write_tail_log(exit_code, log_f)
|
||||
if exit_code == 0:
|
||||
logger.info(f"DataX 任务执行成功: {self.execution.id}")
|
||||
logger.info(f"执行耗时: {self.execution.duration_seconds:.2f} 秒")
|
||||
self.execution.status = TaskStatus.COMPLETED.name
|
||||
else:
|
||||
self.execution.error_message = self.execution.error_message or f"DataX 任务执行失败,退出码: {exit_code}"
|
||||
self.execution.status = TaskStatus.FAILED.name
|
||||
logger.error(self.execution.error_message)
|
||||
except Exception as e:
|
||||
self.execution.completed_at = datetime.now()
|
||||
self.execution.duration_seconds = (self.execution.completed_at - self.execution.started_at).total_seconds()
|
||||
self.execution.error_message = f"执行异常: {e}"
|
||||
self.execution.status = TaskStatus.FAILED.name
|
||||
logger.error(f"执行异常: {e}", exc_info=True)
|
||||
if self.task.sync_mode == SyncMode.ONCE:
|
||||
self.task.status = self.execution.status
|
||||
|
||||
def _run_process(self, cmd: list[str], log_f) -> int:
|
||||
# 启动进程
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
encoding='utf-8',
|
||||
bufsize=1,
|
||||
universal_newlines=True
|
||||
)
|
||||
|
||||
# 创建读取线程
|
||||
stdout_thread = threading.Thread(target=lambda stream=process.stdout: self.read_stream(stream, log_f))
|
||||
stderr_thread = threading.Thread(target=lambda stream=process.stderr: self.read_stream(stream, log_f))
|
||||
|
||||
stdout_thread.start()
|
||||
stderr_thread.start()
|
||||
|
||||
# 等待进程完成
|
||||
try:
|
||||
exit_code = process.wait(timeout=self.task.timeout_seconds)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
exit_code = -1
|
||||
self.execution.error_message = f"任务执行超时({self.task.timeout_seconds}秒)"
|
||||
logger.error(f"任务执行超时({self.task.timeout_seconds}秒)")
|
||||
|
||||
# 等待线程完成
|
||||
stdout_thread.join(timeout=5)
|
||||
stderr_thread.join(timeout=5)
|
||||
return exit_code
|
||||
|
||||
def write_tail_log(self, exit_code: int, log_f):
|
||||
log_f.write("\n" + "=" * 100 + "\n")
|
||||
log_f.write(f"End Time: {self.execution.completed_at}\n")
|
||||
log_f.write(f"Execution Time: {self.execution.duration_seconds:.2f} seconds\n")
|
||||
log_f.write(f"Exit Code: {exit_code}\n")
|
||||
log_f.write(f"Status: {'SUCCESS' if exit_code == 0 else 'FAILED'}\n")
|
||||
|
||||
def write_header_log(self, cmd: str, log_f):
|
||||
log_f.write(f"DataX Task Execution Log\n")
|
||||
log_f.write(f"Job ID: {self.execution.id}\n")
|
||||
log_f.write(f"Start Time: {self.execution.started_at}\n")
|
||||
log_f.write(f"Config Source: JSON String\n")
|
||||
log_f.write(f"Command: {cmd}\n")
|
||||
log_f.write("=" * 100 + "\n\n")
|
||||
|
||||
@staticmethod
|
||||
def read_stream(stream, log_f):
|
||||
"""读取输出流"""
|
||||
for line in stream:
|
||||
line = line.rstrip('\n')
|
||||
if line:
|
||||
# 写入日志文件
|
||||
log_f.write(f"{line}\n")
|
||||
log_f.flush()
|
||||
@@ -0,0 +1,15 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/data-collection",
|
||||
tags = ["data-collection"]
|
||||
)
|
||||
|
||||
# Include sub-routers
|
||||
from .collection import router as collection_router
|
||||
from .execution import router as execution_router
|
||||
from .template import router as template_router
|
||||
|
||||
router.include_router(collection_router)
|
||||
router.include_router(execution_router)
|
||||
router.include_router(template_router)
|
||||
@@ -0,0 +1,157 @@
|
||||
import math
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.db.models.data_collection import CollectionTask, TaskExecution, CollectionTemplate
|
||||
from app.db.session import get_db
|
||||
from app.module.collection.client.datax_client import DataxClient
|
||||
from app.module.collection.schema.collection import CollectionTaskBase, CollectionTaskCreate, converter_to_response, \
|
||||
convert_for_create
|
||||
from app.module.collection.service.collection import CollectionTaskService
|
||||
from app.module.shared.schema import StandardResponse, PaginatedData
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/tasks",
|
||||
tags=["data-collection/tasks"],
|
||||
)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@router.post("", response_model=StandardResponse[CollectionTaskBase])
|
||||
async def create_task(
|
||||
request: CollectionTaskCreate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""创建归集任务"""
|
||||
try:
|
||||
template = await db.execute(select(CollectionTemplate).where(CollectionTemplate.id == request.template_id))
|
||||
template = template.scalar_one_or_none()
|
||||
if not template:
|
||||
raise HTTPException(status_code=400, detail="Template not found")
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
DataxClient.generate_datx_config(request.config, template, f"/dataset/local/{task_id}")
|
||||
task = convert_for_create(request, task_id)
|
||||
task.template_name = template.name
|
||||
|
||||
task_service = CollectionTaskService(db)
|
||||
task = await task_service.create_task(task)
|
||||
|
||||
task = await db.execute(select(CollectionTask).where(CollectionTask.id == task.id))
|
||||
task = task.scalar_one_or_none()
|
||||
await db.commit()
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="Success",
|
||||
data=converter_to_response(task)
|
||||
)
|
||||
except HTTPException:
|
||||
await db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Failed to create collection task: {str(e)}", e)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("", response_model=StandardResponse[PaginatedData[CollectionTaskBase]])
|
||||
async def list_tasks(
|
||||
page: int = 1,
|
||||
size: int = 20,
|
||||
name: Optional[str] = Query(None, description="任务名称模糊查询"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""分页查询归集任务"""
|
||||
try:
|
||||
# 构建查询条件
|
||||
page = page if page > 0 else 1
|
||||
size = size if size > 0 else 20
|
||||
query = select(CollectionTask)
|
||||
|
||||
if name:
|
||||
query = query.where(CollectionTask.name.ilike(f"%{name}%"))
|
||||
|
||||
# 获取总数
|
||||
count_query = select(func.count()).select_from(query.subquery())
|
||||
total = (await db.execute(count_query)).scalar_one()
|
||||
|
||||
# 分页查询
|
||||
offset = (page - 1) * size
|
||||
tasks = (await db.execute(
|
||||
query.order_by(CollectionTask.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(size)
|
||||
)).scalars().all()
|
||||
|
||||
# 转换为响应模型
|
||||
items = [converter_to_response(task) for task in tasks]
|
||||
total_pages = math.ceil(total / size) if total > 0 else 0
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="Success",
|
||||
data=PaginatedData(
|
||||
content=items,
|
||||
total_elements=total,
|
||||
total_pages=total_pages,
|
||||
page=page,
|
||||
size=size,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list evaluation tasks: {str(e)}", e)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.delete("", response_model=StandardResponse[str], status_code=200)
|
||||
async def delete_collection_tasks(
|
||||
ids: list[str] = Query(..., description="要删除的任务ID列表"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
删除归集任务
|
||||
|
||||
Args:
|
||||
ids: 任务ID
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
StandardResponse[str]: 删除结果
|
||||
"""
|
||||
try:
|
||||
# 检查任务是否存在
|
||||
task_id = ids[0]
|
||||
task = await db.get(CollectionTask, task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Collection task not found")
|
||||
|
||||
# 删除任务执行记录
|
||||
await db.execute(
|
||||
TaskExecution.__table__.delete()
|
||||
.where(TaskExecution.task_id == task_id)
|
||||
)
|
||||
|
||||
# 删除任务
|
||||
await db.delete(task)
|
||||
await db.commit()
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="Collection task deleted successfully",
|
||||
data="success"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
await db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Failed to delete collection task: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
@@ -0,0 +1,120 @@
|
||||
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.db.models.data_collection import TaskExecution
|
||||
from app.db.session import get_db
|
||||
from app.module.collection.schema.collection import TaskExecutionBase, converter_execution_to_response
|
||||
from app.module.shared.schema import StandardResponse, PaginatedData
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/executions",
|
||||
tags=["data-collection/executions"],
|
||||
)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@router.get("", response_model=StandardResponse[PaginatedData[TaskExecutionBase]])
|
||||
async def list_executions(
|
||||
page: int = 1,
|
||||
size: int = 20,
|
||||
task_id: Optional[str] = Query(None, description="任务ID"),
|
||||
task_name: Optional[str] = Query(None, description="任务名称模糊查询"),
|
||||
start_time: Optional[datetime] = Query(None, description="开始执行时间范围-起(started_at >= start_time)"),
|
||||
end_time: Optional[datetime] = Query(None, description="开始执行时间范围-止(started_at <= end_time)"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""分页查询归集任务执行记录"""
|
||||
try:
|
||||
query = select(TaskExecution)
|
||||
|
||||
if task_id:
|
||||
query = query.where(TaskExecution.task_id == task_id)
|
||||
|
||||
if task_name:
|
||||
query = query.where(TaskExecution.task_name.ilike(f"%{task_name}%"))
|
||||
|
||||
if start_time:
|
||||
query = query.where(TaskExecution.started_at >= start_time)
|
||||
|
||||
if end_time:
|
||||
query = query.where(TaskExecution.started_at <= end_time)
|
||||
|
||||
count_query = select(func.count()).select_from(query.subquery())
|
||||
total = (await db.execute(count_query)).scalar_one()
|
||||
|
||||
offset = (page - 1) * size
|
||||
executions = (await db.execute(
|
||||
query.order_by(TaskExecution.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(size)
|
||||
)).scalars().all()
|
||||
|
||||
items = [converter_execution_to_response(exe) for exe in executions]
|
||||
total_pages = math.ceil(total / size) if total > 0 else 0
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="Success",
|
||||
data=PaginatedData(
|
||||
content=items,
|
||||
total_elements=total,
|
||||
total_pages=total_pages,
|
||||
page=page,
|
||||
size=size,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list task executions: {str(e)}", e)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("/{execution_id}/log")
|
||||
async def get_execution_log(
|
||||
execution_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取执行记录对应的日志文件内容"""
|
||||
try:
|
||||
execution = await db.get(TaskExecution, execution_id)
|
||||
if not execution:
|
||||
raise HTTPException(status_code=404, detail="Execution record not found")
|
||||
|
||||
log_path = getattr(execution, "log_path", None)
|
||||
if not log_path:
|
||||
raise HTTPException(status_code=404, detail="Log path not found")
|
||||
|
||||
path = Path(str(log_path))
|
||||
if not path.is_absolute():
|
||||
path = Path(os.getcwd()) / path
|
||||
path = path.resolve()
|
||||
|
||||
if not path.exists() or not path.is_file():
|
||||
raise HTTPException(status_code=404, detail="Log file not found")
|
||||
|
||||
filename = path.name
|
||||
headers = {
|
||||
"Content-Disposition": f'inline; filename="{filename}"'
|
||||
}
|
||||
return FileResponse(
|
||||
path=str(path),
|
||||
media_type="text/plain; charset=utf-8",
|
||||
filename=filename,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get execution log: {str(e)}", e)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
@@ -0,0 +1,67 @@
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.db.models.data_collection import CollectionTemplate
|
||||
from app.db.session import get_db
|
||||
from app.module.collection.schema.collection import CollectionTemplateBase, converter_template_to_response
|
||||
from app.module.shared.schema import StandardResponse, PaginatedData
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/templates",
|
||||
tags=["data-collection/templates"],
|
||||
)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@router.get("", response_model=StandardResponse[PaginatedData[CollectionTemplateBase]])
|
||||
async def list_templates(
|
||||
page: int = 1,
|
||||
size: int = 20,
|
||||
name: Optional[str] = Query(None, description="模板名称模糊查询"),
|
||||
built_in: Optional[bool] = Query(None, description="是否系统内置模板"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""分页查询归集任务模板"""
|
||||
try:
|
||||
query = select(CollectionTemplate)
|
||||
|
||||
if name:
|
||||
query = query.where(CollectionTemplate.name.ilike(f"%{name}%"))
|
||||
|
||||
if built_in is not None:
|
||||
query = query.where(CollectionTemplate.built_in == built_in)
|
||||
|
||||
count_query = select(func.count()).select_from(query.subquery())
|
||||
total = (await db.execute(count_query)).scalar_one()
|
||||
|
||||
offset = (page - 1) * size
|
||||
templates = (await db.execute(
|
||||
query.order_by(CollectionTemplate.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(size)
|
||||
)).scalars().all()
|
||||
|
||||
items = [converter_template_to_response(tpl) for tpl in templates]
|
||||
total_pages = math.ceil(total / size) if total > 0 else 0
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="Success",
|
||||
data=PaginatedData(
|
||||
content=items,
|
||||
total_elements=total,
|
||||
total_pages=total_pages,
|
||||
page=page,
|
||||
size=size,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list collection templates: {str(e)}", e)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
@@ -0,0 +1,182 @@
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, validator, ConfigDict
|
||||
from pydantic.alias_generators import to_camel
|
||||
|
||||
from app.db.models.data_collection import CollectionTask, TaskExecution, CollectionTemplate
|
||||
from app.module.shared.schema import TaskStatus
|
||||
|
||||
|
||||
class SyncMode(str, Enum):
|
||||
ONCE = "ONCE"
|
||||
SCHEDULED = "SCHEDULED"
|
||||
|
||||
class CollectionConfig(BaseModel):
|
||||
parameter: Optional[dict] = Field(None, description="模板参数")
|
||||
reader: Optional[dict] = Field(None, description="reader参数")
|
||||
writer: Optional[dict] = Field(None, description="writer参数")
|
||||
job: Optional[dict] = Field(None, description="任务配置")
|
||||
|
||||
class CollectionTaskBase(BaseModel):
|
||||
id: str = Field(..., description="任务id")
|
||||
name: str = Field(..., description="任务名称")
|
||||
description: Optional[str] = Field(None, description="任务描述")
|
||||
target_path: str = Field(..., description="目标存放路径")
|
||||
config: CollectionConfig = Field(..., description="任务配置")
|
||||
template_id: str = Field(..., description="模板ID")
|
||||
template_name: Optional[str] = Field(None, description="模板名称")
|
||||
status: TaskStatus = Field(..., description="任务状态")
|
||||
sync_mode: SyncMode = Field(default=SyncMode.ONCE, description="同步方式")
|
||||
schedule_expression: Optional[str] = Field(None, description="调度表达式(cron)")
|
||||
retry_count: int = Field(default=3, description="重试次数")
|
||||
timeout_seconds: int = Field(default=3600, description="超时时间")
|
||||
last_execution_id: Optional[str] = Field(None, description="最后执行id")
|
||||
created_at: Optional[datetime] = Field(None, description="创建时间")
|
||||
updated_at: Optional[datetime] = Field(None, description="更新时间")
|
||||
created_by: Optional[str] = Field(None, description="创建人")
|
||||
updated_by: Optional[str] = Field(None, description="更新人")
|
||||
|
||||
model_config = ConfigDict(
|
||||
alias_generator=to_camel,
|
||||
populate_by_name=True
|
||||
)
|
||||
|
||||
class CollectionTaskCreate(BaseModel):
|
||||
name: str = Field(..., description="任务名称")
|
||||
description: Optional[str] = Field(None, description="任务描述")
|
||||
sync_mode: SyncMode = Field(default=SyncMode.ONCE, description="同步方式")
|
||||
schedule_expression: Optional[str] = Field(None, description="调度表达式(cron)")
|
||||
config: CollectionConfig = Field(..., description="任务配置")
|
||||
template_id: str = Field(..., description="模板ID")
|
||||
|
||||
model_config = ConfigDict(
|
||||
alias_generator=to_camel,
|
||||
populate_by_name=True
|
||||
)
|
||||
|
||||
def converter_to_response(task: CollectionTask) -> CollectionTaskBase:
|
||||
return CollectionTaskBase(
|
||||
id=task.id,
|
||||
name=task.name,
|
||||
description=task.description,
|
||||
sync_mode=task.sync_mode,
|
||||
template_id=task.template_id,
|
||||
template_name=task.template_name,
|
||||
target_path=task.target_path,
|
||||
config=json.loads(task.config),
|
||||
schedule_expression=task.schedule_expression,
|
||||
status=task.status,
|
||||
retry_count=task.retry_count,
|
||||
timeout_seconds=task.timeout_seconds,
|
||||
last_execution_id=task.last_execution_id,
|
||||
created_at=task.created_at,
|
||||
updated_at=task.updated_at,
|
||||
created_by=task.created_by,
|
||||
updated_by=task.updated_by,
|
||||
)
|
||||
|
||||
def convert_for_create(task: CollectionTaskCreate, task_id: str) -> CollectionTask:
|
||||
return CollectionTask(
|
||||
id=task_id,
|
||||
name=task.name,
|
||||
description=task.description,
|
||||
sync_mode=task.sync_mode,
|
||||
template_id=task.template_id,
|
||||
target_path=f"/dataset/local/{task_id}",
|
||||
config=json.dumps(task.config.dict()),
|
||||
schedule_expression=task.schedule_expression,
|
||||
status=TaskStatus.PENDING.name
|
||||
)
|
||||
|
||||
def create_execute_record(task: CollectionTask) -> TaskExecution:
|
||||
execution_id = str(uuid.uuid4())
|
||||
return TaskExecution(
|
||||
id=execution_id,
|
||||
task_id=task.id,
|
||||
task_name=task.name,
|
||||
status=TaskStatus.RUNNING.name,
|
||||
started_at=datetime.now(),
|
||||
log_path=f"/flow/data-collection/{task.id}/{execution_id}.log"
|
||||
)
|
||||
|
||||
|
||||
class TaskExecutionBase(BaseModel):
|
||||
id: str = Field(..., description="执行记录ID")
|
||||
task_id: str = Field(..., description="任务ID")
|
||||
task_name: str = Field(..., description="任务名称")
|
||||
status: Optional[str] = Field(None, description="执行状态")
|
||||
log_path: Optional[str] = Field(None, description="日志文件路径")
|
||||
started_at: Optional[datetime] = Field(None, description="开始时间")
|
||||
completed_at: Optional[datetime] = Field(None, description="完成时间")
|
||||
duration_seconds: Optional[int] = Field(None, description="执行时长(秒)")
|
||||
error_message: Optional[str] = Field(None, description="错误信息")
|
||||
created_at: Optional[datetime] = Field(None, description="创建时间")
|
||||
updated_at: Optional[datetime] = Field(None, description="更新时间")
|
||||
created_by: Optional[str] = Field(None, description="创建者")
|
||||
updated_by: Optional[str] = Field(None, description="更新者")
|
||||
|
||||
model_config = ConfigDict(
|
||||
alias_generator=to_camel,
|
||||
populate_by_name=True
|
||||
)
|
||||
|
||||
|
||||
def converter_execution_to_response(execution: TaskExecution) -> TaskExecutionBase:
|
||||
return TaskExecutionBase(
|
||||
id=execution.id,
|
||||
task_id=execution.task_id,
|
||||
task_name=execution.task_name,
|
||||
status=execution.status,
|
||||
log_path=execution.log_path,
|
||||
started_at=execution.started_at,
|
||||
completed_at=execution.completed_at,
|
||||
duration_seconds=execution.duration_seconds,
|
||||
error_message=execution.error_message,
|
||||
created_at=execution.created_at,
|
||||
updated_at=execution.updated_at,
|
||||
created_by=execution.created_by,
|
||||
updated_by=execution.updated_by,
|
||||
)
|
||||
|
||||
|
||||
class CollectionTemplateBase(BaseModel):
|
||||
id: str = Field(..., description="模板ID")
|
||||
name: str = Field(..., description="模板名称")
|
||||
description: Optional[str] = Field(None, description="模板描述")
|
||||
source_type: str = Field(..., description="源数据源类型")
|
||||
source_name: str = Field(..., description="源数据源名称")
|
||||
target_type: str = Field(..., description="目标数据源类型")
|
||||
target_name: str = Field(..., description="目标数据源名称")
|
||||
template_content: dict = Field(..., description="模板内容")
|
||||
built_in: Optional[bool] = Field(None, description="是否系统内置模板")
|
||||
created_at: Optional[datetime] = Field(None, description="创建时间")
|
||||
updated_at: Optional[datetime] = Field(None, description="更新时间")
|
||||
created_by: Optional[str] = Field(None, description="创建者")
|
||||
updated_by: Optional[str] = Field(None, description="更新者")
|
||||
|
||||
model_config = ConfigDict(
|
||||
alias_generator=to_camel,
|
||||
populate_by_name=True
|
||||
)
|
||||
|
||||
|
||||
def converter_template_to_response(template: CollectionTemplate) -> CollectionTemplateBase:
|
||||
return CollectionTemplateBase(
|
||||
id=template.id,
|
||||
name=template.name,
|
||||
description=template.description,
|
||||
source_type=template.source_type,
|
||||
source_name=template.source_name,
|
||||
target_type=template.target_type,
|
||||
target_name=template.target_name,
|
||||
template_content=template.template_content,
|
||||
built_in=template.built_in,
|
||||
created_at=template.created_at,
|
||||
updated_at=template.updated_at,
|
||||
created_by=template.created_by,
|
||||
updated_by=template.updated_by,
|
||||
)
|
||||
@@ -0,0 +1,70 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.db.models.data_collection import CollectionTask, CollectionTemplate
|
||||
from app.db.session import AsyncSessionLocal
|
||||
from app.module.collection.client.datax_client import DataxClient
|
||||
from app.module.collection.schema.collection import SyncMode, create_execute_record
|
||||
from app.module.shared.schema import TaskStatus
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _RuntimeTask:
|
||||
id: str
|
||||
config: str
|
||||
timeout_seconds: int
|
||||
sync_mode: str
|
||||
status: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _RuntimeExecution:
|
||||
id: str
|
||||
log_path: str
|
||||
started_at: Optional[Any] = None
|
||||
completed_at: Optional[Any] = None
|
||||
duration_seconds: Optional[float] = None
|
||||
error_message: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
|
||||
class CollectionTaskService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def create_task(self, task: CollectionTask) -> CollectionTask:
|
||||
self.db.add(task)
|
||||
|
||||
# If it's a one-time task, execute it immediately
|
||||
if task.sync_mode == SyncMode.ONCE:
|
||||
task.status = TaskStatus.RUNNING.name
|
||||
await self.db.commit()
|
||||
asyncio.create_task(CollectionTaskService.run_async(task.id))
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
async def run_async(task_id: str):
|
||||
logger.info(f"start to execute task {task_id}")
|
||||
async with AsyncSessionLocal() as session:
|
||||
task = await session.execute(select(CollectionTask).where(CollectionTask.id == task_id))
|
||||
task = task.scalar_one_or_none()
|
||||
if not task:
|
||||
logger.error(f"task {task_id} not exist")
|
||||
return
|
||||
template = await session.execute(select(CollectionTemplate).where(CollectionTemplate.id == task.template_id))
|
||||
if not template:
|
||||
logger.error(f"template {task.template_name} not exist")
|
||||
return
|
||||
task_execution = create_execute_record(task)
|
||||
session.add(task_execution)
|
||||
await session.commit()
|
||||
await asyncio.to_thread(
|
||||
DataxClient(execution=task_execution, task=task).run_datax_job
|
||||
)
|
||||
await session.commit()
|
||||
Reference in New Issue
Block a user