import sys from typing import List import traceback import os import base64 import json import pprint from huggingface_hub import Repository from text_generation import Client from requests.exceptions import ReadTimeout PORT = 7860 # TODO: implement maximum length (currently, each iteration is limited by the slider-specified max length, but this can be iterated, or long code entered into the editor, to get really long documents # if os.path.exists('unlock'): # # create an 'unlock' file (not checked into Git) locally to get full context lengths # MAX_LENGTH = 8192 # else: # # set to a shorter value to prevent long contexts and make the demo more efficient # MAX_LENGTH = 1024 # TRUNCATION_MESSAGE = f'warning: This demo is limited to {MAX_LENGTH} tokens in the document for efficiency.' TRUNCATION_MESSAGE = f'TODO' HF_TOKEN = os.environ.get("HF_TOKEN", None) API_URL = os.environ.get("API_URL") with open("./HHH_prompt.txt", "r") as f: HHH_PROMPT = f.read() + "\n\n" # used by the model FIM_PREFIX = "" FIM_MIDDLE = "" FIM_SUFFIX = "" END_OF_TEXT = "<|endoftext|>" # used to mark infill locations in the editor FIM_INDICATOR = "" client = Client( API_URL, headers={"Authorization": f"Bearer {HF_TOKEN}"}, ) from fastapi import FastAPI, Request from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse, StreamingResponse app = FastAPI(docs_url=None, redoc_url=None) app.mount("/static", StaticFiles(directory="static"), name="static") @app.head("/") @app.get("/") def index() -> FileResponse: return FileResponse(path="static/index.html", media_type="text/html") def generate(prefix, suffix=None, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0): # TODO: deduplicate code between this and `infill` temperature = float(temperature) if temperature < 1e-2: temperature = 1e-2 top_p = float(top_p) generate_kwargs = dict( temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, seed=42, ) fim_mode = suffix is not None if suffix is not None: prompt = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}" else: prompt = prefix output = client.generate(prompt, **generate_kwargs) generated_text = output.generated_text # TODO: set this based on stop reason from client.generate truncated = False while generated_text.endswith(END_OF_TEXT): generated_text = generated_text[:-len(END_OF_TEXT)] generation = { 'truncated': truncated, } if fim_mode: generation['type'] = 'infill' generation['text'] = prefix + generated_text + suffix generation['parts'] = [prefix, suffix] generation['infills'] = [generated_text] else: generation['type'] = 'generate' generation['text'] = prompt + generated_text generation['parts'] = [prompt] return generation @app.get('/generate') async def generate_maybe(info: str): # info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues) # fix padding, following https://stackoverflow.com/a/9956217/1319683 info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8') form = json.loads(info) prompt = form['prompt'] length_limit = int(form['length']) temperature = float(form['temperature']) try: generation = generate(prompt, temperature=temperature, max_new_tokens=length_limit, top_p=0.95, repetition_penalty=1.0) if generation['truncated']: message = TRUNCATION_MESSAGE else: message = '' return {'result': 'success', 'type': 'generate', 'prompt': prompt, 'text': generation['text'], 'message': message} except ReadTimeout as e: print(e) return {'result': 'error', 'type': 'generate', 'prompt': prompt, 'message': f'Request timed out.'} except Exception as e: traceback.print_exception(*sys.exc_info()) return {'result': 'error', 'type': 'generate', 'prompt': prompt, 'message': f'Error: {e}.'} @app.get('/infill') async def infill_maybe(info: str): # info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues) # fix padding, following https://stackoverflow.com/a/9956217/1319683 info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8') form = json.loads(info) length_limit = int(form['length']) temperature = float(form['temperature']) try: if len(form['parts']) > 2: return {'result': 'error', 'text': ''.join(form['parts']), 'type': 'infill', 'message': f"error: Only a single token is supported!"} elif len(form['parts']) == 1: return {'result': 'error', 'text': ''.join(form['parts']), 'type': 'infill', 'message': f"error: Must have an token present!"} prefix, suffix = form['parts'] generation = generate(prefix, suffix=suffix, temperature=temperature, max_new_tokens=length_limit, top_p=0.95, repetition_penalty=1.0) generation['result'] = 'success' if generation['truncated']: generation['message'] = TRUNCATION_MESSAGE else: generation['message'] = '' return generation except ReadTimeout as e: print(e) return {'result': 'error', 'type': 'generate', 'prompt': prompt, 'message': f'Request timed out.'} except Exception as e: traceback.print_exception(*sys.exc_info()) return {'result': 'error', 'type': 'infill', 'message': f'Error: {e}.'} if __name__ == "__main__": app.run(host='0.0.0.0', port=PORT, threaded=False)