import gradio as gr import torch from timeit import default_timer as timer from model import create_GPT_model from utils import prepare_vocab def main(): device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') vocab_size, encode, decode = prepare_vocab() model = create_GPT_model(vocab_size=vocab_size, device=device) model.load_state_dict(torch.load( f="Pretrained_GPT_med_bot.pth", map_location=torch.device(device))) def predict(question: str): start = timer() in_len = len(question) prompt = torch.tensor(encode(question), dtype=torch.long, device=device) model.eval() with torch.inference_mode(): response = model.generate(prompt.unsqueeze(0), max_new_tokens=100)[0].tolist() answer = decode(response)[in_len:] pred_time = round(timer() - start, 5) return answer, pred_time title = "Med Chat Bot" example_list = [ "What are the common symptoms of the flu?", "How can I prevent catching a cold?", "What lifestyle changes can I make to improve my heart health?", "Is it necessary to get vaccinated every year?" ] demo = gr.Interface(fn=predict, inputs=gr.Text(), outputs=[gr.Text(label="Answer"), gr.Number(label="Prediction time (s)")], examples=example_list, title=title) demo.launch() if __name__ == "__main__": main()