|
|
import streamlit as st |
|
|
import requests |
|
|
import json |
|
|
import os |
|
|
import datetime |
|
|
|
|
|
|
|
|
SPACE_URL = "https://z7svds7k42bwhhgm.us-east-1.aws.endpoints.huggingface.cloud" |
|
|
HF_API_KEY = os.getenv("HF_API_KEY") |
|
|
EOS_TOKEN = "<|end|>" |
|
|
CHAT_HISTORY_DIR = "chat_histories" |
|
|
IMAGE_PATH = "DubsChat.png" |
|
|
IMAGE_PATH_2 = "Reboot AI.png" |
|
|
Dubs_PATH = "Dubs.png" |
|
|
|
|
|
|
|
|
try: |
|
|
os.makedirs(CHAT_HISTORY_DIR, exist_ok=True) |
|
|
except OSError as e: |
|
|
st.error(f"Failed to create chat history directory: {e}") |
|
|
|
|
|
|
|
|
st.set_page_config(page_title="DUBSChat", page_icon=IMAGE_PATH, layout="wide") |
|
|
|
|
|
|
|
|
st.logo(IMAGE_PATH_2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_chat_history(session_name, messages): |
|
|
""" |
|
|
Save the chat history to a JSON file. |
|
|
""" |
|
|
file_path = os.path.join(CHAT_HISTORY_DIR, f"{session_name}.json") |
|
|
try: |
|
|
with open(file_path, "w") as f: |
|
|
json.dump(messages, f) |
|
|
except IOError as e: |
|
|
st.error(f"Failed to save chat history: {e}") |
|
|
|
|
|
|
|
|
def load_chat_history(file_name): |
|
|
""" |
|
|
Load the chat history from a JSON file. |
|
|
""" |
|
|
file_path = os.path.join(CHAT_HISTORY_DIR, file_name) |
|
|
try: |
|
|
with open(file_path, "r") as f: |
|
|
return json.load(f) |
|
|
except (FileNotFoundError, json.JSONDecodeError): |
|
|
st.error("Failed to load chat history. Starting with a new session.") |
|
|
return [] |
|
|
|
|
|
|
|
|
def get_saved_sessions(): |
|
|
""" |
|
|
Get the list of saved chat sessions. |
|
|
""" |
|
|
return [f.replace(".json", "") for f in os.listdir(CHAT_HISTORY_DIR) if f.endswith(".json")] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
|
|
|
if st.button("New Chat"): |
|
|
st.session_state["messages"] = [ |
|
|
{"role": "system", "content": "You are DUBS, a helpful assistant capable of conversing in a friendly and knowledgeable way."}, |
|
|
{"role": "assistant", "content": "Hello! How can I assist you today?"} |
|
|
] |
|
|
st.session_state["session_name"] = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
|
|
save_chat_history(st.session_state["session_name"], st.session_state["messages"]) |
|
|
st.success("Chat reset and new session started.") |
|
|
|
|
|
|
|
|
dubs_key = st.text_input("Enter Dubs Key", key="chatbot_api_key", type="password") |
|
|
|
|
|
|
|
|
saved_sessions = get_saved_sessions() |
|
|
if saved_sessions: |
|
|
selected_session = st.radio("Past Sessions:", saved_sessions) |
|
|
if st.button("Load Session"): |
|
|
st.session_state["messages"] = load_chat_history(f"{selected_session}.json") |
|
|
st.session_state["session_name"] = selected_session |
|
|
st.success(f"Loaded session: {selected_session}") |
|
|
else: |
|
|
st.write("No past sessions available.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "messages" not in st.session_state: |
|
|
st.session_state["messages"] = [ |
|
|
{"role": "system", "content": "You are DUBS, a helpful assistant capable of conversing in a friendly and knowledgeable way."}, |
|
|
{"role": "assistant", "content": "Hello! How can I assist you today?"} |
|
|
] |
|
|
if "session_name" not in st.session_state: |
|
|
st.session_state["session_name"] = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.image(IMAGE_PATH, width=250) |
|
|
st.markdown("Empowering you with a Sustainable AI") |
|
|
|
|
|
|
|
|
for message in st.session_state["messages"]: |
|
|
if message["role"] == "user": |
|
|
st.chat_message("user").write(message["content"]) |
|
|
elif message["role"] == "assistant": |
|
|
st.chat_message("assistant", avatar=Dubs_PATH).write(message["content"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def stream_response(prompt_text, api_key): |
|
|
""" |
|
|
Stream text from the HF Inference Endpoint (or any streaming API). |
|
|
Yields each chunk of text as it arrives. |
|
|
""" |
|
|
try: |
|
|
|
|
|
payload = { |
|
|
"inputs": prompt_text, |
|
|
"parameters": { |
|
|
"max_new_tokens": 250, |
|
|
"return_full_text": False, |
|
|
"stream": True |
|
|
} |
|
|
} |
|
|
headers = { |
|
|
"Accept" : "application/json", |
|
|
"Authorization": f"Bearer {api_key}", |
|
|
"Content-Type": "application/json" |
|
|
} |
|
|
|
|
|
|
|
|
response = requests.post( |
|
|
SPACE_URL, |
|
|
json=payload, |
|
|
headers=headers, |
|
|
stream=True |
|
|
) |
|
|
response.raise_for_status() |
|
|
|
|
|
|
|
|
for line in response.iter_lines(): |
|
|
if line: |
|
|
data = json.loads(line.decode("utf-8")) |
|
|
|
|
|
|
|
|
chunk = data[0].get("generated_text", "") |
|
|
yield chunk |
|
|
|
|
|
except requests.exceptions.Timeout: |
|
|
yield "The request timed out. Please try again later." |
|
|
except requests.exceptions.RequestException as e: |
|
|
yield f"Error: {e}" |
|
|
except json.JSONDecodeError: |
|
|
yield "Error decoding server response." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if prompt := st.chat_input(): |
|
|
if not dubs_key: |
|
|
st.warning("Please provide a valid Dubs Key.") |
|
|
else: |
|
|
|
|
|
st.session_state["messages"].append({"role": "user", "content": prompt}) |
|
|
st.chat_message("user").write(prompt) |
|
|
|
|
|
|
|
|
chat_history = "".join( |
|
|
[f"<|{msg['role']}|>{msg['content']}<|end|>" for msg in st.session_state["messages"]] |
|
|
) |
|
|
|
|
|
|
|
|
with st.spinner("Dubs is thinking... Woof Woof! 🐾"): |
|
|
assistant_message_placeholder = st.chat_message("assistant", avatar=Dubs_PATH).empty() |
|
|
|
|
|
full_response = "" |
|
|
|
|
|
for chunk in stream_response(chat_history, HF_API_KEY): |
|
|
full_response += chunk |
|
|
|
|
|
assistant_message_placeholder.write(full_response) |
|
|
|
|
|
|
|
|
st.session_state["messages"].append({"role": "assistant", "content": full_response}) |
|
|
|
|
|
save_chat_history(st.session_state["session_name"], st.session_state["messages"]) |
|
|
|
|
|
|