import sys from typing import List import traceback import os import base64 import json import pprint from huggingface_hub import Repository from text_generation import Client # from flask import Flask, request, render_template # from flask_cors import CORS # app = Flask(__name__, static_folder='static') # app.config['TEMPLATES_AUTO_RELOAD'] = Tru # CORS(app, resources= { # r"/generate": {"origins": origins}, # r"/infill": {"origins": origins}, # }) # origins=[f"http://localhost:{PORT}", "https://huggingface.co", "https://hf.space"] PORT = 7860 VERBOSE = False if os.path.exists('unlock'): MAX_LENGTH = 8192 else: MAX_LENGTH = 8192 TRUNCATION_MESSAGE = f'warning: This demo is limited to {MAX_LENGTH} tokens in the document for efficiency.' from fastapi import FastAPI, Request from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse, StreamingResponse app = FastAPI(docs_url=None, redoc_url=None) app.mount("/static", StaticFiles(directory="static"), name="static") HF_TOKEN = os.environ.get("HF_TOKEN", None) API_URL = os.environ.get("API_URL") with open("./HHH_prompt.txt", "r") as f: HHH_PROMPT = f.read() + "\n\n" FIM_PREFIX = "" FIM_MIDDLE = "" FIM_SUFFIX = "" END_OF_TEXT = "<|endoftext|>" FIM_INDICATOR = "" client = Client( API_URL, headers={"Authorization": f"Bearer {HF_TOKEN}"}, ) @app.head("/") @app.get("/") def index() -> FileResponse: return FileResponse(path="static/index.html", media_type="text/html") def generate(prefix, suffix=None, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0): temperature = float(temperature) if temperature < 1e-2: temperature = 1e-2 top_p = float(top_p) generate_kwargs = dict( temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, seed=42, ) fim_mode = suffix is not None if suffix is not None: prompt = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}" else: prompt = prefix output = client.generate(prompt, **generate_kwargs) # TODO generated_text = output.generated_text truncated = False while generated_text.endswith(END_OF_TEXT): generated_text = generated_text[:-len(END_OF_TEXT)] generation = { 'truncated': truncated, } if fim_mode: generation['text'] = prefix + generated_text + suffix generation['parts'] = [prefix, suffix] generation['infills'] = [generated_text] generation['type'] = 'infill' else: generation['text'] = prompt + generated_text generation['parts'] = [prompt] generation['type'] = 'generate' return generation @app.get('/generate') # async def generate_maybe(request: Request): async def generate_maybe(info: str): # form = await info.json() # form = await request.json() # info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues) # fix padding, following https://stackoverflow.com/a/9956217/1319683 info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8') form = json.loads(info) # print(form) prompt = form['prompt'] length_limit = int(form['length']) temperature = float(form['temperature']) try: generation = generate(prompt, temperature=temperature, max_new_tokens=length_limit, top_p=0.95, repetition_penalty=1.0) if generation['truncated']: message = TRUNCATION_MESSAGE else: message = '' return {'result': 'success', 'type': 'generate', 'prompt': prompt, 'text': generation['text'], 'message': message} except Exception as e: traceback.print_exception(*sys.exc_info()) return {'result': 'error', 'type': 'generate', 'prompt': prompt, 'message': f'Error: {e}.'} @app.get('/infill') # async def infill_maybe(request: Request): async def infill_maybe(info: str): # form = await info.json() # form = await request.json() # info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues) # fix padding, following https://stackoverflow.com/a/9956217/1319683 info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8') form = json.loads(info) length_limit = int(form['length']) temperature = float(form['temperature']) max_retries = 1 extra_sentinel = True try: if len(form['parts']) > 2: return {'result': 'error', 'text': ''.join(form['parts']), 'type': 'infill', 'message': f"error: Only a single infill is supported!"} prefix, suffix = form['parts'] generation = generate(prefix, suffix=suffix, temperature=temperature, max_new_tokens=length_limit, top_p=0.95, repetition_penalty=1.0) generation['result'] = 'success' if generation['truncated']: generation['message'] = TRUNCATION_MESSAGE else: generation['message'] = '' return generation # return {'result': 'success', 'prefix': prefix, 'suffix': suffix, 'text': generation['text']} except Exception as e: traceback.print_exception(*sys.exc_info()) return {'result': 'error', 'type': 'infill', 'message': f'Error: {e}.'} if __name__ == "__main__": app.run(host='0.0.0.0', port=PORT, threaded=False)