retkowski commited on
Commit
d7545dc
·
1 Parent(s): cb71ef5

Handle token by token generation

Browse files
Files changed (1) hide show
  1. generate_text_api.py +113 -114
generate_text_api.py CHANGED
@@ -1,114 +1,113 @@
1
- import json
2
-
3
- import aiohttp
4
-
5
-
6
- class TextGenerator:
7
- def __init__(self, host_url):
8
- self.host_url = host_url.rstrip("/") + "/generate"
9
- self.host_url_stream = host_url.rstrip("/") + "/generate_stream"
10
-
11
- async def generate_text_async(self, prompt, max_new_tokens=100, do_sample=True, temperature=0.8):
12
- payload = {
13
- 'inputs': prompt,
14
- 'parameters': {
15
- 'max_new_tokens': max_new_tokens,
16
- 'do_sample': do_sample,
17
- 'temperature': temperature,
18
- }
19
- }
20
-
21
- headers = {
22
- 'Content-Type': 'application/json'
23
- }
24
-
25
- async with aiohttp.ClientSession() as session:
26
- async with session.post(self.host_url, data=json.dumps(payload), headers=headers) as response:
27
- if response.status == 200:
28
- data = await response.json()
29
- text = data["generated_text"]
30
- return text
31
- else:
32
- # Handle error responses here
33
- return None
34
-
35
- def generate_text(self, prompt, max_new_tokens=100, do_sample=True, temperature=0.8):
36
- import requests
37
-
38
- payload = {
39
- 'inputs': prompt,
40
- 'parameters': {
41
- 'max_new_tokens': max_new_tokens,
42
- 'do_sample': do_sample,
43
- 'temperature': temperature,
44
- }
45
- }
46
-
47
- headers = {
48
- 'Content-Type': 'application/json'
49
- }
50
-
51
- response = requests.post(self.host_url, data=json.dumps(payload), headers=headers).json()
52
- text = response["generated_text"]
53
- return text
54
-
55
- def generate_text_stream(self, prompt, max_new_tokens=100, do_sample=True, temperature=0.8, stop=[], best_of=1):
56
- import requests
57
-
58
- payload = {
59
- 'inputs': prompt,
60
- 'parameters': {
61
- 'max_new_tokens': max_new_tokens,
62
- 'do_sample': do_sample,
63
- 'temperature': temperature,
64
- 'stop': stop,
65
- 'best_of': best_of,
66
- }
67
- }
68
-
69
- headers = {
70
- 'Content-Type': 'application/json',
71
- 'Cache-Control': 'no-cache',
72
- 'Connection': 'keep-alive'
73
- }
74
-
75
- response = requests.post(self.host_url_stream, data=json.dumps(payload), headers=headers, stream=True)
76
-
77
- for line in response.iter_lines():
78
- if line:
79
- print(line)
80
- json_data = line.decode('utf-8')
81
- if json_data.startswith('data:'):
82
- print(json_data)
83
- json_data = json_data[5:]
84
- token_data = json.loads(json_data)
85
- token = token_data['token']['text']
86
- if not token_data['token']['special']:
87
- yield token
88
-
89
- class SummarizerGenerator:
90
- def __init__(self, api):
91
- self.api = api
92
-
93
- def generate_summary_stream(self, text):
94
- import requests
95
- payload = {"text": text}
96
-
97
- headers = {
98
- 'Content-Type': 'application/json',
99
- 'Cache-Control': 'no-cache',
100
- 'Connection': 'keep-alive'
101
- }
102
-
103
- response = requests.post(self.api, data=json.dumps(payload), headers=headers, stream=True)
104
-
105
- i = 1
106
- for line in response.iter_lines():
107
- if line:
108
- print(line)
109
- data = line.decode('utf-8').removesuffix('<|eot_id|>')
110
- if data.startswith("•"):
111
- data = data.replace("•", "-")
112
- data += "\n\n" if i < 3 else ""
113
- yield data
114
- i += 1
 
1
+ import json
2
+
3
+ import aiohttp
4
+
5
+
6
+ class TextGenerator:
7
+ def __init__(self, host_url):
8
+ self.host_url = host_url.rstrip("/") + "/generate"
9
+ self.host_url_stream = host_url.rstrip("/") + "/generate_stream"
10
+
11
+ async def generate_text_async(self, prompt, max_new_tokens=100, do_sample=True, temperature=0.8):
12
+ payload = {
13
+ 'inputs': prompt,
14
+ 'parameters': {
15
+ 'max_new_tokens': max_new_tokens,
16
+ 'do_sample': do_sample,
17
+ 'temperature': temperature,
18
+ }
19
+ }
20
+
21
+ headers = {
22
+ 'Content-Type': 'application/json'
23
+ }
24
+
25
+ async with aiohttp.ClientSession() as session:
26
+ async with session.post(self.host_url, data=json.dumps(payload), headers=headers) as response:
27
+ if response.status == 200:
28
+ data = await response.json()
29
+ text = data["generated_text"]
30
+ return text
31
+ else:
32
+ # Handle error responses here
33
+ return None
34
+
35
+ def generate_text(self, prompt, max_new_tokens=100, do_sample=True, temperature=0.8):
36
+ import requests
37
+
38
+ payload = {
39
+ 'inputs': prompt,
40
+ 'parameters': {
41
+ 'max_new_tokens': max_new_tokens,
42
+ 'do_sample': do_sample,
43
+ 'temperature': temperature,
44
+ }
45
+ }
46
+
47
+ headers = {
48
+ 'Content-Type': 'application/json'
49
+ }
50
+
51
+ response = requests.post(self.host_url, data=json.dumps(payload), headers=headers).json()
52
+ text = response["generated_text"]
53
+ return text
54
+
55
+ def generate_text_stream(self, prompt, max_new_tokens=100, do_sample=True, temperature=0.8, stop=[], best_of=1):
56
+ import requests
57
+
58
+ payload = {
59
+ 'inputs': prompt,
60
+ 'parameters': {
61
+ 'max_new_tokens': max_new_tokens,
62
+ 'do_sample': do_sample,
63
+ 'temperature': temperature,
64
+ 'stop': stop,
65
+ 'best_of': best_of,
66
+ }
67
+ }
68
+
69
+ headers = {
70
+ 'Content-Type': 'application/json',
71
+ 'Cache-Control': 'no-cache',
72
+ 'Connection': 'keep-alive'
73
+ }
74
+
75
+ response = requests.post(self.host_url_stream, data=json.dumps(payload), headers=headers, stream=True)
76
+
77
+ for line in response.iter_lines():
78
+ if line:
79
+ print(line)
80
+ json_data = line.decode('utf-8')
81
+ if json_data.startswith('data:'):
82
+ print(json_data)
83
+ json_data = json_data[5:]
84
+ token_data = json.loads(json_data)
85
+ token = token_data['token']['text']
86
+ if not token_data['token']['special']:
87
+ yield token
88
+
89
+ class SummarizerGenerator:
90
+ def __init__(self, api):
91
+ self.api = api
92
+
93
+ def generate_summary_stream(self, text):
94
+ import requests
95
+ payload = {"text": text}
96
+
97
+ headers = {
98
+ 'Content-Type': 'application/json',
99
+ 'Cache-Control': 'no-cache',
100
+ 'Connection': 'keep-alive'
101
+ }
102
+
103
+ response = requests.post(self.api, data=json.dumps(payload), headers=headers, stream=True)
104
+
105
+ for line in response.iter_lines():
106
+ if line:
107
+ print(line)
108
+ data = line.decode('utf-8').removesuffix('<|eot_id|>')
109
+ if data.startswith("•"):
110
+ data = data.replace("•", "-")
111
+ if data.startswith("-"):
112
+ data = "\n\n" + data
113
+ yield data