File size: 3,820 Bytes
3400791
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import argparse
import json
import numpy as np
import string
import re
import requests
from distutils.util import strtobool

import tornado
from tornado.web import Application, RequestHandler, StaticFileHandler, HTTPServer


def str2title(str):

    str = string.capwords(str)
    str = str.replace(' - - - ', ' — ')
    str = str.replace(' - - ', ' – ')
    str = str.replace('( ', '(')
    str = str.replace(' )', ')')
    str = re.sub(r'(\w)\s+-\s+(\w)', r'\1-\2', str)
    str = re.sub(r'(\w|")\s+:', r'\1:', str)
    str = re.sub(r'"\s+([^"]+)\s+"', r'"\1"', str)
    return str


class GenerateTitleHandler(RequestHandler):

    def initialize(self, api_key, model, default_num_titles=5):

        self.api_key = api_key
        self.model = model
        self.default_num_titles = default_num_titles


    def post(self):

        abstract = self.get_body_argument('abstract')
        temperature = max(1.0, float(self.get_body_argument('temperature', 1.5)))
        num_titles = min(20, max(1, int(self.get_body_argument('num_titles', self.default_num_titles))))
        blocking = bool(strtobool(self.get_body_argument('blocking', True)))
        
        response = self.query_api(
            abstract,
            wait=blocking,
            do_sample=(temperature > 1),
            num_beams=10,
            temperature=temperature,
            top_k=50,
            no_repeat_ngram_size=2,
            num_return_sequences=num_titles
        )

        result = { 'titles' : [] }
        if isinstance(response, dict) and ('error' in response):
            # 'error' : 'Model Callidior/bert2bert-base-arxiv-titlegen is currently loading'
            # 'estimated_time': 19.793826919999997
            result['error'] = response['error']
        else:
            result['titles'] = [str2title(title['summary_text']) for title in response]

        self.write(json.dumps(result))


    def query_api(self, inputs, cache=False, wait=False, **kwargs):

        data = json.dumps({
            'inputs' : inputs,
            'parameters' : kwargs,
            'options' : { 'use_cache' : cache, 'wait_for_model' : wait }
        })
        api_url = "https://api-inference.huggingface.co/models/" + self.model
        headers = { "Authorization": f"Bearer {self.api_key}" }
        response = requests.request("POST", api_url, headers=headers, data=data)
        return json.loads(response.content.decode("utf-8"))


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Web interface for BERT2BERT paper title generation.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('api_key', type=str, help='Key for the Hugging Face Hosted Inference API.')
    parser.add_argument('model', type=str, help='Name of the hosted-inference model on the Hugging Face model repository.')
    parser.add_argument('--default-num', type=int, default=5, help='Default number of generated titles per request.')
    parser.add_argument('--port', type=int, default=8080, help='Webserver port.')
    parser.add_argument('--num-proc', type=int, default=1, help='Number of concurrent server processes.')
    parser.add_argument('--debug', action='store_true', default=False, help='Enable debugging mode.')
    args = parser.parse_args()
    
    app = Application([
        (r'/title', GenerateTitleHandler, {
            'api_key' : args.api_key,
            'model' : args.model,
            'default_num_titles' : args.default_num
        }),
        (r'/(.*)', StaticFileHandler, { 'path' : 'web', 'default_filename' : 'index.html' })
    ], debug=args.debug)
    if args.num_proc != 1:
        server = HTTPServer(app)
        server.listen(args.port)
        server.start(args.num_proc)
    else:
        app.listen(args.port)
    tornado.ioloop.IOLoop.current().start()