You've already forked DataMate
核心功能:
- 三层检索策略:向量检索(Milvus)+ 图检索(KG 服务)+ 融合排序
- LLM 生成:支持同步和流式(SSE)响应
- 知识库访问控制:knowledge_base_id 归属校验 + collection_name 绑定验证
新增模块(9个文件):
- models.py: 请求/响应模型(GraphRAGQueryRequest, RetrievalStrategy, GraphContext 等)
- milvus_client.py: Milvus 向量检索客户端(OpenAI Embeddings + asyncio.to_thread)
- kg_client.py: KG 服务 REST 客户端(全文检索 + 子图导出,fail-open)
- context_builder.py: 三元组文本化(10 种关系模板)+ 上下文构建
- generator.py: LLM 生成(ChatOpenAI,支持同步和流式)
- retriever.py: 检索编排(并行检索 + 融合排序)
- kb_access.py: 知识库访问校验(归属验证 + collection 绑定,fail-close)
- interface.py: FastAPI 端点(/query, /retrieve, /query/stream)
- __init__.py: 模块入口
修改文件(3个):
- app/core/config.py: 添加 13 个 graphrag_* 配置项
- app/module/__init__.py: 注册 kg_graphrag_router
- pyproject.toml: 添加 pymilvus 依赖
测试覆盖(79 tests):
- test_context_builder.py: 13 tests(三元组文本化 + 上下文构建)
- test_kg_client.py: 14 tests(KG 响应解析 + PagedResponse + 边字段映射)
- test_milvus_client.py: 8 tests(向量检索 + asyncio.to_thread)
- test_retriever.py: 11 tests(并行检索 + 融合排序 + fail-open)
- test_kb_access.py: 18 tests(归属校验 + collection 绑定 + 跨用户负例)
- test_interface.py: 15 tests(端点级回归 + 403 short-circuit)
关键设计:
- Fail-open: Milvus/KG 服务失败不阻塞管道,返回空结果
- Fail-close: 访问控制失败拒绝请求,防止授权绕过
- 并行检索: asyncio.gather() 并发运行向量和图检索
- 融合排序: Min-max 归一化 + 加权融合(vector_weight/graph_weight)
- 延迟初始化: 所有客户端在首次请求时初始化
- 配置回退: graphrag_llm_* 为空时回退到 kg_llm_*
安全修复:
- P1-1: KG 响应解析(PagedResponse.content)
- P1-2: 子图边字段映射(sourceEntityId/targetEntityId)
- P1-3: collection_name 越权风险(归属校验 + 绑定验证)
- P1-4: 同步 Milvus I/O(asyncio.to_thread)
- P1-5: 测试覆盖(79 tests,包括安全负例)
测试结果:79 tests pass ✅
102 lines
3.1 KiB
Python
102 lines
3.1 KiB
Python
"""LLM 生成器。
|
|
|
|
基于增强上下文(向量 + 图谱)调用 LLM 生成回答,
|
|
支持同步和流式两种模式。
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import AsyncIterator
|
|
|
|
from pydantic import SecretStr
|
|
|
|
from app.core.logging import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
_SYSTEM_PROMPT = (
|
|
"你是 DataMate 数据管理平台的智能助手。请根据以下上下文信息回答用户的问题。\n"
|
|
"如果上下文中没有相关信息,请明确说明。不要编造信息。"
|
|
)
|
|
|
|
|
|
class GraphRAGGenerator:
|
|
"""GraphRAG LLM 生成器。"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
model: str = "gpt-4o-mini",
|
|
base_url: str | None = None,
|
|
api_key: SecretStr = SecretStr("EMPTY"),
|
|
temperature: float = 0.1,
|
|
timeout: int = 60,
|
|
) -> None:
|
|
self._model = model
|
|
self._base_url = base_url
|
|
self._api_key = api_key
|
|
self._temperature = temperature
|
|
self._timeout = timeout
|
|
self._llm = None
|
|
|
|
@property
|
|
def model_name(self) -> str:
|
|
return self._model
|
|
|
|
@classmethod
|
|
def from_settings(cls) -> GraphRAGGenerator:
|
|
from app.core.config import settings
|
|
|
|
model = settings.graphrag_llm_model or settings.kg_llm_model
|
|
base_url = settings.graphrag_llm_base_url or settings.kg_llm_base_url
|
|
api_key = (
|
|
settings.graphrag_llm_api_key
|
|
if settings.graphrag_llm_api_key.get_secret_value() != "EMPTY"
|
|
else settings.kg_llm_api_key
|
|
)
|
|
return cls(
|
|
model=model,
|
|
base_url=base_url,
|
|
api_key=api_key,
|
|
temperature=settings.graphrag_llm_temperature,
|
|
timeout=settings.graphrag_llm_timeout_seconds,
|
|
)
|
|
|
|
def _get_llm(self):
|
|
if self._llm is None:
|
|
from langchain_openai import ChatOpenAI
|
|
|
|
self._llm = ChatOpenAI(
|
|
model=self._model,
|
|
base_url=self._base_url,
|
|
api_key=self._api_key,
|
|
temperature=self._temperature,
|
|
timeout=self._timeout,
|
|
)
|
|
return self._llm
|
|
|
|
def _build_messages(self, query: str, context: str) -> list[dict[str, str]]:
|
|
return [
|
|
{"role": "system", "content": _SYSTEM_PROMPT},
|
|
{
|
|
"role": "user",
|
|
"content": f"{context}\n\n用户问题: {query}\n\n请基于上下文中的信息回答。",
|
|
},
|
|
]
|
|
|
|
async def generate(self, query: str, context: str) -> str:
|
|
"""基于增强上下文生成回答。"""
|
|
messages = self._build_messages(query, context)
|
|
llm = self._get_llm()
|
|
response = await llm.ainvoke(messages)
|
|
return str(response.content)
|
|
|
|
async def generate_stream(self, query: str, context: str) -> AsyncIterator[str]:
|
|
"""基于增强上下文流式生成回答,逐 token 返回。"""
|
|
messages = self._build_messages(query, context)
|
|
llm = self._get_llm()
|
|
async for chunk in llm.astream(messages):
|
|
content = chunk.content
|
|
if content:
|
|
yield str(content)
|