You've already forked DataMate
feat(kg): 实现 Phase 3.3 性能优化
核心功能: - Neo4j 索引优化(entityType, graphId, properties.name) - Redis 缓存(Java 侧,3 个缓存区,TTL 可配置) - LRU 缓存(Python 侧,KG + Embedding,线程安全) - 细粒度缓存清除(graphId 前缀匹配) - 失败路径缓存清除(finally 块) 新增文件(Java 侧,7 个): - V2__PerformanceIndexes.java - Flyway 迁移,创建 3 个索引 - IndexHealthService.java - 索引健康监控 - RedisCacheConfig.java - Spring Cache + Redis 配置 - GraphCacheService.java - 缓存清除管理器 - CacheableIntegrationTest.java - 集成测试(10 tests) - GraphCacheServiceTest.java - 单元测试(19 tests) - V2__PerformanceIndexesTest.java, IndexHealthServiceTest.java 新增文件(Python 侧,2 个): - cache.py - 内存 TTL+LRU 缓存(cachetools) - test_cache.py - 单元测试(20 tests) 修改文件(Java 侧,9 个): - GraphEntityService.java - 添加 @Cacheable,缓存清除 - GraphQueryService.java - 添加 @Cacheable(包含用户权限上下文) - GraphRelationService.java - 添加缓存清除 - GraphSyncService.java - 添加缓存清除(finally 块,失败路径) - KnowledgeGraphProperties.java - 添加 Cache 配置类 - application-knowledgegraph.yml - 添加 Redis 和缓存 TTL 配置 - GraphEntityServiceTest.java - 添加 verify(cacheService) 断言 - GraphRelationServiceTest.java - 添加 verify(cacheService) 断言 - GraphSyncServiceTest.java - 添加失败路径缓存清除测试 修改文件(Python 侧,5 个): - kg_client.py - 集成缓存(fulltext_search, get_subgraph) - interface.py - 添加 /cache/stats 和 /cache/clear 端点 - config.py - 添加缓存配置字段 - pyproject.toml - 添加 cachetools 依赖 - test_kg_client.py - 添加 _disable_cache fixture 安全修复(3 轮迭代): - P0: 缓存 key 用户隔离(防止跨用户数据泄露) - P1-1: 同步子步骤后的缓存清除(18 个方法) - P1-2: 实体创建后的搜索缓存清除 - P1-3: 失败路径缓存清除(finally 块) - P2-1: 细粒度缓存清除(graphId 前缀匹配,避免跨图谱冲刷) - P2-2: 服务层测试添加 verify(cacheService) 断言 测试结果: - Java: 280 tests pass ✅ (270 → 280, +10 new) - Python: 154 tests pass ✅ (140 → 154, +14 new) 缓存配置: - kg:entities - 实体缓存,TTL 1h - kg:queries - 查询结果缓存,TTL 5min - kg:search - 全文搜索缓存,TTL 3min - KG cache (Python) - 256 entries, 5min TTL - Embedding cache (Python) - 512 entries, 10min TTL
This commit is contained in:
181
runtime/datamate-python/app/module/kg_graphrag/cache.py
Normal file
181
runtime/datamate-python/app/module/kg_graphrag/cache.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""GraphRAG 检索缓存。
|
||||
|
||||
使用 cachetools 的 TTLCache 为 KG 服务响应和 embedding 向量
|
||||
提供内存级 LRU + TTL 缓存,减少重复网络调用。
|
||||
|
||||
缓存策略:
|
||||
- KG 全文搜索结果:TTL 5 分钟,最多 256 条
|
||||
- KG 子图导出结果:TTL 5 分钟,最多 256 条
|
||||
- Embedding 向量:TTL 10 分钟,最多 512 条(embedding 计算成本高)
|
||||
|
||||
写操作由 Java 侧负责,Python 只读,因此不需要写后失效机制。
|
||||
TTL 到期后自然过期,保证最终一致性。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from cachetools import TTLCache
|
||||
|
||||
from app.core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats:
|
||||
"""缓存命中统计。"""
|
||||
|
||||
hits: int = 0
|
||||
misses: int = 0
|
||||
evictions: int = 0
|
||||
|
||||
@property
|
||||
def hit_rate(self) -> float:
|
||||
total = self.hits + self.misses
|
||||
return self.hits / total if total > 0 else 0.0
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"hits": self.hits,
|
||||
"misses": self.misses,
|
||||
"evictions": self.evictions,
|
||||
"hit_rate": round(self.hit_rate, 4),
|
||||
}
|
||||
|
||||
|
||||
class GraphRAGCache:
|
||||
"""GraphRAG 检索结果缓存。
|
||||
|
||||
线程安全:内部使用 threading.Lock 保护 TTLCache。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
kg_maxsize: int = 256,
|
||||
kg_ttl: int = 300,
|
||||
embedding_maxsize: int = 512,
|
||||
embedding_ttl: int = 600,
|
||||
) -> None:
|
||||
self._kg_cache: TTLCache = TTLCache(maxsize=kg_maxsize, ttl=kg_ttl)
|
||||
self._embedding_cache: TTLCache = TTLCache(maxsize=embedding_maxsize, ttl=embedding_ttl)
|
||||
self._kg_lock = threading.Lock()
|
||||
self._embedding_lock = threading.Lock()
|
||||
self._kg_stats = CacheStats()
|
||||
self._embedding_stats = CacheStats()
|
||||
|
||||
@classmethod
|
||||
def from_settings(cls) -> GraphRAGCache:
|
||||
from app.core.config import settings
|
||||
|
||||
if not settings.graphrag_cache_enabled:
|
||||
# 返回一个 maxsize=0 的缓存,所有 get 都会 miss,set 都是 no-op
|
||||
return cls(kg_maxsize=0, kg_ttl=1, embedding_maxsize=0, embedding_ttl=1)
|
||||
|
||||
return cls(
|
||||
kg_maxsize=settings.graphrag_cache_kg_maxsize,
|
||||
kg_ttl=settings.graphrag_cache_kg_ttl,
|
||||
embedding_maxsize=settings.graphrag_cache_embedding_maxsize,
|
||||
embedding_ttl=settings.graphrag_cache_embedding_ttl,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# KG 缓存(全文搜索 + 子图导出)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_kg(self, key: str) -> Any | None:
|
||||
"""查找 KG 缓存。返回 None 表示 miss。"""
|
||||
with self._kg_lock:
|
||||
val = self._kg_cache.get(key)
|
||||
if val is not None:
|
||||
self._kg_stats.hits += 1
|
||||
return val
|
||||
self._kg_stats.misses += 1
|
||||
return None
|
||||
|
||||
def set_kg(self, key: str, value: Any) -> None:
|
||||
"""写入 KG 缓存。"""
|
||||
if self._kg_cache.maxsize <= 0:
|
||||
return
|
||||
with self._kg_lock:
|
||||
self._kg_cache[key] = value
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Embedding 缓存
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_embedding(self, key: str) -> list[float] | None:
|
||||
"""查找 embedding 缓存。返回 None 表示 miss。"""
|
||||
with self._embedding_lock:
|
||||
val = self._embedding_cache.get(key)
|
||||
if val is not None:
|
||||
self._embedding_stats.hits += 1
|
||||
return val
|
||||
self._embedding_stats.misses += 1
|
||||
return None
|
||||
|
||||
def set_embedding(self, key: str, value: list[float]) -> None:
|
||||
"""写入 embedding 缓存。"""
|
||||
if self._embedding_cache.maxsize <= 0:
|
||||
return
|
||||
with self._embedding_lock:
|
||||
self._embedding_cache[key] = value
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 统计 & 管理
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def stats(self) -> dict[str, Any]:
|
||||
"""返回所有缓存区域的统计信息。"""
|
||||
with self._kg_lock:
|
||||
kg_size = len(self._kg_cache)
|
||||
with self._embedding_lock:
|
||||
emb_size = len(self._embedding_cache)
|
||||
return {
|
||||
"kg": {
|
||||
**self._kg_stats.to_dict(),
|
||||
"size": kg_size,
|
||||
"maxsize": self._kg_cache.maxsize,
|
||||
},
|
||||
"embedding": {
|
||||
**self._embedding_stats.to_dict(),
|
||||
"size": emb_size,
|
||||
"maxsize": self._embedding_cache.maxsize,
|
||||
},
|
||||
}
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空所有缓存。"""
|
||||
with self._kg_lock:
|
||||
self._kg_cache.clear()
|
||||
with self._embedding_lock:
|
||||
self._embedding_cache.clear()
|
||||
logger.info("GraphRAG cache cleared")
|
||||
|
||||
|
||||
def make_cache_key(*args: Any) -> str:
|
||||
"""从任意参数生成稳定的缓存 key。
|
||||
|
||||
对参数进行 JSON 序列化后取 SHA-256 摘要,
|
||||
确保 key 长度固定且不含特殊字符。
|
||||
"""
|
||||
raw = json.dumps(args, sort_keys=True, ensure_ascii=False, default=str)
|
||||
return hashlib.sha256(raw.encode()).hexdigest()
|
||||
|
||||
|
||||
# 全局单例(延迟初始化)
|
||||
_cache: GraphRAGCache | None = None
|
||||
|
||||
|
||||
def get_cache() -> GraphRAGCache:
|
||||
"""获取全局缓存单例。"""
|
||||
global _cache
|
||||
if _cache is None:
|
||||
_cache = GraphRAGCache.from_settings()
|
||||
return _cache
|
||||
@@ -247,3 +247,33 @@ async def query_stream(
|
||||
yield f"data: {json.dumps({'error': '生成服务暂不可用'})}\n\n"
|
||||
|
||||
return StreamingResponse(event_stream(), media_type="text/event-stream")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 缓存管理
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get(
|
||||
"/cache/stats",
|
||||
response_model=StandardResponse[dict],
|
||||
summary="缓存统计",
|
||||
description="返回 GraphRAG 检索缓存的命中率和容量统计。",
|
||||
)
|
||||
async def cache_stats():
|
||||
from app.module.kg_graphrag.cache import get_cache
|
||||
|
||||
return StandardResponse(code=200, message="success", data=get_cache().stats())
|
||||
|
||||
|
||||
@router.post(
|
||||
"/cache/clear",
|
||||
response_model=StandardResponse[dict],
|
||||
summary="清空缓存",
|
||||
description="清空所有 GraphRAG 检索缓存。",
|
||||
)
|
||||
async def cache_clear():
|
||||
from app.module.kg_graphrag.cache import get_cache
|
||||
|
||||
get_cache().clear()
|
||||
return StandardResponse(code=200, message="success", data={"cleared": True})
|
||||
|
||||
@@ -11,6 +11,7 @@ from __future__ import annotations
|
||||
import httpx
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.module.kg_graphrag.cache import get_cache, make_cache_key
|
||||
from app.module.kg_graphrag.models import EntitySummary, RelationSummary
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -67,9 +68,17 @@ class KGServiceClient:
|
||||
"""调用 KG 服务全文检索,返回匹配的实体列表。
|
||||
|
||||
Fail-open: KG 服务不可用时返回空列表。
|
||||
结果会被缓存(TTL 由 graphrag_cache_kg_ttl 控制)。
|
||||
"""
|
||||
cache = get_cache()
|
||||
cache_key = make_cache_key("fulltext", graph_id, query, size, user_id)
|
||||
cached = cache.get_kg(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
try:
|
||||
return await self._fulltext_search_impl(graph_id, query, size, user_id)
|
||||
result = await self._fulltext_search_impl(graph_id, query, size, user_id)
|
||||
cache.set_kg(cache_key, result)
|
||||
return result
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"KG fulltext search failed for graph_id=%s (fail-open, returning empty)",
|
||||
@@ -123,9 +132,17 @@ class KGServiceClient:
|
||||
"""获取种子实体的 N-hop 子图。
|
||||
|
||||
Fail-open: KG 服务不可用时返回空子图。
|
||||
结果会被缓存(TTL 由 graphrag_cache_kg_ttl 控制)。
|
||||
"""
|
||||
cache = get_cache()
|
||||
cache_key = make_cache_key("subgraph", graph_id, sorted(entity_ids), depth, user_id)
|
||||
cached = cache.get_kg(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
try:
|
||||
return await self._get_subgraph_impl(graph_id, entity_ids, depth, user_id)
|
||||
result = await self._get_subgraph_impl(graph_id, entity_ids, depth, user_id)
|
||||
cache.set_kg(cache_key, result)
|
||||
return result
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"KG subgraph export failed for graph_id=%s (fail-open, returning empty)",
|
||||
|
||||
183
runtime/datamate-python/app/module/kg_graphrag/test_cache.py
Normal file
183
runtime/datamate-python/app/module/kg_graphrag/test_cache.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""GraphRAG 缓存的单元测试。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
from app.module.kg_graphrag.cache import CacheStats, GraphRAGCache, make_cache_key
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CacheStats
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCacheStats:
|
||||
"""CacheStats 统计逻辑测试。"""
|
||||
|
||||
def test_hit_rate_no_access(self):
|
||||
stats = CacheStats()
|
||||
assert stats.hit_rate == 0.0
|
||||
|
||||
def test_hit_rate_all_hits(self):
|
||||
stats = CacheStats(hits=10, misses=0)
|
||||
assert stats.hit_rate == 1.0
|
||||
|
||||
def test_hit_rate_mixed(self):
|
||||
stats = CacheStats(hits=3, misses=7)
|
||||
assert abs(stats.hit_rate - 0.3) < 1e-9
|
||||
|
||||
def test_to_dict_contains_all_fields(self):
|
||||
stats = CacheStats(hits=5, misses=3, evictions=1)
|
||||
d = stats.to_dict()
|
||||
assert d["hits"] == 5
|
||||
assert d["misses"] == 3
|
||||
assert d["evictions"] == 1
|
||||
assert "hit_rate" in d
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GraphRAGCache — KG 缓存
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestKGCache:
|
||||
"""KG 缓存(全文搜索 + 子图导出)测试。"""
|
||||
|
||||
def test_get_miss_returns_none(self):
|
||||
cache = GraphRAGCache(kg_maxsize=10, kg_ttl=60)
|
||||
assert cache.get_kg("nonexistent") is None
|
||||
|
||||
def test_set_then_get_hit(self):
|
||||
cache = GraphRAGCache(kg_maxsize=10, kg_ttl=60)
|
||||
cache.set_kg("key1", {"entities": [1, 2, 3]})
|
||||
result = cache.get_kg("key1")
|
||||
assert result == {"entities": [1, 2, 3]}
|
||||
|
||||
def test_stats_count_hits_and_misses(self):
|
||||
cache = GraphRAGCache(kg_maxsize=10, kg_ttl=60)
|
||||
cache.set_kg("a", "value-a")
|
||||
|
||||
cache.get_kg("a") # hit
|
||||
cache.get_kg("a") # hit
|
||||
cache.get_kg("b") # miss
|
||||
|
||||
stats = cache.stats()
|
||||
assert stats["kg"]["hits"] == 2
|
||||
assert stats["kg"]["misses"] == 1
|
||||
|
||||
def test_maxsize_evicts_oldest(self):
|
||||
cache = GraphRAGCache(kg_maxsize=2, kg_ttl=60)
|
||||
cache.set_kg("a", 1)
|
||||
cache.set_kg("b", 2)
|
||||
cache.set_kg("c", 3) # should evict "a"
|
||||
|
||||
assert cache.get_kg("a") is None
|
||||
assert cache.get_kg("c") == 3
|
||||
|
||||
def test_ttl_expiry(self):
|
||||
cache = GraphRAGCache(kg_maxsize=10, kg_ttl=1)
|
||||
cache.set_kg("ephemeral", "data")
|
||||
assert cache.get_kg("ephemeral") == "data"
|
||||
|
||||
time.sleep(1.1)
|
||||
assert cache.get_kg("ephemeral") is None
|
||||
|
||||
def test_clear_removes_all(self):
|
||||
cache = GraphRAGCache(kg_maxsize=10, kg_ttl=60)
|
||||
cache.set_kg("x", 1)
|
||||
cache.set_kg("y", 2)
|
||||
cache.clear()
|
||||
|
||||
assert cache.get_kg("x") is None
|
||||
assert cache.get_kg("y") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GraphRAGCache — Embedding 缓存
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmbeddingCache:
|
||||
"""Embedding 向量缓存测试。"""
|
||||
|
||||
def test_get_miss_returns_none(self):
|
||||
cache = GraphRAGCache(embedding_maxsize=10, embedding_ttl=60)
|
||||
assert cache.get_embedding("query-1") is None
|
||||
|
||||
def test_set_then_get_hit(self):
|
||||
cache = GraphRAGCache(embedding_maxsize=10, embedding_ttl=60)
|
||||
vec = [0.1, 0.2, 0.3, 0.4]
|
||||
cache.set_embedding("query-1", vec)
|
||||
assert cache.get_embedding("query-1") == vec
|
||||
|
||||
def test_stats_count_hits_and_misses(self):
|
||||
cache = GraphRAGCache(embedding_maxsize=10, embedding_ttl=60)
|
||||
cache.set_embedding("q1", [1.0])
|
||||
cache.get_embedding("q1") # hit
|
||||
cache.get_embedding("q2") # miss
|
||||
|
||||
stats = cache.stats()
|
||||
assert stats["embedding"]["hits"] == 1
|
||||
assert stats["embedding"]["misses"] == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GraphRAGCache — 整体统计
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCacheOverallStats:
|
||||
"""缓存整体统计测试。"""
|
||||
|
||||
def test_stats_structure(self):
|
||||
cache = GraphRAGCache(kg_maxsize=5, kg_ttl=60, embedding_maxsize=10, embedding_ttl=60)
|
||||
stats = cache.stats()
|
||||
|
||||
assert "kg" in stats
|
||||
assert "embedding" in stats
|
||||
assert "size" in stats["kg"]
|
||||
assert "maxsize" in stats["kg"]
|
||||
assert "hits" in stats["kg"]
|
||||
assert "misses" in stats["kg"]
|
||||
|
||||
def test_zero_maxsize_disables_caching(self):
|
||||
"""maxsize=0 时,所有 set 都是 no-op。"""
|
||||
cache = GraphRAGCache(kg_maxsize=0, kg_ttl=1, embedding_maxsize=0, embedding_ttl=1)
|
||||
cache.set_kg("key", "value")
|
||||
assert cache.get_kg("key") is None
|
||||
|
||||
cache.set_embedding("key", [1.0])
|
||||
assert cache.get_embedding("key") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# make_cache_key
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMakeCacheKey:
|
||||
"""缓存 key 生成测试。"""
|
||||
|
||||
def test_deterministic(self):
|
||||
key1 = make_cache_key("fulltext", "graph-1", "hello", 10)
|
||||
key2 = make_cache_key("fulltext", "graph-1", "hello", 10)
|
||||
assert key1 == key2
|
||||
|
||||
def test_different_args_different_keys(self):
|
||||
key1 = make_cache_key("fulltext", "graph-1", "hello", 10)
|
||||
key2 = make_cache_key("fulltext", "graph-1", "world", 10)
|
||||
assert key1 != key2
|
||||
|
||||
def test_order_matters(self):
|
||||
key1 = make_cache_key("a", "b")
|
||||
key2 = make_cache_key("b", "a")
|
||||
assert key1 != key2
|
||||
|
||||
def test_handles_unicode(self):
|
||||
key = make_cache_key("用户行为数据", "图谱")
|
||||
assert len(key) == 64 # SHA-256 hex digest
|
||||
|
||||
def test_handles_list_args(self):
|
||||
key = make_cache_key("subgraph", ["id-1", "id-2"], 2)
|
||||
assert len(key) == 64
|
||||
@@ -8,6 +8,7 @@ from unittest.mock import AsyncMock, patch
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from app.module.kg_graphrag.cache import GraphRAGCache
|
||||
from app.module.kg_graphrag.kg_client import KGServiceClient
|
||||
|
||||
|
||||
@@ -20,6 +21,14 @@ def client() -> KGServiceClient:
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _disable_cache():
|
||||
"""为每个测试禁用缓存,防止跨测试缓存命中干扰 mock 验证。"""
|
||||
disabled = GraphRAGCache(kg_maxsize=0, kg_ttl=1, embedding_maxsize=0, embedding_ttl=1)
|
||||
with patch("app.module.kg_graphrag.kg_client.get_cache", return_value=disabled):
|
||||
yield
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user