User1342 commited on
Commit
1cd7d07
1 Parent(s): fdb36dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2165 -317
app.py CHANGED
@@ -1,333 +1,2181 @@
1
- #!/usr/bin/env python
2
- # coding: utf-8
3
- import json
4
- import flask
5
- import os
6
- import re
7
- import time
8
- from random import random
9
  import socket
10
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  from threading import Thread
12
- from time import sleep
13
-
14
- test_html = '''
15
- <!-- Header -->
16
- <header class="w3-display-container w3-content w3-wide" style="max-width:1500px;" id="home">
17
- <img class="w3-image" src="https://cdn.pixabay.com/photo/2018/12/10/16/22/city-3867295_960_720.png" alt="Architecture" width="1500" height="800">
18
- <div class="w3-display-middle w3-margin-top w3-center">
19
- <h1 class="w3-xxlarge w3-text-white"><span class="w3-padding w3-black w3-opacity-min"><b>WATCH</b></span> <span class="w3-hide-small w3-text-dark-grey">Tower</span></h1>
20
- </div>
21
- </header>
22
-
23
- <!-- Container (About Section) -->
24
- <div class="w3-content w3-container w3-padding-64" id="about">
25
- <h3 class="w3-center">Block Violent Content Before It Reaches Your Feed</h3>
26
- <p class="w3-center"><em>WatchTower identifies, blocks, and filters out violent and radical content before it reaches your Twitter feed.
27
- </em></p>
28
- <br>
29
- <p>WatchTower works to protect you from violent, misinformation, hate speech and other malicious communication by using a suite of machine learning models to identify user accounts that post content that commonly falls into these categories. WatchTower is broken down into two components, the first utilises the Twitter streaming API and applies a suite of machine learning models to identify users that commonly post malicious information, while the second element provides a web UI where users can authenticaate with Twitter and tailor the types and thresholds for the accounts they block. </p>
30
- <br>
31
- <p> WatchTower was developed solely by James Stevenson and primarily uses Pinpoint, a machine learning model also developed by James. The future roadmap sees WatchTower incoperate other models for identifying contrent such as misinformation and hate speech. More on Pinpoint and the model WatchTower uses to identify violent extremism can be seen below.</p>
32
-
33
- <p class="w3-large w3-center w3-padding-16">Model Accuracy:</p>
34
- <p class="w3-center"><em>Machine learning models can be validated based on several statistics. These statistics for Pinpoint the main ML model used by WatchTower can be seen below. </p>
35
- <br>
36
- <p class="w3-wide"><i class="fa fa-camera"></i>Accuracy</p>
37
- <div class="w3-light-grey">
38
- <div class="w3-container w3-padding-small w3-dark-grey w3-center" style="width:73%">73%</div>
39
- </div>
40
- <p class="w3-wide"><i class="fa fa-laptop"></i>Recall</p>
41
- <div class="w3-light-grey">
42
- <div class="w3-container w3-padding-small w3-dark-grey w3-center" style="width:62%">62%</div>
43
- </div>
44
- <p class="w3-wide"><i class="fa fa-photo"></i>Precision</p>
45
- <div class="w3-light-grey">
46
- <div class="w3-container w3-padding-small w3-dark-grey w3-center" style="width:78%">78%</div>
47
- </div>
48
- <p class="w3-wide"><i class="fa fa-photo"></i>F-Measure</p>
49
- <div class="w3-light-grey">
50
- <div class="w3-container w3-padding-small w3-dark-grey w3-center" style="width:69%">69%</div>
51
- </div>
52
- </div>
53
-
54
- <div class="w3-row w3-center w3-dark-grey w3-padding-16">
55
- <div class="w3-quarter w3-section">
56
- <span class="w3-xlarge">14+</span><br>
57
- Partners
58
- </div>
59
- <div class="w3-quarter w3-section">
60
- <span class="w3-xlarge">55+</span><br>
61
- Projects Done
62
- </div>
63
- <div class="w3-quarter w3-section">
64
- <span class="w3-xlarge">89+</span><br>
65
- Happy Clients
66
- </div>
67
- <div class="w3-quarter w3-section">
68
- <span class="w3-xlarge">150+</span><br>
69
- Meetings
70
- </div>
71
- </div>
72
- <br>
73
- <!-- Container (Portfolio Section) -->
74
- <div class="w3-content w3-container w3-padding-64" id="portfolio">
75
- <h3 class="w3-center">Chirp Development Challenge 2022</h3>
76
- <p class="w3-center"><em>WatchTower was developed for the Chirp 2022 Twitter API Developer Challenge</em></p>
77
- </div><p> Watchtower was developed solely by James Stevenson for the Chirp 2022 Twitter API Developer Challenge. More infomration of this can be found below.</p>
78
- <br>
79
- <img class="w3-image" src="https://cdn.cms-twdigitalassets.com/content/dam/developer-twitter/redesign-2021-images/blog2022/chirp/Chirp-Hero-Banner.jpg.twimg.1920.jpg" alt="Architecture" width="1500" height="800">
80
- <br>
81
- <!-- Modal for full size images on click-->
82
- <div id="modal01" class="w3-modal w3-black" onclick="this.style.display='none'">
83
- <span class="w3-button w3-large w3-black w3-display-topright" title="Close Modal Image"><i class="fa fa-remove"></i></span>
84
- <div class="w3-modal-content w3-animate-zoom w3-center w3-transparent w3-padding-64">
85
- <img id="img01" class="w3-image">
86
- <p id="caption" class="w3-opacity w3-large"></p>
87
- </div>
88
- </div>
89
-
90
- <script>
91
- // Modal Image Gallery
92
- function onClick(element) {
93
- document.getElementById("img01").src = element.src;
94
- document.getElementById("modal01").style.display = "block";
95
- var captionText = document.getElementById("caption");
96
- captionText.innerHTML = element.alt;
97
- }
98
-
99
- // Change style of navbar on scroll
100
- window.onscroll = function() {myFunction()};
101
- function myFunction() {
102
- var navbar = document.getElementById("myNavbar");
103
- if (document.body.scrollTop > 100 || document.documentElement.scrollTop > 100) {
104
- navbar.className = "w3-bar" + " w3-card" + " w3-animate-top" + " w3-white";
105
- } else {
106
- navbar.className = navbar.className.replace(" w3-card w3-animate-top w3-white", "");
107
- }
108
- }
109
-
110
- // Used to toggle the menu on small screens when clicking on the menu button
111
- function toggleFunction() {
112
- var x = document.getElementById("navDemo");
113
- if (x.className.indexOf("w3-show") == -1) {
114
- x.className += " w3-show";
115
- } else {
116
- x.className = x.className.replace(" w3-show", "");
117
- }
118
- }
119
- </script>
120
-
121
- </body>
122
- </html>
123
-
124
-
125
-
126
- '''
127
-
128
- import gradio as gr
129
- import tweepy
130
- from fastapi import FastAPI, Request
131
-
132
- consumer_token = os.getenv('CONSUMER_TOKEN')
133
- consumer_secret = os.getenv('CONSUMER_SECRET')
134
- my_access_token = os.getenv('ACCESS_TOKEN')
135
- my_access_secret = os.getenv('ACCESS_SECRET')
136
- bearer = os.getenv('BEARER')
137
- client_id = os.getenv('CLIENT_ID')
138
- client_secret = os.getenv('CLIENT_SECRET')
139
-
140
- oauth1_user_handler = tweepy.OAuth2UserHandler(client_id=client_id,
141
- client_secret=client_secret,
142
- redirect_uri="https://hf.space/embed/User1342/WatchTower/",
143
- scope=["block.write"]
144
 
 
 
 
 
 
 
 
 
145
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
- target_website = oauth1_user_handler.get_authorization_url()
148
- print(target_website)
149
-
150
- block = gr.Blocks(css=".container { max-width: 800px; margin: auto; }")
151
-
152
- chat_history = []
153
-
154
-
155
- def get_client_from_tokens():
156
- oauth1_user_handler = tweepy.OAuth2UserHandler(client_id=client_id,
157
- client_secret=client_secret,
158
- redirect_uri="https://hf.space/embed/User1342/WatchTower/",
159
- scope=["block.write"])
160
-
161
- for connection in block.server.server_state.connections:
162
- # connection_app_id = connection.app.app.blocks.app_id
163
- # if active_app_id == connection_app_id:
164
- # print("Its a match")
165
- url = None
166
- if connection.headers != None:
167
- for header in connection.headers:
168
- header = header[1].decode()
169
- if "code" in header:
170
- url = header
171
- print(header)
172
- print("urls {}".format(url))
173
- their_client = None
174
- if url != None:
175
- access_token = oauth1_user_handler.fetch_token(
176
- url
 
 
 
 
 
 
 
 
 
 
177
  )
178
- their_client = tweepy.Client(access_token)
179
 
180
- return their_client
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
- def block_users(client, threshold, dataset):
183
- num_users_blocked = 0
 
 
184
 
185
- for filename in os.listdir("users"):
186
- filename = os.path.join("users", filename)
 
 
 
187
 
188
- user_file = open(filename, "r")
189
- users = json.load(user_file)
 
190
 
191
- for user in users:
192
- if threshold >= user["threshold"]:
 
 
193
 
194
- user_id = str(user["username"])
 
 
 
195
 
196
- finished = False
197
- while not finished:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  try:
199
- client.block(target_user_id=user_id)
200
- except tweepy.errors.TooManyRequests as e:
201
- print(e)
202
- time.sleep(240)
203
- continue
204
- finished = True
205
- me = client.get_me()
206
- print("{} blocked {}".format(me.data["username"], user))
207
- num_users_blocked = num_users_blocked + 1
208
-
209
- return num_users_blocked
210
-
211
- username_populated = False
212
-
213
-
214
- def chat(radio_score=None, selected_option=None):
215
- client = get_client_from_tokens()
216
- history = []
217
-
218
- if radio_score != None and selected_option != None:
219
- response = "no blocking"
220
- if client != None:
221
- chat_history.append(
222
- ["Model tuned to a '{}%' threshold and is using the '{}' dataset.".format(radio_score, selected_option),
223
- "{} Account blocking initialised".format(selected_option.capitalize())])
224
- num_users_blocked = block_users(client, radio_score, selected_option)
225
- chat_history.append(
226
- ["Blocked {} user account(s).".format(num_users_blocked), "Thank you for using Watchtower."])
227
- elif radio_score != None or selected_option != None:
228
- chat_history.append(["Initialisation error!", "Please tune the model by using the above options"])
229
-
230
- return chat_history
231
-
232
-
233
- def infer(prompt):
234
- pass
235
-
236
-
237
- have_initialised = False
238
- client = None
239
- name = None
240
-
241
-
242
- def changed_tab():
243
- global have_initialised
244
- global chatbot
245
- global chat_history
246
- global client
247
- global name
248
-
249
- name = "no username"
250
-
251
- chat_history = [
252
- ["Welcome to Watchtower.".format(name), "Log in via Twitter and configure your blocking options above."]]
253
-
254
- if client != None and name != "no username":
255
- chat_history = [["Welcome {}".format(name), "Initialising WatchTower"]]
256
-
257
- print("changed tabs - {}".format(name))
258
- chatbot.value = chat_history
259
- chatbot.update(value=chat_history)
260
-
261
-
262
- client = get_client_from_tokens()
263
- name = client.get_me().data.name
264
- have_initialised = True
265
- chat_history = [["Welcome {}".format(name), "Initialising WatchTower"]]
266
-
267
- chatbot.value = chat_history
268
- chatbot.update(value=chat_history)
269
-
270
- with block:
271
- gr.HTML('''
272
-
273
- <meta name="viewport" content="width=device-width, initial-scale=1">
274
- <link rel="stylesheet" href="https://www.w3schools.com/w3css/4/w3.css">
275
- <!-- Navbar (sit on top) -->
276
- <div class="w3-top">
277
- <div class="w3-bar w3-white w3-wide w3-padding w3-card">
278
- <p class="w3-bar-item w3-button"><b>WATCH</b> Tower</p>
279
- </div>
280
- </div>
281
- ''')
282
- gr.HTML("<center><p><br></p></center>")
283
-
284
- # todo check if user signed in
285
-
286
- user_message = "Log in via Twitter and configure your blocking options above."
287
-
288
- chat_history.append(["Welcome to Watchtower.", user_message])
289
- tabs = gr.Tabs()
290
- with tabs:
291
- intro_tab = gr.TabItem("Introduction")
292
- with intro_tab:
293
- gr.HTML(test_html)
294
-
295
- prediction_tab = gr.TabItem("Getting Started")
296
- with prediction_tab:
297
- gr.HTML('''
298
- <header class="w3-display-container w3-content w3-wide" style="max-height:250px;" id="home">
299
- <img class="w3-image" src="https://cdn.pixabay.com/photo/2018/12/10/16/22/city-3867295_960_720.png" alt="Architecture" width="1500" height="800">
300
- <div class="w3-display-middle w3-margin-top w3-center">
301
- <h1 class="w3-xxlarge w3-text-white"><span class="w3-padding w3-black w3-opacity-min"><b>WATCH</b></span> <span class="w3-hide-small w3-text-dark-grey">Tower</span></h1>
302
- </div>
303
- </header>
304
- ''')
305
- with gr.Group():
306
- with gr.Box():
307
- with gr.Row().style(mobile_collapse=False, equal_height=True):
308
- gr.HTML(
309
- value='<a href={}><img src="https://cdn.cms-twdigitalassets.com/content/dam/developer-twitter/auth-docs/sign-in-with-twitter-gray.png.twimg.1920.png" alt="Log In With Twitter"></a><br>'.format(
310
- target_website))
311
- with gr.Row().style(mobile_collapse=False, equal_height=True):
312
- radio = gr.CheckboxGroup(value="Violent", choices=["Violent", "Hate Speech", "Misinformation"],
313
- interactive=False, label="Behaviour To Block")
314
-
315
- slider = gr.Slider(value=80, label="Threshold Certainty Tolerance")
316
-
317
- chatbot = gr.Chatbot(value=chat_history, label="Watchtower Output").style()
318
- btn = gr.Button("Run WatchTower").style(full_width=True)
319
-
320
-
321
- btn.click(fn=chat, inputs=[slider, radio], outputs=chatbot)
322
- tabs.change(fn=changed_tab, inputs=None, outputs=None)
323
-
324
- gr.Markdown(
325
- """___
326
- <p style='text-align: center'>
327
- Created by <a href="https://twitter.com/borisdayma" target="_blank"James Stevenson</a> et al. 2021-2022
328
- <br/>
329
- <a href="https://github.com/CartographerLabs/Pinpoint" target="_blank">GitHub</a>
330
- </p>"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
 
333
- block.launch(enable_queue=False)
 
1
+ import array
2
+ import asyncio
3
+ import concurrent.futures
4
+ import math
 
 
 
 
5
  import socket
6
+ import sys
7
+ from asyncio.base_events import _run_until_complete_cb # type: ignore[attr-defined]
8
+ from collections import OrderedDict, deque
9
+ from concurrent.futures import Future
10
+ from contextvars import Context, copy_context
11
+ from dataclasses import dataclass
12
+ from functools import partial, wraps
13
+ from inspect import (
14
+ CORO_RUNNING,
15
+ CORO_SUSPENDED,
16
+ GEN_RUNNING,
17
+ GEN_SUSPENDED,
18
+ getcoroutinestate,
19
+ getgeneratorstate,
20
+ )
21
+ from io import IOBase
22
+ from os import PathLike
23
+ from queue import Queue
24
+ from socket import AddressFamily, SocketKind
25
  from threading import Thread
26
+ from types import TracebackType
27
+ from typing import (
28
+ IO,
29
+ Any,
30
+ AsyncGenerator,
31
+ Awaitable,
32
+ Callable,
33
+ Collection,
34
+ Coroutine,
35
+ Deque,
36
+ Dict,
37
+ Generator,
38
+ Iterable,
39
+ List,
40
+ Mapping,
41
+ Optional,
42
+ Sequence,
43
+ Set,
44
+ Tuple,
45
+ Type,
46
+ TypeVar,
47
+ Union,
48
+ cast,
49
+ )
50
+ from weakref import WeakKeyDictionary
51
+
52
+ import sniffio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc
55
+ from .._core._compat import DeprecatedAsyncContextManager, DeprecatedAwaitable
56
+ from .._core._eventloop import claim_worker_thread, threadlocals
57
+ from .._core._exceptions import (
58
+ BrokenResourceError,
59
+ BusyResourceError,
60
+ ClosedResourceError,
61
+ EndOfStream,
62
  )
63
+ from .._core._exceptions import ExceptionGroup as BaseExceptionGroup
64
+ from .._core._exceptions import WouldBlock
65
+ from .._core._sockets import GetAddrInfoReturnType, convert_ipv6_sockaddr
66
+ from .._core._synchronization import CapacityLimiter as BaseCapacityLimiter
67
+ from .._core._synchronization import Event as BaseEvent
68
+ from .._core._synchronization import ResourceGuard
69
+ from .._core._tasks import CancelScope as BaseCancelScope
70
+ from ..abc import IPSockAddrType, UDPPacketType
71
+ from ..lowlevel import RunVar
72
+
73
+ if sys.version_info >= (3, 8):
74
+
75
+ def get_coro(task: asyncio.Task) -> Union[Generator, Awaitable[Any]]:
76
+ return task.get_coro()
77
+
78
+ else:
79
+
80
+ def get_coro(task: asyncio.Task) -> Union[Generator, Awaitable[Any]]:
81
+ return task._coro
82
+
83
+
84
+ if sys.version_info >= (3, 7):
85
+ from asyncio import all_tasks, create_task, current_task, get_running_loop
86
+ from asyncio import run as native_run
87
+
88
+ def _get_task_callbacks(task: asyncio.Task) -> Iterable[Callable]:
89
+ return [cb for cb, context in task._callbacks] # type: ignore[attr-defined]
90
+
91
+ else:
92
+ _T = TypeVar("_T")
93
+
94
+ def _get_task_callbacks(task: asyncio.Task) -> Iterable[Callable]:
95
+ return task._callbacks
96
+
97
+ def native_run(main, *, debug=False):
98
+ # Snatched from Python 3.7
99
+ from asyncio import coroutines, events, tasks
100
+
101
+ def _cancel_all_tasks(loop):
102
+ to_cancel = all_tasks(loop)
103
+ if not to_cancel:
104
+ return
105
+
106
+ for task in to_cancel:
107
+ task.cancel()
108
+
109
+ loop.run_until_complete(
110
+ tasks.gather(*to_cancel, loop=loop, return_exceptions=True)
111
+ )
112
+
113
+ for task in to_cancel:
114
+ if task.cancelled():
115
+ continue
116
+ if task.exception() is not None:
117
+ loop.call_exception_handler(
118
+ {
119
+ "message": "unhandled exception during asyncio.run() shutdown",
120
+ "exception": task.exception(),
121
+ "task": task,
122
+ }
123
+ )
124
+
125
+ if events._get_running_loop() is not None:
126
+ raise RuntimeError(
127
+ "asyncio.run() cannot be called from a running event loop"
128
+ )
129
+
130
+ if not coroutines.iscoroutine(main):
131
+ raise ValueError(f"a coroutine was expected, got {main!r}")
132
+
133
+ loop = events.new_event_loop()
134
+ try:
135
+ events.set_event_loop(loop)
136
+ loop.set_debug(debug)
137
+ return loop.run_until_complete(main)
138
+ finally:
139
+ try:
140
+ _cancel_all_tasks(loop)
141
+ loop.run_until_complete(loop.shutdown_asyncgens())
142
+ finally:
143
+ events.set_event_loop(None)
144
+ loop.close()
145
+
146
+ def create_task(
147
+ coro: Union[Generator[Any, None, _T], Awaitable[_T]], *, name: object = None
148
+ ) -> asyncio.Task:
149
+ return get_running_loop().create_task(coro)
150
+
151
+ def get_running_loop() -> asyncio.AbstractEventLoop:
152
+ loop = asyncio._get_running_loop()
153
+ if loop is not None:
154
+ return loop
155
+ else:
156
+ raise RuntimeError("no running event loop")
157
+
158
+ def all_tasks(
159
+ loop: Optional[asyncio.AbstractEventLoop] = None,
160
+ ) -> Set[asyncio.Task]:
161
+ """Return a set of all tasks for the loop."""
162
+ from asyncio import Task
163
+
164
+ if loop is None:
165
+ loop = get_running_loop()
166
+
167
+ return {t for t in Task.all_tasks(loop) if not t.done()}
168
+
169
+ def current_task(
170
+ loop: Optional[asyncio.AbstractEventLoop] = None,
171
+ ) -> Optional[asyncio.Task]:
172
+ if loop is None:
173
+ loop = get_running_loop()
174
+
175
+ return asyncio.Task.current_task(loop)
176
+
177
+
178
+ T_Retval = TypeVar("T_Retval")
179
+
180
+ # Check whether there is native support for task names in asyncio (3.8+)
181
+ _native_task_names = hasattr(asyncio.Task, "get_name")
182
+
183
+
184
+ _root_task: RunVar[Optional[asyncio.Task]] = RunVar("_root_task")
185
+
186
+
187
+ def find_root_task() -> asyncio.Task:
188
+ root_task = _root_task.get(None)
189
+ if root_task is not None and not root_task.done():
190
+ return root_task
191
+
192
+ # Look for a task that has been started via run_until_complete()
193
+ for task in all_tasks():
194
+ if task._callbacks and not task.done():
195
+ for cb in _get_task_callbacks(task):
196
+ if (
197
+ cb is _run_until_complete_cb
198
+ or getattr(cb, "__module__", None) == "uvloop.loop"
199
+ ):
200
+ _root_task.set(task)
201
+ return task
202
+
203
+ # Look up the topmost task in the AnyIO task tree, if possible
204
+ task = cast(asyncio.Task, current_task())
205
+ state = _task_states.get(task)
206
+ if state:
207
+ cancel_scope = state.cancel_scope
208
+ while cancel_scope and cancel_scope._parent_scope is not None:
209
+ cancel_scope = cancel_scope._parent_scope
210
+
211
+ if cancel_scope is not None:
212
+ return cast(asyncio.Task, cancel_scope._host_task)
213
+
214
+ return task
215
+
216
+
217
+ def get_callable_name(func: Callable) -> str:
218
+ module = getattr(func, "__module__", None)
219
+ qualname = getattr(func, "__qualname__", None)
220
+ return ".".join([x for x in (module, qualname) if x])
221
+
222
+
223
+ #
224
+ # Event loop
225
+ #
226
+
227
+ _run_vars = (
228
+ WeakKeyDictionary()
229
+ ) # type: WeakKeyDictionary[asyncio.AbstractEventLoop, Any]
230
+
231
+ current_token = get_running_loop
232
+
233
+
234
+ def _task_started(task: asyncio.Task) -> bool:
235
+ """Return ``True`` if the task has been started and has not finished."""
236
+ coro = cast(Coroutine[Any, Any, Any], get_coro(task))
237
+ try:
238
+ return getcoroutinestate(coro) in (CORO_RUNNING, CORO_SUSPENDED)
239
+ except AttributeError:
240
+ try:
241
+ return getgeneratorstate(cast(Generator, coro)) in (
242
+ GEN_RUNNING,
243
+ GEN_SUSPENDED,
244
+ )
245
+ except AttributeError:
246
+ # task coro is async_genenerator_asend https://bugs.python.org/issue37771
247
+ raise Exception(f"Cannot determine if task {task} has started or not")
248
+
249
+
250
+ def _maybe_set_event_loop_policy(
251
+ policy: Optional[asyncio.AbstractEventLoopPolicy], use_uvloop: bool
252
+ ) -> None:
253
+ # On CPython, use uvloop when possible if no other policy has been given and if not
254
+ # explicitly disabled
255
+ if policy is None and use_uvloop and sys.implementation.name == "cpython":
256
+ try:
257
+ import uvloop
258
+ except ImportError:
259
+ pass
260
+ else:
261
+ # Test for missing shutdown_default_executor() (uvloop 0.14.0 and earlier)
262
+ if not hasattr(
263
+ asyncio.AbstractEventLoop, "shutdown_default_executor"
264
+ ) or hasattr(uvloop.loop.Loop, "shutdown_default_executor"):
265
+ policy = uvloop.EventLoopPolicy()
266
+
267
+ if policy is not None:
268
+ asyncio.set_event_loop_policy(policy)
269
+
270
+
271
+ def run(
272
+ func: Callable[..., Awaitable[T_Retval]],
273
+ *args: object,
274
+ debug: bool = False,
275
+ use_uvloop: bool = False,
276
+ policy: Optional[asyncio.AbstractEventLoopPolicy] = None,
277
+ ) -> T_Retval:
278
+ @wraps(func)
279
+ async def wrapper() -> T_Retval:
280
+ task = cast(asyncio.Task, current_task())
281
+ task_state = TaskState(None, get_callable_name(func), None)
282
+ _task_states[task] = task_state
283
+ if _native_task_names:
284
+ task.set_name(task_state.name)
285
+
286
+ try:
287
+ return await func(*args)
288
+ finally:
289
+ del _task_states[task]
290
+
291
+ _maybe_set_event_loop_policy(policy, use_uvloop)
292
+ return native_run(wrapper(), debug=debug)
293
+
294
+
295
+ #
296
+ # Miscellaneous
297
+ #
298
+
299
+ sleep = asyncio.sleep
300
+
301
+
302
+ #
303
+ # Timeouts and cancellation
304
+ #
305
+
306
+ CancelledError = asyncio.CancelledError
307
+
308
+
309
+ class CancelScope(BaseCancelScope):
310
+ def __new__(
311
+ cls, *, deadline: float = math.inf, shield: bool = False
312
+ ) -> "CancelScope":
313
+ return object.__new__(cls)
314
+
315
+ def __init__(self, deadline: float = math.inf, shield: bool = False):
316
+ self._deadline = deadline
317
+ self._shield = shield
318
+ self._parent_scope: Optional[CancelScope] = None
319
+ self._cancel_called = False
320
+ self._active = False
321
+ self._timeout_handle: Optional[asyncio.TimerHandle] = None
322
+ self._cancel_handle: Optional[asyncio.Handle] = None
323
+ self._tasks: Set[asyncio.Task] = set()
324
+ self._host_task: Optional[asyncio.Task] = None
325
+ self._timeout_expired = False
326
+
327
+ def __enter__(self) -> "CancelScope":
328
+ if self._active:
329
+ raise RuntimeError(
330
+ "Each CancelScope may only be used for a single 'with' block"
331
+ )
332
+
333
+ self._host_task = host_task = cast(asyncio.Task, current_task())
334
+ self._tasks.add(host_task)
335
+ try:
336
+ task_state = _task_states[host_task]
337
+ except KeyError:
338
+ task_name = host_task.get_name() if _native_task_names else None
339
+ task_state = TaskState(None, task_name, self)
340
+ _task_states[host_task] = task_state
341
+ else:
342
+ self._parent_scope = task_state.cancel_scope
343
+ task_state.cancel_scope = self
344
+
345
+ self._timeout()
346
+ self._active = True
347
+ return self
348
+
349
+ def __exit__(
350
+ self,
351
+ exc_type: Optional[Type[BaseException]],
352
+ exc_val: Optional[BaseException],
353
+ exc_tb: Optional[TracebackType],
354
+ ) -> Optional[bool]:
355
+ if not self._active:
356
+ raise RuntimeError("This cancel scope is not active")
357
+ if current_task() is not self._host_task:
358
+ raise RuntimeError(
359
+ "Attempted to exit cancel scope in a different task than it was "
360
+ "entered in"
361
+ )
362
+
363
+ assert self._host_task is not None
364
+ host_task_state = _task_states.get(self._host_task)
365
+ if host_task_state is None or host_task_state.cancel_scope is not self:
366
+ raise RuntimeError(
367
+ "Attempted to exit a cancel scope that isn't the current tasks's "
368
+ "current cancel scope"
369
+ )
370
+
371
+ self._active = False
372
+ if self._timeout_handle:
373
+ self._timeout_handle.cancel()
374
+ self._timeout_handle = None
375
+
376
+ self._tasks.remove(self._host_task)
377
+
378
+ host_task_state.cancel_scope = self._parent_scope
379
+
380
+ # Restart the cancellation effort in the farthest directly cancelled parent scope if this
381
+ # one was shielded
382
+ if self._shield:
383
+ self._deliver_cancellation_to_parent()
384
+
385
+ if exc_val is not None:
386
+ exceptions = (
387
+ exc_val.exceptions if isinstance(exc_val, ExceptionGroup) else [exc_val]
388
+ )
389
+ if all(isinstance(exc, CancelledError) for exc in exceptions):
390
+ if self._timeout_expired:
391
+ return True
392
+ elif not self._cancel_called:
393
+ # Task was cancelled natively
394
+ return None
395
+ elif not self._parent_cancelled():
396
+ # This scope was directly cancelled
397
+ return True
398
+
399
+ return None
400
+
401
+ def _timeout(self) -> None:
402
+ if self._deadline != math.inf:
403
+ loop = get_running_loop()
404
+ if loop.time() >= self._deadline:
405
+ self._timeout_expired = True
406
+ self.cancel()
407
+ else:
408
+ self._timeout_handle = loop.call_at(self._deadline, self._timeout)
409
+
410
+ def _deliver_cancellation(self) -> None:
411
+ """
412
+ Deliver cancellation to directly contained tasks and nested cancel scopes.
413
+
414
+ Schedule another run at the end if we still have tasks eligible for cancellation.
415
+ """
416
+ should_retry = False
417
+ current = current_task()
418
+ for task in self._tasks:
419
+ if task._must_cancel: # type: ignore[attr-defined]
420
+ continue
421
+
422
+ # The task is eligible for cancellation if it has started and is not in a cancel
423
+ # scope shielded from this one
424
+ cancel_scope = _task_states[task].cancel_scope
425
+ while cancel_scope is not self:
426
+ if cancel_scope is None or cancel_scope._shield:
427
+ break
428
+ else:
429
+ cancel_scope = cancel_scope._parent_scope
430
+ else:
431
+ should_retry = True
432
+ if task is not current and (
433
+ task is self._host_task or _task_started(task)
434
+ ):
435
+ task.cancel()
436
+
437
+ # Schedule another callback if there are still tasks left
438
+ if should_retry:
439
+ self._cancel_handle = get_running_loop().call_soon(
440
+ self._deliver_cancellation
441
+ )
442
+ else:
443
+ self._cancel_handle = None
444
+
445
+ def _deliver_cancellation_to_parent(self) -> None:
446
+ """Start cancellation effort in the farthest directly cancelled parent scope"""
447
+ scope = self._parent_scope
448
+ scope_to_cancel: Optional[CancelScope] = None
449
+ while scope is not None:
450
+ if scope._cancel_called and scope._cancel_handle is None:
451
+ scope_to_cancel = scope
452
+
453
+ # No point in looking beyond any shielded scope
454
+ if scope._shield:
455
+ break
456
+
457
+ scope = scope._parent_scope
458
+
459
+ if scope_to_cancel is not None:
460
+ scope_to_cancel._deliver_cancellation()
461
+
462
+ def _parent_cancelled(self) -> bool:
463
+ # Check whether any parent has been cancelled
464
+ cancel_scope = self._parent_scope
465
+ while cancel_scope is not None and not cancel_scope._shield:
466
+ if cancel_scope._cancel_called:
467
+ return True
468
+ else:
469
+ cancel_scope = cancel_scope._parent_scope
470
+
471
+ return False
472
+
473
+ def cancel(self) -> DeprecatedAwaitable:
474
+ if not self._cancel_called:
475
+ if self._timeout_handle:
476
+ self._timeout_handle.cancel()
477
+ self._timeout_handle = None
478
+
479
+ self._cancel_called = True
480
+ self._deliver_cancellation()
481
+
482
+ return DeprecatedAwaitable(self.cancel)
483
+
484
+ @property
485
+ def deadline(self) -> float:
486
+ return self._deadline
487
+
488
+ @deadline.setter
489
+ def deadline(self, value: float) -> None:
490
+ self._deadline = float(value)
491
+ if self._timeout_handle is not None:
492
+ self._timeout_handle.cancel()
493
+ self._timeout_handle = None
494
+
495
+ if self._active and not self._cancel_called:
496
+ self._timeout()
497
+
498
+ @property
499
+ def cancel_called(self) -> bool:
500
+ return self._cancel_called
501
+
502
+ @property
503
+ def shield(self) -> bool:
504
+ return self._shield
505
+
506
+ @shield.setter
507
+ def shield(self, value: bool) -> None:
508
+ if self._shield != value:
509
+ self._shield = value
510
+ if not value:
511
+ self._deliver_cancellation_to_parent()
512
+
513
+
514
+ async def checkpoint() -> None:
515
+ await sleep(0)
516
+
517
+
518
+ async def checkpoint_if_cancelled() -> None:
519
+ task = current_task()
520
+ if task is None:
521
+ return
522
+
523
+ try:
524
+ cancel_scope = _task_states[task].cancel_scope
525
+ except KeyError:
526
+ return
527
+
528
+ while cancel_scope:
529
+ if cancel_scope.cancel_called:
530
+ await sleep(0)
531
+ elif cancel_scope.shield:
532
+ break
533
+ else:
534
+ cancel_scope = cancel_scope._parent_scope
535
+
536
+
537
+ async def cancel_shielded_checkpoint() -> None:
538
+ with CancelScope(shield=True):
539
+ await sleep(0)
540
+
541
+
542
+ def current_effective_deadline() -> float:
543
+ try:
544
+ cancel_scope = _task_states[current_task()].cancel_scope # type: ignore[index]
545
+ except KeyError:
546
+ return math.inf
547
+
548
+ deadline = math.inf
549
+ while cancel_scope:
550
+ deadline = min(deadline, cancel_scope.deadline)
551
+ if cancel_scope.shield:
552
+ break
553
+ else:
554
+ cancel_scope = cancel_scope._parent_scope
555
+
556
+ return deadline
557
+
558
+
559
+ def current_time() -> float:
560
+ return get_running_loop().time()
561
+
562
+
563
+ #
564
+ # Task states
565
+ #
566
+
567
+
568
+ class TaskState:
569
+ """
570
+ Encapsulates auxiliary task information that cannot be added to the Task instance itself
571
+ because there are no guarantees about its implementation.
572
+ """
573
+
574
+ __slots__ = "parent_id", "name", "cancel_scope"
575
+
576
+ def __init__(
577
+ self,
578
+ parent_id: Optional[int],
579
+ name: Optional[str],
580
+ cancel_scope: Optional[CancelScope],
581
+ ):
582
+ self.parent_id = parent_id
583
+ self.name = name
584
+ self.cancel_scope = cancel_scope
585
+
586
+
587
+ _task_states = WeakKeyDictionary() # type: WeakKeyDictionary[asyncio.Task, TaskState]
588
+
589
+
590
+ #
591
+ # Task groups
592
+ #
593
+
594
+
595
+ class ExceptionGroup(BaseExceptionGroup):
596
+ def __init__(self, exceptions: List[BaseException]):
597
+ super().__init__()
598
+ self.exceptions = exceptions
599
+
600
+
601
+ class _AsyncioTaskStatus(abc.TaskStatus):
602
+ def __init__(self, future: asyncio.Future, parent_id: int):
603
+ self._future = future
604
+ self._parent_id = parent_id
605
+
606
+ def started(self, value: object = None) -> None:
607
+ try:
608
+ self._future.set_result(value)
609
+ except asyncio.InvalidStateError:
610
+ raise RuntimeError(
611
+ "called 'started' twice on the same task status"
612
+ ) from None
613
+
614
+ task = cast(asyncio.Task, current_task())
615
+ _task_states[task].parent_id = self._parent_id
616
+
617
+
618
+ class TaskGroup(abc.TaskGroup):
619
+ def __init__(self) -> None:
620
+ self.cancel_scope: CancelScope = CancelScope()
621
+ self._active = False
622
+ self._exceptions: List[BaseException] = []
623
+
624
+ async def __aenter__(self) -> "TaskGroup":
625
+ self.cancel_scope.__enter__()
626
+ self._active = True
627
+ return self
628
+
629
+ async def __aexit__(
630
+ self,
631
+ exc_type: Optional[Type[BaseException]],
632
+ exc_val: Optional[BaseException],
633
+ exc_tb: Optional[TracebackType],
634
+ ) -> Optional[bool]:
635
+ ignore_exception = self.cancel_scope.__exit__(exc_type, exc_val, exc_tb)
636
+ if exc_val is not None:
637
+ self.cancel_scope.cancel()
638
+ self._exceptions.append(exc_val)
639
+
640
+ while self.cancel_scope._tasks:
641
+ try:
642
+ await asyncio.wait(self.cancel_scope._tasks)
643
+ except asyncio.CancelledError:
644
+ self.cancel_scope.cancel()
645
+
646
+ self._active = False
647
+ if not self.cancel_scope._parent_cancelled():
648
+ exceptions = self._filter_cancellation_errors(self._exceptions)
649
+ else:
650
+ exceptions = self._exceptions
651
+
652
+ try:
653
+ if len(exceptions) > 1:
654
+ if all(
655
+ isinstance(e, CancelledError) and not e.args for e in exceptions
656
+ ):
657
+ # Tasks were cancelled natively, without a cancellation message
658
+ raise CancelledError
659
+ else:
660
+ raise ExceptionGroup(exceptions)
661
+ elif exceptions and exceptions[0] is not exc_val:
662
+ raise exceptions[0]
663
+ except BaseException as exc:
664
+ # Clear the context here, as it can only be done in-flight.
665
+ # If the context is not cleared, it can result in recursive tracebacks (see #145).
666
+ exc.__context__ = None
667
+ raise
668
+
669
+ return ignore_exception
670
+
671
+ @staticmethod
672
+ def _filter_cancellation_errors(
673
+ exceptions: Sequence[BaseException],
674
+ ) -> List[BaseException]:
675
+ filtered_exceptions: List[BaseException] = []
676
+ for exc in exceptions:
677
+ if isinstance(exc, ExceptionGroup):
678
+ new_exceptions = TaskGroup._filter_cancellation_errors(exc.exceptions)
679
+ if len(new_exceptions) > 1:
680
+ filtered_exceptions.append(exc)
681
+ elif len(new_exceptions) == 1:
682
+ filtered_exceptions.append(new_exceptions[0])
683
+ elif new_exceptions:
684
+ new_exc = ExceptionGroup(new_exceptions)
685
+ new_exc.__cause__ = exc.__cause__
686
+ new_exc.__context__ = exc.__context__
687
+ new_exc.__traceback__ = exc.__traceback__
688
+ filtered_exceptions.append(new_exc)
689
+ elif not isinstance(exc, CancelledError) or exc.args:
690
+ filtered_exceptions.append(exc)
691
+
692
+ return filtered_exceptions
693
+
694
+ async def _run_wrapped_task(
695
+ self, coro: Coroutine, task_status_future: Optional[asyncio.Future]
696
+ ) -> None:
697
+ # This is the code path for Python 3.6 and 3.7 on which asyncio freaks out if a task raises
698
+ # a BaseException.
699
+ __traceback_hide__ = __tracebackhide__ = True # noqa: F841
700
+ task = cast(asyncio.Task, current_task())
701
+ try:
702
+ await coro
703
+ except BaseException as exc:
704
+ if task_status_future is None or task_status_future.done():
705
+ self._exceptions.append(exc)
706
+ self.cancel_scope.cancel()
707
+ else:
708
+ task_status_future.set_exception(exc)
709
+ else:
710
+ if task_status_future is not None and not task_status_future.done():
711
+ task_status_future.set_exception(
712
+ RuntimeError("Child exited without calling task_status.started()")
713
+ )
714
+ finally:
715
+ if task in self.cancel_scope._tasks:
716
+ self.cancel_scope._tasks.remove(task)
717
+ del _task_states[task]
718
+
719
+ def _spawn(
720
+ self,
721
+ func: Callable[..., Coroutine],
722
+ args: tuple,
723
+ name: object,
724
+ task_status_future: Optional[asyncio.Future] = None,
725
+ ) -> asyncio.Task:
726
+ def task_done(_task: asyncio.Task) -> None:
727
+ # This is the code path for Python 3.8+
728
+ assert _task in self.cancel_scope._tasks
729
+ self.cancel_scope._tasks.remove(_task)
730
+ del _task_states[_task]
731
+
732
+ try:
733
+ exc = _task.exception()
734
+ except CancelledError as e:
735
+ while isinstance(e.__context__, CancelledError):
736
+ e = e.__context__
737
+
738
+ exc = e
739
+
740
+ if exc is not None:
741
+ if task_status_future is None or task_status_future.done():
742
+ self._exceptions.append(exc)
743
+ self.cancel_scope.cancel()
744
+ else:
745
+ task_status_future.set_exception(exc)
746
+ elif task_status_future is not None and not task_status_future.done():
747
+ task_status_future.set_exception(
748
+ RuntimeError("Child exited without calling task_status.started()")
749
+ )
750
+
751
+ if not self._active:
752
+ raise RuntimeError(
753
+ "This task group is not active; no new tasks can be started."
754
+ )
755
+
756
+ options = {}
757
+ name = get_callable_name(func) if name is None else str(name)
758
+ if _native_task_names:
759
+ options["name"] = name
760
+
761
+ kwargs = {}
762
+ if task_status_future:
763
+ parent_id = id(current_task())
764
+ kwargs["task_status"] = _AsyncioTaskStatus(
765
+ task_status_future, id(self.cancel_scope._host_task)
766
+ )
767
+ else:
768
+ parent_id = id(self.cancel_scope._host_task)
769
+
770
+ coro = func(*args, **kwargs)
771
+ if not asyncio.iscoroutine(coro):
772
+ raise TypeError(
773
+ f"Expected an async function, but {func} appears to be synchronous"
774
+ )
775
+
776
+ foreign_coro = not hasattr(coro, "cr_frame") and not hasattr(coro, "gi_frame")
777
+ if foreign_coro or sys.version_info < (3, 8):
778
+ coro = self._run_wrapped_task(coro, task_status_future)
779
+
780
+ task = create_task(coro, **options)
781
+ if not foreign_coro and sys.version_info >= (3, 8):
782
+ task.add_done_callback(task_done)
783
+
784
+ # Make the spawned task inherit the task group's cancel scope
785
+ _task_states[task] = TaskState(
786
+ parent_id=parent_id, name=name, cancel_scope=self.cancel_scope
787
+ )
788
+ self.cancel_scope._tasks.add(task)
789
+ return task
790
+
791
+ def start_soon(
792
+ self, func: Callable[..., Coroutine], *args: object, name: object = None
793
+ ) -> None:
794
+ self._spawn(func, args, name)
795
+
796
+ async def start(
797
+ self, func: Callable[..., Coroutine], *args: object, name: object = None
798
+ ) -> None:
799
+ future: asyncio.Future = asyncio.Future()
800
+ task = self._spawn(func, args, name, future)
801
+
802
+ # If the task raises an exception after sending a start value without a switch point
803
+ # between, the task group is cancelled and this method never proceeds to process the
804
+ # completed future. That's why we have to have a shielded cancel scope here.
805
+ with CancelScope(shield=True):
806
+ try:
807
+ return await future
808
+ except CancelledError:
809
+ task.cancel()
810
+ raise
811
+
812
+
813
+ #
814
+ # Threads
815
+ #
816
+
817
+ _Retval_Queue_Type = Tuple[Optional[T_Retval], Optional[BaseException]]
818
+
819
+
820
+ class WorkerThread(Thread):
821
+ MAX_IDLE_TIME = 10 # seconds
822
+
823
+ def __init__(
824
+ self,
825
+ root_task: asyncio.Task,
826
+ workers: Set["WorkerThread"],
827
+ idle_workers: Deque["WorkerThread"],
828
+ ):
829
+ super().__init__(name="AnyIO worker thread")
830
+ self.root_task = root_task
831
+ self.workers = workers
832
+ self.idle_workers = idle_workers
833
+ self.loop = root_task._loop
834
+ self.queue: Queue[
835
+ Union[Tuple[Context, Callable, tuple, asyncio.Future], None]
836
+ ] = Queue(2)
837
+ self.idle_since = current_time()
838
+ self.stopping = False
839
+
840
+ def _report_result(
841
+ self, future: asyncio.Future, result: Any, exc: Optional[BaseException]
842
+ ) -> None:
843
+ self.idle_since = current_time()
844
+ if not self.stopping:
845
+ self.idle_workers.append(self)
846
+
847
+ if not future.cancelled():
848
+ if exc is not None:
849
+ future.set_exception(exc)
850
+ else:
851
+ future.set_result(result)
852
+
853
+ def run(self) -> None:
854
+ with claim_worker_thread("asyncio"):
855
+ threadlocals.loop = self.loop
856
+ while True:
857
+ item = self.queue.get()
858
+ if item is None:
859
+ # Shutdown command received
860
+ return
861
+
862
+ context, func, args, future = item
863
+ if not future.cancelled():
864
+ result = None
865
+ exception: Optional[BaseException] = None
866
+ try:
867
+ result = context.run(func, *args)
868
+ except BaseException as exc:
869
+ exception = exc
870
+
871
+ if not self.loop.is_closed():
872
+ self.loop.call_soon_threadsafe(
873
+ self._report_result, future, result, exception
874
+ )
875
+
876
+ self.queue.task_done()
877
+
878
+ def stop(self, f: Optional[asyncio.Task] = None) -> None:
879
+ self.stopping = True
880
+ self.queue.put_nowait(None)
881
+ self.workers.discard(self)
882
+ try:
883
+ self.idle_workers.remove(self)
884
+ except ValueError:
885
+ pass
886
+
887
+
888
+ _threadpool_idle_workers: RunVar[Deque[WorkerThread]] = RunVar(
889
+ "_threadpool_idle_workers"
890
+ )
891
+ _threadpool_workers: RunVar[Set[WorkerThread]] = RunVar("_threadpool_workers")
892
+
893
+
894
+ async def run_sync_in_worker_thread(
895
+ func: Callable[..., T_Retval],
896
+ *args: object,
897
+ cancellable: bool = False,
898
+ limiter: Optional["CapacityLimiter"] = None,
899
+ ) -> T_Retval:
900
+ await checkpoint()
901
+
902
+ # If this is the first run in this event loop thread, set up the necessary variables
903
+ try:
904
+ idle_workers = _threadpool_idle_workers.get()
905
+ workers = _threadpool_workers.get()
906
+ except LookupError:
907
+ idle_workers = deque()
908
+ workers = set()
909
+ _threadpool_idle_workers.set(idle_workers)
910
+ _threadpool_workers.set(workers)
911
+
912
+ async with (limiter or current_default_thread_limiter()):
913
+ with CancelScope(shield=not cancellable):
914
+ future: asyncio.Future = asyncio.Future()
915
+ root_task = find_root_task()
916
+ if not idle_workers:
917
+ worker = WorkerThread(root_task, workers, idle_workers)
918
+ worker.start()
919
+ workers.add(worker)
920
+ root_task.add_done_callback(worker.stop)
921
+ else:
922
+ worker = idle_workers.pop()
923
+
924
+ # Prune any other workers that have been idle for MAX_IDLE_TIME seconds or longer
925
+ now = current_time()
926
+ while idle_workers:
927
+ if now - idle_workers[0].idle_since < WorkerThread.MAX_IDLE_TIME:
928
+ break
929
+
930
+ expired_worker = idle_workers.popleft()
931
+ expired_worker.root_task.remove_done_callback(expired_worker.stop)
932
+ expired_worker.stop()
933
+
934
+ context = copy_context()
935
+ context.run(sniffio.current_async_library_cvar.set, None)
936
+ worker.queue.put_nowait((context, func, args, future))
937
+ return await future
938
+
939
+
940
+ def run_sync_from_thread(
941
+ func: Callable[..., T_Retval],
942
+ *args: object,
943
+ loop: Optional[asyncio.AbstractEventLoop] = None,
944
+ ) -> T_Retval:
945
+ @wraps(func)
946
+ def wrapper() -> None:
947
+ try:
948
+ f.set_result(func(*args))
949
+ except BaseException as exc:
950
+ f.set_exception(exc)
951
+ if not isinstance(exc, Exception):
952
+ raise
953
+
954
+ f: concurrent.futures.Future[T_Retval] = Future()
955
+ loop = loop or threadlocals.loop
956
+ if sys.version_info < (3, 7):
957
+ loop.call_soon_threadsafe(copy_context().run, wrapper)
958
+ else:
959
+ loop.call_soon_threadsafe(wrapper)
960
+
961
+ return f.result()
962
+
963
+
964
+ def run_async_from_thread(
965
+ func: Callable[..., Coroutine[Any, Any, T_Retval]], *args: object
966
+ ) -> T_Retval:
967
+ f: concurrent.futures.Future[T_Retval] = asyncio.run_coroutine_threadsafe(
968
+ func(*args), threadlocals.loop
969
+ )
970
+ return f.result()
971
+
972
+
973
+ class BlockingPortal(abc.BlockingPortal):
974
+ def __new__(cls) -> "BlockingPortal":
975
+ return object.__new__(cls)
976
+
977
+ def __init__(self) -> None:
978
+ super().__init__()
979
+ self._loop = get_running_loop()
980
+
981
+ def _spawn_task_from_thread(
982
+ self,
983
+ func: Callable,
984
+ args: tuple,
985
+ kwargs: Dict[str, Any],
986
+ name: object,
987
+ future: Future,
988
+ ) -> None:
989
+ run_sync_from_thread(
990
+ partial(self._task_group.start_soon, name=name),
991
+ self._call_func,
992
+ func,
993
+ args,
994
+ kwargs,
995
+ future,
996
+ loop=self._loop,
997
+ )
998
+
999
+
1000
+ #
1001
+ # Subprocesses
1002
+ #
1003
+
1004
+
1005
+ @dataclass(eq=False)
1006
+ class StreamReaderWrapper(abc.ByteReceiveStream):
1007
+ _stream: asyncio.StreamReader
1008
+
1009
+ async def receive(self, max_bytes: int = 65536) -> bytes:
1010
+ data = await self._stream.read(max_bytes)
1011
+ if data:
1012
+ return data
1013
+ else:
1014
+ raise EndOfStream
1015
+
1016
+ async def aclose(self) -> None:
1017
+ self._stream.feed_eof()
1018
+
1019
+
1020
+ @dataclass(eq=False)
1021
+ class StreamWriterWrapper(abc.ByteSendStream):
1022
+ _stream: asyncio.StreamWriter
1023
+
1024
+ async def send(self, item: bytes) -> None:
1025
+ self._stream.write(item)
1026
+ await self._stream.drain()
1027
+
1028
+ async def aclose(self) -> None:
1029
+ self._stream.close()
1030
+
1031
+
1032
+ @dataclass(eq=False)
1033
+ class Process(abc.Process):
1034
+ _process: asyncio.subprocess.Process
1035
+ _stdin: Optional[StreamWriterWrapper]
1036
+ _stdout: Optional[StreamReaderWrapper]
1037
+ _stderr: Optional[StreamReaderWrapper]
1038
+
1039
+ async def aclose(self) -> None:
1040
+ if self._stdin:
1041
+ await self._stdin.aclose()
1042
+ if self._stdout:
1043
+ await self._stdout.aclose()
1044
+ if self._stderr:
1045
+ await self._stderr.aclose()
1046
+
1047
+ await self.wait()
1048
+
1049
+ async def wait(self) -> int:
1050
+ return await self._process.wait()
1051
+
1052
+ def terminate(self) -> None:
1053
+ self._process.terminate()
1054
+
1055
+ def kill(self) -> None:
1056
+ self._process.kill()
1057
+
1058
+ def send_signal(self, signal: int) -> None:
1059
+ self._process.send_signal(signal)
1060
+
1061
+ @property
1062
+ def pid(self) -> int:
1063
+ return self._process.pid
1064
+
1065
+ @property
1066
+ def returncode(self) -> Optional[int]:
1067
+ return self._process.returncode
1068
+
1069
+ @property
1070
+ def stdin(self) -> Optional[abc.ByteSendStream]:
1071
+ return self._stdin
1072
 
1073
+ @property
1074
+ def stdout(self) -> Optional[abc.ByteReceiveStream]:
1075
+ return self._stdout
1076
+
1077
+ @property
1078
+ def stderr(self) -> Optional[abc.ByteReceiveStream]:
1079
+ return self._stderr
1080
+
1081
+
1082
+ async def open_process(
1083
+ command: Union[str, bytes, Sequence[Union[str, bytes]]],
1084
+ *,
1085
+ shell: bool,
1086
+ stdin: Union[int, IO[Any], None],
1087
+ stdout: Union[int, IO[Any], None],
1088
+ stderr: Union[int, IO[Any], None],
1089
+ cwd: Union[str, bytes, PathLike, None] = None,
1090
+ env: Optional[Mapping[str, str]] = None,
1091
+ start_new_session: bool = False,
1092
+ ) -> Process:
1093
+ await checkpoint()
1094
+ if shell:
1095
+ process = await asyncio.create_subprocess_shell(
1096
+ cast(Union[str, bytes], command),
1097
+ stdin=stdin,
1098
+ stdout=stdout,
1099
+ stderr=stderr,
1100
+ cwd=cwd,
1101
+ env=env,
1102
+ start_new_session=start_new_session,
1103
+ )
1104
+ else:
1105
+ process = await asyncio.create_subprocess_exec(
1106
+ *command,
1107
+ stdin=stdin,
1108
+ stdout=stdout,
1109
+ stderr=stderr,
1110
+ cwd=cwd,
1111
+ env=env,
1112
+ start_new_session=start_new_session,
1113
  )
 
1114
 
1115
+ stdin_stream = StreamWriterWrapper(process.stdin) if process.stdin else None
1116
+ stdout_stream = StreamReaderWrapper(process.stdout) if process.stdout else None
1117
+ stderr_stream = StreamReaderWrapper(process.stderr) if process.stderr else None
1118
+ return Process(process, stdin_stream, stdout_stream, stderr_stream)
1119
+
1120
+
1121
+ def _forcibly_shutdown_process_pool_on_exit(
1122
+ workers: Set[Process], _task: object
1123
+ ) -> None:
1124
+ """
1125
+ Forcibly shuts down worker processes belonging to this event loop."""
1126
+ child_watcher: Optional[asyncio.AbstractChildWatcher]
1127
+ try:
1128
+ child_watcher = asyncio.get_event_loop_policy().get_child_watcher()
1129
+ except NotImplementedError:
1130
+ child_watcher = None
1131
+
1132
+ # Close as much as possible (w/o async/await) to avoid warnings
1133
+ for process in workers:
1134
+ if process.returncode is None:
1135
+ continue
1136
+
1137
+ process._stdin._stream._transport.close() # type: ignore[union-attr]
1138
+ process._stdout._stream._transport.close() # type: ignore[union-attr]
1139
+ process._stderr._stream._transport.close() # type: ignore[union-attr]
1140
+ process.kill()
1141
+ if child_watcher:
1142
+ child_watcher.remove_child_handler(process.pid)
1143
+
1144
+
1145
+ async def _shutdown_process_pool_on_exit(workers: Set[Process]) -> None:
1146
+ """
1147
+ Shuts down worker processes belonging to this event loop.
1148
+
1149
+ NOTE: this only works when the event loop was started using asyncio.run() or anyio.run().
1150
+
1151
+ """
1152
+ process: Process
1153
+ try:
1154
+ await sleep(math.inf)
1155
+ except asyncio.CancelledError:
1156
+ for process in workers:
1157
+ if process.returncode is None:
1158
+ process.kill()
1159
+
1160
+ for process in workers:
1161
+ await process.aclose()
1162
+
1163
+
1164
+ def setup_process_pool_exit_at_shutdown(workers: Set[Process]) -> None:
1165
+ kwargs = {"name": "AnyIO process pool shutdown task"} if _native_task_names else {}
1166
+ create_task(_shutdown_process_pool_on_exit(workers), **kwargs)
1167
+ find_root_task().add_done_callback(
1168
+ partial(_forcibly_shutdown_process_pool_on_exit, workers)
1169
+ )
1170
+
1171
+
1172
+ #
1173
+ # Sockets and networking
1174
+ #
1175
+
1176
+
1177
+ class StreamProtocol(asyncio.Protocol):
1178
+ read_queue: Deque[bytes]
1179
+ read_event: asyncio.Event
1180
+ write_event: asyncio.Event
1181
+ exception: Optional[Exception] = None
1182
+
1183
+ def connection_made(self, transport: asyncio.BaseTransport) -> None:
1184
+ self.read_queue = deque()
1185
+ self.read_event = asyncio.Event()
1186
+ self.write_event = asyncio.Event()
1187
+ self.write_event.set()
1188
+ cast(asyncio.Transport, transport).set_write_buffer_limits(0)
1189
+
1190
+ def connection_lost(self, exc: Optional[Exception]) -> None:
1191
+ if exc:
1192
+ self.exception = BrokenResourceError()
1193
+ self.exception.__cause__ = exc
1194
+
1195
+ self.read_event.set()
1196
+ self.write_event.set()
1197
+
1198
+ def data_received(self, data: bytes) -> None:
1199
+ self.read_queue.append(data)
1200
+ self.read_event.set()
1201
+
1202
+ def eof_received(self) -> Optional[bool]:
1203
+ self.read_event.set()
1204
+ return True
1205
+
1206
+ def pause_writing(self) -> None:
1207
+ self.write_event = asyncio.Event()
1208
+
1209
+ def resume_writing(self) -> None:
1210
+ self.write_event.set()
1211
+
1212
+
1213
+ class DatagramProtocol(asyncio.DatagramProtocol):
1214
+ read_queue: Deque[Tuple[bytes, IPSockAddrType]]
1215
+ read_event: asyncio.Event
1216
+ write_event: asyncio.Event
1217
+ exception: Optional[Exception] = None
1218
+
1219
+ def connection_made(self, transport: asyncio.BaseTransport) -> None:
1220
+ self.read_queue = deque(maxlen=100) # arbitrary value
1221
+ self.read_event = asyncio.Event()
1222
+ self.write_event = asyncio.Event()
1223
+ self.write_event.set()
1224
+
1225
+ def connection_lost(self, exc: Optional[Exception]) -> None:
1226
+ self.read_event.set()
1227
+ self.write_event.set()
1228
+
1229
+ def datagram_received(self, data: bytes, addr: IPSockAddrType) -> None:
1230
+ addr = convert_ipv6_sockaddr(addr)
1231
+ self.read_queue.append((data, addr))
1232
+ self.read_event.set()
1233
+
1234
+ def error_received(self, exc: Exception) -> None:
1235
+ self.exception = exc
1236
+
1237
+ def pause_writing(self) -> None:
1238
+ self.write_event.clear()
1239
+
1240
+ def resume_writing(self) -> None:
1241
+ self.write_event.set()
1242
+
1243
+
1244
+ class SocketStream(abc.SocketStream):
1245
+ def __init__(self, transport: asyncio.Transport, protocol: StreamProtocol):
1246
+ self._transport = transport
1247
+ self._protocol = protocol
1248
+ self._receive_guard = ResourceGuard("reading from")
1249
+ self._send_guard = ResourceGuard("writing to")
1250
+ self._closed = False
1251
+
1252
+ @property
1253
+ def _raw_socket(self) -> socket.socket:
1254
+ return self._transport.get_extra_info("socket")
1255
+
1256
+ async def receive(self, max_bytes: int = 65536) -> bytes:
1257
+ with self._receive_guard:
1258
+ await checkpoint()
1259
+
1260
+ if (
1261
+ not self._protocol.read_event.is_set()
1262
+ and not self._transport.is_closing()
1263
+ ):
1264
+ self._transport.resume_reading()
1265
+ await self._protocol.read_event.wait()
1266
+ self._transport.pause_reading()
1267
+
1268
+ try:
1269
+ chunk = self._protocol.read_queue.popleft()
1270
+ except IndexError:
1271
+ if self._closed:
1272
+ raise ClosedResourceError from None
1273
+ elif self._protocol.exception:
1274
+ raise self._protocol.exception
1275
+ else:
1276
+ raise EndOfStream from None
1277
+
1278
+ if len(chunk) > max_bytes:
1279
+ # Split the oversized chunk
1280
+ chunk, leftover = chunk[:max_bytes], chunk[max_bytes:]
1281
+ self._protocol.read_queue.appendleft(leftover)
1282
+
1283
+ # If the read queue is empty, clear the flag so that the next call will block until
1284
+ # data is available
1285
+ if not self._protocol.read_queue:
1286
+ self._protocol.read_event.clear()
1287
+
1288
+ return chunk
1289
+
1290
+ async def send(self, item: bytes) -> None:
1291
+ with self._send_guard:
1292
+ await checkpoint()
1293
+
1294
+ if self._closed:
1295
+ raise ClosedResourceError
1296
+ elif self._protocol.exception is not None:
1297
+ raise self._protocol.exception
1298
+
1299
+ try:
1300
+ self._transport.write(item)
1301
+ except RuntimeError as exc:
1302
+ if self._transport.is_closing():
1303
+ raise BrokenResourceError from exc
1304
+ else:
1305
+ raise
1306
+
1307
+ await self._protocol.write_event.wait()
1308
+
1309
+ async def send_eof(self) -> None:
1310
+ try:
1311
+ self._transport.write_eof()
1312
+ except OSError:
1313
+ pass
1314
+
1315
+ async def aclose(self) -> None:
1316
+ if not self._transport.is_closing():
1317
+ self._closed = True
1318
+ try:
1319
+ self._transport.write_eof()
1320
+ except OSError:
1321
+ pass
1322
+
1323
+ self._transport.close()
1324
+ await sleep(0)
1325
+ self._transport.abort()
1326
+
1327
 
1328
+ class UNIXSocketStream(abc.SocketStream):
1329
+ _receive_future: Optional[asyncio.Future] = None
1330
+ _send_future: Optional[asyncio.Future] = None
1331
+ _closing = False
1332
 
1333
+ def __init__(self, raw_socket: socket.socket):
1334
+ self.__raw_socket = raw_socket
1335
+ self._loop = get_running_loop()
1336
+ self._receive_guard = ResourceGuard("reading from")
1337
+ self._send_guard = ResourceGuard("writing to")
1338
 
1339
+ @property
1340
+ def _raw_socket(self) -> socket.socket:
1341
+ return self.__raw_socket
1342
 
1343
+ def _wait_until_readable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
1344
+ def callback(f: object) -> None:
1345
+ del self._receive_future
1346
+ loop.remove_reader(self.__raw_socket)
1347
 
1348
+ f = self._receive_future = asyncio.Future()
1349
+ self._loop.add_reader(self.__raw_socket, f.set_result, None)
1350
+ f.add_done_callback(callback)
1351
+ return f
1352
 
1353
+ def _wait_until_writable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
1354
+ def callback(f: object) -> None:
1355
+ del self._send_future
1356
+ loop.remove_writer(self.__raw_socket)
1357
+
1358
+ f = self._send_future = asyncio.Future()
1359
+ self._loop.add_writer(self.__raw_socket, f.set_result, None)
1360
+ f.add_done_callback(callback)
1361
+ return f
1362
+
1363
+ async def send_eof(self) -> None:
1364
+ with self._send_guard:
1365
+ self._raw_socket.shutdown(socket.SHUT_WR)
1366
+
1367
+ async def receive(self, max_bytes: int = 65536) -> bytes:
1368
+ loop = get_running_loop()
1369
+ await checkpoint()
1370
+ with self._receive_guard:
1371
+ while True:
1372
+ try:
1373
+ data = self.__raw_socket.recv(max_bytes)
1374
+ except BlockingIOError:
1375
+ await self._wait_until_readable(loop)
1376
+ except OSError as exc:
1377
+ if self._closing:
1378
+ raise ClosedResourceError from None
1379
+ else:
1380
+ raise BrokenResourceError from exc
1381
+ else:
1382
+ if not data:
1383
+ raise EndOfStream
1384
+
1385
+ return data
1386
+
1387
+ async def send(self, item: bytes) -> None:
1388
+ loop = get_running_loop()
1389
+ await checkpoint()
1390
+ with self._send_guard:
1391
+ view = memoryview(item)
1392
+ while view:
1393
+ try:
1394
+ bytes_sent = self.__raw_socket.send(item)
1395
+ except BlockingIOError:
1396
+ await self._wait_until_writable(loop)
1397
+ except OSError as exc:
1398
+ if self._closing:
1399
+ raise ClosedResourceError from None
1400
+ else:
1401
+ raise BrokenResourceError from exc
1402
+ else:
1403
+ view = view[bytes_sent:]
1404
+
1405
+ async def receive_fds(self, msglen: int, maxfds: int) -> Tuple[bytes, List[int]]:
1406
+ if not isinstance(msglen, int) or msglen < 0:
1407
+ raise ValueError("msglen must be a non-negative integer")
1408
+ if not isinstance(maxfds, int) or maxfds < 1:
1409
+ raise ValueError("maxfds must be a positive integer")
1410
+
1411
+ loop = get_running_loop()
1412
+ fds = array.array("i")
1413
+ await checkpoint()
1414
+ with self._receive_guard:
1415
+ while True:
1416
+ try:
1417
+ message, ancdata, flags, addr = self.__raw_socket.recvmsg(
1418
+ msglen, socket.CMSG_LEN(maxfds * fds.itemsize)
1419
+ )
1420
+ except BlockingIOError:
1421
+ await self._wait_until_readable(loop)
1422
+ except OSError as exc:
1423
+ if self._closing:
1424
+ raise ClosedResourceError from None
1425
+ else:
1426
+ raise BrokenResourceError from exc
1427
+ else:
1428
+ if not message and not ancdata:
1429
+ raise EndOfStream
1430
+
1431
+ break
1432
+
1433
+ for cmsg_level, cmsg_type, cmsg_data in ancdata:
1434
+ if cmsg_level != socket.SOL_SOCKET or cmsg_type != socket.SCM_RIGHTS:
1435
+ raise RuntimeError(
1436
+ f"Received unexpected ancillary data; message = {message!r}, "
1437
+ f"cmsg_level = {cmsg_level}, cmsg_type = {cmsg_type}"
1438
+ )
1439
+
1440
+ fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
1441
+
1442
+ return message, list(fds)
1443
+
1444
+ async def send_fds(
1445
+ self, message: bytes, fds: Collection[Union[int, IOBase]]
1446
+ ) -> None:
1447
+ if not message:
1448
+ raise ValueError("message must not be empty")
1449
+ if not fds:
1450
+ raise ValueError("fds must not be empty")
1451
+
1452
+ loop = get_running_loop()
1453
+ filenos: List[int] = []
1454
+ for fd in fds:
1455
+ if isinstance(fd, int):
1456
+ filenos.append(fd)
1457
+ elif isinstance(fd, IOBase):
1458
+ filenos.append(fd.fileno())
1459
+
1460
+ fdarray = array.array("i", filenos)
1461
+ await checkpoint()
1462
+ with self._send_guard:
1463
+ while True:
1464
+ try:
1465
+ # The ignore can be removed after mypy picks up
1466
+ # https://github.com/python/typeshed/pull/5545
1467
+ self.__raw_socket.sendmsg(
1468
+ [message], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fdarray)]
1469
+ )
1470
+ break
1471
+ except BlockingIOError:
1472
+ await self._wait_until_writable(loop)
1473
+ except OSError as exc:
1474
+ if self._closing:
1475
+ raise ClosedResourceError from None
1476
+ else:
1477
+ raise BrokenResourceError from exc
1478
+
1479
+ async def aclose(self) -> None:
1480
+ if not self._closing:
1481
+ self._closing = True
1482
+ if self.__raw_socket.fileno() != -1:
1483
+ self.__raw_socket.close()
1484
+
1485
+ if self._receive_future:
1486
+ self._receive_future.set_result(None)
1487
+ if self._send_future:
1488
+ self._send_future.set_result(None)
1489
+
1490
+
1491
+ class TCPSocketListener(abc.SocketListener):
1492
+ _accept_scope: Optional[CancelScope] = None
1493
+ _closed = False
1494
+
1495
+ def __init__(self, raw_socket: socket.socket):
1496
+ self.__raw_socket = raw_socket
1497
+ self._loop = cast(asyncio.BaseEventLoop, get_running_loop())
1498
+ self._accept_guard = ResourceGuard("accepting connections from")
1499
+
1500
+ @property
1501
+ def _raw_socket(self) -> socket.socket:
1502
+ return self.__raw_socket
1503
+
1504
+ async def accept(self) -> abc.SocketStream:
1505
+ if self._closed:
1506
+ raise ClosedResourceError
1507
+
1508
+ with self._accept_guard:
1509
+ await checkpoint()
1510
+ with CancelScope() as self._accept_scope:
1511
+ try:
1512
+ client_sock, _addr = await self._loop.sock_accept(self._raw_socket)
1513
+ except asyncio.CancelledError:
1514
+ # Workaround for https://bugs.python.org/issue41317
1515
  try:
1516
+ self._loop.remove_reader(self._raw_socket)
1517
+ except (ValueError, NotImplementedError):
1518
+ pass
1519
+
1520
+ if self._closed:
1521
+ raise ClosedResourceError from None
1522
+
1523
+ raise
1524
+ finally:
1525
+ self._accept_scope = None
1526
+
1527
+ client_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
1528
+ transport, protocol = await self._loop.connect_accepted_socket(
1529
+ StreamProtocol, client_sock
1530
+ )
1531
+ return SocketStream(cast(asyncio.Transport, transport), protocol)
1532
+
1533
+ async def aclose(self) -> None:
1534
+ if self._closed:
1535
+ return
1536
+
1537
+ self._closed = True
1538
+ if self._accept_scope:
1539
+ # Workaround for https://bugs.python.org/issue41317
1540
+ try:
1541
+ self._loop.remove_reader(self._raw_socket)
1542
+ except (ValueError, NotImplementedError):
1543
+ pass
1544
+
1545
+ self._accept_scope.cancel()
1546
+ await sleep(0)
1547
+
1548
+ self._raw_socket.close()
1549
+
1550
+
1551
+ class UNIXSocketListener(abc.SocketListener):
1552
+ def __init__(self, raw_socket: socket.socket):
1553
+ self.__raw_socket = raw_socket
1554
+ self._loop = get_running_loop()
1555
+ self._accept_guard = ResourceGuard("accepting connections from")
1556
+ self._closed = False
1557
+
1558
+ async def accept(self) -> abc.SocketStream:
1559
+ await checkpoint()
1560
+ with self._accept_guard:
1561
+ while True:
1562
+ try:
1563
+ client_sock, _ = self.__raw_socket.accept()
1564
+ client_sock.setblocking(False)
1565
+ return UNIXSocketStream(client_sock)
1566
+ except BlockingIOError:
1567
+ f: asyncio.Future = asyncio.Future()
1568
+ self._loop.add_reader(self.__raw_socket, f.set_result, None)
1569
+ f.add_done_callback(
1570
+ lambda _: self._loop.remove_reader(self.__raw_socket)
1571
+ )
1572
+ await f
1573
+ except OSError as exc:
1574
+ if self._closed:
1575
+ raise ClosedResourceError from None
1576
+ else:
1577
+ raise BrokenResourceError from exc
1578
+
1579
+ async def aclose(self) -> None:
1580
+ self._closed = True
1581
+ self.__raw_socket.close()
1582
+
1583
+ @property
1584
+ def _raw_socket(self) -> socket.socket:
1585
+ return self.__raw_socket
1586
+
1587
+
1588
+ class UDPSocket(abc.UDPSocket):
1589
+ def __init__(
1590
+ self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
1591
+ ):
1592
+ self._transport = transport
1593
+ self._protocol = protocol
1594
+ self._receive_guard = ResourceGuard("reading from")
1595
+ self._send_guard = ResourceGuard("writing to")
1596
+ self._closed = False
1597
+
1598
+ @property
1599
+ def _raw_socket(self) -> socket.socket:
1600
+ return self._transport.get_extra_info("socket")
1601
+
1602
+ async def aclose(self) -> None:
1603
+ if not self._transport.is_closing():
1604
+ self._closed = True
1605
+ self._transport.close()
1606
+
1607
+ async def receive(self) -> Tuple[bytes, IPSockAddrType]:
1608
+ with self._receive_guard:
1609
+ await checkpoint()
1610
+
1611
+ # If the buffer is empty, ask for more data
1612
+ if not self._protocol.read_queue and not self._transport.is_closing():
1613
+ self._protocol.read_event.clear()
1614
+ await self._protocol.read_event.wait()
1615
+
1616
+ try:
1617
+ return self._protocol.read_queue.popleft()
1618
+ except IndexError:
1619
+ if self._closed:
1620
+ raise ClosedResourceError from None
1621
+ else:
1622
+ raise BrokenResourceError from None
1623
+
1624
+ async def send(self, item: UDPPacketType) -> None:
1625
+ with self._send_guard:
1626
+ await checkpoint()
1627
+ await self._protocol.write_event.wait()
1628
+ if self._closed:
1629
+ raise ClosedResourceError
1630
+ elif self._transport.is_closing():
1631
+ raise BrokenResourceError
1632
+ else:
1633
+ self._transport.sendto(*item)
1634
+
1635
+
1636
+ class ConnectedUDPSocket(abc.ConnectedUDPSocket):
1637
+ def __init__(
1638
+ self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
1639
+ ):
1640
+ self._transport = transport
1641
+ self._protocol = protocol
1642
+ self._receive_guard = ResourceGuard("reading from")
1643
+ self._send_guard = ResourceGuard("writing to")
1644
+ self._closed = False
1645
+
1646
+ @property
1647
+ def _raw_socket(self) -> socket.socket:
1648
+ return self._transport.get_extra_info("socket")
1649
+
1650
+ async def aclose(self) -> None:
1651
+ if not self._transport.is_closing():
1652
+ self._closed = True
1653
+ self._transport.close()
1654
+
1655
+ async def receive(self) -> bytes:
1656
+ with self._receive_guard:
1657
+ await checkpoint()
1658
+
1659
+ # If the buffer is empty, ask for more data
1660
+ if not self._protocol.read_queue and not self._transport.is_closing():
1661
+ self._protocol.read_event.clear()
1662
+ await self._protocol.read_event.wait()
1663
+
1664
+ try:
1665
+ packet = self._protocol.read_queue.popleft()
1666
+ except IndexError:
1667
+ if self._closed:
1668
+ raise ClosedResourceError from None
1669
+ else:
1670
+ raise BrokenResourceError from None
1671
+
1672
+ return packet[0]
1673
+
1674
+ async def send(self, item: bytes) -> None:
1675
+ with self._send_guard:
1676
+ await checkpoint()
1677
+ await self._protocol.write_event.wait()
1678
+ if self._closed:
1679
+ raise ClosedResourceError
1680
+ elif self._transport.is_closing():
1681
+ raise BrokenResourceError
1682
+ else:
1683
+ self._transport.sendto(item)
1684
+
1685
+
1686
+ async def connect_tcp(
1687
+ host: str, port: int, local_addr: Optional[Tuple[str, int]] = None
1688
+ ) -> SocketStream:
1689
+ transport, protocol = cast(
1690
+ Tuple[asyncio.Transport, StreamProtocol],
1691
+ await get_running_loop().create_connection(
1692
+ StreamProtocol, host, port, local_addr=local_addr
1693
+ ),
1694
+ )
1695
+ transport.pause_reading()
1696
+ return SocketStream(transport, protocol)
1697
+
1698
+
1699
+ async def connect_unix(path: str) -> UNIXSocketStream:
1700
+ await checkpoint()
1701
+ loop = get_running_loop()
1702
+ raw_socket = socket.socket(socket.AF_UNIX)
1703
+ raw_socket.setblocking(False)
1704
+ while True:
1705
+ try:
1706
+ raw_socket.connect(path)
1707
+ except BlockingIOError:
1708
+ f: asyncio.Future = asyncio.Future()
1709
+ loop.add_writer(raw_socket, f.set_result, None)
1710
+ f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
1711
+ await f
1712
+ except BaseException:
1713
+ raw_socket.close()
1714
+ raise
1715
+ else:
1716
+ return UNIXSocketStream(raw_socket)
1717
+
1718
+
1719
+ async def create_udp_socket(
1720
+ family: socket.AddressFamily,
1721
+ local_address: Optional[IPSockAddrType],
1722
+ remote_address: Optional[IPSockAddrType],
1723
+ reuse_port: bool,
1724
+ ) -> Union[UDPSocket, ConnectedUDPSocket]:
1725
+ result = await get_running_loop().create_datagram_endpoint(
1726
+ DatagramProtocol,
1727
+ local_addr=local_address,
1728
+ remote_addr=remote_address,
1729
+ family=family,
1730
+ reuse_port=reuse_port,
1731
  )
1732
+ transport = cast(asyncio.DatagramTransport, result[0])
1733
+ protocol = result[1]
1734
+ if protocol.exception:
1735
+ transport.close()
1736
+ raise protocol.exception
1737
+
1738
+ if not remote_address:
1739
+ return UDPSocket(transport, protocol)
1740
+ else:
1741
+ return ConnectedUDPSocket(transport, protocol)
1742
+
1743
+
1744
+ async def getaddrinfo(
1745
+ host: Union[bytes, str],
1746
+ port: Union[str, int, None],
1747
+ *,
1748
+ family: Union[int, AddressFamily] = 0,
1749
+ type: Union[int, SocketKind] = 0,
1750
+ proto: int = 0,
1751
+ flags: int = 0,
1752
+ ) -> GetAddrInfoReturnType:
1753
+ # https://github.com/python/typeshed/pull/4304
1754
+ result = await get_running_loop().getaddrinfo(
1755
+ host, port, family=family, type=type, proto=proto, flags=flags
1756
+ )
1757
+ return cast(GetAddrInfoReturnType, result)
1758
+
1759
+
1760
+ async def getnameinfo(sockaddr: IPSockAddrType, flags: int = 0) -> Tuple[str, str]:
1761
+ return await get_running_loop().getnameinfo(sockaddr, flags)
1762
+
1763
+
1764
+ _read_events: RunVar[Dict[Any, asyncio.Event]] = RunVar("read_events")
1765
+ _write_events: RunVar[Dict[Any, asyncio.Event]] = RunVar("write_events")
1766
+
1767
+
1768
+ async def wait_socket_readable(sock: socket.socket) -> None:
1769
+ await checkpoint()
1770
+ try:
1771
+ read_events = _read_events.get()
1772
+ except LookupError:
1773
+ read_events = {}
1774
+ _read_events.set(read_events)
1775
+
1776
+ if read_events.get(sock):
1777
+ raise BusyResourceError("reading from") from None
1778
+
1779
+ loop = get_running_loop()
1780
+ event = read_events[sock] = asyncio.Event()
1781
+ loop.add_reader(sock, event.set)
1782
+ try:
1783
+ await event.wait()
1784
+ finally:
1785
+ if read_events.pop(sock, None) is not None:
1786
+ loop.remove_reader(sock)
1787
+ readable = True
1788
+ else:
1789
+ readable = False
1790
+
1791
+ if not readable:
1792
+ raise ClosedResourceError
1793
+
1794
+
1795
+ async def wait_socket_writable(sock: socket.socket) -> None:
1796
+ await checkpoint()
1797
+ try:
1798
+ write_events = _write_events.get()
1799
+ except LookupError:
1800
+ write_events = {}
1801
+ _write_events.set(write_events)
1802
+
1803
+ if write_events.get(sock):
1804
+ raise BusyResourceError("writing to") from None
1805
+
1806
+ loop = get_running_loop()
1807
+ event = write_events[sock] = asyncio.Event()
1808
+ loop.add_writer(sock.fileno(), event.set)
1809
+ try:
1810
+ await event.wait()
1811
+ finally:
1812
+ if write_events.pop(sock, None) is not None:
1813
+ loop.remove_writer(sock)
1814
+ writable = True
1815
+ else:
1816
+ writable = False
1817
+
1818
+ if not writable:
1819
+ raise ClosedResourceError
1820
+
1821
+
1822
+ #
1823
+ # Synchronization
1824
+ #
1825
+
1826
+
1827
+ class Event(BaseEvent):
1828
+ def __new__(cls) -> "Event":
1829
+ return object.__new__(cls)
1830
+
1831
+ def __init__(self) -> None:
1832
+ self._event = asyncio.Event()
1833
+
1834
+ def set(self) -> DeprecatedAwaitable:
1835
+ self._event.set()
1836
+ return DeprecatedAwaitable(self.set)
1837
+
1838
+ def is_set(self) -> bool:
1839
+ return self._event.is_set()
1840
+
1841
+ async def wait(self) -> None:
1842
+ if await self._event.wait():
1843
+ await checkpoint()
1844
+
1845
+ def statistics(self) -> EventStatistics:
1846
+ return EventStatistics(len(self._event._waiters)) # type: ignore[attr-defined]
1847
+
1848
+
1849
+ class CapacityLimiter(BaseCapacityLimiter):
1850
+ _total_tokens: float = 0
1851
+
1852
+ def __new__(cls, total_tokens: float) -> "CapacityLimiter":
1853
+ return object.__new__(cls)
1854
+
1855
+ def __init__(self, total_tokens: float):
1856
+ self._borrowers: Set[Any] = set()
1857
+ self._wait_queue: Dict[Any, asyncio.Event] = OrderedDict()
1858
+ self.total_tokens = total_tokens
1859
+
1860
+ async def __aenter__(self) -> None:
1861
+ await self.acquire()
1862
+
1863
+ async def __aexit__(
1864
+ self,
1865
+ exc_type: Optional[Type[BaseException]],
1866
+ exc_val: Optional[BaseException],
1867
+ exc_tb: Optional[TracebackType],
1868
+ ) -> None:
1869
+ self.release()
1870
+
1871
+ @property
1872
+ def total_tokens(self) -> float:
1873
+ return self._total_tokens
1874
+
1875
+ @total_tokens.setter
1876
+ def total_tokens(self, value: float) -> None:
1877
+ if not isinstance(value, int) and not math.isinf(value):
1878
+ raise TypeError("total_tokens must be an int or math.inf")
1879
+ if value < 1:
1880
+ raise ValueError("total_tokens must be >= 1")
1881
+
1882
+ old_value = self._total_tokens
1883
+ self._total_tokens = value
1884
+ events = []
1885
+ for event in self._wait_queue.values():
1886
+ if value <= old_value:
1887
+ break
1888
+
1889
+ if not event.is_set():
1890
+ events.append(event)
1891
+ old_value += 1
1892
+
1893
+ for event in events:
1894
+ event.set()
1895
+
1896
+ @property
1897
+ def borrowed_tokens(self) -> int:
1898
+ return len(self._borrowers)
1899
+
1900
+ @property
1901
+ def available_tokens(self) -> float:
1902
+ return self._total_tokens - len(self._borrowers)
1903
+
1904
+ def acquire_nowait(self) -> DeprecatedAwaitable:
1905
+ self.acquire_on_behalf_of_nowait(current_task())
1906
+ return DeprecatedAwaitable(self.acquire_nowait)
1907
+
1908
+ def acquire_on_behalf_of_nowait(self, borrower: object) -> DeprecatedAwaitable:
1909
+ if borrower in self._borrowers:
1910
+ raise RuntimeError(
1911
+ "this borrower is already holding one of this CapacityLimiter's "
1912
+ "tokens"
1913
+ )
1914
+
1915
+ if self._wait_queue or len(self._borrowers) >= self._total_tokens:
1916
+ raise WouldBlock
1917
+
1918
+ self._borrowers.add(borrower)
1919
+ return DeprecatedAwaitable(self.acquire_on_behalf_of_nowait)
1920
+
1921
+ async def acquire(self) -> None:
1922
+ return await self.acquire_on_behalf_of(current_task())
1923
+
1924
+ async def acquire_on_behalf_of(self, borrower: object) -> None:
1925
+ await checkpoint_if_cancelled()
1926
+ try:
1927
+ self.acquire_on_behalf_of_nowait(borrower)
1928
+ except WouldBlock:
1929
+ event = asyncio.Event()
1930
+ self._wait_queue[borrower] = event
1931
+ try:
1932
+ await event.wait()
1933
+ except BaseException:
1934
+ self._wait_queue.pop(borrower, None)
1935
+ raise
1936
+
1937
+ self._borrowers.add(borrower)
1938
+ else:
1939
+ try:
1940
+ await cancel_shielded_checkpoint()
1941
+ except BaseException:
1942
+ self.release()
1943
+ raise
1944
+
1945
+ def release(self) -> None:
1946
+ self.release_on_behalf_of(current_task())
1947
+
1948
+ def release_on_behalf_of(self, borrower: object) -> None:
1949
+ try:
1950
+ self._borrowers.remove(borrower)
1951
+ except KeyError:
1952
+ raise RuntimeError(
1953
+ "this borrower isn't holding any of this CapacityLimiter's " "tokens"
1954
+ ) from None
1955
+
1956
+ # Notify the next task in line if this limiter has free capacity now
1957
+ if self._wait_queue and len(self._borrowers) < self._total_tokens:
1958
+ event = self._wait_queue.popitem()[1]
1959
+ event.set()
1960
+
1961
+ def statistics(self) -> CapacityLimiterStatistics:
1962
+ return CapacityLimiterStatistics(
1963
+ self.borrowed_tokens,
1964
+ self.total_tokens,
1965
+ tuple(self._borrowers),
1966
+ len(self._wait_queue),
1967
+ )
1968
+
1969
+
1970
+ _default_thread_limiter: RunVar[CapacityLimiter] = RunVar("_default_thread_limiter")
1971
+
1972
+
1973
+ def current_default_thread_limiter() -> CapacityLimiter:
1974
+ try:
1975
+ return _default_thread_limiter.get()
1976
+ except LookupError:
1977
+ limiter = CapacityLimiter(40)
1978
+ _default_thread_limiter.set(limiter)
1979
+ return limiter
1980
+
1981
+
1982
+ #
1983
+ # Operating system signals
1984
+ #
1985
+
1986
+
1987
+ class _SignalReceiver(DeprecatedAsyncContextManager["_SignalReceiver"]):
1988
+ def __init__(self, signals: Tuple[int, ...]):
1989
+ self._signals = signals
1990
+ self._loop = get_running_loop()
1991
+ self._signal_queue: Deque[int] = deque()
1992
+ self._future: asyncio.Future = asyncio.Future()
1993
+ self._handled_signals: Set[int] = set()
1994
+
1995
+ def _deliver(self, signum: int) -> None:
1996
+ self._signal_queue.append(signum)
1997
+ if not self._future.done():
1998
+ self._future.set_result(None)
1999
+
2000
+ def __enter__(self) -> "_SignalReceiver":
2001
+ for sig in set(self._signals):
2002
+ self._loop.add_signal_handler(sig, self._deliver, sig)
2003
+ self._handled_signals.add(sig)
2004
+
2005
+ return self
2006
+
2007
+ def __exit__(
2008
+ self,
2009
+ exc_type: Optional[Type[BaseException]],
2010
+ exc_val: Optional[BaseException],
2011
+ exc_tb: Optional[TracebackType],
2012
+ ) -> Optional[bool]:
2013
+ for sig in self._handled_signals:
2014
+ self._loop.remove_signal_handler(sig)
2015
+ return None
2016
+
2017
+ def __aiter__(self) -> "_SignalReceiver":
2018
+ return self
2019
+
2020
+ async def __anext__(self) -> int:
2021
+ await checkpoint()
2022
+ if not self._signal_queue:
2023
+ self._future = asyncio.Future()
2024
+ await self._future
2025
+
2026
+ return self._signal_queue.popleft()
2027
+
2028
+
2029
+ def open_signal_receiver(*signals: int) -> _SignalReceiver:
2030
+ return _SignalReceiver(signals)
2031
+
2032
+
2033
+ #
2034
+ # Testing and debugging
2035
+ #
2036
+
2037
+
2038
+ def _create_task_info(task: asyncio.Task) -> TaskInfo:
2039
+ task_state = _task_states.get(task)
2040
+ if task_state is None:
2041
+ name = task.get_name() if _native_task_names else None
2042
+ parent_id = None
2043
+ else:
2044
+ name = task_state.name
2045
+ parent_id = task_state.parent_id
2046
+
2047
+ return TaskInfo(id(task), parent_id, name, get_coro(task))
2048
+
2049
+
2050
+ def get_current_task() -> TaskInfo:
2051
+ return _create_task_info(current_task()) # type: ignore[arg-type]
2052
+
2053
+
2054
+ def get_running_tasks() -> List[TaskInfo]:
2055
+ return [_create_task_info(task) for task in all_tasks() if not task.done()]
2056
+
2057
+
2058
+ async def wait_all_tasks_blocked() -> None:
2059
+ await checkpoint()
2060
+ this_task = current_task()
2061
+ while True:
2062
+ for task in all_tasks():
2063
+ if task is this_task:
2064
+ continue
2065
+
2066
+ if task._fut_waiter is None or task._fut_waiter.done(): # type: ignore[attr-defined]
2067
+ await sleep(0.1)
2068
+ break
2069
+ else:
2070
+ return
2071
+
2072
+
2073
+ class TestRunner(abc.TestRunner):
2074
+ def __init__(
2075
+ self,
2076
+ debug: bool = False,
2077
+ use_uvloop: bool = False,
2078
+ policy: Optional[asyncio.AbstractEventLoopPolicy] = None,
2079
+ ):
2080
+ self._exceptions: List[BaseException] = []
2081
+ _maybe_set_event_loop_policy(policy, use_uvloop)
2082
+ self._loop = asyncio.new_event_loop()
2083
+ self._loop.set_debug(debug)
2084
+ self._loop.set_exception_handler(self._exception_handler)
2085
+ asyncio.set_event_loop(self._loop)
2086
+
2087
+ def _cancel_all_tasks(self) -> None:
2088
+ to_cancel = all_tasks(self._loop)
2089
+ if not to_cancel:
2090
+ return
2091
+
2092
+ for task in to_cancel:
2093
+ task.cancel()
2094
+
2095
+ self._loop.run_until_complete(
2096
+ asyncio.gather(*to_cancel, return_exceptions=True)
2097
+ )
2098
+
2099
+ for task in to_cancel:
2100
+ if task.cancelled():
2101
+ continue
2102
+ if task.exception() is not None:
2103
+ raise cast(BaseException, task.exception())
2104
+
2105
+ def _exception_handler(
2106
+ self, loop: asyncio.AbstractEventLoop, context: Dict[str, Any]
2107
+ ) -> None:
2108
+ if isinstance(context.get("exception"), Exception):
2109
+ self._exceptions.append(context["exception"])
2110
+ else:
2111
+ loop.default_exception_handler(context)
2112
+
2113
+ def _raise_async_exceptions(self) -> None:
2114
+ # Re-raise any exceptions raised in asynchronous callbacks
2115
+ if self._exceptions:
2116
+ exceptions, self._exceptions = self._exceptions, []
2117
+ if len(exceptions) == 1:
2118
+ raise exceptions[0]
2119
+ elif exceptions:
2120
+ raise ExceptionGroup(exceptions)
2121
+
2122
+ def close(self) -> None:
2123
+ try:
2124
+ self._cancel_all_tasks()
2125
+ self._loop.run_until_complete(self._loop.shutdown_asyncgens())
2126
+ finally:
2127
+ asyncio.set_event_loop(None)
2128
+ self._loop.close()
2129
+
2130
+ def run_asyncgen_fixture(
2131
+ self,
2132
+ fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]],
2133
+ kwargs: Dict[str, Any],
2134
+ ) -> Iterable[T_Retval]:
2135
+ async def fixture_runner() -> None:
2136
+ agen = fixture_func(**kwargs)
2137
+ try:
2138
+ retval = await agen.asend(None)
2139
+ self._raise_async_exceptions()
2140
+ except BaseException as exc:
2141
+ f.set_exception(exc)
2142
+ return
2143
+ else:
2144
+ f.set_result(retval)
2145
+
2146
+ await event.wait()
2147
+ try:
2148
+ await agen.asend(None)
2149
+ except StopAsyncIteration:
2150
+ pass
2151
+ else:
2152
+ await agen.aclose()
2153
+ raise RuntimeError("Async generator fixture did not stop")
2154
+
2155
+ f = self._loop.create_future()
2156
+ event = asyncio.Event()
2157
+ fixture_task = self._loop.create_task(fixture_runner())
2158
+ self._loop.run_until_complete(f)
2159
+ yield f.result()
2160
+ event.set()
2161
+ self._loop.run_until_complete(fixture_task)
2162
+ self._raise_async_exceptions()
2163
+
2164
+ def run_fixture(
2165
+ self,
2166
+ fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]],
2167
+ kwargs: Dict[str, Any],
2168
+ ) -> T_Retval:
2169
+ retval = self._loop.run_until_complete(fixture_func(**kwargs))
2170
+ self._raise_async_exceptions()
2171
+ return retval
2172
+
2173
+ def run_test(
2174
+ self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: Dict[str, Any]
2175
+ ) -> None:
2176
+ try:
2177
+ self._loop.run_until_complete(test_func(**kwargs))
2178
+ except Exception as exc:
2179
+ self._exceptions.append(exc)
2180
 
2181
+ self._raise_async_exceptions()