bigcode-editor / app.py
Daniel Fried
decruft and add some comments
13f28f2
raw
history blame
5.35 kB
import sys
from typing import List
import traceback
import os
import base64
import json
import pprint
from huggingface_hub import Repository
from text_generation import Client
PORT = 7860
# TODO: implement maximum length (currently, each iteration is limited by the slider-specified max length, but this can be iterated, or long code entered into the editor, to get really long documents
# if os.path.exists('unlock'):
# # create an 'unlock' file (not checked into Git) locally to get full context lengths
# MAX_LENGTH = 8192
# else:
# # set to a shorter value to prevent long contexts and make the demo more efficient
# MAX_LENGTH = 1024
# TRUNCATION_MESSAGE = f'warning: This demo is limited to {MAX_LENGTH} tokens in the document for efficiency.'
TRUNCATION_MESSAGE = f'TODO'
HF_TOKEN = os.environ.get("HF_TOKEN", None)
API_URL = os.environ.get("API_URL")
with open("./HHH_prompt.txt", "r") as f:
HHH_PROMPT = f.read() + "\n\n"
# used by the model
FIM_PREFIX = "<fim_prefix>"
FIM_MIDDLE = "<fim_middle>"
FIM_SUFFIX = "<fim_suffix>"
END_OF_TEXT = "<|endoftext|>"
# used to mark infill locations in the editor
FIM_INDICATOR = "<infill>"
client = Client(
API_URL, headers={"Authorization": f"Bearer {HF_TOKEN}"},
)
from fastapi import FastAPI, Request
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse, StreamingResponse
app = FastAPI(docs_url=None, redoc_url=None)
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.head("/")
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="static/index.html", media_type="text/html")
def generate(prefix, suffix=None, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
# TODO: deduplicate code between this and `infill`
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
)
fim_mode = suffix is not None
if suffix is not None:
prompt = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"
else:
prompt = prefix
output = client.generate(prompt, **generate_kwargs)
generated_text = output.generated_text
# TODO: set this based on stop reason from client.generate
truncated = False
while generated_text.endswith(END_OF_TEXT):
generated_text = generated_text[:-len(END_OF_TEXT)]
generation = {
'truncated': truncated,
}
if fim_mode:
generation['type'] = 'infill'
generation['text'] = prefix + generated_text + suffix
generation['parts'] = [prefix, suffix]
generation['infills'] = [generated_text]
else:
generation['type'] = 'generate'
generation['text'] = prompt + generated_text
generation['parts'] = [prompt]
return generation
@app.get('/generate')
async def generate_maybe(info: str):
# info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues)
# fix padding, following https://stackoverflow.com/a/9956217/1319683
info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8')
form = json.loads(info)
prompt = form['prompt']
length_limit = int(form['length'])
temperature = float(form['temperature'])
try:
generation = generate(prompt, temperature=temperature, max_new_tokens=length_limit, top_p=0.95, repetition_penalty=1.0)
if generation['truncated']:
message = TRUNCATION_MESSAGE
else:
message = ''
return {'result': 'success', 'type': 'generate', 'prompt': prompt, 'text': generation['text'], 'message': message}
except Exception as e:
traceback.print_exception(*sys.exc_info())
return {'result': 'error', 'type': 'generate', 'prompt': prompt, 'message': f'Error: {e}.'}
@app.get('/infill')
async def infill_maybe(info: str):
# info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues)
# fix padding, following https://stackoverflow.com/a/9956217/1319683
info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8')
form = json.loads(info)
length_limit = int(form['length'])
temperature = float(form['temperature'])
try:
if len(form['parts']) > 2:
return {'result': 'error', 'text': ''.join(form['parts']), 'type': 'infill', 'message': f"error: Only a single infill is supported!"}
prefix, suffix = form['parts']
generation = generate(prefix, suffix=suffix, temperature=temperature, max_new_tokens=length_limit, top_p=0.95, repetition_penalty=1.0)
generation['result'] = 'success'
if generation['truncated']:
generation['message'] = TRUNCATION_MESSAGE
else:
generation['message'] = ''
return generation
except Exception as e:
traceback.print_exception(*sys.exc_info())
return {'result': 'error', 'type': 'infill', 'message': f'Error: {e}.'}
if __name__ == "__main__":
app.run(host='0.0.0.0', port=PORT, threaded=False)