Spaces:
Sleeping
Sleeping
Handle token by token generation
Browse files- generate_text_api.py +113 -114
generate_text_api.py
CHANGED
@@ -1,114 +1,113 @@
|
|
1 |
-
import json
|
2 |
-
|
3 |
-
import aiohttp
|
4 |
-
|
5 |
-
|
6 |
-
class TextGenerator:
|
7 |
-
def __init__(self, host_url):
|
8 |
-
self.host_url = host_url.rstrip("/") + "/generate"
|
9 |
-
self.host_url_stream = host_url.rstrip("/") + "/generate_stream"
|
10 |
-
|
11 |
-
async def generate_text_async(self, prompt, max_new_tokens=100, do_sample=True, temperature=0.8):
|
12 |
-
payload = {
|
13 |
-
'inputs': prompt,
|
14 |
-
'parameters': {
|
15 |
-
'max_new_tokens': max_new_tokens,
|
16 |
-
'do_sample': do_sample,
|
17 |
-
'temperature': temperature,
|
18 |
-
}
|
19 |
-
}
|
20 |
-
|
21 |
-
headers = {
|
22 |
-
'Content-Type': 'application/json'
|
23 |
-
}
|
24 |
-
|
25 |
-
async with aiohttp.ClientSession() as session:
|
26 |
-
async with session.post(self.host_url, data=json.dumps(payload), headers=headers) as response:
|
27 |
-
if response.status == 200:
|
28 |
-
data = await response.json()
|
29 |
-
text = data["generated_text"]
|
30 |
-
return text
|
31 |
-
else:
|
32 |
-
# Handle error responses here
|
33 |
-
return None
|
34 |
-
|
35 |
-
def generate_text(self, prompt, max_new_tokens=100, do_sample=True, temperature=0.8):
|
36 |
-
import requests
|
37 |
-
|
38 |
-
payload = {
|
39 |
-
'inputs': prompt,
|
40 |
-
'parameters': {
|
41 |
-
'max_new_tokens': max_new_tokens,
|
42 |
-
'do_sample': do_sample,
|
43 |
-
'temperature': temperature,
|
44 |
-
}
|
45 |
-
}
|
46 |
-
|
47 |
-
headers = {
|
48 |
-
'Content-Type': 'application/json'
|
49 |
-
}
|
50 |
-
|
51 |
-
response = requests.post(self.host_url, data=json.dumps(payload), headers=headers).json()
|
52 |
-
text = response["generated_text"]
|
53 |
-
return text
|
54 |
-
|
55 |
-
def generate_text_stream(self, prompt, max_new_tokens=100, do_sample=True, temperature=0.8, stop=[], best_of=1):
|
56 |
-
import requests
|
57 |
-
|
58 |
-
payload = {
|
59 |
-
'inputs': prompt,
|
60 |
-
'parameters': {
|
61 |
-
'max_new_tokens': max_new_tokens,
|
62 |
-
'do_sample': do_sample,
|
63 |
-
'temperature': temperature,
|
64 |
-
'stop': stop,
|
65 |
-
'best_of': best_of,
|
66 |
-
}
|
67 |
-
}
|
68 |
-
|
69 |
-
headers = {
|
70 |
-
'Content-Type': 'application/json',
|
71 |
-
'Cache-Control': 'no-cache',
|
72 |
-
'Connection': 'keep-alive'
|
73 |
-
}
|
74 |
-
|
75 |
-
response = requests.post(self.host_url_stream, data=json.dumps(payload), headers=headers, stream=True)
|
76 |
-
|
77 |
-
for line in response.iter_lines():
|
78 |
-
if line:
|
79 |
-
print(line)
|
80 |
-
json_data = line.decode('utf-8')
|
81 |
-
if json_data.startswith('data:'):
|
82 |
-
print(json_data)
|
83 |
-
json_data = json_data[5:]
|
84 |
-
token_data = json.loads(json_data)
|
85 |
-
token = token_data['token']['text']
|
86 |
-
if not token_data['token']['special']:
|
87 |
-
yield token
|
88 |
-
|
89 |
-
class SummarizerGenerator:
|
90 |
-
def __init__(self, api):
|
91 |
-
self.api = api
|
92 |
-
|
93 |
-
def generate_summary_stream(self, text):
|
94 |
-
import requests
|
95 |
-
payload = {"text": text}
|
96 |
-
|
97 |
-
headers = {
|
98 |
-
'Content-Type': 'application/json',
|
99 |
-
'Cache-Control': 'no-cache',
|
100 |
-
'Connection': 'keep-alive'
|
101 |
-
}
|
102 |
-
|
103 |
-
response = requests.post(self.api, data=json.dumps(payload), headers=headers, stream=True)
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
data
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
yield data
|
114 |
-
i += 1
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
import aiohttp
|
4 |
+
|
5 |
+
|
6 |
+
class TextGenerator:
|
7 |
+
def __init__(self, host_url):
|
8 |
+
self.host_url = host_url.rstrip("/") + "/generate"
|
9 |
+
self.host_url_stream = host_url.rstrip("/") + "/generate_stream"
|
10 |
+
|
11 |
+
async def generate_text_async(self, prompt, max_new_tokens=100, do_sample=True, temperature=0.8):
|
12 |
+
payload = {
|
13 |
+
'inputs': prompt,
|
14 |
+
'parameters': {
|
15 |
+
'max_new_tokens': max_new_tokens,
|
16 |
+
'do_sample': do_sample,
|
17 |
+
'temperature': temperature,
|
18 |
+
}
|
19 |
+
}
|
20 |
+
|
21 |
+
headers = {
|
22 |
+
'Content-Type': 'application/json'
|
23 |
+
}
|
24 |
+
|
25 |
+
async with aiohttp.ClientSession() as session:
|
26 |
+
async with session.post(self.host_url, data=json.dumps(payload), headers=headers) as response:
|
27 |
+
if response.status == 200:
|
28 |
+
data = await response.json()
|
29 |
+
text = data["generated_text"]
|
30 |
+
return text
|
31 |
+
else:
|
32 |
+
# Handle error responses here
|
33 |
+
return None
|
34 |
+
|
35 |
+
def generate_text(self, prompt, max_new_tokens=100, do_sample=True, temperature=0.8):
|
36 |
+
import requests
|
37 |
+
|
38 |
+
payload = {
|
39 |
+
'inputs': prompt,
|
40 |
+
'parameters': {
|
41 |
+
'max_new_tokens': max_new_tokens,
|
42 |
+
'do_sample': do_sample,
|
43 |
+
'temperature': temperature,
|
44 |
+
}
|
45 |
+
}
|
46 |
+
|
47 |
+
headers = {
|
48 |
+
'Content-Type': 'application/json'
|
49 |
+
}
|
50 |
+
|
51 |
+
response = requests.post(self.host_url, data=json.dumps(payload), headers=headers).json()
|
52 |
+
text = response["generated_text"]
|
53 |
+
return text
|
54 |
+
|
55 |
+
def generate_text_stream(self, prompt, max_new_tokens=100, do_sample=True, temperature=0.8, stop=[], best_of=1):
|
56 |
+
import requests
|
57 |
+
|
58 |
+
payload = {
|
59 |
+
'inputs': prompt,
|
60 |
+
'parameters': {
|
61 |
+
'max_new_tokens': max_new_tokens,
|
62 |
+
'do_sample': do_sample,
|
63 |
+
'temperature': temperature,
|
64 |
+
'stop': stop,
|
65 |
+
'best_of': best_of,
|
66 |
+
}
|
67 |
+
}
|
68 |
+
|
69 |
+
headers = {
|
70 |
+
'Content-Type': 'application/json',
|
71 |
+
'Cache-Control': 'no-cache',
|
72 |
+
'Connection': 'keep-alive'
|
73 |
+
}
|
74 |
+
|
75 |
+
response = requests.post(self.host_url_stream, data=json.dumps(payload), headers=headers, stream=True)
|
76 |
+
|
77 |
+
for line in response.iter_lines():
|
78 |
+
if line:
|
79 |
+
print(line)
|
80 |
+
json_data = line.decode('utf-8')
|
81 |
+
if json_data.startswith('data:'):
|
82 |
+
print(json_data)
|
83 |
+
json_data = json_data[5:]
|
84 |
+
token_data = json.loads(json_data)
|
85 |
+
token = token_data['token']['text']
|
86 |
+
if not token_data['token']['special']:
|
87 |
+
yield token
|
88 |
+
|
89 |
+
class SummarizerGenerator:
|
90 |
+
def __init__(self, api):
|
91 |
+
self.api = api
|
92 |
+
|
93 |
+
def generate_summary_stream(self, text):
|
94 |
+
import requests
|
95 |
+
payload = {"text": text}
|
96 |
+
|
97 |
+
headers = {
|
98 |
+
'Content-Type': 'application/json',
|
99 |
+
'Cache-Control': 'no-cache',
|
100 |
+
'Connection': 'keep-alive'
|
101 |
+
}
|
102 |
+
|
103 |
+
response = requests.post(self.api, data=json.dumps(payload), headers=headers, stream=True)
|
104 |
+
|
105 |
+
for line in response.iter_lines():
|
106 |
+
if line:
|
107 |
+
print(line)
|
108 |
+
data = line.decode('utf-8').removesuffix('<|eot_id|>')
|
109 |
+
if data.startswith("•"):
|
110 |
+
data = data.replace("•", "-")
|
111 |
+
if data.startswith("-"):
|
112 |
+
data = "\n\n" + data
|
113 |
+
yield data
|
|