Files
DataMate/runtime/datamate-python/app/module/collection/interface/collection.py
hefanli 63f4e3e447 refactor: modify data collection to python implementation (#214)
* feature: LabelStudio jumps without login

* refactor: modify data collection to python implementation

* refactor: modify data collection to python implementation

* refactor: modify data collection to python implementation

* refactor: modify data collection to python implementation

* refactor: modify data collection to python implementation

* refactor: modify data collection to python implementation

* fix: remove terrabase dependency

* feature: add the collection task executions page and the collection template page

* fix: fix the collection task creation

* fix: fix the collection task creation
2025-12-30 18:48:43 +08:00

158 lines
5.0 KiB
Python

import math
import uuid
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.models.data_collection import CollectionTask, TaskExecution, CollectionTemplate
from app.db.session import get_db
from app.module.collection.client.datax_client import DataxClient
from app.module.collection.schema.collection import CollectionTaskBase, CollectionTaskCreate, converter_to_response, \
convert_for_create
from app.module.collection.service.collection import CollectionTaskService
from app.module.shared.schema import StandardResponse, PaginatedData
router = APIRouter(
prefix="/tasks",
tags=["data-collection/tasks"],
)
logger = get_logger(__name__)
@router.post("", response_model=StandardResponse[CollectionTaskBase])
async def create_task(
request: CollectionTaskCreate,
db: AsyncSession = Depends(get_db)
):
"""创建归集任务"""
try:
template = await db.execute(select(CollectionTemplate).where(CollectionTemplate.id == request.template_id))
template = template.scalar_one_or_none()
if not template:
raise HTTPException(status_code=400, detail="Template not found")
task_id = str(uuid.uuid4())
DataxClient.generate_datx_config(request.config, template, f"/dataset/local/{task_id}")
task = convert_for_create(request, task_id)
task.template_name = template.name
task_service = CollectionTaskService(db)
task = await task_service.create_task(task)
task = await db.execute(select(CollectionTask).where(CollectionTask.id == task.id))
task = task.scalar_one_or_none()
await db.commit()
return StandardResponse(
code=200,
message="Success",
data=converter_to_response(task)
)
except HTTPException:
await db.rollback()
raise
except Exception as e:
await db.rollback()
logger.error(f"Failed to create collection task: {str(e)}", e)
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("", response_model=StandardResponse[PaginatedData[CollectionTaskBase]])
async def list_tasks(
page: int = 1,
size: int = 20,
name: Optional[str] = Query(None, description="任务名称模糊查询"),
db: AsyncSession = Depends(get_db)
):
"""分页查询归集任务"""
try:
# 构建查询条件
page = page if page > 0 else 1
size = size if size > 0 else 20
query = select(CollectionTask)
if name:
query = query.where(CollectionTask.name.ilike(f"%{name}%"))
# 获取总数
count_query = select(func.count()).select_from(query.subquery())
total = (await db.execute(count_query)).scalar_one()
# 分页查询
offset = (page - 1) * size
tasks = (await db.execute(
query.order_by(CollectionTask.created_at.desc())
.offset(offset)
.limit(size)
)).scalars().all()
# 转换为响应模型
items = [converter_to_response(task) for task in tasks]
total_pages = math.ceil(total / size) if total > 0 else 0
return StandardResponse(
code=200,
message="Success",
data=PaginatedData(
content=items,
total_elements=total,
total_pages=total_pages,
page=page,
size=size,
)
)
except Exception as e:
logger.error(f"Failed to list evaluation tasks: {str(e)}", e)
raise HTTPException(status_code=500, detail="Internal server error")
@router.delete("", response_model=StandardResponse[str], status_code=200)
async def delete_collection_tasks(
ids: list[str] = Query(..., description="要删除的任务ID列表"),
db: AsyncSession = Depends(get_db),
):
"""
删除归集任务
Args:
ids: 任务ID
db: 数据库会话
Returns:
StandardResponse[str]: 删除结果
"""
try:
# 检查任务是否存在
task_id = ids[0]
task = await db.get(CollectionTask, task_id)
if not task:
raise HTTPException(status_code=404, detail="Collection task not found")
# 删除任务执行记录
await db.execute(
TaskExecution.__table__.delete()
.where(TaskExecution.task_id == task_id)
)
# 删除任务
await db.delete(task)
await db.commit()
return StandardResponse(
code=200,
message="Collection task deleted successfully",
data="success"
)
except HTTPException:
await db.rollback()
raise
except Exception as e:
await db.rollback()
logger.error(f"Failed to delete collection task: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")