File size: 8,099 Bytes
05164e0
 
 
003db9a
44efa8c
51676cf
 
 
 
 
bd92b34
 
 
 
 
 
488ca72
bd92b34
fb51e42
460cb94
003db9a
537df60
 
05164e0
 
 
 
0b73ae7
05164e0
 
 
 
 
 
 
 
 
3443876
 
 
 
0b73ae7
 
cecb224
460cb94
44efa8c
 
 
 
 
5b5750d
 
44efa8c
5b5750d
05164e0
 
 
 
 
 
 
51676cf
44efa8c
51676cf
5b5750d
51676cf
05164e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b73ae7
 
 
 
 
 
 
 
fb51e42
05164e0
 
 
0b73ae7
05164e0
 
 
 
 
 
0b73ae7
05164e0
0b73ae7
05164e0
 
51676cf
05164e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b73ae7
 
05164e0
 
 
51676cf
05164e0
 
 
 
 
 
 
 
 
 
 
 
51676cf
 
 
 
 
 
 
 
 
 
 
 
05164e0
 
 
 
 
 
0b73ae7
05164e0
 
 
 
 
 
 
 
 
44efa8c
05164e0
 
44efa8c
 
 
 
05164e0
44efa8c
05164e0
 
 
51676cf
 
 
 
 
05164e0
0b73ae7
 
 
 
 
 
05164e0
 
51676cf
0b73ae7
05164e0
 
44efa8c
05164e0
 
44efa8c
 
 
 
05164e0
 
 
 
 
51676cf
 
 
 
 
05164e0
29e0eba
0b6b3f9
05164e0
 
 
0b73ae7
 
 
 
05164e0
 
 
 
51676cf
0b73ae7
05164e0
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
import sys
from typing import List
import traceback
import os
import base64

import logging
logging.basicConfig(level=logging.INFO)
import modules.cloud_logging

import tokenizers
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import pprint

# needs to be imported *before* transformers
if os.path.exists('debug'):
    BIG_MODEL = False
    CUDA = False
else:
    BIG_MODEL = True
    CUDA = True

# from flask import Flask, request, render_template
# from flask_cors import CORS
# app = Flask(__name__, static_folder='static')
# app.config['TEMPLATES_AUTO_RELOAD'] = Tru
# CORS(app, resources= {
#     r"/generate": {"origins": origins},
#     r"/infill": {"origins": origins},
# })
# origins=[f"http://localhost:{PORT}", "https://huggingface.co", "https://hf.space"]

PORT = 7860
VERBOSE = False

if os.path.exists('unlock'):
    MAX_LENGTH = 2048
else:
    MAX_LENGTH = 256+64
TRUNCATION_MESSAGE = f'warning: This demo is limited to {MAX_LENGTH} tokens in the document for efficiency.'

if BIG_MODEL:
    model_name = "facebook/incoder-6B"
    kwargs = dict(
        revision="float16",
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
    )
else:
    model_name = "facebook/incoder-1B"
    kwargs = dict()

from fastapi import FastAPI, Request
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse, StreamingResponse
app = FastAPI(docs_url=None, redoc_url=None)
app.mount("/static", StaticFiles(directory="static"), name="static")


logging.info("loading model")
model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
logging.info("loading tokenizer")
tokenizer = AutoTokenizer.from_pretrained(model_name)
logging.info("loading complete")

if CUDA:
    model = model.half().cuda()

BOS = "<|endoftext|>"
EOM = "<|endofmask|>"

def make_sentinel(i):
    return f"<|mask:{i}|>"

SPECIAL_TOKENS = [make_sentinel(i) for i in range(256)] + [EOM]

def generate(input, length_limit=None, temperature=None):
    input_ids = tokenizer(input, return_tensors="pt").input_ids
    if CUDA:
        input_ids = input_ids.cuda()
    current_length = input_ids.flatten().size(0)
    max_length = length_limit + current_length
    truncated = False
    if max_length > MAX_LENGTH:
        max_length = MAX_LENGTH
        truncated = True
    if max_length == current_length:
        return input, True
    output = model.generate(input_ids=input_ids, do_sample=True, top_p=0.95, temperature=temperature, max_length=max_length)
    detok_hypo_str = tokenizer.decode(output.flatten())
    if detok_hypo_str.startswith(BOS):
        detok_hypo_str = detok_hypo_str[len(BOS):]
    return detok_hypo_str, truncated

def infill(parts: List[str], length_limit=None, temperature=None, extra_sentinel=False, max_retries=1):
    assert isinstance(parts, list)
    retries_attempted = 0
    done = False


    while (not done) and (retries_attempted < max_retries):
        any_truncated = False
        retries_attempted += 1
        if VERBOSE:
            logging.info(f"retry {retries_attempted}")
        if len(parts) == 1:
            prompt = parts[0]
        else:
            prompt = ""
            # encode parts separated by sentinel
            for sentinel_ix, part in enumerate(parts):
                prompt += part
                if extra_sentinel or (sentinel_ix < len(parts) - 1):
                    prompt += make_sentinel(sentinel_ix)
            
            # prompt += TokenizerWrapper.make_sentinel(0)
        
        infills = []
        complete = []

        done = True

        for sentinel_ix, part in enumerate(parts[:-1]):
            complete.append(part)
            prompt += make_sentinel(sentinel_ix)
            completion, this_truncated = generate(prompt, length_limit, temperature)
            any_truncated |= this_truncated
            completion = completion[len(prompt):]
            if EOM not in completion:
                if VERBOSE:
                    logging.info(f"warning: {EOM} not found")
                completion += EOM
                # TODO: break inner loop here
                done = False
            completion = completion[:completion.index(EOM) + len(EOM)]
            infilled = completion[:-len(EOM)]
            infills.append(infilled)
            complete.append(infilled)
            prompt += completion
        complete.append(parts[-1])
        text = ''.join(complete)

    if VERBOSE:
        logging.info("generated text:")
        logging.info(prompt)
        logging.info()
        logging.info("parts:")
        logging.info(parts)
        logging.info()
        logging.info("infills:")
        logging.info(infills)
        logging.info()
        logging.info("restitched text:")
        logging.info(text)
        logging.info()
    
    return {
        'text': text,
        'parts': parts,
        'infills': infills,
        'retries_attempted': retries_attempted,
        'truncated': any_truncated,
    } 


@app.head("/")
@app.get("/")
def index() -> FileResponse:
    return FileResponse(path="static/index.html", media_type="text/html")

@app.get('/generate')
# async def generate_maybe(request: Request):
async def generate_maybe(info: str):
    # form = await info.json()
    # form = await request.json()
    # info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues)
    # fix padding, following https://stackoverflow.com/a/9956217/1319683
    info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8')
    form = json.loads(info)
    # print(form)
    prompt = form['prompt']
    length_limit = int(form['length'])
    temperature = float(form['temperature'])
    logging.info(json.dumps({
        'length': length_limit,
        'temperature': temperature,
        'prompt': prompt,
    }))
    try:
        generation, truncated = generate(prompt, length_limit, temperature)
        if truncated:
            message = TRUNCATION_MESSAGE 
        else:
            message = ''
        return {'result': 'success', 'type': 'generate', 'prompt': prompt, 'text': generation, 'message': message}
    except Exception as e:
        traceback.print_exception(*sys.exc_info())
        logging.error(e)
        return {'result': 'error', 'type': 'generate', 'prompt': prompt, 'message': f'Error: {e}.'}

@app.get('/infill')
# async def infill_maybe(request: Request):
async def infill_maybe(info: str):
    # form = await info.json()
    # form = await request.json()
    # info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues)
    # fix padding, following https://stackoverflow.com/a/9956217/1319683
    info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8')
    form = json.loads(info)
    length_limit = int(form['length'])
    temperature = float(form['temperature'])
    max_retries = 1
    extra_sentinel = True
    logging.info(json.dumps({
        'length': length_limit,
        'temperature': temperature,
        'parts_joined': '<infill>'.join(form['parts']),
    }))
    try:
        if len(form['parts']) > 4:
            return {'result': 'error', 'text': ''.join(form['parts']), 'type': 'infill', 'message': f"error: Can't use more than 3 <infill> tokens in this demo (for efficiency)."}
        generation = infill(form['parts'], length_limit, temperature, extra_sentinel=extra_sentinel, max_retries=max_retries)
        generation['result'] = 'success'
        generation['type'] = 'infill'
        if generation['truncated']:
            generation['message'] = TRUNCATION_MESSAGE
        else:
            generation['message'] = ''
        return generation
        # return {'result': 'success', 'prefix': prefix, 'suffix': suffix,  'text': generation['text']}
    except Exception as e:
        traceback.print_exception(*sys.exc_info())
        logging.error(e)
        return {'result': 'error', 'type': 'infill', 'message': f'Error: {e}.'}


if __name__ == "__main__":
    app.run(host='0.0.0.0', port=PORT, threaded=False)