Spaces:
Sleeping
Sleeping
import sys | |
from typing import List | |
import traceback | |
import os | |
import base64 | |
import logging | |
logging.basicConfig(level=logging.INFO) | |
import tokenizers | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import json | |
import pprint | |
# needs to be imported *before* transformers | |
if os.path.exists('debug'): | |
BIG_MODEL = False | |
CUDA = False | |
else: | |
BIG_MODEL = True | |
CUDA = True | |
# from flask import Flask, request, render_template | |
# from flask_cors import CORS | |
# app = Flask(__name__, static_folder='static') | |
# app.config['TEMPLATES_AUTO_RELOAD'] = Tru | |
# CORS(app, resources= { | |
# r"/generate": {"origins": origins}, | |
# r"/infill": {"origins": origins}, | |
# }) | |
# origins=[f"http://localhost:{PORT}", "https://huggingface.co", "https://hf.space"] | |
PORT = 7860 | |
VERBOSE = False | |
if os.path.exists('unlock'): | |
MAX_LENGTH = 2048 | |
else: | |
MAX_LENGTH = 256+64 | |
TRUNCATION_MESSAGE = f'warning: This demo is limited to {MAX_LENGTH} tokens in the document for efficiency.' | |
if BIG_MODEL: | |
model_name = "facebook/incoder-6B" | |
kwargs = dict( | |
revision="float16", | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True, | |
) | |
else: | |
model_name = "facebook/incoder-1B" | |
kwargs = dict() | |
from fastapi import FastAPI, Request | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.responses import FileResponse, StreamingResponse | |
app = FastAPI(docs_url=None, redoc_url=None) | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
logging.info("loading model") | |
model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs) | |
logging.info("loading tokenizer") | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
logging.info("loading complete") | |
if CUDA: | |
model = model.half().cuda() | |
BOS = "<|endoftext|>" | |
EOM = "<|endofmask|>" | |
def make_sentinel(i): | |
return f"<|mask:{i}|>" | |
SPECIAL_TOKENS = [make_sentinel(i) for i in range(256)] + [EOM] | |
def generate(input, length_limit=None, temperature=None): | |
input_ids = tokenizer(input, return_tensors="pt").input_ids | |
if CUDA: | |
input_ids = input_ids.cuda() | |
current_length = input_ids.flatten().size(0) | |
max_length = length_limit + current_length | |
truncated = False | |
if max_length > MAX_LENGTH: | |
max_length = MAX_LENGTH | |
truncated = True | |
if max_length == current_length: | |
return input, True | |
output = model.generate(input_ids=input_ids, do_sample=True, top_p=0.95, temperature=temperature, max_length=max_length) | |
detok_hypo_str = tokenizer.decode(output.flatten()) | |
if detok_hypo_str.startswith(BOS): | |
detok_hypo_str = detok_hypo_str[len(BOS):] | |
return detok_hypo_str, truncated | |
def infill(parts: List[str], length_limit=None, temperature=None, extra_sentinel=False, max_retries=1): | |
assert isinstance(parts, list) | |
retries_attempted = 0 | |
done = False | |
while (not done) and (retries_attempted < max_retries): | |
any_truncated = False | |
retries_attempted += 1 | |
if VERBOSE: | |
logging.info(f"retry {retries_attempted}") | |
if len(parts) == 1: | |
prompt = parts[0] | |
else: | |
prompt = "" | |
# encode parts separated by sentinel | |
for sentinel_ix, part in enumerate(parts): | |
prompt += part | |
if extra_sentinel or (sentinel_ix < len(parts) - 1): | |
prompt += make_sentinel(sentinel_ix) | |
# prompt += TokenizerWrapper.make_sentinel(0) | |
infills = [] | |
complete = [] | |
done = True | |
for sentinel_ix, part in enumerate(parts[:-1]): | |
complete.append(part) | |
prompt += make_sentinel(sentinel_ix) | |
completion, this_truncated = generate(prompt, length_limit, temperature) | |
any_truncated |= this_truncated | |
completion = completion[len(prompt):] | |
if EOM not in completion: | |
if VERBOSE: | |
logging.info(f"warning: {EOM} not found") | |
completion += EOM | |
# TODO: break inner loop here | |
done = False | |
completion = completion[:completion.index(EOM) + len(EOM)] | |
infilled = completion[:-len(EOM)] | |
infills.append(infilled) | |
complete.append(infilled) | |
prompt += completion | |
complete.append(parts[-1]) | |
text = ''.join(complete) | |
if VERBOSE: | |
logging.info("generated text:") | |
logging.info(prompt) | |
logging.info() | |
logging.info("parts:") | |
logging.info(parts) | |
logging.info() | |
logging.info("infills:") | |
logging.info(infills) | |
logging.info() | |
logging.info("restitched text:") | |
logging.info(text) | |
logging.info() | |
return { | |
'text': text, | |
'parts': parts, | |
'infills': infills, | |
'retries_attempted': retries_attempted, | |
'truncated': any_truncated, | |
} | |
def index() -> FileResponse: | |
return FileResponse(path="static/index.html", media_type="text/html") | |
# async def generate_maybe(request: Request): | |
async def generate_maybe(info: str): | |
# form = await info.json() | |
# form = await request.json() | |
# info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues) | |
# fix padding, following https://stackoverflow.com/a/9956217/1319683 | |
info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8') | |
form = json.loads(info) | |
# print(form) | |
prompt = form['prompt'] | |
length_limit = int(form['length']) | |
temperature = float(form['temperature']) | |
logging.info(json.dumps({ | |
'length': length_limit, | |
'temperature': temperature, | |
'prompt': prompt, | |
})) | |
try: | |
generation, truncated = generate(prompt, length_limit, temperature) | |
if truncated: | |
message = TRUNCATION_MESSAGE | |
else: | |
message = '' | |
return {'result': 'success', 'type': 'generate', 'prompt': prompt, 'text': generation, 'message': message} | |
except Exception as e: | |
traceback.print_exception(*sys.exc_info()) | |
logging.error(e) | |
return {'result': 'error', 'type': 'generate', 'prompt': prompt, 'message': f'Error: {e}.'} | |
# async def infill_maybe(request: Request): | |
async def infill_maybe(info: str): | |
# form = await info.json() | |
# form = await request.json() | |
# info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues) | |
# fix padding, following https://stackoverflow.com/a/9956217/1319683 | |
info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8') | |
form = json.loads(info) | |
length_limit = int(form['length']) | |
temperature = float(form['temperature']) | |
max_retries = 1 | |
extra_sentinel = True | |
logging.info(json.dumps({ | |
'length': length_limit, | |
'temperature': temperature, | |
'parts_joined': '<infill>'.join(form['parts']), | |
})) | |
try: | |
if len(form['parts']) > 4: | |
return {'result': 'error', 'text': ''.join(form['parts']), 'type': 'infill', 'message': f"error: Can't use more than 3 <infill> tokens in this demo (for efficiency)."} | |
generation = infill(form['parts'], length_limit, temperature, extra_sentinel=extra_sentinel, max_retries=max_retries) | |
generation['result'] = 'success' | |
generation['type'] = 'infill' | |
if generation['truncated']: | |
generation['message'] = TRUNCATION_MESSAGE | |
else: | |
generation['message'] = '' | |
return generation | |
# return {'result': 'success', 'prefix': prefix, 'suffix': suffix, 'text': generation['text']} | |
except Exception as e: | |
traceback.print_exception(*sys.exc_info()) | |
logging.error(e) | |
return {'result': 'error', 'type': 'infill', 'message': f'Error: {e}.'} | |
if __name__ == "__main__": | |
app.run(host='0.0.0.0', port=PORT, threaded=False) | |