Spaces:
Sleeping
Sleeping
# Import and class names setup | |
import gradio as gr | |
import os | |
import torch | |
import random | |
#import nltk_u | |
import pandas as pd | |
from sklearn.model_selection import train_test_split | |
import time | |
#from model import RNN_model | |
from timeit import default_timer as timer | |
from typing import Tuple, Dict | |
################################################################################ | |
import argparse | |
import numpy as np | |
import pprint | |
import os | |
import copy | |
from str2bool import str2bool | |
from typing import Dict, Sequence | |
from sentence_transformers import SentenceTransformer | |
import torch | |
import json | |
import transformers | |
from modeling_phi import PhiForCausalLM | |
from tokenization_codegen import CodeGenTokenizer | |
################################################################################ | |
parser = argparse.ArgumentParser() | |
############################################################################################################################# | |
parser.add_argument('--device_id', type=str, default="0") | |
parser.add_argument('--model', type=str, default="microsoft/phi-2", help="") ## /phi-1.5 | |
parser.add_argument('--embedder', type=str, default="BAAI/bge-small-en-v1.5") ## /bge-small-en-v1.5 # bge-m3 | |
parser.add_argument('--output_path', type=str, default="/home/henry/Desktop/HKU-DASC7606-A2/Outputs/ARC-Challenge-test", help="") ## -bge-m3 | |
parser.add_argument('--start_index', type=int, default=0, help="") | |
parser.add_argument('--end_index', type=int, default=9999, help="") | |
parser.add_argument('--N', type=int, default=8, help="") | |
parser.add_argument('--max_len', type=int, default=1024, help="") | |
parser.add_argument('--prompt_type', type=str, default="v2.0", help="") | |
parser.add_argument('--top_k', type=str2bool, default=True, help="") | |
############################################################################################################################# | |
args = parser.parse_args() | |
if torch.cuda.is_available(): | |
device = "cuda" | |
print(f'################################################################# device: {device}#################################################################') | |
else: | |
device = "cpu" | |
def get_model(base_model: str = "bigcode/starcoder",): | |
tokenizer = CodeGenTokenizer.from_pretrained(base_model) | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
tokenizer.pad_token = tokenizer.eos_token | |
model = PhiForCausalLM.from_pretrained( | |
base_model, | |
device_map="auto", | |
) | |
model.config.pad_token_id = tokenizer.pad_token_id | |
model.eval() | |
return tokenizer, model | |
################################################################################ | |
''' | |
# Import data | |
df= pd.read_csv('Symptom2Disease.csv') | |
df.drop('Unnamed: 0', axis= 1, inplace= True) | |
# Preprocess data | |
df.drop_duplicates(inplace= True) | |
train_data, test_data= train_test_split(df, test_size=0.15, random_state=42 ) | |
''' | |
howto= """Welcome to the <b>Medical Chatbot</b>, powered by Gradio. | |
Currently, the chatbot can WELCOME YOU, PREDICT DISEASE based on your symptoms and SUGGEST POSSIBLE SOLUTIONS AND RECOMENDATIONS, and BID YOU FAREWELL. | |
<b>How to Start:</b> Simply type your messages in the textbox to chat with the Chatbot and press enter!<br><br> | |
The bot will respond based on the best possible answers to your messages. | |
""" | |
# Create the gradio demo | |
with gr.Blocks(css = """#col_container { margin-left: auto; margin-right: auto;} #chatbot {height: 520px; overflow: auto;}""") as demo: | |
gr.HTML('<h1 align="center">Medical Chatbot: ARIN 7102') | |
#gr.HTML('<h3 align="center">To know more about this project') | |
with gr.Accordion("Follow these Steps to use the Gradio WebUI", open=True): | |
gr.HTML(howto) | |
chatbot = gr.Chatbot() | |
msg = gr.Textbox() | |
clear = gr.ClearButton([msg, chatbot]) | |
def respond(message, chat_history): | |
# Create couple of if-else statements to capture/mimick peoples's Interaction | |
embedder = SentenceTransformer(args.embedder, device=device) | |
tokenizer, model = get_model(base_model=args.model) | |
message_embeddings = embedder.encode(message) | |
bot_message = model(message_embeddings) | |
chat_history.append((message, bot_message)) | |
time.sleep(2) | |
return "", chat_history | |
msg.submit(respond, [msg, chatbot], [msg, chatbot]) | |
# Launch the demo | |
demo.launch() | |