File size: 4,319 Bytes
7db0ae4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# +-----------------------------------------------+
# |                                               |
# |           Give Feedback / Get Help            |
# | https://github.com/BerriAI/litellm/issues/new |
# |                                               |
# +-----------------------------------------------+
#
#  Thank you users! We ❤️ you! - Krrish & Ishaan

"""
Module containing "timeout" decorator for sync and async callables.
"""

import asyncio

from concurrent import futures
from inspect import iscoroutinefunction
from functools import wraps
from threading import Thread
from litellm.exceptions import Timeout


def timeout(timeout_duration: float = 0.0, exception_to_raise=Timeout):
    """
    Wraps a function to raise the specified exception if execution time
    is greater than the specified timeout.

    Works with both synchronous and asynchronous callables, but with synchronous ones will introduce
    some overhead due to the backend use of threads and asyncio.

        :param float timeout_duration: Timeout duration in seconds. If none callable won't time out.
        :param OpenAIError exception_to_raise: Exception to raise when the callable times out.
            Defaults to TimeoutError.
        :return: The decorated function.
        :rtype: callable
    """

    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            async def async_func():
                return func(*args, **kwargs)

            thread = _LoopWrapper()
            thread.start()
            future = asyncio.run_coroutine_threadsafe(async_func(), thread.loop)
            local_timeout_duration = timeout_duration
            if "force_timeout" in kwargs and kwargs["force_timeout"] is not None:
                local_timeout_duration = kwargs["force_timeout"]
            elif "request_timeout" in kwargs and kwargs["request_timeout"] is not None:
                local_timeout_duration = kwargs["request_timeout"]
            try:
                result = future.result(timeout=local_timeout_duration)
            except futures.TimeoutError:
                thread.stop_loop()
                model = args[0] if len(args) > 0 else kwargs["model"]
                raise exception_to_raise(
                    f"A timeout error occurred. The function call took longer than {local_timeout_duration} second(s).",
                    model=model,  # [TODO]: replace with logic for parsing out llm provider from model name
                    llm_provider="openai",
                )
            thread.stop_loop()
            return result

        @wraps(func)
        async def async_wrapper(*args, **kwargs):
            local_timeout_duration = timeout_duration
            if "force_timeout" in kwargs:
                local_timeout_duration = kwargs["force_timeout"]
            elif "request_timeout" in kwargs and kwargs["request_timeout"] is not None:
                local_timeout_duration = kwargs["request_timeout"]
            try:
                value = await asyncio.wait_for(
                    func(*args, **kwargs), timeout=timeout_duration
                )
                return value
            except asyncio.TimeoutError:
                model = args[0] if len(args) > 0 else kwargs["model"]
                raise exception_to_raise(
                    f"A timeout error occurred. The function call took longer than {local_timeout_duration} second(s).",
                    model=model,  # [TODO]: replace with logic for parsing out llm provider from model name
                    llm_provider="openai",
                )

        if iscoroutinefunction(func):
            return async_wrapper
        return wrapper

    return decorator


class _LoopWrapper(Thread):
    def __init__(self):
        super().__init__(daemon=True)
        self.loop = asyncio.new_event_loop()

    def run(self) -> None:
        try:
            self.loop.run_forever()
            self.loop.call_soon_threadsafe(self.loop.close)
        except Exception as e:
            # Log exception here
            pass
        finally:
            self.loop.close()
            asyncio.set_event_loop(None)

    def stop_loop(self):
        for task in asyncio.all_tasks(self.loop):
            task.cancel()
        self.loop.call_soon_threadsafe(self.loop.stop)