"""GraphRAG 检索缓存。 使用 cachetools 的 TTLCache 为 KG 服务响应和 embedding 向量 提供内存级 LRU + TTL 缓存,减少重复网络调用。 缓存策略: - KG 全文搜索结果:TTL 5 分钟,最多 256 条 - KG 子图导出结果:TTL 5 分钟,最多 256 条 - Embedding 向量:TTL 10 分钟,最多 512 条(embedding 计算成本高) 写操作由 Java 侧负责,Python 只读,因此不需要写后失效机制。 TTL 到期后自然过期,保证最终一致性。 """ from __future__ import annotations import hashlib import json import threading from dataclasses import dataclass, field from typing import Any from cachetools import TTLCache from app.core.logging import get_logger logger = get_logger(__name__) @dataclass class CacheStats: """缓存命中统计。""" hits: int = 0 misses: int = 0 evictions: int = 0 @property def hit_rate(self) -> float: total = self.hits + self.misses return self.hits / total if total > 0 else 0.0 def to_dict(self) -> dict[str, Any]: return { "hits": self.hits, "misses": self.misses, "evictions": self.evictions, "hit_rate": round(self.hit_rate, 4), } class GraphRAGCache: """GraphRAG 检索结果缓存。 线程安全:内部使用 threading.Lock 保护 TTLCache。 """ def __init__( self, *, kg_maxsize: int = 256, kg_ttl: int = 300, embedding_maxsize: int = 512, embedding_ttl: int = 600, ) -> None: self._kg_cache: TTLCache = TTLCache(maxsize=kg_maxsize, ttl=kg_ttl) self._embedding_cache: TTLCache = TTLCache(maxsize=embedding_maxsize, ttl=embedding_ttl) self._kg_lock = threading.Lock() self._embedding_lock = threading.Lock() self._kg_stats = CacheStats() self._embedding_stats = CacheStats() @classmethod def from_settings(cls) -> GraphRAGCache: from app.core.config import settings if not settings.graphrag_cache_enabled: # 返回一个 maxsize=0 的缓存,所有 get 都会 miss,set 都是 no-op return cls(kg_maxsize=0, kg_ttl=1, embedding_maxsize=0, embedding_ttl=1) return cls( kg_maxsize=settings.graphrag_cache_kg_maxsize, kg_ttl=settings.graphrag_cache_kg_ttl, embedding_maxsize=settings.graphrag_cache_embedding_maxsize, embedding_ttl=settings.graphrag_cache_embedding_ttl, ) # ------------------------------------------------------------------ # KG 缓存(全文搜索 + 子图导出) # ------------------------------------------------------------------ def get_kg(self, key: str) -> Any | None: """查找 KG 缓存。返回 None 表示 miss。""" with self._kg_lock: val = self._kg_cache.get(key) if val is not None: self._kg_stats.hits += 1 return val self._kg_stats.misses += 1 return None def set_kg(self, key: str, value: Any) -> None: """写入 KG 缓存。""" if self._kg_cache.maxsize <= 0: return with self._kg_lock: self._kg_cache[key] = value # ------------------------------------------------------------------ # Embedding 缓存 # ------------------------------------------------------------------ def get_embedding(self, key: str) -> list[float] | None: """查找 embedding 缓存。返回 None 表示 miss。""" with self._embedding_lock: val = self._embedding_cache.get(key) if val is not None: self._embedding_stats.hits += 1 return val self._embedding_stats.misses += 1 return None def set_embedding(self, key: str, value: list[float]) -> None: """写入 embedding 缓存。""" if self._embedding_cache.maxsize <= 0: return with self._embedding_lock: self._embedding_cache[key] = value # ------------------------------------------------------------------ # 统计 & 管理 # ------------------------------------------------------------------ def stats(self) -> dict[str, Any]: """返回所有缓存区域的统计信息。""" with self._kg_lock: kg_size = len(self._kg_cache) with self._embedding_lock: emb_size = len(self._embedding_cache) return { "kg": { **self._kg_stats.to_dict(), "size": kg_size, "maxsize": self._kg_cache.maxsize, }, "embedding": { **self._embedding_stats.to_dict(), "size": emb_size, "maxsize": self._embedding_cache.maxsize, }, } def clear(self) -> None: """清空所有缓存。""" with self._kg_lock: self._kg_cache.clear() with self._embedding_lock: self._embedding_cache.clear() logger.info("GraphRAG cache cleared") def make_cache_key(*args: Any) -> str: """从任意参数生成稳定的缓存 key。 对参数进行 JSON 序列化后取 SHA-256 摘要, 确保 key 长度固定且不含特殊字符。 """ raw = json.dumps(args, sort_keys=True, ensure_ascii=False, default=str) return hashlib.sha256(raw.encode()).hexdigest() # 全局单例(延迟初始化) _cache: GraphRAGCache | None = None def get_cache() -> GraphRAGCache: """获取全局缓存单例。""" global _cache if _cache is None: _cache = GraphRAGCache.from_settings() return _cache