ashok2216 commited on
Commit
13d9684
·
verified ·
1 Parent(s): 7ce4952

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -75
app.py CHANGED
@@ -1,86 +1,109 @@
 
 
1
  import streamlit as st
2
- from huggingface_hub import InferenceClient
3
-
4
- # Initialize the InferenceClient
5
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
6
 
7
  # Streamlit app configuration
8
- st.set_page_config(page_title="Health Care ChatBot")
9
- st.title("Health Care ChatBot")
10
-
11
- # Initialize session state for messages if not present
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  if 'messages' not in st.session_state:
13
  st.session_state.messages = [
14
- {"role": "system", "content": "You are a knowledgeable and empathetic medical assistant providing accurate and compassionate health advice based on user input."}
15
  ]
16
 
17
- def respond(message, history, max_tokens, temperature, top_p):
18
- # Prepare the list of messages for the chat completion
19
- messages = [{"role": "system", "content": st.session_state.messages[0]["content"]}]
20
-
21
- for val in history:
22
- if val["role"] == "user":
23
- messages.append({"role": "user", "content": val["content"]})
24
- elif val["role"] == "assistant":
25
- messages.append({"role": "assistant", "content": val["content"]})
26
-
27
- messages.append({"role": "user", "content": message})
28
-
29
- # Generate response
30
- response = ""
31
- response_container = st.empty() # Placeholder to update the response text dynamically
32
-
33
- for message in client.chat_completion(
34
- messages,
35
- max_tokens=max_tokens,
36
- stream=True,
37
- temperature=temperature,
38
- top_p=top_p,
39
- ):
40
- token = message.choices[0].delta.content
41
- response += token
42
- response_container.text(response) # Stream the response
43
-
44
- return response
45
-
46
- # Sidebar for parameters
47
- with st.sidebar:
48
- max_tokens = st.slider("Max new tokens", 1, 2048, 512)
49
- temperature = st.slider("Temperature", 0.1, 4.0, 0.7)
50
- top_p = st.slider("Top-p (nucleus sampling)", 0.1, 1.0, 0.95)
51
-
52
- # Display chat messages from history
53
  for message in st.session_state.messages:
54
- if message["role"] == "user":
55
- # User message on the right
56
- col1, col2 = st.columns([1, 4])
57
- with col2:
58
- with st.chat_message("user"):
59
- st.write(message["content"])
60
- with col1:
61
- st.write("") # Empty space on the left for alignment
62
-
63
- elif message["role"] == "assistant":
64
- # Assistant message on the left
65
- col1, col2 = st.columns([4, 1])
66
- with col1:
67
- with st.chat_message("assistant"):
68
- st.write(message["content"])
69
- with col2:
70
- st.write("") # Empty space on the right for alignment
71
-
72
- # Keep user input box at the bottom
73
- st.divider() # Optional, to visually separate chat history from input box
74
- user_input = st.text_input("You:", key="user_message", placeholder="Type your message here...")
75
-
76
- if user_input:
77
- # Append user message to the chat history
78
  st.session_state.messages.append({"role": "user", "content": user_input})
79
 
80
- # Generate assistant response
81
- response = respond(user_input, st.session_state.messages, max_tokens, temperature, top_p)
 
 
82
  st.session_state.messages.append({"role": "assistant", "content": response})
83
-
84
- # Refresh to display new messages
85
- st.experimental_rerun()
86
-
 
1
+ import os
2
+ import torch
3
  import streamlit as st
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
5
+ from huggingface_hub import login
 
 
6
 
7
  # Streamlit app configuration
8
+ st.set_page_config(page_title="Medical Chatbot", layout="wide")
9
+ st.title("Medical Chatbot")
10
+
11
+ # Get the token from the environment variable
12
+ hf_token = os.getenv("HF_TOKEN")
13
+ if hf_token is None:
14
+ raise ValueError("Hugging Face token not found. Please set the HF_TOKEN environment variable.")
15
+
16
+ # Authenticate with Hugging Face
17
+ login(hf_token)
18
+
19
+ # Set the random seed for reproducibility
20
+ torch.manual_seed(0)
21
+
22
+ # Supported models
23
+ model_links = {
24
+ "Phi-3-mini-128k-instruct": "microsoft/Phi-3-mini-128k-instruct",
25
+ "Mistral-7B": "mistralai/Mistral-7B-Instruct-v0.2",
26
+ "Zephyr-7B": "HuggingFaceH4/zephyr-7b-beta"
27
+ }
28
+
29
+ model_info = {
30
+ "Phi-3-mini-128k-instruct": {
31
+ 'description': """Phi-3-mini-128k-instruct is a large language model from Microsoft for health-related interactions.
32
+ It has been optimized for instruct-based queries.""",
33
+ 'logo': 'https://upload.wikimedia.org/wikipedia/commons/thumb/a/a6/Microsoft_logo_%282012%29.svg/200px-Microsoft_logo_%282012%29.svg.png'
34
+ },
35
+ "Mistral-7B": {
36
+ 'description': """Mistral 7B is a large language model from Mistral AI optimized for Q&A tasks.""",
37
+ 'logo': 'https://mistral.ai/images/logo_hubc88c4ece131b91c7cb753f40e9e1cc5_2589_256x0_resize_q97_h2_lanczos_3.webp'
38
+ },
39
+ "Zephyr-7B": {
40
+ 'description': """Zephyr 7B is a Huggingface model, fine-tuned for helpful and instructive interactions.""",
41
+ 'logo': 'https://huggingface.co/HuggingFaceH4/zephyr-7b-gemma-v0.1/resolve/main/thumbnail.png'
42
+ }
43
+ }
44
+
45
+ # Sidebar for model selection and parameters
46
+ selected_model = st.sidebar.selectbox("Select Model", model_links.keys())
47
+ st.sidebar.write(f"You're now chatting with **{selected_model}**")
48
+ st.sidebar.markdown(model_info[selected_model]['description'])
49
+ st.sidebar.image(model_info[selected_model]['logo'])
50
+
51
+ # Temperature slider
52
+ temperature = st.sidebar.slider('Temperature', 0.1, 1.0, 0.7)
53
+
54
+ # Reset conversation button
55
+ def reset_conversation():
56
+ st.session_state.messages = []
57
+ st.session_state.model = selected_model
58
+
59
+ st.sidebar.button('Reset Chat', on_click=reset_conversation)
60
+
61
+ # Load model and tokenizer only if it's not already loaded
62
+ if 'model' not in st.session_state or st.session_state.model != selected_model:
63
+ model = AutoModelForCausalLM.from_pretrained(
64
+ model_links[selected_model],
65
+ device_map="auto",
66
+ torch_dtype="auto",
67
+ trust_remote_code=True
68
+ )
69
+ tokenizer = AutoTokenizer.from_pretrained(model_links[selected_model])
70
+
71
+ # Initialize the text generation pipeline
72
+ pipe = pipeline(
73
+ "text-generation",
74
+ model=model,
75
+ tokenizer=tokenizer
76
+ )
77
+
78
+ st.session_state.model = selected_model
79
+ st.session_state.pipe = pipe
80
+
81
+ # Initialize chat messages
82
  if 'messages' not in st.session_state:
83
  st.session_state.messages = [
84
+ {"role": "system", "content": "You are a medical chatbot. You should only respond to health questions!"}
85
  ]
86
 
87
+ # Display chat history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  for message in st.session_state.messages:
89
+ with st.chat_message(message["role"]):
90
+ st.markdown(message["content"])
91
+
92
+ # Function to generate responses
93
+ def generate_response(messages):
94
+ messages_str = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages if msg['role'] != 'system'])
95
+ output = st.session_state.pipe(messages_str, max_new_tokens=150, temperature=temperature, return_full_text=False)
96
+ return output[0]['generated_text']
97
+
98
+ # User input
99
+ if user_input := st.chat_input("Ask a health question..."):
100
+ # Display user message
101
+ with st.chat_message("user"):
102
+ st.markdown(user_input)
 
 
 
 
 
 
 
 
 
 
103
  st.session_state.messages.append({"role": "user", "content": user_input})
104
 
105
+ # Generate and display assistant response
106
+ response = generate_response(st.session_state.messages)
107
+ with st.chat_message("assistant"):
108
+ st.markdown(response)
109
  st.session_state.messages.append({"role": "assistant", "content": response})