import csv import csv import datetime import os from io import StringIO from pathlib import Path from fastapi import HTTPException from langchain_community.document_loaders import Docx2txtLoader, PyPDFLoader from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.core.logging import get_logger from app.db.models import Dataset, DatasetFiles from app.module.dataset.schema.pdf_extract import PdfTextExtractResponse logger = get_logger(__name__) PDF_FILE_TYPE = "pdf" DOC_FILE_TYPE = "doc" DOCX_FILE_TYPE = "docx" XLS_FILE_TYPE = "xls" XLSX_FILE_TYPE = "xlsx" CSV_FILE_TYPE = "csv" TEXT_FILE_TYPE = "txt" CSV_FILE_EXTENSION = ".csv" TEXT_FILE_EXTENSION = ".txt" EXCEL_FILE_TYPES = {XLS_FILE_TYPE, XLSX_FILE_TYPE} SUPPORTED_FILE_TYPES = {PDF_FILE_TYPE, DOC_FILE_TYPE, DOCX_FILE_TYPE, XLS_FILE_TYPE, XLSX_FILE_TYPE} PARSER_BY_FILE_TYPE = { PDF_FILE_TYPE: "PyPDFLoader", DOC_FILE_TYPE: "Docx2txtLoader", DOCX_FILE_TYPE: "Docx2txtLoader", XLS_FILE_TYPE: "xlrd", XLSX_FILE_TYPE: "openpyxl", } DEFAULT_EXTENSION_BY_TYPE = { PDF_FILE_TYPE: ".pdf", DOC_FILE_TYPE: ".doc", DOCX_FILE_TYPE: ".docx", XLS_FILE_TYPE: ".xls", XLSX_FILE_TYPE: ".xlsx", } DERIVED_EXTENSION_BY_TYPE = { PDF_FILE_TYPE: TEXT_FILE_EXTENSION, DOC_FILE_TYPE: TEXT_FILE_EXTENSION, DOCX_FILE_TYPE: TEXT_FILE_EXTENSION, XLS_FILE_TYPE: CSV_FILE_EXTENSION, XLSX_FILE_TYPE: CSV_FILE_EXTENSION, } DERIVED_FILE_TYPE_BY_SOURCE = { PDF_FILE_TYPE: TEXT_FILE_TYPE, DOC_FILE_TYPE: TEXT_FILE_TYPE, DOCX_FILE_TYPE: TEXT_FILE_TYPE, XLS_FILE_TYPE: CSV_FILE_TYPE, XLSX_FILE_TYPE: CSV_FILE_TYPE, } DERIVED_METADATA_KEY = "derived_from_file_id" DERIVED_METADATA_NAME_KEY = "derived_from_file_name" DERIVED_METADATA_TYPE_KEY = "derived_from_file_type" DERIVED_METADATA_PARSER_KEY = "parser" class PdfTextExtractService: def __init__(self, db: AsyncSession): self.db = db async def extract_pdf_to_text(self, dataset_id: str, file_id: str) -> PdfTextExtractResponse: dataset = await self._get_dataset(dataset_id) file_record = await self._get_file_record(dataset_id, file_id) self._validate_dataset_and_file(dataset, file_record) file_type = str(getattr(file_record, "file_type", "") or "").lower() derived_file_type = DERIVED_FILE_TYPE_BY_SOURCE.get(file_type, TEXT_FILE_TYPE) assert derived_file_type source_path = self._resolve_source_path(file_record) dataset_path = self._resolve_dataset_path(dataset) target_path = self._resolve_target_path(dataset_path, source_path, file_record, file_id, file_type) existing_record = await self._find_existing_text_record(dataset_id, target_path) if existing_record: return self._build_response(dataset_id, file_id, existing_record) if target_path.exists(): file_size = self._get_file_size(target_path) parser_name = PARSER_BY_FILE_TYPE.get(file_type, "") record = await self._create_text_file_record( dataset, file_record, target_path, file_size, parser_name, derived_file_type ) return self._build_response(dataset_id, file_id, record) text_content, parser_name = self._parse_document(source_path, file_type) assert isinstance(text_content, str) self._write_text_file(target_path, text_content) file_size = self._get_file_size(target_path) record = await self._create_text_file_record( dataset, file_record, target_path, file_size, parser_name, derived_file_type ) return self._build_response(dataset_id, file_id, record) async def _get_dataset(self, dataset_id: str) -> Dataset: result = await self.db.execute(select(Dataset).where(Dataset.id == dataset_id)) dataset = result.scalar_one_or_none() if not dataset: raise HTTPException(status_code=404, detail=f"数据集不存在: {dataset_id}") return dataset async def _get_file_record(self, dataset_id: str, file_id: str) -> DatasetFiles: result = await self.db.execute( select(DatasetFiles).where( DatasetFiles.id == file_id, DatasetFiles.dataset_id == dataset_id, ) ) file_record = result.scalar_one_or_none() if not file_record: raise HTTPException(status_code=404, detail=f"文件不存在: {file_id}") return file_record @staticmethod def _validate_dataset_and_file(dataset: Dataset, file_record: DatasetFiles) -> None: dataset_type = str(getattr(dataset, "dataset_type", "") or "").upper() if dataset_type != "TEXT": raise HTTPException(status_code=400, detail="仅支持文本类型数据集") file_type = str(getattr(file_record, "file_type", "") or "").lower() if file_type not in SUPPORTED_FILE_TYPES: raise HTTPException(status_code=400, detail="仅支持 PDF/DOC/DOCX/XLS/XLSX 文件解析") @staticmethod def _resolve_source_path(file_record: DatasetFiles) -> Path: source_path = Path(str(file_record.file_path)).expanduser().resolve() if not source_path.exists(): raise HTTPException(status_code=404, detail="源文件不存在") return source_path @staticmethod def _resolve_dataset_path(dataset: Dataset) -> Path: dataset_path_value = str(getattr(dataset, "path", "") or "").strip() if not dataset_path_value: raise HTTPException(status_code=500, detail="数据集路径为空") dataset_path = Path(dataset_path_value).expanduser().resolve() dataset_path.mkdir(parents=True, exist_ok=True) return dataset_path @staticmethod def _build_output_filename(file_record: DatasetFiles, file_id: str, file_type: str) -> str: original_name = str(getattr(file_record, "file_name", "") or "").strip() if not original_name: default_extension = DEFAULT_EXTENSION_BY_TYPE.get(file_type, f".{file_type}") original_name = f"{file_id}{default_extension}" derived_extension = DERIVED_EXTENSION_BY_TYPE.get(file_type, TEXT_FILE_EXTENSION) return f"{original_name}{derived_extension}" def _resolve_target_path( self, dataset_path: Path, source_path: Path, file_record: DatasetFiles, file_id: str, file_type: str, ) -> Path: output_name = self._build_output_filename(file_record, file_id, file_type) if dataset_path in source_path.parents: target_dir = source_path.parent else: target_dir = dataset_path target_dir = target_dir.resolve() if target_dir != dataset_path and dataset_path not in target_dir.parents: raise HTTPException(status_code=400, detail="解析文件路径超出数据集目录") target_dir.mkdir(parents=True, exist_ok=True) return target_dir / output_name async def _find_existing_text_record(self, dataset_id: str, target_path: Path) -> DatasetFiles | None: result = await self.db.execute( select(DatasetFiles).where( DatasetFiles.dataset_id == dataset_id, DatasetFiles.file_path == str(target_path), ) ) return result.scalar_one_or_none() @staticmethod def _parse_document(source_path: Path, file_type: str) -> tuple[str, str]: if file_type == PDF_FILE_TYPE: loader = PyPDFLoader(str(source_path)) parser_name = PARSER_BY_FILE_TYPE[PDF_FILE_TYPE] elif file_type in EXCEL_FILE_TYPES: parser_name = PARSER_BY_FILE_TYPE.get(file_type, "excel") csv_content = PdfTextExtractService._parse_excel_to_csv(source_path, file_type) return csv_content, parser_name else: loader = Docx2txtLoader(str(source_path)) parser_name = PARSER_BY_FILE_TYPE.get(file_type, "Docx2txtLoader") docs = loader.load() contents = [doc.page_content for doc in docs if doc.page_content] return "\n\n".join(contents), parser_name @staticmethod def _parse_excel_to_csv(source_path: Path, file_type: str) -> str: output = StringIO(newline="") writer = csv.writer(output) try: if file_type == XLSX_FILE_TYPE: try: from openpyxl import load_workbook except ImportError as exc: raise HTTPException(status_code=500, detail="缺少 openpyxl 依赖") from exc workbook = load_workbook(filename=str(source_path), read_only=True, data_only=True) try: sheet_names = workbook.sheetnames include_sheet_name = len(sheet_names) > 1 for sheet_name in sheet_names: sheet = workbook[sheet_name] for row in sheet.iter_rows(values_only=True): row_values = list(row) if include_sheet_name: row_values.insert(0, sheet_name) writer.writerow(row_values) finally: workbook.close() elif file_type == XLS_FILE_TYPE: try: import xlrd except ImportError as exc: raise HTTPException(status_code=500, detail="缺少 xlrd 依赖") from exc workbook = xlrd.open_workbook(str(source_path)) sheet_names = workbook.sheet_names() include_sheet_name = len(sheet_names) > 1 for sheet_index in range(workbook.nsheets): sheet = workbook.sheet_by_index(sheet_index) for row_index in range(sheet.nrows): row_values = sheet.row_values(row_index) if include_sheet_name: row_values = [sheet.name, *row_values] writer.writerow(row_values) else: raise HTTPException(status_code=400, detail="不支持的 Excel 文件格式") except HTTPException: raise except Exception as exc: logger.error("Excel 转 CSV 失败: %s", exc) raise HTTPException(status_code=500, detail="Excel 转 CSV 失败") from exc return output.getvalue() @staticmethod def _write_text_file(target_path: Path, content: str) -> None: with open(target_path, "w", encoding="utf-8") as handle: handle.write(content or "") @staticmethod def _get_file_size(path: Path) -> int: try: return int(os.path.getsize(path)) except OSError: return 0 async def _create_text_file_record( self, dataset: Dataset, source_file: DatasetFiles, target_path: Path, file_size: int, parser_name: str, derived_file_type: str, ) -> DatasetFiles: assert parser_name assert derived_file_type metadata = { DERIVED_METADATA_KEY: str(getattr(source_file, "id", "")), DERIVED_METADATA_NAME_KEY: str(getattr(source_file, "file_name", "")), DERIVED_METADATA_TYPE_KEY: str(getattr(source_file, "file_type", "")), DERIVED_METADATA_PARSER_KEY: parser_name, } record = DatasetFiles( dataset_id=dataset.id, # type: ignore[arg-type] file_name=target_path.name, file_path=str(target_path), file_type=derived_file_type, file_size=file_size, dataset_filemetadata=metadata, last_access_time=datetime.datetime.now(datetime.UTC), ) self.db.add(record) dataset.file_count = (dataset.file_count or 0) + 1 dataset.size_bytes = (dataset.size_bytes or 0) + file_size dataset.status = "ACTIVE" await self.db.commit() await self.db.refresh(record) return record @staticmethod def _build_response(dataset_id: str, file_id: str, record: DatasetFiles) -> PdfTextExtractResponse: return PdfTextExtractResponse( datasetId=dataset_id, sourceFileId=file_id, textFileId=str(record.id), textFileName=str(record.file_name), textFilePath=str(record.file_path), textFileSize=int(record.file_size or 0), )