XciD's picture
XciD HF staff
initial commit
8969f81
raw history blame
No virus
4.21 kB
from threading import Thread
import falcon
from falcon.http_status import HTTPStatus
import json
import requests
import time
from Model import generate_completion
import sys
class AutoComplete(object):
def on_post(self, req, resp, single_endpoint=True, x=None, y=None):
json_data = json.loads(req.bounded_stream.read())
resp.status = falcon.HTTP_200
start = time.time()
try:
context = json_data["context"].rstrip()
except KeyError:
resp.body = "The context field is required"
resp.status = falcon.HTTP_422
return
try:
n_samples = json_data['samples']
except KeyError:
n_samples = 3
try:
length = json_data['gen_length']
except KeyError:
length = 20
try:
max_time = json_data['max_time']
except KeyError:
max_time = -1
try:
model_name = json_data['model_size']
except KeyError:
model_name = "small"
try:
temperature = json_data['temperature']
except KeyError:
temperature = 0.7
try:
max_tokens = json_data['max_tokens']
except KeyError:
max_tokens = 256
try:
top_p = json_data['top_p']
except KeyError:
top_p = 0.95
try:
top_k = json_data['top_k']
except KeyError:
top_k = 40
# CTRL
try:
repetition_penalty = json_data['repetition_penalty']
except KeyError:
repetition_penalty = 0.02
# PPLM
try:
stepsize = json_data['step_size']
except KeyError:
stepsize = 0.02
try:
gm_scale = json_data['gm_scale']
except KeyError:
gm_scale = None
try:
kl_scale = json_data['kl_scale']
except KeyError:
kl_scale = None
try:
num_iterations = json_data['num_iterations']
except KeyError:
num_iterations = None
try:
use_sampling = json_data['use_sampling']
except KeyError:
use_sampling = None
try:
bag_of_words_or_discrim = json_data['bow_or_discrim']
except KeyError:
bag_of_words_or_discrim = "kitchen"
print(json_data)
sentences = generate_completion(
context,
length=length,
max_time=max_time,
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
top_k=top_k,
# CTRL
repetition_penalty=repetition_penalty,
# PPLM
stepsize=stepsize,
bag_of_words_or_discrim=bag_of_words_or_discrim,
gm_scale=gm_scale,
kl_scale=kl_scale,
num_iterations=num_iterations,
use_sampling=use_sampling
)
resp.body = json.dumps({"sentences": sentences, 'time': time.time() - start})
resp.status = falcon.HTTP_200
sys.stdout.flush()
class Request(Thread):
def __init__(self, end_point, data):
Thread.__init__(self)
self.end_point = end_point
self.data = data
self.ret = None
def run(self):
print("Requesting with url", self.end_point)
self.ret = requests.post(url=self.end_point, json=self.data)
def join(self):
Thread.join(self)
return self.ret.text
class HandleCORS(object):
def process_request(self, req, resp):
resp.set_header('Access-Control-Allow-Origin', '*')
resp.set_header('Access-Control-Allow-Methods', '*')
resp.set_header('Access-Control-Allow-Headers', '*')
if req.method == 'OPTIONS':
raise HTTPStatus(falcon.HTTP_200, body='\n')
autocomplete = AutoComplete()
app = falcon.API(middleware=[HandleCORS()])
app.add_route('/autocomplete', autocomplete)
app.add_route('/autocomplete/{x}', autocomplete)
app.add_route('/autocomplete/{x}/{y}', autocomplete)
application = app