ytseg_demo / generate_text_api.py
ScientiaEtVeritas
initial commit
c57bf8a
import json
import aiohttp
class TextGenerator:
def __init__(self, host_url):
self.host_url = host_url.rstrip("/") + "/generate"
self.host_url_stream = host_url.rstrip("/") + "/generate_stream"
async def generate_text_async(self, prompt, max_new_tokens=100, do_sample=True, temperature=0.8):
payload = {
'inputs': prompt,
'parameters': {
'max_new_tokens': max_new_tokens,
'do_sample': do_sample,
'temperature': temperature,
}
}
headers = {
'Content-Type': 'application/json'
}
async with aiohttp.ClientSession() as session:
async with session.post(self.host_url, data=json.dumps(payload), headers=headers) as response:
if response.status == 200:
data = await response.json()
text = data["generated_text"]
return text
else:
# Handle error responses here
return None
def generate_text(self, prompt, max_new_tokens=100, do_sample=True, temperature=0.8):
import requests
payload = {
'inputs': prompt,
'parameters': {
'max_new_tokens': max_new_tokens,
'do_sample': do_sample,
'temperature': temperature,
}
}
headers = {
'Content-Type': 'application/json'
}
response = requests.post(self.host_url, data=json.dumps(payload), headers=headers).json()
text = response["generated_text"]
return text
def generate_text_stream(self, prompt, max_new_tokens=100, do_sample=True, temperature=0.8, stop=[], best_of=1):
import requests
payload = {
'inputs': prompt,
'parameters': {
'max_new_tokens': max_new_tokens,
'do_sample': do_sample,
'temperature': temperature,
'stop': stop,
'best_of': best_of,
}
}
headers = {
'Content-Type': 'application/json',
'Cache-Control': 'no-cache',
'Connection': 'keep-alive'
}
response = requests.post(self.host_url_stream, data=json.dumps(payload), headers=headers, stream=True)
for line in response.iter_lines():
if line:
print(line)
json_data = line.decode('utf-8')
if json_data.startswith('data:'):
print(json_data)
json_data = json_data[5:]
token_data = json.loads(json_data)
token = token_data['token']['text']
if not token_data['token']['special']:
yield token