petals-playground / chat_client.py
slush0's picture
Basics works, but still WIP; separators and examples need to be updated from bloom to llama2-related models.
c461bd0
raw
history blame contribute delete
No virus
2.53 kB
#!/usr/bin/env python
import json
import sys
# pip install websocket-client
import websocket
class ModelClient(object):
def __init__(self, endpoint_url):
self.endpoint_url = endpoint_url
self.ws = None
self.model = None
def open_session(self, model, max_length):
self.ws = websocket.create_connection(self.endpoint_url, enable_multithread=True)
self.model = model
payload = {
"type": "open_inference_session",
"model": self.model,
"max_length": max_length,
}
self.ws.send(json.dumps(payload))
assert json.loads(self.ws.recv())['ok'] == True
def is_session(self):
return self.ws != None
def close_session(self):
if self.ws:
self.ws.close()
self.ws = None
def generate(self, prompt, **kwargs):
try:
return self._generate(prompt, **kwargs)
except:
self.close_session()
raise
def _generate(self, prompt, **kwargs):
payload = {
"type": "generate",
"inputs": prompt,
"max_new_tokens": 1,
"do_sample": 0,
"temperature": 1,
"stop_sequence": "</s>" if "bloomz" in self.model else "\n\n",
}
payload = {**payload, **kwargs}
self.ws.send(json.dumps(payload))
while True:
data = json.loads(self.ws.recv())
if not data['ok']:
raise Exception(data['traceback'])
yield data['outputs']
if data['stop']:
break
def main():
#client = ModelClient("ws://localhost:8000/api/v2/generate")
client = ModelClient("wss://chat.petals.dev/api/v2/generate")
client.open_session("stabilityai/StableBeluga2", 128)
if len(sys.argv) > 1:
prompt = sys.argv[1]
# Bloomz variant uses </s> instead of \n\n as an eos token
if not prompt.endswith("\n\n"):
prompt += "\n\n"
else:
prompt = "The SQL command to extract all the users whose name starts with A is: \n\n"
print(f"Prompt: {prompt}")
# petals.client.routing.sequence_manager.MissingBlocksError
for out in client.generate(prompt,
do_sample=True,
temperature=0.75,
top_p=0.9):
print(out, end="", flush=True)
client.close_session()
if __name__ == '__main__':
main()