Files
DataMate/runtime/datamate-python/app/module/dataset/service/pdf_extract.py
Jerry Yan f77fd99c31 feat(data-management): 扩展文档解析功能支持DOC和DOCX格式
- 添加对DOC和DOCX文件类型的常量定义和支持
- 将文件类型验证逻辑从仅PDF扩展为PDF/DOC/DOCX
- 集成Docx2txtLoader用于处理Word文档解析
- 更新错误消息为中文描述以提升用户体验
- 重构文件解析方法以支持多种文档格式
- 添加解析器元数据记录以追踪使用的解析工具
- 更新文件路径验证和构建逻辑以适配新的文件类型
2026-01-29 13:05:58 +08:00

215 lines
8.7 KiB
Python

import datetime
import os
from pathlib import Path
from fastapi import HTTPException
from langchain_community.document_loaders import Docx2txtLoader, PyPDFLoader
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.models import Dataset, DatasetFiles
from app.module.dataset.schema.pdf_extract import PdfTextExtractResponse
logger = get_logger(__name__)
PDF_FILE_TYPE = "pdf"
DOC_FILE_TYPE = "doc"
DOCX_FILE_TYPE = "docx"
TEXT_FILE_TYPE = "txt"
TEXT_FILE_EXTENSION = ".txt"
SUPPORTED_FILE_TYPES = {PDF_FILE_TYPE, DOC_FILE_TYPE, DOCX_FILE_TYPE}
PARSER_BY_FILE_TYPE = {
PDF_FILE_TYPE: "PyPDFLoader",
DOC_FILE_TYPE: "Docx2txtLoader",
DOCX_FILE_TYPE: "Docx2txtLoader",
}
DEFAULT_EXTENSION_BY_TYPE = {
PDF_FILE_TYPE: ".pdf",
DOC_FILE_TYPE: ".doc",
DOCX_FILE_TYPE: ".docx",
}
DERIVED_METADATA_KEY = "derived_from_file_id"
DERIVED_METADATA_NAME_KEY = "derived_from_file_name"
DERIVED_METADATA_TYPE_KEY = "derived_from_file_type"
DERIVED_METADATA_PARSER_KEY = "parser"
class PdfTextExtractService:
def __init__(self, db: AsyncSession):
self.db = db
async def extract_pdf_to_text(self, dataset_id: str, file_id: str) -> PdfTextExtractResponse:
dataset = await self._get_dataset(dataset_id)
file_record = await self._get_file_record(dataset_id, file_id)
self._validate_dataset_and_file(dataset, file_record)
file_type = str(getattr(file_record, "file_type", "") or "").lower()
source_path = self._resolve_source_path(file_record)
dataset_path = self._resolve_dataset_path(dataset)
target_path = self._resolve_target_path(dataset_path, source_path, file_record, file_id)
existing_record = await self._find_existing_text_record(dataset_id, target_path)
if existing_record:
return self._build_response(dataset_id, file_id, existing_record)
if target_path.exists():
file_size = self._get_file_size(target_path)
parser_name = PARSER_BY_FILE_TYPE.get(file_type, "")
record = await self._create_text_file_record(dataset, file_record, target_path, file_size, parser_name)
return self._build_response(dataset_id, file_id, record)
text_content, parser_name = self._parse_document(source_path, file_type)
assert isinstance(text_content, str)
self._write_text_file(target_path, text_content)
file_size = self._get_file_size(target_path)
record = await self._create_text_file_record(dataset, file_record, target_path, file_size, parser_name)
return self._build_response(dataset_id, file_id, record)
async def _get_dataset(self, dataset_id: str) -> Dataset:
result = await self.db.execute(select(Dataset).where(Dataset.id == dataset_id))
dataset = result.scalar_one_or_none()
if not dataset:
raise HTTPException(status_code=404, detail=f"数据集不存在: {dataset_id}")
return dataset
async def _get_file_record(self, dataset_id: str, file_id: str) -> DatasetFiles:
result = await self.db.execute(
select(DatasetFiles).where(
DatasetFiles.id == file_id,
DatasetFiles.dataset_id == dataset_id,
)
)
file_record = result.scalar_one_or_none()
if not file_record:
raise HTTPException(status_code=404, detail=f"文件不存在: {file_id}")
return file_record
@staticmethod
def _validate_dataset_and_file(dataset: Dataset, file_record: DatasetFiles) -> None:
dataset_type = str(getattr(dataset, "dataset_type", "") or "").upper()
if dataset_type != "TEXT":
raise HTTPException(status_code=400, detail="仅支持文本类型数据集")
file_type = str(getattr(file_record, "file_type", "") or "").lower()
if file_type not in SUPPORTED_FILE_TYPES:
raise HTTPException(status_code=400, detail="仅支持 PDF/DOC/DOCX 文件解析")
@staticmethod
def _resolve_source_path(file_record: DatasetFiles) -> Path:
source_path = Path(str(file_record.file_path)).expanduser().resolve()
if not source_path.exists():
raise HTTPException(status_code=404, detail="源文件不存在")
return source_path
@staticmethod
def _resolve_dataset_path(dataset: Dataset) -> Path:
dataset_path_value = str(getattr(dataset, "path", "") or "").strip()
if not dataset_path_value:
raise HTTPException(status_code=500, detail="数据集路径为空")
dataset_path = Path(dataset_path_value).expanduser().resolve()
dataset_path.mkdir(parents=True, exist_ok=True)
return dataset_path
@staticmethod
def _build_output_filename(file_record: DatasetFiles, file_id: str) -> str:
original_name = str(getattr(file_record, "file_name", "") or "").strip()
if not original_name:
file_type = str(getattr(file_record, "file_type", "") or "").lower()
default_extension = DEFAULT_EXTENSION_BY_TYPE.get(file_type, f".{file_type}")
original_name = f"{file_id}{default_extension}"
return f"{original_name}{TEXT_FILE_EXTENSION}"
def _resolve_target_path(
self,
dataset_path: Path,
source_path: Path,
file_record: DatasetFiles,
file_id: str,
) -> Path:
output_name = self._build_output_filename(file_record, file_id)
if dataset_path in source_path.parents:
target_dir = source_path.parent
else:
target_dir = dataset_path
target_dir = target_dir.resolve()
if target_dir != dataset_path and dataset_path not in target_dir.parents:
raise HTTPException(status_code=400, detail="解析文件路径超出数据集目录")
target_dir.mkdir(parents=True, exist_ok=True)
return target_dir / output_name
async def _find_existing_text_record(self, dataset_id: str, target_path: Path) -> DatasetFiles | None:
result = await self.db.execute(
select(DatasetFiles).where(
DatasetFiles.dataset_id == dataset_id,
DatasetFiles.file_path == str(target_path),
)
)
return result.scalar_one_or_none()
@staticmethod
def _parse_document(source_path: Path, file_type: str) -> tuple[str, str]:
if file_type == PDF_FILE_TYPE:
loader = PyPDFLoader(str(source_path))
parser_name = PARSER_BY_FILE_TYPE[PDF_FILE_TYPE]
else:
loader = Docx2txtLoader(str(source_path))
parser_name = PARSER_BY_FILE_TYPE.get(file_type, "Docx2txtLoader")
docs = loader.load()
contents = [doc.page_content for doc in docs if doc.page_content]
return "\n\n".join(contents), parser_name
@staticmethod
def _write_text_file(target_path: Path, content: str) -> None:
with open(target_path, "w", encoding="utf-8") as handle:
handle.write(content or "")
@staticmethod
def _get_file_size(path: Path) -> int:
try:
return int(os.path.getsize(path))
except OSError:
return 0
async def _create_text_file_record(
self,
dataset: Dataset,
source_file: DatasetFiles,
target_path: Path,
file_size: int,
parser_name: str,
) -> DatasetFiles:
assert parser_name
metadata = {
DERIVED_METADATA_KEY: str(getattr(source_file, "id", "")),
DERIVED_METADATA_NAME_KEY: str(getattr(source_file, "file_name", "")),
DERIVED_METADATA_TYPE_KEY: str(getattr(source_file, "file_type", "")),
DERIVED_METADATA_PARSER_KEY: parser_name,
}
record = DatasetFiles(
dataset_id=dataset.id, # type: ignore[arg-type]
file_name=target_path.name,
file_path=str(target_path),
file_type=TEXT_FILE_TYPE,
file_size=file_size,
dataset_filemetadata=metadata,
last_access_time=datetime.datetime.now(datetime.UTC),
)
self.db.add(record)
dataset.file_count = (dataset.file_count or 0) + 1
dataset.size_bytes = (dataset.size_bytes or 0) + file_size
dataset.status = "ACTIVE"
await self.db.commit()
await self.db.refresh(record)
return record
@staticmethod
def _build_response(dataset_id: str, file_id: str, record: DatasetFiles) -> PdfTextExtractResponse:
return PdfTextExtractResponse(
datasetId=dataset_id,
sourceFileId=file_id,
textFileId=str(record.id),
textFileName=str(record.file_name),
textFilePath=str(record.file_path),
textFileSize=int(record.file_size or 0),
)