"""GraphRAG 缓存的单元测试。""" from __future__ import annotations import time from app.module.kg_graphrag.cache import CacheStats, GraphRAGCache, make_cache_key # --------------------------------------------------------------------------- # CacheStats # --------------------------------------------------------------------------- class TestCacheStats: """CacheStats 统计逻辑测试。""" def test_hit_rate_no_access(self): stats = CacheStats() assert stats.hit_rate == 0.0 def test_hit_rate_all_hits(self): stats = CacheStats(hits=10, misses=0) assert stats.hit_rate == 1.0 def test_hit_rate_mixed(self): stats = CacheStats(hits=3, misses=7) assert abs(stats.hit_rate - 0.3) < 1e-9 def test_to_dict_contains_all_fields(self): stats = CacheStats(hits=5, misses=3, evictions=1) d = stats.to_dict() assert d["hits"] == 5 assert d["misses"] == 3 assert d["evictions"] == 1 assert "hit_rate" in d # --------------------------------------------------------------------------- # GraphRAGCache — KG 缓存 # --------------------------------------------------------------------------- class TestKGCache: """KG 缓存(全文搜索 + 子图导出)测试。""" def test_get_miss_returns_none(self): cache = GraphRAGCache(kg_maxsize=10, kg_ttl=60) assert cache.get_kg("nonexistent") is None def test_set_then_get_hit(self): cache = GraphRAGCache(kg_maxsize=10, kg_ttl=60) cache.set_kg("key1", {"entities": [1, 2, 3]}) result = cache.get_kg("key1") assert result == {"entities": [1, 2, 3]} def test_stats_count_hits_and_misses(self): cache = GraphRAGCache(kg_maxsize=10, kg_ttl=60) cache.set_kg("a", "value-a") cache.get_kg("a") # hit cache.get_kg("a") # hit cache.get_kg("b") # miss stats = cache.stats() assert stats["kg"]["hits"] == 2 assert stats["kg"]["misses"] == 1 def test_maxsize_evicts_oldest(self): cache = GraphRAGCache(kg_maxsize=2, kg_ttl=60) cache.set_kg("a", 1) cache.set_kg("b", 2) cache.set_kg("c", 3) # should evict "a" assert cache.get_kg("a") is None assert cache.get_kg("c") == 3 def test_ttl_expiry(self): cache = GraphRAGCache(kg_maxsize=10, kg_ttl=1) cache.set_kg("ephemeral", "data") assert cache.get_kg("ephemeral") == "data" time.sleep(1.1) assert cache.get_kg("ephemeral") is None def test_clear_removes_all(self): cache = GraphRAGCache(kg_maxsize=10, kg_ttl=60) cache.set_kg("x", 1) cache.set_kg("y", 2) cache.clear() assert cache.get_kg("x") is None assert cache.get_kg("y") is None # --------------------------------------------------------------------------- # GraphRAGCache — Embedding 缓存 # --------------------------------------------------------------------------- class TestEmbeddingCache: """Embedding 向量缓存测试。""" def test_get_miss_returns_none(self): cache = GraphRAGCache(embedding_maxsize=10, embedding_ttl=60) assert cache.get_embedding("query-1") is None def test_set_then_get_hit(self): cache = GraphRAGCache(embedding_maxsize=10, embedding_ttl=60) vec = [0.1, 0.2, 0.3, 0.4] cache.set_embedding("query-1", vec) assert cache.get_embedding("query-1") == vec def test_stats_count_hits_and_misses(self): cache = GraphRAGCache(embedding_maxsize=10, embedding_ttl=60) cache.set_embedding("q1", [1.0]) cache.get_embedding("q1") # hit cache.get_embedding("q2") # miss stats = cache.stats() assert stats["embedding"]["hits"] == 1 assert stats["embedding"]["misses"] == 1 # --------------------------------------------------------------------------- # GraphRAGCache — 整体统计 # --------------------------------------------------------------------------- class TestCacheOverallStats: """缓存整体统计测试。""" def test_stats_structure(self): cache = GraphRAGCache(kg_maxsize=5, kg_ttl=60, embedding_maxsize=10, embedding_ttl=60) stats = cache.stats() assert "kg" in stats assert "embedding" in stats assert "size" in stats["kg"] assert "maxsize" in stats["kg"] assert "hits" in stats["kg"] assert "misses" in stats["kg"] def test_zero_maxsize_disables_caching(self): """maxsize=0 时,所有 set 都是 no-op。""" cache = GraphRAGCache(kg_maxsize=0, kg_ttl=1, embedding_maxsize=0, embedding_ttl=1) cache.set_kg("key", "value") assert cache.get_kg("key") is None cache.set_embedding("key", [1.0]) assert cache.get_embedding("key") is None # --------------------------------------------------------------------------- # make_cache_key # --------------------------------------------------------------------------- class TestMakeCacheKey: """缓存 key 生成测试。""" def test_deterministic(self): key1 = make_cache_key("fulltext", "graph-1", "hello", 10) key2 = make_cache_key("fulltext", "graph-1", "hello", 10) assert key1 == key2 def test_different_args_different_keys(self): key1 = make_cache_key("fulltext", "graph-1", "hello", 10) key2 = make_cache_key("fulltext", "graph-1", "world", 10) assert key1 != key2 def test_order_matters(self): key1 = make_cache_key("a", "b") key2 = make_cache_key("b", "a") assert key1 != key2 def test_handles_unicode(self): key = make_cache_key("用户行为数据", "图谱") assert len(key) == 64 # SHA-256 hex digest def test_handles_list_args(self): key = make_cache_key("subgraph", ["id-1", "id-2"], 2) assert len(key) == 64