from threading import Thread import falcon from falcon.http_status import HTTPStatus import json import requests import time from Model import generate_completion import sys class AutoComplete(object): def on_post(self, req, resp, single_endpoint=True, x=None, y=None): json_data = json.loads(req.bounded_stream.read()) resp.status = falcon.HTTP_200 start = time.time() try: context = json_data["context"].rstrip() except KeyError: resp.body = "The context field is required" resp.status = falcon.HTTP_422 return try: n_samples = json_data['samples'] except KeyError: n_samples = 3 try: length = json_data['gen_length'] except KeyError: length = 20 try: max_time = json_data['max_time'] except KeyError: max_time = -1 try: model_name = json_data['model_size'] except KeyError: model_name = "small" try: temperature = json_data['temperature'] except KeyError: temperature = 0.7 try: max_tokens = json_data['max_tokens'] except KeyError: max_tokens = 256 try: top_p = json_data['top_p'] except KeyError: top_p = 0.95 try: top_k = json_data['top_k'] except KeyError: top_k = 40 # CTRL try: repetition_penalty = json_data['repetition_penalty'] except KeyError: repetition_penalty = 0.02 # PPLM try: stepsize = json_data['step_size'] except KeyError: stepsize = 0.02 try: gm_scale = json_data['gm_scale'] except KeyError: gm_scale = None try: kl_scale = json_data['kl_scale'] except KeyError: kl_scale = None try: num_iterations = json_data['num_iterations'] except KeyError: num_iterations = None try: use_sampling = json_data['use_sampling'] except KeyError: use_sampling = None try: bag_of_words_or_discrim = json_data['bow_or_discrim'] except KeyError: bag_of_words_or_discrim = "kitchen" print(json_data) sentences = generate_completion( context, length=length, max_time=max_time, model_name=model_name, temperature=temperature, max_tokens=max_tokens, top_p=top_p, top_k=top_k, # CTRL repetition_penalty=repetition_penalty, # PPLM stepsize=stepsize, bag_of_words_or_discrim=bag_of_words_or_discrim, gm_scale=gm_scale, kl_scale=kl_scale, num_iterations=num_iterations, use_sampling=use_sampling ) resp.body = json.dumps({"sentences": sentences, 'time': time.time() - start}) resp.status = falcon.HTTP_200 sys.stdout.flush() class Request(Thread): def __init__(self, end_point, data): Thread.__init__(self) self.end_point = end_point self.data = data self.ret = None def run(self): print("Requesting with url", self.end_point) self.ret = requests.post(url=self.end_point, json=self.data) def join(self): Thread.join(self) return self.ret.text class HandleCORS(object): def process_request(self, req, resp): resp.set_header('Access-Control-Allow-Origin', '*') resp.set_header('Access-Control-Allow-Methods', '*') resp.set_header('Access-Control-Allow-Headers', '*') if req.method == 'OPTIONS': raise HTTPStatus(falcon.HTTP_200, body='\n') autocomplete = AutoComplete() app = falcon.API(middleware=[HandleCORS()]) app.add_route('/autocomplete', autocomplete) app.add_route('/autocomplete/{x}', autocomplete) app.add_route('/autocomplete/{x}/{y}', autocomplete) application = app