Spaces:
Sleeping
Sleeping
import requests | |
import json | |
import streamlit as st | |
import os | |
from dotenv import load_dotenv | |
load_dotenv() | |
# Page config should be the first Streamlit command | |
st.set_page_config( | |
page_title="Chat with Einstein LLMs!", | |
page_icon=":brain:", | |
layout="wide", | |
) | |
# Initialize session state for authentication | |
if "authenticated" not in st.session_state: | |
st.session_state.authenticated = False | |
def check_password(): | |
"""Returns `True` if the user had the correct password.""" | |
password_input = os.getenv("PASSWORD") | |
def password_entered(): | |
"""Checks whether a password entered by the user is correct.""" | |
if st.session_state["password"] == password_input: | |
st.session_state.authenticated = True | |
del st.session_state["password"] # Remove password from session state for security | |
else: | |
st.session_state.authenticated = False | |
st.error("π Password incorrect") | |
# Show input for password | |
st.text_input( | |
"Please enter the password to access the Einstein Assistant", | |
type="password", | |
on_change=password_entered, | |
key="password" | |
) | |
# if not st.session_state.authenticated: | |
# st.error("π Password incorrect") | |
# return False | |
# return True | |
# Show chat interface only if authenticated | |
if st.session_state.authenticated: | |
client_id = os.getenv("CLIENT_ID") | |
client_secret = os.getenv("CLIENT_SECRET") | |
base_url = os.getenv("BASE_URL") | |
def get_access_token(): | |
url = base_url+"/services/oauth2/token" | |
payload = { | |
"grant_type": "client_credentials", | |
"client_id": client_id, | |
"client_secret": client_secret | |
} | |
response = requests.post(url, data=payload) | |
# Add error handling for response | |
if response.status_code != 200: | |
st.error(f"Error fetching access token: {response.status_code} - {response.text}") | |
return None | |
data = response.json() | |
access_token = data.get('access_token', 'Token not found') | |
return access_token | |
# Add model selection dictionary | |
MODEL_OPTIONS = { | |
"GPT4-Omni": "sfdc_ai__DefaultOpenAIGPT4Omni", | |
"Gemini": "sfdc_ai__DefaultVertexAIGemini20Flash001", | |
"Claude": "sfdc_ai__DefaultBedrockAnthropicClaude37Sonnet" | |
} | |
# Add sidebar with model selection | |
with st.sidebar: | |
st.title("Model Settings") | |
selected_model_name = st.selectbox( | |
"Choose AI Model", | |
options=list(MODEL_OPTIONS.keys()), | |
index=0 | |
) | |
model = MODEL_OPTIONS[selected_model_name] | |
# Update the page title to reflect selected model | |
st.subheader(f"π€ Chat with {selected_model_name}") | |
# Modify get_gpt_response function to use selected model | |
def get_gpt_response(prompt): | |
url = f"https://api.salesforce.com/einstein/platform/v1/models/{model}/chat-generations" | |
access_token = get_access_token() | |
headers = { | |
"Authorization": f"Bearer {access_token}", | |
"Content-Type": "application/json;charset=utf-8", | |
'x-sfdc-app-context': 'EinsteinGPT', | |
'x-client-feature-id': 'ai-platform-models-connected-app' | |
} | |
chat_payload = { | |
"messages": prompt | |
} | |
try: | |
response = requests.post(url, headers=headers, data=json.dumps(chat_payload)) | |
response.raise_for_status() # Raise exception for bad status codes | |
data = response.json() | |
return data["generationDetails"]["generations"][0]["content"] | |
except requests.exceptions.RequestException as e: | |
st.error(f"Error calling the API: {str(e)}") | |
return "I apologize, but I encountered an error. Please try again." | |
except (KeyError, IndexError) as e: | |
st.error(f"Error parsing response: {str(e)}") | |
return "I apologize, but I received an invalid response. Please try again." | |
# Display the chatbot's title on the page | |
if "messages" not in st.session_state: | |
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}] | |
for msg in st.session_state.messages: | |
st.chat_message(msg["role"], avatar="π€").write(msg["content"]) | |
if "image" in msg: | |
st.image(msg["image"]) | |
prompt = st.chat_input( | |
"Say something and/or attach an image", | |
accept_file=True, | |
file_type=["jpg", "jpeg", "png"], | |
) | |
if prompt: | |
# Handle text input | |
if prompt.text: | |
st.session_state.messages.append({"role": "user", "content": prompt.text}) | |
st.chat_message("user").write(prompt.text) | |
# Handle image upload | |
if prompt.get("files"): | |
uploaded_file = prompt["files"][0] | |
st.session_state.messages.append({ | |
"role": "user", | |
"content": "Uploaded an image", | |
"image": uploaded_file | |
}) | |
st.chat_message("user").write("Uploaded an image") | |
st.image(uploaded_file) | |
# Get AI response if there's any input | |
if prompt.text or prompt.get("files"): | |
msg = get_gpt_response(st.session_state.messages) | |
st.session_state.messages.append({"role": "assistant", "content": msg}) | |
st.chat_message("assistant", avatar="π€").write(msg) | |
else: | |
# Show login page | |
st.title("Welcome to Einstein Assistant") | |
#st.markdown("Please log in to continue") | |
check_password() |