File size: 4,184 Bytes
5e29726
2263c30
5e29726
 
2263c30
 
 
 
 
 
 
 
5e29726
 
 
2263c30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e29726
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2263c30
 
5e29726
 
2263c30
5e29726
2263c30
5e29726
 
 
2263c30
5e29726
 
 
2263c30
5e29726
 
 
 
 
 
2263c30
5e29726
 
 
2263c30
5e29726
 
 
 
 
 
 
 
 
2263c30
5e29726
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import markdown
import streamlit as st
from streamlit_chat import message
from streamlit_extras.colored_header import colored_header
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

model_id = "Narrativaai/BioGPT-Large-finetuned-chatdoctor"
tokenizer = AutoTokenizer.from_pretrained("microsoft/BioGPT-Large")
model = AutoModelForCausalLM.from_pretrained(model_id)


def answer_question(
    prompt, temperature=0.1, top_p=0.75, top_k=40, num_beams=2, **kwargs
):
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"].to("cpu")
    attention_mask = inputs["attention_mask"].to("cpu")
    generation_config = GenerationConfig(
        temperature=temperature, top_p=top_p, top_k=top_k, num_beams=num_beams, **kwargs
    )
    with torch.no_grad():
        generation_output = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            generation_config=generation_config,
            return_dict_in_generate=True,
            output_scores=True,
            max_new_tokens=512,
            eos_token_id=tokenizer.eos_token_id,
        )
    s = generation_output.sequences[0]
    output = tokenizer.decode(s, skip_special_tokens=True)
    return output.split(" Response:")[1]


st.set_page_config(page_title="Talk To Me", page_icon=":ambulance:", layout="wide")

colored_header(
    label="Talk To Me",
    description="Talk your way to better health",
    color_name="violet-70",
)

# st.title("Talk To Me")
# st.caption("Talk your way to better health")

# add sidebar
with open("./sidebar.md", "r") as sidebar_file:
    sidebar_content = sidebar_file.read()

with open("./styles.md", "r") as styles_file:
    styles_content = styles_file.read()


def add_sbg_from_url():
    st.markdown(
        f"""
         <style>
         .css-6qob1r {{
             background-image: url("https://images.unsplash.com/photo-1524169358666-79f22534bc6e?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=3540&q=80");
             background-attachment: fixed;
             background-size: cover
         }}
         </style>
         """,
        unsafe_allow_html=True,
    )


add_sbg_from_url()


def add_mbg_from_url():
    st.markdown(
        f"""
         <style>
         .stApp {{
             background-image: url("https://images.unsplash.com/photo-1536353602887-521e965eb03f?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=987&q=80");
             background-attachment: fixed;
             background-size: cover
         }}
         </style>
         """,
        unsafe_allow_html=True,
    )


add_mbg_from_url()


# Display the sidebar content
st.sidebar.markdown(sidebar_content)

st.write(styles_content, unsafe_allow_html=True)

# Initialize session state
if "chat_history" not in st.session_state:
    st.session_state.chat_history = []

# display default message if no chat history
if not st.session_state.chat_history:
    message("Hi, I'm a medical chat bot. Ask me a question!")

# Display the chat history
for chat in st.session_state.chat_history:
    if chat["is_user"]:
        message(chat["message"], is_user=True)
    else:
        message(chat["message"])

with st.form("user_input_form"):
    st.write("Please enter your question below:")
    user_input = st.text_input("You: ")

    # Check if user has submitted a question
    if st.form_submit_button("Submit") and user_input:
        with st.spinner('Loading model and generating response...'):
        # Generate response and update chat history
            bot_response = answer_question(f"Input: {user_input}\nResponse:")
            st.session_state.chat_history.append({"message": user_input, "is_user": True})
            st.session_state.chat_history.append(
                {"message": bot_response, "is_user": False}
            )

# Display the latest chat in the chat history
if st.session_state.chat_history:
    latest_chat = st.session_state.chat_history[-1]
    if latest_chat["is_user"]:
        message(latest_chat["message"], is_user=True)
    else:
        message(latest_chat["message"])