File size: 3,820 Bytes
3400791 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import argparse
import json
import numpy as np
import string
import re
import requests
from distutils.util import strtobool
import tornado
from tornado.web import Application, RequestHandler, StaticFileHandler, HTTPServer
def str2title(str):
str = string.capwords(str)
str = str.replace(' - - - ', ' — ')
str = str.replace(' - - ', ' – ')
str = str.replace('( ', '(')
str = str.replace(' )', ')')
str = re.sub(r'(\w)\s+-\s+(\w)', r'\1-\2', str)
str = re.sub(r'(\w|")\s+:', r'\1:', str)
str = re.sub(r'"\s+([^"]+)\s+"', r'"\1"', str)
return str
class GenerateTitleHandler(RequestHandler):
def initialize(self, api_key, model, default_num_titles=5):
self.api_key = api_key
self.model = model
self.default_num_titles = default_num_titles
def post(self):
abstract = self.get_body_argument('abstract')
temperature = max(1.0, float(self.get_body_argument('temperature', 1.5)))
num_titles = min(20, max(1, int(self.get_body_argument('num_titles', self.default_num_titles))))
blocking = bool(strtobool(self.get_body_argument('blocking', True)))
response = self.query_api(
abstract,
wait=blocking,
do_sample=(temperature > 1),
num_beams=10,
temperature=temperature,
top_k=50,
no_repeat_ngram_size=2,
num_return_sequences=num_titles
)
result = { 'titles' : [] }
if isinstance(response, dict) and ('error' in response):
# 'error' : 'Model Callidior/bert2bert-base-arxiv-titlegen is currently loading'
# 'estimated_time': 19.793826919999997
result['error'] = response['error']
else:
result['titles'] = [str2title(title['summary_text']) for title in response]
self.write(json.dumps(result))
def query_api(self, inputs, cache=False, wait=False, **kwargs):
data = json.dumps({
'inputs' : inputs,
'parameters' : kwargs,
'options' : { 'use_cache' : cache, 'wait_for_model' : wait }
})
api_url = "https://api-inference.huggingface.co/models/" + self.model
headers = { "Authorization": f"Bearer {self.api_key}" }
response = requests.request("POST", api_url, headers=headers, data=data)
return json.loads(response.content.decode("utf-8"))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Web interface for BERT2BERT paper title generation.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('api_key', type=str, help='Key for the Hugging Face Hosted Inference API.')
parser.add_argument('model', type=str, help='Name of the hosted-inference model on the Hugging Face model repository.')
parser.add_argument('--default-num', type=int, default=5, help='Default number of generated titles per request.')
parser.add_argument('--port', type=int, default=8080, help='Webserver port.')
parser.add_argument('--num-proc', type=int, default=1, help='Number of concurrent server processes.')
parser.add_argument('--debug', action='store_true', default=False, help='Enable debugging mode.')
args = parser.parse_args()
app = Application([
(r'/title', GenerateTitleHandler, {
'api_key' : args.api_key,
'model' : args.model,
'default_num_titles' : args.default_num
}),
(r'/(.*)', StaticFileHandler, { 'path' : 'web', 'default_filename' : 'index.html' })
], debug=args.debug)
if args.num_proc != 1:
server = HTTPServer(app)
server.listen(args.port)
server.start(args.num_proc)
else:
app.listen(args.port)
tornado.ioloop.IOLoop.current().start() |