Files
DataMate/runtime/datamate-python/app/module/kg_graphrag/test_cache.py
Jerry Yan 9b6ff59a11 feat(kg): 实现 Phase 3.3 性能优化
核心功能:
- Neo4j 索引优化(entityType, graphId, properties.name)
- Redis 缓存(Java 侧,3 个缓存区,TTL 可配置)
- LRU 缓存(Python 侧,KG + Embedding,线程安全)
- 细粒度缓存清除(graphId 前缀匹配)
- 失败路径缓存清除(finally 块)

新增文件(Java 侧,7 个):
- V2__PerformanceIndexes.java - Flyway 迁移,创建 3 个索引
- IndexHealthService.java - 索引健康监控
- RedisCacheConfig.java - Spring Cache + Redis 配置
- GraphCacheService.java - 缓存清除管理器
- CacheableIntegrationTest.java - 集成测试(10 tests)
- GraphCacheServiceTest.java - 单元测试(19 tests)
- V2__PerformanceIndexesTest.java, IndexHealthServiceTest.java

新增文件(Python 侧,2 个):
- cache.py - 内存 TTL+LRU 缓存(cachetools)
- test_cache.py - 单元测试(20 tests)

修改文件(Java 侧,9 个):
- GraphEntityService.java - 添加 @Cacheable,缓存清除
- GraphQueryService.java - 添加 @Cacheable(包含用户权限上下文)
- GraphRelationService.java - 添加缓存清除
- GraphSyncService.java - 添加缓存清除(finally 块,失败路径)
- KnowledgeGraphProperties.java - 添加 Cache 配置类
- application-knowledgegraph.yml - 添加 Redis 和缓存 TTL 配置
- GraphEntityServiceTest.java - 添加 verify(cacheService) 断言
- GraphRelationServiceTest.java - 添加 verify(cacheService) 断言
- GraphSyncServiceTest.java - 添加失败路径缓存清除测试

修改文件(Python 侧,5 个):
- kg_client.py - 集成缓存(fulltext_search, get_subgraph)
- interface.py - 添加 /cache/stats 和 /cache/clear 端点
- config.py - 添加缓存配置字段
- pyproject.toml - 添加 cachetools 依赖
- test_kg_client.py - 添加 _disable_cache fixture

安全修复(3 轮迭代):
- P0: 缓存 key 用户隔离(防止跨用户数据泄露)
- P1-1: 同步子步骤后的缓存清除(18 个方法)
- P1-2: 实体创建后的搜索缓存清除
- P1-3: 失败路径缓存清除(finally 块)
- P2-1: 细粒度缓存清除(graphId 前缀匹配,避免跨图谱冲刷)
- P2-2: 服务层测试添加 verify(cacheService) 断言

测试结果:
- Java: 280 tests pass  (270 → 280, +10 new)
- Python: 154 tests pass  (140 → 154, +14 new)

缓存配置:
- kg:entities - 实体缓存,TTL 1h
- kg:queries - 查询结果缓存,TTL 5min
- kg:search - 全文搜索缓存,TTL 3min
- KG cache (Python) - 256 entries, 5min TTL
- Embedding cache (Python) - 512 entries, 10min TTL
2026-02-20 18:28:33 +08:00

184 lines
5.8 KiB
Python

"""GraphRAG 缓存的单元测试。"""
from __future__ import annotations
import time
from app.module.kg_graphrag.cache import CacheStats, GraphRAGCache, make_cache_key
# ---------------------------------------------------------------------------
# CacheStats
# ---------------------------------------------------------------------------
class TestCacheStats:
"""CacheStats 统计逻辑测试。"""
def test_hit_rate_no_access(self):
stats = CacheStats()
assert stats.hit_rate == 0.0
def test_hit_rate_all_hits(self):
stats = CacheStats(hits=10, misses=0)
assert stats.hit_rate == 1.0
def test_hit_rate_mixed(self):
stats = CacheStats(hits=3, misses=7)
assert abs(stats.hit_rate - 0.3) < 1e-9
def test_to_dict_contains_all_fields(self):
stats = CacheStats(hits=5, misses=3, evictions=1)
d = stats.to_dict()
assert d["hits"] == 5
assert d["misses"] == 3
assert d["evictions"] == 1
assert "hit_rate" in d
# ---------------------------------------------------------------------------
# GraphRAGCache — KG 缓存
# ---------------------------------------------------------------------------
class TestKGCache:
"""KG 缓存(全文搜索 + 子图导出)测试。"""
def test_get_miss_returns_none(self):
cache = GraphRAGCache(kg_maxsize=10, kg_ttl=60)
assert cache.get_kg("nonexistent") is None
def test_set_then_get_hit(self):
cache = GraphRAGCache(kg_maxsize=10, kg_ttl=60)
cache.set_kg("key1", {"entities": [1, 2, 3]})
result = cache.get_kg("key1")
assert result == {"entities": [1, 2, 3]}
def test_stats_count_hits_and_misses(self):
cache = GraphRAGCache(kg_maxsize=10, kg_ttl=60)
cache.set_kg("a", "value-a")
cache.get_kg("a") # hit
cache.get_kg("a") # hit
cache.get_kg("b") # miss
stats = cache.stats()
assert stats["kg"]["hits"] == 2
assert stats["kg"]["misses"] == 1
def test_maxsize_evicts_oldest(self):
cache = GraphRAGCache(kg_maxsize=2, kg_ttl=60)
cache.set_kg("a", 1)
cache.set_kg("b", 2)
cache.set_kg("c", 3) # should evict "a"
assert cache.get_kg("a") is None
assert cache.get_kg("c") == 3
def test_ttl_expiry(self):
cache = GraphRAGCache(kg_maxsize=10, kg_ttl=1)
cache.set_kg("ephemeral", "data")
assert cache.get_kg("ephemeral") == "data"
time.sleep(1.1)
assert cache.get_kg("ephemeral") is None
def test_clear_removes_all(self):
cache = GraphRAGCache(kg_maxsize=10, kg_ttl=60)
cache.set_kg("x", 1)
cache.set_kg("y", 2)
cache.clear()
assert cache.get_kg("x") is None
assert cache.get_kg("y") is None
# ---------------------------------------------------------------------------
# GraphRAGCache — Embedding 缓存
# ---------------------------------------------------------------------------
class TestEmbeddingCache:
"""Embedding 向量缓存测试。"""
def test_get_miss_returns_none(self):
cache = GraphRAGCache(embedding_maxsize=10, embedding_ttl=60)
assert cache.get_embedding("query-1") is None
def test_set_then_get_hit(self):
cache = GraphRAGCache(embedding_maxsize=10, embedding_ttl=60)
vec = [0.1, 0.2, 0.3, 0.4]
cache.set_embedding("query-1", vec)
assert cache.get_embedding("query-1") == vec
def test_stats_count_hits_and_misses(self):
cache = GraphRAGCache(embedding_maxsize=10, embedding_ttl=60)
cache.set_embedding("q1", [1.0])
cache.get_embedding("q1") # hit
cache.get_embedding("q2") # miss
stats = cache.stats()
assert stats["embedding"]["hits"] == 1
assert stats["embedding"]["misses"] == 1
# ---------------------------------------------------------------------------
# GraphRAGCache — 整体统计
# ---------------------------------------------------------------------------
class TestCacheOverallStats:
"""缓存整体统计测试。"""
def test_stats_structure(self):
cache = GraphRAGCache(kg_maxsize=5, kg_ttl=60, embedding_maxsize=10, embedding_ttl=60)
stats = cache.stats()
assert "kg" in stats
assert "embedding" in stats
assert "size" in stats["kg"]
assert "maxsize" in stats["kg"]
assert "hits" in stats["kg"]
assert "misses" in stats["kg"]
def test_zero_maxsize_disables_caching(self):
"""maxsize=0 时,所有 set 都是 no-op。"""
cache = GraphRAGCache(kg_maxsize=0, kg_ttl=1, embedding_maxsize=0, embedding_ttl=1)
cache.set_kg("key", "value")
assert cache.get_kg("key") is None
cache.set_embedding("key", [1.0])
assert cache.get_embedding("key") is None
# ---------------------------------------------------------------------------
# make_cache_key
# ---------------------------------------------------------------------------
class TestMakeCacheKey:
"""缓存 key 生成测试。"""
def test_deterministic(self):
key1 = make_cache_key("fulltext", "graph-1", "hello", 10)
key2 = make_cache_key("fulltext", "graph-1", "hello", 10)
assert key1 == key2
def test_different_args_different_keys(self):
key1 = make_cache_key("fulltext", "graph-1", "hello", 10)
key2 = make_cache_key("fulltext", "graph-1", "world", 10)
assert key1 != key2
def test_order_matters(self):
key1 = make_cache_key("a", "b")
key2 = make_cache_key("b", "a")
assert key1 != key2
def test_handles_unicode(self):
key = make_cache_key("用户行为数据", "图谱")
assert len(key) == 64 # SHA-256 hex digest
def test_handles_list_args(self):
key = make_cache_key("subgraph", ["id-1", "id-2"], 2)
assert len(key) == 64