ARIN_7102 / app.py
ClefChen's picture
Upload 6 files
f7a18c9 verified
raw
history blame
No virus
4.28 kB
# Import and class names setup
import gradio as gr
import os
import torch
import random
#import nltk_u
import pandas as pd
from sklearn.model_selection import train_test_split
import time
#from model import RNN_model
from timeit import default_timer as timer
from typing import Tuple, Dict
################################################################################
import argparse
import numpy as np
import pprint
import os
import copy
from str2bool import str2bool
from typing import Dict, Sequence
from sentence_transformers import SentenceTransformer
import torch
import json
import transformers
from modeling_phi import PhiForCausalLM
from tokenization_codegen import CodeGenTokenizer
################################################################################
parser = argparse.ArgumentParser()
#############################################################################################################################
parser.add_argument('--device_id', type=str, default="0")
parser.add_argument('--model', type=str, default="microsoft/phi-2", help="") ## /phi-1.5
parser.add_argument('--embedder', type=str, default="BAAI/bge-small-en-v1.5") ## /bge-small-en-v1.5 # bge-m3
parser.add_argument('--output_path', type=str, default="/home/henry/Desktop/HKU-DASC7606-A2/Outputs/ARC-Challenge-test", help="") ## -bge-m3
parser.add_argument('--start_index', type=int, default=0, help="")
parser.add_argument('--end_index', type=int, default=9999, help="")
parser.add_argument('--N', type=int, default=8, help="")
parser.add_argument('--max_len', type=int, default=1024, help="")
parser.add_argument('--prompt_type', type=str, default="v2.0", help="")
parser.add_argument('--top_k', type=str2bool, default=True, help="")
#############################################################################################################################
args = parser.parse_args()
if torch.cuda.is_available():
device = "cuda"
print(f'################################################################# device: {device}#################################################################')
else:
device = "cpu"
def get_model(base_model: str = "bigcode/starcoder",):
tokenizer = CodeGenTokenizer.from_pretrained(base_model)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
model = PhiForCausalLM.from_pretrained(
base_model,
device_map="auto",
)
model.config.pad_token_id = tokenizer.pad_token_id
model.eval()
return tokenizer, model
################################################################################
'''
# Import data
df= pd.read_csv('Symptom2Disease.csv')
df.drop('Unnamed: 0', axis= 1, inplace= True)
# Preprocess data
df.drop_duplicates(inplace= True)
train_data, test_data= train_test_split(df, test_size=0.15, random_state=42 )
'''
howto= """Welcome to the <b>Medical Chatbot</b>, powered by Gradio.
Currently, the chatbot can WELCOME YOU, PREDICT DISEASE based on your symptoms and SUGGEST POSSIBLE SOLUTIONS AND RECOMENDATIONS, and BID YOU FAREWELL.
<b>How to Start:</b> Simply type your messages in the textbox to chat with the Chatbot and press enter!<br><br>
The bot will respond based on the best possible answers to your messages.
"""
# Create the gradio demo
with gr.Blocks(css = """#col_container { margin-left: auto; margin-right: auto;} #chatbot {height: 520px; overflow: auto;}""") as demo:
gr.HTML('<h1 align="center">Medical Chatbot: ARIN 7102')
#gr.HTML('<h3 align="center">To know more about this project')
with gr.Accordion("Follow these Steps to use the Gradio WebUI", open=True):
gr.HTML(howto)
chatbot = gr.Chatbot()
msg = gr.Textbox()
clear = gr.ClearButton([msg, chatbot])
def respond(message, chat_history):
# Create couple of if-else statements to capture/mimick peoples's Interaction
embedder = SentenceTransformer(args.embedder, device=device)
tokenizer, model = get_model(base_model=args.model)
message_embeddings = embedder.encode(message)
bot_message = model(message_embeddings)
chat_history.append((message, bot_message))
time.sleep(2)
return "", chat_history
msg.submit(respond, [msg, chatbot], [msg, chatbot])
# Launch the demo
demo.launch()