Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from timeit import default_timer as timer | |
from model import create_GPT_model | |
from utils import prepare_vocab | |
def main(): | |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
vocab_size, encode, decode = prepare_vocab() | |
model = create_GPT_model(vocab_size=vocab_size, device=device) | |
model.load_state_dict(torch.load( | |
f="Pretrained_GPT_med_bot.pth", | |
map_location=torch.device(device))) | |
def predict(question: str): | |
start = timer() | |
in_len = len(question) | |
prompt = torch.tensor(encode(question), | |
dtype=torch.long, | |
device=device) | |
model.eval() | |
with torch.inference_mode(): | |
response = model.generate(prompt.unsqueeze(0), | |
max_new_tokens=200)[0].tolist() | |
answer = decode(response)[in_len:] | |
pred_time = round(timer() - start, 5) | |
return answer, pred_time | |
title = "Med Chat Bot" | |
example_list = [ | |
"What are the common symptoms of the flu?", | |
"How can I prevent catching a cold?", | |
"What lifestyle changes can I make to improve my heart health?", | |
"Is it necessary to get vaccinated every year?" | |
] | |
demo = gr.Interface(fn=predict, | |
inputs=gr.Text(), | |
outputs=[gr.Text(label="Answer"), | |
gr.Number(label="Prediction time (s)")], | |
examples=example_list, | |
title=title) | |
demo.launch() | |
if __name__ == "__main__": | |
main() | |