bigcode-editor / app.py
Daniel Fried
better single <infill> error handling
a2fc40d
raw
history blame
5.54 kB
import sys
from typing import List
import traceback
import os
import base64
import json
import pprint
from huggingface_hub import Repository
from text_generation import Client
PORT = 7860
# TODO: implement maximum length (currently, each iteration is limited by the slider-specified max length, but this can be iterated, or long code entered into the editor, to get really long documents
# if os.path.exists('unlock'):
# # create an 'unlock' file (not checked into Git) locally to get full context lengths
# MAX_LENGTH = 8192
# else:
# # set to a shorter value to prevent long contexts and make the demo more efficient
# MAX_LENGTH = 1024
# TRUNCATION_MESSAGE = f'warning: This demo is limited to {MAX_LENGTH} tokens in the document for efficiency.'
TRUNCATION_MESSAGE = f'TODO'
HF_TOKEN = os.environ.get("HF_TOKEN", None)
API_URL = os.environ.get("API_URL")
with open("./HHH_prompt.txt", "r") as f:
HHH_PROMPT = f.read() + "\n\n"
# used by the model
FIM_PREFIX = "<fim_prefix>"
FIM_MIDDLE = "<fim_middle>"
FIM_SUFFIX = "<fim_suffix>"
END_OF_TEXT = "<|endoftext|>"
# used to mark infill locations in the editor
FIM_INDICATOR = "<infill>"
client = Client(
API_URL, headers={"Authorization": f"Bearer {HF_TOKEN}"},
)
from fastapi import FastAPI, Request
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse, StreamingResponse
app = FastAPI(docs_url=None, redoc_url=None)
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.head("/")
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="static/index.html", media_type="text/html")
def generate(prefix, suffix=None, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
# TODO: deduplicate code between this and `infill`
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
)
fim_mode = suffix is not None
if suffix is not None:
prompt = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"
else:
prompt = prefix
output = client.generate(prompt, **generate_kwargs)
generated_text = output.generated_text
# TODO: set this based on stop reason from client.generate
truncated = False
while generated_text.endswith(END_OF_TEXT):
generated_text = generated_text[:-len(END_OF_TEXT)]
generation = {
'truncated': truncated,
}
if fim_mode:
generation['type'] = 'infill'
generation['text'] = prefix + generated_text + suffix
generation['parts'] = [prefix, suffix]
generation['infills'] = [generated_text]
else:
generation['type'] = 'generate'
generation['text'] = prompt + generated_text
generation['parts'] = [prompt]
return generation
@app.get('/generate')
async def generate_maybe(info: str):
# info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues)
# fix padding, following https://stackoverflow.com/a/9956217/1319683
info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8')
form = json.loads(info)
prompt = form['prompt']
length_limit = int(form['length'])
temperature = float(form['temperature'])
try:
generation = generate(prompt, temperature=temperature, max_new_tokens=length_limit, top_p=0.95, repetition_penalty=1.0)
if generation['truncated']:
message = TRUNCATION_MESSAGE
else:
message = ''
return {'result': 'success', 'type': 'generate', 'prompt': prompt, 'text': generation['text'], 'message': message}
except Exception as e:
traceback.print_exception(*sys.exc_info())
return {'result': 'error', 'type': 'generate', 'prompt': prompt, 'message': f'Error: {e}.'}
@app.get('/infill')
async def infill_maybe(info: str):
# info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues)
# fix padding, following https://stackoverflow.com/a/9956217/1319683
info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8')
form = json.loads(info)
length_limit = int(form['length'])
temperature = float(form['temperature'])
try:
if len(form['parts']) > 2:
return {'result': 'error', 'text': ''.join(form['parts']), 'type': 'infill', 'message': f"error: Only a single <infill> token is supported!"}
elif len(form['parts']) == 1:
return {'result': 'error', 'text': ''.join(form['parts']), 'type': 'infill', 'message': f"error: Must have an <infill> token present!"}
prefix, suffix = form['parts']
generation = generate(prefix, suffix=suffix, temperature=temperature, max_new_tokens=length_limit, top_p=0.95, repetition_penalty=1.0)
generation['result'] = 'success'
if generation['truncated']:
generation['message'] = TRUNCATION_MESSAGE
else:
generation['message'] = ''
return generation
except Exception as e:
traceback.print_exception(*sys.exc_info())
return {'result': 'error', 'type': 'infill', 'message': f'Error: {e}.'}
if __name__ == "__main__":
app.run(host='0.0.0.0', port=PORT, threaded=False)