File size: 1,924 Bytes
1ceabce
6c044bb
cd14061
6c044bb
c31b6fc
 
7f908de
c31b6fc
3adad86
 
 
d1cd59a
 
3adad86
4377389
6c044bb
1ceabce
c9ffc16
1ceabce
d1cd59a
3c02006
edcb083
3c02006
 
 
 
d1cd59a
3c02006
 
5b4227b
3c02006
 
a60243d
43ef6f6
2fb3b9d
5b4227b
1ceabce
 
 
 
 
 
 
2fb3b9d
1ceabce
 
 
7f908de
7378061
3c02006
6c044bb
3c02006
 
 
 
c9ffc16
9aae32c
e3f3e69
3c02006
1ceabce
3c02006
03d61c6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# gpt Gradio App by Craig Brennan
import torch
import tiktoken
import gradio as gr
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
import spaces

n_layers = 24
n_heads = 16
embed_dim = 1024
ffn_dim = embed_dim * 4
n_vocab = 50257
max_seq_len = 740
dropout = 0.0

# Number of tokens per update interval
update_interval = 14

@spaces.GPU
def load_model():
    from model import GPTModel
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    tokenizer = tiktoken.encoding_for_model('gpt2')
    model_path = hf_hub_download('cbspace/gpt', 'model.safetensors')
    state_dict = load_file(model_path)

    model = GPTModel(device, n_layers, n_heads, embed_dim, ffn_dim, n_vocab, max_seq_len, dropout)
    model.load_state_dict(state_dict, strict=False)
    
    model.eval()
    return model, tokenizer, device

@spaces.GPU(duration=120)
def generate(prompt,out_tokens,top_k_value,temperature):
    model.to(device)
    outputs = tokenizer.encode(prompt)
    tokens_remaining = int(out_tokens)
    out_text = prompt
    yield out_text

    while tokens_remaining:
        new_inputs_len = update_interval if tokens_remaining >= update_interval else tokens_remaining % update_interval
        outputs = model.generate(outputs, len(outputs)+new_inputs_len, temperature, top_k=int(top_k_value))
        tokens_remaining -= new_inputs_len
        out_text += tokenizer.decode(outputs[-new_inputs_len:])
        yield out_text

# Create the model
model, tokenizer, device = load_model()

app = gr.Interface(
    generate,
    [
        gr.Textbox(label='Prompt', lines=3), 
        gr.Number(label='Output Tokens', value=180),
        gr.Slider(1, 100, step=5, value=60, label='Top-k Value'),
        gr.Slider(0.1, 2.0, step=0.05, value=0.9, label='Temperature')
    ],
    gr.Textbox(label='Output', lines=15, max_lines=15)
)
app.queue().launch(ssr_mode=False, share=True)