Spaces:
Running
Running
File size: 5,541 Bytes
bac999f 13f28f2 bac999f 13f28f2 bac999f 13f28f2 bac999f 13f28f2 bac999f 13f28f2 bac999f 13f28f2 bac999f 13f28f2 bac999f 13f28f2 bac999f a2fc40d bac999f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import sys
from typing import List
import traceback
import os
import base64
import json
import pprint
from huggingface_hub import Repository
from text_generation import Client
PORT = 7860
# TODO: implement maximum length (currently, each iteration is limited by the slider-specified max length, but this can be iterated, or long code entered into the editor, to get really long documents
# if os.path.exists('unlock'):
# # create an 'unlock' file (not checked into Git) locally to get full context lengths
# MAX_LENGTH = 8192
# else:
# # set to a shorter value to prevent long contexts and make the demo more efficient
# MAX_LENGTH = 1024
# TRUNCATION_MESSAGE = f'warning: This demo is limited to {MAX_LENGTH} tokens in the document for efficiency.'
TRUNCATION_MESSAGE = f'TODO'
HF_TOKEN = os.environ.get("HF_TOKEN", None)
API_URL = os.environ.get("API_URL")
with open("./HHH_prompt.txt", "r") as f:
HHH_PROMPT = f.read() + "\n\n"
# used by the model
FIM_PREFIX = "<fim_prefix>"
FIM_MIDDLE = "<fim_middle>"
FIM_SUFFIX = "<fim_suffix>"
END_OF_TEXT = "<|endoftext|>"
# used to mark infill locations in the editor
FIM_INDICATOR = "<infill>"
client = Client(
API_URL, headers={"Authorization": f"Bearer {HF_TOKEN}"},
)
from fastapi import FastAPI, Request
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse, StreamingResponse
app = FastAPI(docs_url=None, redoc_url=None)
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.head("/")
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="static/index.html", media_type="text/html")
def generate(prefix, suffix=None, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
# TODO: deduplicate code between this and `infill`
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
)
fim_mode = suffix is not None
if suffix is not None:
prompt = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"
else:
prompt = prefix
output = client.generate(prompt, **generate_kwargs)
generated_text = output.generated_text
# TODO: set this based on stop reason from client.generate
truncated = False
while generated_text.endswith(END_OF_TEXT):
generated_text = generated_text[:-len(END_OF_TEXT)]
generation = {
'truncated': truncated,
}
if fim_mode:
generation['type'] = 'infill'
generation['text'] = prefix + generated_text + suffix
generation['parts'] = [prefix, suffix]
generation['infills'] = [generated_text]
else:
generation['type'] = 'generate'
generation['text'] = prompt + generated_text
generation['parts'] = [prompt]
return generation
@app.get('/generate')
async def generate_maybe(info: str):
# info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues)
# fix padding, following https://stackoverflow.com/a/9956217/1319683
info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8')
form = json.loads(info)
prompt = form['prompt']
length_limit = int(form['length'])
temperature = float(form['temperature'])
try:
generation = generate(prompt, temperature=temperature, max_new_tokens=length_limit, top_p=0.95, repetition_penalty=1.0)
if generation['truncated']:
message = TRUNCATION_MESSAGE
else:
message = ''
return {'result': 'success', 'type': 'generate', 'prompt': prompt, 'text': generation['text'], 'message': message}
except Exception as e:
traceback.print_exception(*sys.exc_info())
return {'result': 'error', 'type': 'generate', 'prompt': prompt, 'message': f'Error: {e}.'}
@app.get('/infill')
async def infill_maybe(info: str):
# info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues)
# fix padding, following https://stackoverflow.com/a/9956217/1319683
info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8')
form = json.loads(info)
length_limit = int(form['length'])
temperature = float(form['temperature'])
try:
if len(form['parts']) > 2:
return {'result': 'error', 'text': ''.join(form['parts']), 'type': 'infill', 'message': f"error: Only a single <infill> token is supported!"}
elif len(form['parts']) == 1:
return {'result': 'error', 'text': ''.join(form['parts']), 'type': 'infill', 'message': f"error: Must have an <infill> token present!"}
prefix, suffix = form['parts']
generation = generate(prefix, suffix=suffix, temperature=temperature, max_new_tokens=length_limit, top_p=0.95, repetition_penalty=1.0)
generation['result'] = 'success'
if generation['truncated']:
generation['message'] = TRUNCATION_MESSAGE
else:
generation['message'] = ''
return generation
except Exception as e:
traceback.print_exception(*sys.exc_info())
return {'result': 'error', 'type': 'infill', 'message': f'Error: {e}.'}
if __name__ == "__main__":
app.run(host='0.0.0.0', port=PORT, threaded=False)
|