File size: 6,510 Bytes
a4b70d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
from __future__ import annotations

try:
    from curl_cffi.requests import AsyncSession, Response
    has_curl_cffi = True
except ImportError:
    # Fallback for systems where curl_cffi is not available or causes illegal instruction errors
    from typing import Any
    class AsyncSession:
        def __init__(self, *args, **kwargs):
            raise ImportError("curl_cffi is not available on this platform")
    class Response:
        pass
    has_curl_cffi = False

if has_curl_cffi:
    try:
        from curl_cffi import CurlMime
        has_curl_mime = True
    except ImportError:
        has_curl_mime = False
    try:
        from curl_cffi import CurlWsFlag
        has_curl_ws = True
    except ImportError:
        has_curl_ws = False
else:
    has_curl_mime = False
    has_curl_ws = False
from typing import AsyncGenerator, Any
from functools import partialmethod
import json

if has_curl_cffi:
    class StreamResponse:
        """
        A wrapper class for handling asynchronous streaming responses.

        Attributes:
            inner (Response): The original Response object.
        """

        def __init__(self, inner: Response) -> None:
            """Initialize the StreamResponse with the provided Response object."""
            self.inner: Response = inner

        async def text(self) -> str:
            """Asynchronously get the response text."""
            return await self.inner.atext()

        def raise_for_status(self) -> None:
            """Raise an HTTPError if one occurred."""
            self.inner.raise_for_status()

        async def json(self, **kwargs) -> Any:
            """Asynchronously parse the JSON response content."""
            return json.loads(await self.inner.acontent(), **kwargs)

        def iter_lines(self) -> AsyncGenerator[bytes, None]:
            """Asynchronously iterate over the lines of the response."""
            return  self.inner.aiter_lines()

        def iter_content(self) -> AsyncGenerator[bytes, None]:
            """Asynchronously iterate over the response content."""
            return self.inner.aiter_content()

        async def sse(self) -> AsyncGenerator[dict, None]:
            """Asynchronously iterate over the Server-Sent Events of the response."""
            async for line in self.iter_lines():
                if line.startswith(b"data: "):
                    chunk = line[6:]
                    if chunk == b"[DONE]":
                        break
                    try:
                        yield json.loads(chunk)
                    except json.JSONDecodeError:
                        continue

        async def __aenter__(self):
            """Asynchronously enter the runtime context for the response object."""
            inner: Response = await self.inner
            self.inner = inner
            self.url = inner.url
            self.method = inner.request.method
            self.request = inner.request
            self.status: int = inner.status_code
            self.reason: str = inner.reason
            self.ok: bool = inner.ok
            self.headers = inner.headers
            self.cookies = inner.cookies
            return self

        async def __aexit__(self, *args):
            """Asynchronously exit the runtime context for the response object."""
            await self.inner.aclose()

    class StreamSession(AsyncSession):
        """
        An asynchronous session class for handling HTTP requests with streaming.

        Inherits from AsyncSession.
        """

        def request(
            self, method: str, url: str, ssl = None, **kwargs
        ) -> StreamResponse:
            if has_curl_mime and kwargs.get("data") and isinstance(kwargs.get("data"), CurlMime):
                kwargs["multipart"] = kwargs.pop("data")
            """Create and return a StreamResponse object for the given HTTP request."""
            return StreamResponse(super().request(method, url, stream=True, verify=ssl, **kwargs))

        def ws_connect(self, url, *args, **kwargs):
            return WebSocket(self, url, **kwargs)

        def _ws_connect(self, url, **kwargs):
            return super().ws_connect(url, **kwargs)

        # Defining HTTP methods as partial methods of the request method.
        head = partialmethod(request, "HEAD")
        get = partialmethod(request, "GET")
        post = partialmethod(request, "POST")
        put = partialmethod(request, "PUT")
        patch = partialmethod(request, "PATCH")
        delete = partialmethod(request, "DELETE")
        options = partialmethod(request, "OPTIONS")

else:
    # Fallback classes when curl_cffi is not available
    class StreamResponse:
        def __init__(self, *args, **kwargs):
            raise ImportError("curl_cffi is not available on this platform")

    class StreamSession:
        def __init__(self, *args, **kwargs):
            raise ImportError("curl_cffi is not available on this platform")

if has_curl_cffi and has_curl_mime:
    class FormData(CurlMime):
        def add_field(self, name, data=None, content_type: str = None, filename: str = None) -> None:
            self.addpart(name, content_type=content_type, filename=filename, data=data)
else:
    class FormData():
        def __init__(self) -> None:
            raise RuntimeError("curl_cffi FormData is not available on this platform")

if has_curl_cffi and has_curl_ws:
    class WebSocket():
        def __init__(self, session, url, **kwargs) -> None:
            self.session: StreamSession = session
            self.url: str = url
            del kwargs["autoping"]
            self.options: dict = kwargs

        async def __aenter__(self):
            self.inner = await self.session._ws_connect(self.url, **self.options)
            return self

        async def __aexit__(self, *args):
            await self.inner.aclose() if hasattr(self.inner, "aclose") else await self.inner.close()

        async def receive_str(self, **kwargs) -> str:
            method = self.inner.arecv if hasattr(self.inner, "arecv") else self.inner.recv
            bytes, _ = await method()
            return bytes.decode(errors="ignore")

        async def send_str(self, data: str):
            method = self.inner.asend if hasattr(self.inner, "asend") else self.inner.send
            await method(data.encode(), CurlWsFlag.TEXT)
else:
    class WebSocket():
        def __init__(self, *args, **kwargs) -> None:
            raise RuntimeError("curl_cffi WebSocket is not available on this platform")