File size: 5,892 Bytes
bac999f
 
 
 
 
 
 
 
 
 
 
422733b
 
bac999f
 
13f28f2
 
 
 
 
 
 
 
 
bac999f
 
 
 
 
 
 
13f28f2
bac999f
 
 
 
 
13f28f2
bac999f
 
 
 
 
 
13f28f2
 
 
 
 
 
bac999f
 
 
 
 
 
13f28f2
bac999f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13f28f2
bac999f
 
 
 
 
 
 
13f28f2
bac999f
 
 
 
13f28f2
bac999f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422733b
 
 
bac999f
 
 
 
 
 
 
 
 
 
 
 
 
 
a2fc40d
 
 
bac999f
 
 
 
 
 
 
 
422733b
 
 
bac999f
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import sys
from typing import List
import traceback
import os
import base64
import json
import pprint

from huggingface_hub import Repository
from text_generation import Client

from requests.exceptions import ReadTimeout

PORT = 7860

# TODO: implement maximum length (currently, each iteration is limited by the slider-specified max length, but this can be iterated, or long code entered into the editor, to get really long documents
# if os.path.exists('unlock'):
#     # create an 'unlock' file (not checked into Git) locally to get full context lengths
#     MAX_LENGTH = 8192
# else:
#     # set to a shorter value to prevent long contexts and make the demo more efficient
#     MAX_LENGTH = 1024
# TRUNCATION_MESSAGE = f'warning: This demo is limited to {MAX_LENGTH} tokens in the document for efficiency.'
TRUNCATION_MESSAGE = f'TODO'

HF_TOKEN = os.environ.get("HF_TOKEN", None)
API_URL = os.environ.get("API_URL")

with open("./HHH_prompt.txt", "r") as f:
    HHH_PROMPT = f.read() + "\n\n"

# used by the model
FIM_PREFIX = "<fim_prefix>"
FIM_MIDDLE = "<fim_middle>"
FIM_SUFFIX = "<fim_suffix>"
END_OF_TEXT = "<|endoftext|>"

# used to mark infill locations in the editor
FIM_INDICATOR = "<infill>"

client = Client(
    API_URL, headers={"Authorization": f"Bearer {HF_TOKEN}"},
)

from fastapi import FastAPI, Request
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse, StreamingResponse
app = FastAPI(docs_url=None, redoc_url=None)
app.mount("/static", StaticFiles(directory="static"), name="static")

@app.head("/")
@app.get("/")
def index() -> FileResponse:
    return FileResponse(path="static/index.html", media_type="text/html")

def generate(prefix, suffix=None, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
    # TODO: deduplicate code between this and `infill`
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)
    
    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    fim_mode = suffix is not None
    
    if suffix is not None:
        prompt = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"
    else:
        prompt = prefix
    output = client.generate(prompt, **generate_kwargs)
    generated_text = output.generated_text
    # TODO: set this based on stop reason from client.generate
    truncated = False
    while generated_text.endswith(END_OF_TEXT):
        generated_text = generated_text[:-len(END_OF_TEXT)]
    generation = {
        'truncated': truncated,
    }
    if fim_mode:
        generation['type'] = 'infill'
        generation['text'] = prefix + generated_text + suffix
        generation['parts'] = [prefix, suffix]
        generation['infills'] = [generated_text]
    else:
        generation['type'] = 'generate'
        generation['text'] = prompt + generated_text
        generation['parts'] = [prompt]
    return generation

@app.get('/generate')
async def generate_maybe(info: str):
    # info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues)
    # fix padding, following https://stackoverflow.com/a/9956217/1319683
    info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8')
    form = json.loads(info)
    prompt = form['prompt']
    length_limit = int(form['length'])
    temperature = float(form['temperature'])
    try:
        generation = generate(prompt, temperature=temperature, max_new_tokens=length_limit, top_p=0.95, repetition_penalty=1.0)
        if generation['truncated']:
            message = TRUNCATION_MESSAGE 
        else:
            message = ''
        return {'result': 'success', 'type': 'generate', 'prompt': prompt, 'text': generation['text'], 'message': message}
    except ReadTimeout as e:
        print(e)
        return {'result': 'error', 'type': 'generate', 'prompt': prompt, 'message': f'Request timed out.'}
    except Exception as e:
        traceback.print_exception(*sys.exc_info())
        return {'result': 'error', 'type': 'generate', 'prompt': prompt, 'message': f'Error: {e}.'}

@app.get('/infill')
async def infill_maybe(info: str):
    # info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues)
    # fix padding, following https://stackoverflow.com/a/9956217/1319683
    info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8')
    form = json.loads(info)
    length_limit = int(form['length'])
    temperature = float(form['temperature'])
    try:
        if len(form['parts']) > 2:
            return {'result': 'error', 'text': ''.join(form['parts']), 'type': 'infill', 'message': f"error: Only a single <infill> token is supported!"}
        elif len(form['parts']) == 1:
            return {'result': 'error', 'text': ''.join(form['parts']), 'type': 'infill', 'message': f"error: Must have an <infill> token present!"}
        prefix, suffix = form['parts']
        generation = generate(prefix, suffix=suffix, temperature=temperature, max_new_tokens=length_limit, top_p=0.95, repetition_penalty=1.0)
        generation['result'] = 'success'
        if generation['truncated']:
            generation['message'] = TRUNCATION_MESSAGE
        else:
            generation['message'] = ''
        return generation
    except ReadTimeout as e:
        print(e)
        return {'result': 'error', 'type': 'generate', 'prompt': prompt, 'message': f'Request timed out.'}
    except Exception as e:
        traceback.print_exception(*sys.exc_info())
        return {'result': 'error', 'type': 'infill', 'message': f'Error: {e}.'}

if __name__ == "__main__":
    app.run(host='0.0.0.0', port=PORT, threaded=False)