Daniel Fried
initial commit
05164e0
raw history blame
No virus
5.37 kB
import sys
from typing import List
import traceback
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
# from flask import Flask, request, render_template
# from flask_cors import CORS
# app = Flask(__name__, static_folder='static')
# app.config['TEMPLATES_AUTO_RELOAD'] = True
# CORS(app, resources= {
# r"/generate": {"origins": origins},
# r"/infill": {"origins": origins},
# })
# origins=[f"http://localhost:{PORT}", "https://huggingface.co", "https://hf.space"]
CUDA = True
PORT = 7860
VERBOSE = False
from fastapi import FastAPI, Request
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse, StreamingResponse
app = FastAPI(docs_url=None, redoc_url=None)
app.mount("/static", StaticFiles(directory="static"), name="static")
print("loading model")
model = AutoModelForCausalLM.from_pretrained("facebook/incoder-6B")
print("loading tokenizer")
tokenizer = AutoTokenizer.from_pretrained("facebook/incoder-6B")
print("loading complete")
if CUDA:
model = model.half().cuda()
BOS = "<|endoftext|>"
EOM = "<|endofmask|>"
def make_sentinel(i):
return f"<|mask:{i}|>"
SPECIAL_TOKENS = [make_sentinel(i) for i in range(256)] + [EOM]
def generate(input, length_limit=None, temperature=None):
input_ids = tokenizer(input, return_tensors="pt").input_ids
if CUDA:
input_ids = input_ids.cuda()
output = model.generate(input_ids=input_ids, do_sample=True, top_p=0.95, temperature=temperature, max_length=length_limit)
detok_hypo_str = tokenizer.decode(output.flatten())
if detok_hypo_str.startswith(BOS):
detok_hypo_str = detok_hypo_str[len(BOS):]
return detok_hypo_str
def infill(parts: List[str], length_limit=None, temperature=None, extra_sentinel=False, max_retries=1):
assert isinstance(parts, list)
retries_attempted = 0
done = False
while (not done) and (retries_attempted < max_retries):
retries_attempted += 1
if VERBOSE:
print(f"retry {retries_attempted}")
if len(parts) == 1:
prompt = parts[0]
else:
prompt = ""
# encode parts separated by sentinel
for sentinel_ix, part in enumerate(parts):
prompt += part
if extra_sentinel or (sentinel_ix < len(parts) - 1):
prompt += make_sentinel(sentinel_ix)
# prompt += TokenizerWrapper.make_sentinel(0)
infills = []
complete = []
done = True
for sentinel_ix, part in enumerate(parts[:-1]):
complete.append(part)
prompt += make_sentinel(sentinel_ix)
completion = generate(prompt, length_limit, temperature)
completion = completion[len(prompt):]
if EOM not in completion:
if VERBOSE:
print(f"warning: {EOM} not found")
completion += EOM
# TODO: break inner loop here
done = False
completion = completion[:completion.index(EOM) + len(EOM)]
infilled = completion[:-len(EOM)]
infills.append(infilled)
complete.append(infilled)
prompt += completion
complete.append(parts[-1])
text = ''.join(complete)
if VERBOSE:
print("generated text:")
print(prompt)
print()
print("parts:")
print(parts)
print()
print("infills:")
print(infills)
print()
print("restitched text:")
print(text)
print()
return {
'text': text,
'parts': parts,
'infills': infills,
'retries_attempted': retries_attempted,
}
@app.head("/")
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="static/index.html", media_type="text/html")
@app.get('/generate')
async def generate_maybe(info: str):
# form = await info.json()
form = json.loads(info)
prompt = form['prompt']
length_limit = int(form['length'])
temperature = float(form['temperature'])
if VERBOSE:
print(prompt)
try:
generation = generate(prompt, length_limit, temperature)
return {'result': 'success', 'type': 'generate', 'prompt': prompt, 'text': generation}
except Exception as e:
traceback.print_exception(*sys.exc_info())
return {'result': 'error', 'type': 'generate', 'prompt': prompt, 'text': f'There was an error: {e}. Tell Daniel.'}
@app.get('/infill')
async def infill_maybe(info: str):
# form = await info.json()
form = json.loads(info)
length_limit = int(form['length'])
temperature = float(form['temperature'])
max_retries = 1
extra_sentinel = True
try:
generation = infill(form['parts'], length_limit, temperature, extra_sentinel=extra_sentinel, max_retries=max_retries)
generation['result'] = 'success'
generation['type'] = 'infill'
return generation
# return {'result': 'success', 'prefix': prefix, 'suffix': suffix, 'text': generation['text']}
except Exception as e:
traceback.print_exception(*sys.exc_info())
print(e)
return {'result': 'error', 'type': 'infill', 'text': f'There was an error: {e}.'}
if __name__ == "__main__":
app.run(host='0.0.0.0', port=PORT, threaded=False)