File size: 1,924 Bytes
1ceabce 6c044bb cd14061 6c044bb c31b6fc 7f908de c31b6fc 3adad86 d1cd59a 3adad86 4377389 6c044bb 1ceabce c9ffc16 1ceabce d1cd59a 3c02006 edcb083 3c02006 d1cd59a 3c02006 5b4227b 3c02006 a60243d 43ef6f6 2fb3b9d 5b4227b 1ceabce 2fb3b9d 1ceabce 7f908de 7378061 3c02006 6c044bb 3c02006 c9ffc16 9aae32c e3f3e69 3c02006 1ceabce 3c02006 03d61c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
# gpt Gradio App by Craig Brennan
import torch
import tiktoken
import gradio as gr
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
import spaces
n_layers = 24
n_heads = 16
embed_dim = 1024
ffn_dim = embed_dim * 4
n_vocab = 50257
max_seq_len = 740
dropout = 0.0
# Number of tokens per update interval
update_interval = 14
@spaces.GPU
def load_model():
from model import GPTModel
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = tiktoken.encoding_for_model('gpt2')
model_path = hf_hub_download('cbspace/gpt', 'model.safetensors')
state_dict = load_file(model_path)
model = GPTModel(device, n_layers, n_heads, embed_dim, ffn_dim, n_vocab, max_seq_len, dropout)
model.load_state_dict(state_dict, strict=False)
model.eval()
return model, tokenizer, device
@spaces.GPU(duration=120)
def generate(prompt,out_tokens,top_k_value,temperature):
model.to(device)
outputs = tokenizer.encode(prompt)
tokens_remaining = int(out_tokens)
out_text = prompt
yield out_text
while tokens_remaining:
new_inputs_len = update_interval if tokens_remaining >= update_interval else tokens_remaining % update_interval
outputs = model.generate(outputs, len(outputs)+new_inputs_len, temperature, top_k=int(top_k_value))
tokens_remaining -= new_inputs_len
out_text += tokenizer.decode(outputs[-new_inputs_len:])
yield out_text
# Create the model
model, tokenizer, device = load_model()
app = gr.Interface(
generate,
[
gr.Textbox(label='Prompt', lines=3),
gr.Number(label='Output Tokens', value=180),
gr.Slider(1, 100, step=5, value=60, label='Top-k Value'),
gr.Slider(0.1, 2.0, step=0.05, value=0.9, label='Temperature')
],
gr.Textbox(label='Output', lines=15, max_lines=15)
)
app.queue().launch(ssr_mode=False, share=True) |