| """
|
| Redis Cache Implementation for Production
|
| """
|
|
|
| import json
|
| import hashlib
|
| from typing import Any, Optional, Union
|
| from datetime import timedelta
|
| import redis.asyncio as aioredis
|
|
|
| from src.core.config import settings
|
| from src.core.logging import logger
|
| from src.core.exceptions import CacheError
|
|
|
|
|
| class RedisCache:
|
| """Redis cache manager with async support"""
|
|
|
| def __init__(self):
|
| self.redis: Optional[aioredis.Redis] = None
|
| self.enabled = settings.CACHE_PREDICTIONS
|
|
|
| async def connect(self):
|
| """Connect to Redis"""
|
| if not self.enabled:
|
| logger.info("Redis cache is disabled")
|
| return
|
|
|
| try:
|
| self.redis = await aioredis.from_url(
|
| settings.REDIS_URL,
|
| encoding="utf-8",
|
| decode_responses=True,
|
| max_connections=50
|
| )
|
|
|
| await self.redis.ping()
|
| logger.info(f"Connected to Redis at {settings.REDIS_HOST}:{settings.REDIS_PORT}")
|
| except Exception as e:
|
| logger.error(f"Failed to connect to Redis: {e}")
|
| self.enabled = False
|
| raise CacheError(f"Redis connection failed: {e}")
|
|
|
| async def disconnect(self):
|
| """Disconnect from Redis"""
|
| if self.redis:
|
| await self.redis.close()
|
| logger.info("Disconnected from Redis")
|
|
|
| def _generate_cache_key(self, prefix: str, data: Union[str, dict]) -> str:
|
| """Generate cache key from data"""
|
| if isinstance(data, dict):
|
| data_str = json.dumps(data, sort_keys=True)
|
| else:
|
| data_str = str(data)
|
|
|
| hash_value = hashlib.sha256(data_str.encode()).hexdigest()[:16]
|
| return f"{prefix}:{hash_value}"
|
|
|
| async def get(self, key: str) -> Optional[Any]:
|
| """Get value from cache"""
|
| if not self.enabled or not self.redis:
|
| return None
|
|
|
| try:
|
| value = await self.redis.get(key)
|
| if value:
|
| logger.debug(f"Cache hit: {key}")
|
| return json.loads(value)
|
| logger.debug(f"Cache miss: {key}")
|
| return None
|
| except Exception as e:
|
| logger.warning(f"Cache get error for {key}: {e}")
|
| return None
|
|
|
| async def set(
|
| self,
|
| key: str,
|
| value: Any,
|
| ttl: Optional[int] = None
|
| ) -> bool:
|
| """Set value in cache with TTL"""
|
| if not self.enabled or not self.redis:
|
| return False
|
|
|
| try:
|
| ttl = ttl or settings.CACHE_TTL
|
| value_json = json.dumps(value)
|
| await self.redis.setex(key, ttl, value_json)
|
| logger.debug(f"Cache set: {key} (TTL: {ttl}s)")
|
| return True
|
| except Exception as e:
|
| logger.warning(f"Cache set error for {key}: {e}")
|
| return False
|
|
|
| async def delete(self, key: str) -> bool:
|
| """Delete key from cache"""
|
| if not self.enabled or not self.redis:
|
| return False
|
|
|
| try:
|
| await self.redis.delete(key)
|
| logger.debug(f"Cache delete: {key}")
|
| return True
|
| except Exception as e:
|
| logger.warning(f"Cache delete error for {key}: {e}")
|
| return False
|
|
|
| async def get_prediction(
|
| self,
|
| model_type: str,
|
| input_data: Union[str, dict]
|
| ) -> Optional[dict]:
|
| """Get cached prediction"""
|
| key = self._generate_cache_key(f"pred:{model_type}", input_data)
|
| return await self.get(key)
|
|
|
| async def set_prediction(
|
| self,
|
| model_type: str,
|
| input_data: Union[str, dict],
|
| result: dict,
|
| ttl: Optional[int] = None
|
| ) -> bool:
|
| """Cache prediction result"""
|
| key = self._generate_cache_key(f"pred:{model_type}", input_data)
|
| return await self.set(key, result, ttl)
|
|
|
| async def increment_rate_limit(
|
| self,
|
| identifier: str,
|
| window_seconds: int
|
| ) -> int:
|
| """Increment rate limit counter"""
|
| if not self.enabled or not self.redis:
|
| return 0
|
|
|
| try:
|
| key = f"ratelimit:{identifier}"
|
| pipe = self.redis.pipeline()
|
| pipe.incr(key)
|
| pipe.expire(key, window_seconds)
|
| result = await pipe.execute()
|
| count = result[0]
|
| logger.debug(f"Rate limit count for {identifier}: {count}")
|
| return count
|
| except Exception as e:
|
| logger.warning(f"Rate limit increment error: {e}")
|
| return 0
|
|
|
| async def get_rate_limit_count(self, identifier: str) -> int:
|
| """Get current rate limit count"""
|
| if not self.enabled or not self.redis:
|
| return 0
|
|
|
| try:
|
| key = f"ratelimit:{identifier}"
|
| count = await self.redis.get(key)
|
| return int(count) if count else 0
|
| except Exception as e:
|
| logger.warning(f"Rate limit get error: {e}")
|
| return 0
|
|
|
| async def clear_all(self) -> bool:
|
| """Clear all cache (use with caution!)"""
|
| if not self.enabled or not self.redis:
|
| return False
|
|
|
| try:
|
| await self.redis.flushdb()
|
| logger.warning("All cache cleared!")
|
| return True
|
| except Exception as e:
|
| logger.error(f"Cache clear error: {e}")
|
| return False
|
|
|
|
|
|
|
| cache = RedisCache()
|
|
|
|
|
|
|
| def cached(prefix: str, ttl: Optional[int] = None):
|
| """Decorator to cache function results"""
|
| def decorator(func):
|
| async def wrapper(*args, **kwargs):
|
|
|
| cache_data = {"args": str(args), "kwargs": str(kwargs)}
|
| cache_key = cache._generate_cache_key(prefix, cache_data)
|
|
|
|
|
| cached_result = await cache.get(cache_key)
|
| if cached_result is not None:
|
| return cached_result
|
|
|
|
|
| result = await func(*args, **kwargs)
|
|
|
|
|
| await cache.set(cache_key, result, ttl)
|
|
|
| return result
|
| return wrapper
|
| return decorator
|
|
|
|
|
| if __name__ == "__main__":
|
| import asyncio
|
|
|
| async def test_cache():
|
|
|
| await cache.connect()
|
|
|
|
|
| await cache.set("test_key", {"value": 123}, ttl=60)
|
| result = await cache.get("test_key")
|
| print(f"Retrieved: {result}")
|
|
|
|
|
| await cache.set_prediction(
|
| "deepfake",
|
| {"image": "test.jpg"},
|
| {"prediction": "FAKE", "confidence": 0.95},
|
| ttl=300
|
| )
|
|
|
| cached_pred = await cache.get_prediction("deepfake", {"image": "test.jpg"})
|
| print(f"Cached prediction: {cached_pred}")
|
|
|
|
|
| for i in range(5):
|
| count = await cache.increment_rate_limit("user:123", 60)
|
| print(f"Request {i+1}: Rate limit count = {count}")
|
|
|
|
|
| await cache.disconnect()
|
|
|
| asyncio.run(test_cache())
|
|
|