import os import pickle import gradio as gr import torch from model import GPT, GPTConfig ckpt_path = 'model/ckpt.pt' meta_path = 'model/meta.pkl' seed = 1337 device = 'cpu' torch.manual_seed(seed) # Load the model and meta data checkpoint = torch.load(ckpt_path, map_location=device) gptconf = GPTConfig(**checkpoint['model_args']) model = GPT(gptconf) state_dict = checkpoint['model'] unwanted_prefix = '_orig_mod.' for k, v in list(state_dict.items()): if k.startswith(unwanted_prefix): state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) model.load_state_dict(state_dict) model.eval() model.to(device) with open(meta_path, 'rb') as f: meta = pickle.load(f) stoi, itos = meta['stoi'], meta['itos'] encode = lambda s: [stoi[c] for c in s] decode = lambda l: ''.join([itos[i] for i in l]) # Define the function for generating text def generate_text(start, temperature, max_new_tokens): start_ids = encode(start) x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) # Generate text with torch.no_grad(): y = model.generate(x, max_new_tokens, temperature=temperature) generated_text = decode(y[0].tolist()) return generated_text # Create a Gradio interface with sliders examples = [['sport', 0.7, 200], ['lord', 1.2, 300]] iface = gr.Interface( fn=generate_text, inputs=[ gr.Textbox(label="Starting Prompt"), gr.Slider(minimum=0.1, maximum=4, step=0.1, label="Temperature"), gr.Slider(minimum=100, maximum=1000, step=50, label="Max New Tokens"), ], outputs=gr.Textbox(label="Generated Text"), examples = examples ) iface.launch()