You've already forked DataMate
refactor: modify data collection to python implementation (#214)
* feature: LabelStudio jumps without login * refactor: modify data collection to python implementation * refactor: modify data collection to python implementation * refactor: modify data collection to python implementation * refactor: modify data collection to python implementation * refactor: modify data collection to python implementation * refactor: modify data collection to python implementation * fix: remove terrabase dependency * feature: add the collection task executions page and the collection template page * fix: fix the collection task creation * fix: fix the collection task creation
This commit is contained in:
@@ -0,0 +1,15 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/data-collection",
|
||||
tags = ["data-collection"]
|
||||
)
|
||||
|
||||
# Include sub-routers
|
||||
from .collection import router as collection_router
|
||||
from .execution import router as execution_router
|
||||
from .template import router as template_router
|
||||
|
||||
router.include_router(collection_router)
|
||||
router.include_router(execution_router)
|
||||
router.include_router(template_router)
|
||||
@@ -0,0 +1,157 @@
|
||||
import math
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.db.models.data_collection import CollectionTask, TaskExecution, CollectionTemplate
|
||||
from app.db.session import get_db
|
||||
from app.module.collection.client.datax_client import DataxClient
|
||||
from app.module.collection.schema.collection import CollectionTaskBase, CollectionTaskCreate, converter_to_response, \
|
||||
convert_for_create
|
||||
from app.module.collection.service.collection import CollectionTaskService
|
||||
from app.module.shared.schema import StandardResponse, PaginatedData
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/tasks",
|
||||
tags=["data-collection/tasks"],
|
||||
)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@router.post("", response_model=StandardResponse[CollectionTaskBase])
|
||||
async def create_task(
|
||||
request: CollectionTaskCreate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""创建归集任务"""
|
||||
try:
|
||||
template = await db.execute(select(CollectionTemplate).where(CollectionTemplate.id == request.template_id))
|
||||
template = template.scalar_one_or_none()
|
||||
if not template:
|
||||
raise HTTPException(status_code=400, detail="Template not found")
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
DataxClient.generate_datx_config(request.config, template, f"/dataset/local/{task_id}")
|
||||
task = convert_for_create(request, task_id)
|
||||
task.template_name = template.name
|
||||
|
||||
task_service = CollectionTaskService(db)
|
||||
task = await task_service.create_task(task)
|
||||
|
||||
task = await db.execute(select(CollectionTask).where(CollectionTask.id == task.id))
|
||||
task = task.scalar_one_or_none()
|
||||
await db.commit()
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="Success",
|
||||
data=converter_to_response(task)
|
||||
)
|
||||
except HTTPException:
|
||||
await db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Failed to create collection task: {str(e)}", e)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("", response_model=StandardResponse[PaginatedData[CollectionTaskBase]])
|
||||
async def list_tasks(
|
||||
page: int = 1,
|
||||
size: int = 20,
|
||||
name: Optional[str] = Query(None, description="任务名称模糊查询"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""分页查询归集任务"""
|
||||
try:
|
||||
# 构建查询条件
|
||||
page = page if page > 0 else 1
|
||||
size = size if size > 0 else 20
|
||||
query = select(CollectionTask)
|
||||
|
||||
if name:
|
||||
query = query.where(CollectionTask.name.ilike(f"%{name}%"))
|
||||
|
||||
# 获取总数
|
||||
count_query = select(func.count()).select_from(query.subquery())
|
||||
total = (await db.execute(count_query)).scalar_one()
|
||||
|
||||
# 分页查询
|
||||
offset = (page - 1) * size
|
||||
tasks = (await db.execute(
|
||||
query.order_by(CollectionTask.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(size)
|
||||
)).scalars().all()
|
||||
|
||||
# 转换为响应模型
|
||||
items = [converter_to_response(task) for task in tasks]
|
||||
total_pages = math.ceil(total / size) if total > 0 else 0
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="Success",
|
||||
data=PaginatedData(
|
||||
content=items,
|
||||
total_elements=total,
|
||||
total_pages=total_pages,
|
||||
page=page,
|
||||
size=size,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list evaluation tasks: {str(e)}", e)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.delete("", response_model=StandardResponse[str], status_code=200)
|
||||
async def delete_collection_tasks(
|
||||
ids: list[str] = Query(..., description="要删除的任务ID列表"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
删除归集任务
|
||||
|
||||
Args:
|
||||
ids: 任务ID
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
StandardResponse[str]: 删除结果
|
||||
"""
|
||||
try:
|
||||
# 检查任务是否存在
|
||||
task_id = ids[0]
|
||||
task = await db.get(CollectionTask, task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Collection task not found")
|
||||
|
||||
# 删除任务执行记录
|
||||
await db.execute(
|
||||
TaskExecution.__table__.delete()
|
||||
.where(TaskExecution.task_id == task_id)
|
||||
)
|
||||
|
||||
# 删除任务
|
||||
await db.delete(task)
|
||||
await db.commit()
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="Collection task deleted successfully",
|
||||
data="success"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
await db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Failed to delete collection task: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
@@ -0,0 +1,120 @@
|
||||
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.db.models.data_collection import TaskExecution
|
||||
from app.db.session import get_db
|
||||
from app.module.collection.schema.collection import TaskExecutionBase, converter_execution_to_response
|
||||
from app.module.shared.schema import StandardResponse, PaginatedData
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/executions",
|
||||
tags=["data-collection/executions"],
|
||||
)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@router.get("", response_model=StandardResponse[PaginatedData[TaskExecutionBase]])
|
||||
async def list_executions(
|
||||
page: int = 1,
|
||||
size: int = 20,
|
||||
task_id: Optional[str] = Query(None, description="任务ID"),
|
||||
task_name: Optional[str] = Query(None, description="任务名称模糊查询"),
|
||||
start_time: Optional[datetime] = Query(None, description="开始执行时间范围-起(started_at >= start_time)"),
|
||||
end_time: Optional[datetime] = Query(None, description="开始执行时间范围-止(started_at <= end_time)"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""分页查询归集任务执行记录"""
|
||||
try:
|
||||
query = select(TaskExecution)
|
||||
|
||||
if task_id:
|
||||
query = query.where(TaskExecution.task_id == task_id)
|
||||
|
||||
if task_name:
|
||||
query = query.where(TaskExecution.task_name.ilike(f"%{task_name}%"))
|
||||
|
||||
if start_time:
|
||||
query = query.where(TaskExecution.started_at >= start_time)
|
||||
|
||||
if end_time:
|
||||
query = query.where(TaskExecution.started_at <= end_time)
|
||||
|
||||
count_query = select(func.count()).select_from(query.subquery())
|
||||
total = (await db.execute(count_query)).scalar_one()
|
||||
|
||||
offset = (page - 1) * size
|
||||
executions = (await db.execute(
|
||||
query.order_by(TaskExecution.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(size)
|
||||
)).scalars().all()
|
||||
|
||||
items = [converter_execution_to_response(exe) for exe in executions]
|
||||
total_pages = math.ceil(total / size) if total > 0 else 0
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="Success",
|
||||
data=PaginatedData(
|
||||
content=items,
|
||||
total_elements=total,
|
||||
total_pages=total_pages,
|
||||
page=page,
|
||||
size=size,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list task executions: {str(e)}", e)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("/{execution_id}/log")
|
||||
async def get_execution_log(
|
||||
execution_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取执行记录对应的日志文件内容"""
|
||||
try:
|
||||
execution = await db.get(TaskExecution, execution_id)
|
||||
if not execution:
|
||||
raise HTTPException(status_code=404, detail="Execution record not found")
|
||||
|
||||
log_path = getattr(execution, "log_path", None)
|
||||
if not log_path:
|
||||
raise HTTPException(status_code=404, detail="Log path not found")
|
||||
|
||||
path = Path(str(log_path))
|
||||
if not path.is_absolute():
|
||||
path = Path(os.getcwd()) / path
|
||||
path = path.resolve()
|
||||
|
||||
if not path.exists() or not path.is_file():
|
||||
raise HTTPException(status_code=404, detail="Log file not found")
|
||||
|
||||
filename = path.name
|
||||
headers = {
|
||||
"Content-Disposition": f'inline; filename="{filename}"'
|
||||
}
|
||||
return FileResponse(
|
||||
path=str(path),
|
||||
media_type="text/plain; charset=utf-8",
|
||||
filename=filename,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get execution log: {str(e)}", e)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
@@ -0,0 +1,67 @@
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.db.models.data_collection import CollectionTemplate
|
||||
from app.db.session import get_db
|
||||
from app.module.collection.schema.collection import CollectionTemplateBase, converter_template_to_response
|
||||
from app.module.shared.schema import StandardResponse, PaginatedData
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/templates",
|
||||
tags=["data-collection/templates"],
|
||||
)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@router.get("", response_model=StandardResponse[PaginatedData[CollectionTemplateBase]])
|
||||
async def list_templates(
|
||||
page: int = 1,
|
||||
size: int = 20,
|
||||
name: Optional[str] = Query(None, description="模板名称模糊查询"),
|
||||
built_in: Optional[bool] = Query(None, description="是否系统内置模板"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""分页查询归集任务模板"""
|
||||
try:
|
||||
query = select(CollectionTemplate)
|
||||
|
||||
if name:
|
||||
query = query.where(CollectionTemplate.name.ilike(f"%{name}%"))
|
||||
|
||||
if built_in is not None:
|
||||
query = query.where(CollectionTemplate.built_in == built_in)
|
||||
|
||||
count_query = select(func.count()).select_from(query.subquery())
|
||||
total = (await db.execute(count_query)).scalar_one()
|
||||
|
||||
offset = (page - 1) * size
|
||||
templates = (await db.execute(
|
||||
query.order_by(CollectionTemplate.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(size)
|
||||
)).scalars().all()
|
||||
|
||||
items = [converter_template_to_response(tpl) for tpl in templates]
|
||||
total_pages = math.ceil(total / size) if total > 0 else 0
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="Success",
|
||||
data=PaginatedData(
|
||||
content=items,
|
||||
total_elements=total,
|
||||
total_pages=total_pages,
|
||||
page=page,
|
||||
size=size,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list collection templates: {str(e)}", e)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
Reference in New Issue
Block a user