Med_chat_bot / app.py
Matthev00's picture
improved app
914b0b5
raw
history blame contribute delete
No virus
1.65 kB
import gradio as gr
import torch
from timeit import default_timer as timer
from model import create_GPT_model
from utils import prepare_vocab
def main():
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
vocab_size, encode, decode = prepare_vocab()
model = create_GPT_model(vocab_size=vocab_size, device=device)
model.load_state_dict(torch.load(
f="Pretrained_GPT_med_bot.pth",
map_location=torch.device(device)))
def predict(question: str):
start = timer()
in_len = len(question)
prompt = torch.tensor(encode(question),
dtype=torch.long,
device=device)
model.eval()
with torch.inference_mode():
response = model.generate(prompt.unsqueeze(0),
max_new_tokens=100)[0].tolist()
answer = decode(response)[in_len:]
pred_time = round(timer() - start, 5)
return answer, pred_time
title = "Med Chat Bot"
example_list = [
"What are the common symptoms of the flu?",
"How can I prevent catching a cold?",
"What lifestyle changes can I make to improve my heart health?",
"Is it necessary to get vaccinated every year?"
]
demo = gr.Interface(fn=predict,
inputs=gr.Text(),
outputs=[gr.Text(label="Answer"),
gr.Number(label="Prediction time (s)")],
examples=example_list,
title=title)
demo.launch()
if __name__ == "__main__":
main()