import sys from typing import List import traceback from transformers import AutoModelForCausalLM, AutoTokenizer import json # from flask import Flask, request, render_template # from flask_cors import CORS # app = Flask(__name__, static_folder='static') # app.config['TEMPLATES_AUTO_RELOAD'] = True # CORS(app, resources= { # r"/generate": {"origins": origins}, # r"/infill": {"origins": origins}, # }) # origins=[f"http://localhost:{PORT}", "https://huggingface.co", "https://hf.space"] CUDA = True PORT = 7860 VERBOSE = False from fastapi import FastAPI, Request from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse, StreamingResponse app = FastAPI(docs_url=None, redoc_url=None) app.mount("/static", StaticFiles(directory="static"), name="static") print("loading model") model = AutoModelForCausalLM.from_pretrained("facebook/incoder-6B") print("loading tokenizer") tokenizer = AutoTokenizer.from_pretrained("facebook/incoder-6B") print("loading complete") if CUDA: model = model.half().cuda() BOS = "<|endoftext|>" EOM = "<|endofmask|>" def make_sentinel(i): return f"<|mask:{i}|>" SPECIAL_TOKENS = [make_sentinel(i) for i in range(256)] + [EOM] def generate(input, length_limit=None, temperature=None): input_ids = tokenizer(input, return_tensors="pt").input_ids if CUDA: input_ids = input_ids.cuda() output = model.generate(input_ids=input_ids, do_sample=True, top_p=0.95, temperature=temperature, max_length=length_limit) detok_hypo_str = tokenizer.decode(output.flatten()) if detok_hypo_str.startswith(BOS): detok_hypo_str = detok_hypo_str[len(BOS):] return detok_hypo_str def infill(parts: List[str], length_limit=None, temperature=None, extra_sentinel=False, max_retries=1): assert isinstance(parts, list) retries_attempted = 0 done = False while (not done) and (retries_attempted < max_retries): retries_attempted += 1 if VERBOSE: print(f"retry {retries_attempted}") if len(parts) == 1: prompt = parts[0] else: prompt = "" # encode parts separated by sentinel for sentinel_ix, part in enumerate(parts): prompt += part if extra_sentinel or (sentinel_ix < len(parts) - 1): prompt += make_sentinel(sentinel_ix) # prompt += TokenizerWrapper.make_sentinel(0) infills = [] complete = [] done = True for sentinel_ix, part in enumerate(parts[:-1]): complete.append(part) prompt += make_sentinel(sentinel_ix) completion = generate(prompt, length_limit, temperature) completion = completion[len(prompt):] if EOM not in completion: if VERBOSE: print(f"warning: {EOM} not found") completion += EOM # TODO: break inner loop here done = False completion = completion[:completion.index(EOM) + len(EOM)] infilled = completion[:-len(EOM)] infills.append(infilled) complete.append(infilled) prompt += completion complete.append(parts[-1]) text = ''.join(complete) if VERBOSE: print("generated text:") print(prompt) print() print("parts:") print(parts) print() print("infills:") print(infills) print() print("restitched text:") print(text) print() return { 'text': text, 'parts': parts, 'infills': infills, 'retries_attempted': retries_attempted, } @app.head("/") @app.get("/") def index() -> FileResponse: return FileResponse(path="static/index.html", media_type="text/html") @app.get('/generate') async def generate_maybe(info: str): # form = await info.json() form = json.loads(info) prompt = form['prompt'] length_limit = int(form['length']) temperature = float(form['temperature']) if VERBOSE: print(prompt) try: generation = generate(prompt, length_limit, temperature) return {'result': 'success', 'type': 'generate', 'prompt': prompt, 'text': generation} except Exception as e: traceback.print_exception(*sys.exc_info()) return {'result': 'error', 'type': 'generate', 'prompt': prompt, 'text': f'There was an error: {e}. Tell Daniel.'} @app.get('/infill') async def infill_maybe(info: str): # form = await info.json() form = json.loads(info) length_limit = int(form['length']) temperature = float(form['temperature']) max_retries = 1 extra_sentinel = True try: generation = infill(form['parts'], length_limit, temperature, extra_sentinel=extra_sentinel, max_retries=max_retries) generation['result'] = 'success' generation['type'] = 'infill' return generation # return {'result': 'success', 'prefix': prefix, 'suffix': suffix, 'text': generation['text']} except Exception as e: traceback.print_exception(*sys.exc_info()) print(e) return {'result': 'error', 'type': 'infill', 'text': f'There was an error: {e}.'} if __name__ == "__main__": app.run(host='0.0.0.0', port=PORT, threaded=False)