diff --git a/backend/open_webui/utils/redis.py b/backend/open_webui/utils/redis.py index 70ae18f115..22a86f4796 100644 --- a/backend/open_webui/utils/redis.py +++ b/backend/open_webui/utils/redis.py @@ -1,6 +1,68 @@ -import socketio +import inspect from urllib.parse import urlparse -from typing import Optional + +import redis + + +MAX_RETRY_COUNT = 2 + +class SentinelRedisProxy: + def __init__(self, sentinel, service, *, async_mode: bool = True, **kw): + self._sentinel = sentinel + self._service = service + self._kw = kw + self._async_mode = async_mode + + def _master(self): + return self._sentinel.master_for(self._service, **self._kw) + + def __getattr__(self, item): + master = self._master() + orig_attr = getattr(master, item) + + if not callable(orig_attr): + return orig_attr + + FACTORY_METHODS = {"pipeline", "pubsub", "monitor", "client", "transaction"} + if item in FACTORY_METHODS: + return orig_attr + + if self._async_mode: + + async def _wrapped(*args, **kwargs): + for i in range(MAX_RETRY_COUNT): + try: + method = getattr(self._master(), item) + result = method(*args, **kwargs) + if inspect.iscoroutine(result): + return await result + return result + except ( + redis.exceptions.ConnectionError, + redis.exceptions.ReadOnlyError, + ) as e: + if i < MAX_RETRY_COUNT - 1: + continue + raise e from e + + return _wrapped + + else: + + def _wrapped(*args, **kwargs): + for i in range(MAX_RETRY_COUNT): + try: + method = getattr(self._master(), item) + return method(*args, **kwargs) + except ( + redis.exceptions.ConnectionError, + redis.exceptions.ReadOnlyError, + ) as e: + if i < MAX_RETRY_COUNT - 1: + continue + raise e from e + + return _wrapped def parse_redis_service_url(redis_url): @@ -34,7 +96,11 @@ def get_redis_connection( password=redis_config["password"], decode_responses=decode_responses, ) - return sentinel.master_for(redis_config["service"]) + return SentinelRedisProxy( + sentinel, + redis_config["service"], + async_mode=async_mode, + ) elif redis_url: return redis.from_url(redis_url, decode_responses=decode_responses) else: @@ -52,7 +118,11 @@ def get_redis_connection( password=redis_config["password"], decode_responses=decode_responses, ) - return sentinel.master_for(redis_config["service"]) + return SentinelRedisProxy( + sentinel, + redis_config["service"], + async_mode=async_mode, + ) elif redis_url: return redis.Redis.from_url(redis_url, decode_responses=decode_responses) else: