from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy import func from typing import Optional, List, Dict, Any from datetime import datetime from app.core.config import settings from app.core.logging import get_logger from app.db.models import Dataset, DatasetFiles from ..schema import DatasetResponse, PagedDatasetFileResponse, DatasetFileResponse logger = get_logger(__name__) class Service: """数据管理服务客户端 - 直接访问数据库""" def __init__(self, db: AsyncSession): """ 初始化 DM 客户端 Args: db: 数据库会话 """ self.db = db logger.debug("Initialize DM service client (Database mode)") async def get_dataset(self, dataset_id: str) -> Optional[DatasetResponse]: """获取数据集详情""" try: logger.debug(f"Getting dataset detail: {dataset_id} ...") result = await self.db.execute( select(Dataset).where(Dataset.id == dataset_id) ) dataset = result.scalar_one_or_none() if not dataset: logger.error(f"Dataset not found: {dataset_id}") return None # 将数据库模型转换为响应模型 # type: ignore 用于忽略 SQLAlchemy 的类型检查问题 return DatasetResponse( id=dataset.id, # type: ignore name=dataset.name, # type: ignore description=dataset.description or "", # type: ignore datasetType=dataset.dataset_type, # type: ignore status=dataset.status, # type: ignore fileCount=dataset.file_count or 0, # type: ignore totalSize=dataset.size_bytes or 0, # type: ignore createdAt=dataset.created_at, # type: ignore updatedAt=dataset.updated_at, # type: ignore createdBy=dataset.created_by # type: ignore ) except Exception as e: logger.error(f"Failed to get dataset {dataset_id}: {e}") return None async def get_dataset_files( self, dataset_id: str, page: int = 0, size: int = 100, file_type: Optional[str] = None, status: Optional[str] = None ) -> Optional[PagedDatasetFileResponse]: """获取数据集文件列表""" try: logger.debug(f"Get dataset files: dataset={dataset_id}, page={page}, size={size}") # 构建查询 query = select(DatasetFiles).where(DatasetFiles.dataset_id == dataset_id) # 添加可选过滤条件 if file_type: query = query.where(DatasetFiles.file_type == file_type) if status: query = query.where(DatasetFiles.status == status) # 获取总数 count_query = select(func.count()).select_from(DatasetFiles).where( DatasetFiles.dataset_id == dataset_id ) if file_type: count_query = count_query.where(DatasetFiles.file_type == file_type) if status: count_query = count_query.where(DatasetFiles.status == status) count_result = await self.db.execute(count_query) total = count_result.scalar_one() # 分页查询 query = query.offset(page * size).limit(size).order_by(DatasetFiles.created_at.desc()) result = await self.db.execute(query) files = result.scalars().all() # 转换为响应模型 # type: ignore 用于忽略 SQLAlchemy 的类型检查问题 content = [ DatasetFileResponse( id=f.id, # type: ignore fileName=f.file_name, # type: ignore fileType=f.file_type or "", # type: ignore filePath=f.file_path, # type: ignore originalName=f.file_name, # type: ignore size=f.file_size, # type: ignore status=f.status, # type: ignore uploadedAt=f.upload_time, # type: ignore description=None, uploadedBy=None, lastAccessTime=f.last_access_time, # type: ignore tags=f.tags, # type: ignore tags_updated_at=f.tags_updated_at # type: ignore ) for f in files ] total_pages = (total + size - 1) // size if size > 0 else 0 return PagedDatasetFileResponse( content=content, totalElements=total, totalPages=total_pages, page=page, size=size ) except Exception as e: logger.error(f"Failed to get dataset files for {dataset_id}: {e}") return None async def download_file(self, dataset_id: str, file_id: str) -> Optional[bytes]: """ 下载文件内容 注意:此方法保留接口兼容性,但实际文件下载可能需要通过文件系统或对象存储 """ logger.warning(f"download_file is deprecated when using database mode. Use get_file_download_url instead.") return None async def get_file_download_url(self, dataset_id: str, file_id: str) -> Optional[str]: """获取文件下载URL(或文件路径)""" try: result = await self.db.execute( select(DatasetFiles).where( DatasetFiles.id == file_id, DatasetFiles.dataset_id == dataset_id ) ) file = result.scalar_one_or_none() if not file: logger.error(f"File not found: {file_id} in dataset {dataset_id}") return None # 返回文件路径(可以是本地路径或对象存储URL) return file.file_path # type: ignore except Exception as e: logger.error(f"Failed to get file path for {file_id}: {e}") return None async def close(self): """关闭客户端连接(数据库模式下无需操作)""" logger.info("DM service client closed (Database mode)") async def update_file_tags_partial( self, file_id: str, new_tags: List[Dict[str, Any]] ) -> tuple[bool, Optional[str], Optional[datetime]]: """ 部分更新文件标签 Args: file_id: 文件ID new_tags: 新的标签列表(部分更新) Returns: (成功标志, 错误信息, 更新时间) """ try: logger.info(f"Partial updating tags for file: {file_id}") # 获取文件记录 result = await self.db.execute( select(DatasetFiles).where(DatasetFiles.id == file_id) ) file_record = result.scalar_one_or_none() if not file_record: logger.error(f"File not found: {file_id}") return False, f"File not found: {file_id}", None # 获取现有标签 existing_tags: List[Dict[str, Any]] = file_record.tags or [] # type: ignore # 创建标签ID到索引的映射 tag_id_map = {tag.get('id'): idx for idx, tag in enumerate(existing_tags) if tag.get('id')} # 更新或追加标签 for new_tag in new_tags: tag_id = new_tag.get('id') if tag_id and tag_id in tag_id_map: # 更新现有标签 idx = tag_id_map[tag_id] existing_tags[idx] = new_tag logger.debug(f"Updated existing tag with id: {tag_id}") else: # 追加新标签 existing_tags.append(new_tag) logger.debug(f"Added new tag with id: {tag_id}") # 更新数据库 update_time = datetime.utcnow() file_record.tags = existing_tags # type: ignore file_record.tags_updated_at = update_time # type: ignore await self.db.commit() await self.db.refresh(file_record) logger.info(f"Successfully updated tags for file: {file_id}") return True, None, update_time except Exception as e: logger.error(f"Failed to update tags for file {file_id}: {e}") await self.db.rollback() return False, str(e), None