|
|
|
|
|
|
|
import streamlit as st |
|
import time |
|
import json |
|
import os |
|
import base64 |
|
import getpass |
|
from cryptography.fernet import Fernet |
|
from langchain_openai import ChatOpenAI |
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler |
|
from langchain_core.messages import HumanMessage, SystemMessage |
|
from langchain_openai import OpenAIEmbeddings |
|
from langchain_community.vectorstores import FAISS |
|
from langchain_community.document_loaders import PyPDFLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.chat_message_histories import ChatMessageHistory |
|
from langchain_core.documents import Document |
|
|
|
from langchain.callbacks.base import BaseCallbackHandler |
|
|
|
from pydantic import BaseModel, Field |
|
from typing import Annotated |
|
|
|
|
|
from autogen import ConversableAgent, LLMConfig, UpdateSystemMessage |
|
import tempfile |
|
from autogen.coding import LocalCommandLineCodeExecutor, CodeBlock |
|
import matplotlib |
|
matplotlib.use('Agg') |
|
import matplotlib.pyplot as plt |
|
import io |
|
from PIL import Image |
|
import re |
|
import subprocess |
|
import sys |
|
from typing import Tuple |
|
import contextlib |
|
|
|
|
|
def save_encrypted_key(encrypted_key, username): |
|
"""Save encrypted key to file with username prefix""" |
|
try: |
|
filename = f"{username}_encrypted_api_key" if username else ".encrypted_api_key" |
|
with open(filename, "w") as f: |
|
f.write(encrypted_key) |
|
return True |
|
except Exception as e: |
|
return False |
|
|
|
def load_encrypted_key(username): |
|
"""Load encrypted key from file with username prefix""" |
|
try: |
|
filename = f"{username}_encrypted_api_key" if username else ".encrypted_api_key" |
|
with open(filename, "r") as f: |
|
return f.read() |
|
except FileNotFoundError: |
|
return None |
|
|
|
def read_keys_from_file(file_path): |
|
with open(file_path, 'r') as file: |
|
return json.load(file) |
|
|
|
def read_prompt_from_file(path): |
|
with open(path, 'r') as f: |
|
return f.read() |
|
|
|
class Response: |
|
def __init__(self, content): |
|
self.content = content |
|
|
|
|
|
class Feedback(BaseModel): |
|
grade: Annotated[int, Field(description="Score from 1 to 10")] |
|
improvement_instructions: Annotated[str, Field(description="Advice on how to improve the reply")] |
|
|
|
class StreamHandler(BaseCallbackHandler): |
|
def __init__(self, container): |
|
self.container = container |
|
self.text = "" |
|
|
|
def on_llm_new_token(self, token: str, **kwargs): |
|
self.text += token |
|
self.container.markdown(self.text + "▌") |
|
|
|
|
|
st.set_page_config( |
|
page_title="CLAPP Agent", |
|
page_icon="🤖", |
|
layout="wide", |
|
initial_sidebar_state="auto" |
|
) |
|
|
|
st.markdown("# CLAPP: CLASS LLM Agent for Pair Programming") |
|
col1, col2, col3 = st.columns([1, 2, 1]) |
|
with col2: |
|
st.image("images/CLAPP.png", width=400) |
|
|
|
|
|
|
|
Initial_Agent_Instructions = read_prompt_from_file("prompts/class_instructions.txt") |
|
Review_Agent_Instructions = read_prompt_from_file("prompts/review_instructions.txt") |
|
|
|
Formatting_Agent_Instructions = read_prompt_from_file("prompts/formatting_instructions.txt") |
|
Code_Execution_Agent_Instructions = read_prompt_from_file("prompts/codeexecutor_instructions.txt") |
|
|
|
|
|
def init_session(): |
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
if "debug" not in st.session_state: |
|
st.session_state.debug = False |
|
if "llm" not in st.session_state: |
|
st.session_state.llm = None |
|
if "llmBG" not in st.session_state: |
|
st.session_state.llmBG = None |
|
if "memory" not in st.session_state: |
|
st.session_state.memory = ChatMessageHistory() |
|
if "vector_store" not in st.session_state: |
|
st.session_state.vector_store = None |
|
if "last_token_count" not in st.session_state: |
|
st.session_state.last_token_count = 0 |
|
if "selected_model" not in st.session_state: |
|
st.session_state.selected_model = "gpt-4o-mini" |
|
if "greeted" not in st.session_state: |
|
st.session_state.greeted = False |
|
if "debug_messages" not in st.session_state: |
|
st.session_state.debug_messages = [] |
|
|
|
|
|
init_session() |
|
|
|
|
|
|
|
|
|
with st.sidebar: |
|
st.header("🔐 API & Assistants") |
|
api_key = st.text_input("1. OpenAI API Key", type="password") |
|
username = st.text_input("2. Username (for saving your API key)", placeholder="Enter your username") |
|
user_password = st.text_input("3. Password to encrypt/decrypt API key", type="password") |
|
|
|
|
|
if api_key and user_password: |
|
|
|
key = base64.urlsafe_b64encode(user_password.ljust(32)[:32].encode()) |
|
fernet = Fernet(key) |
|
|
|
|
|
if "saved_api_key" not in st.session_state or api_key != st.session_state.saved_api_key: |
|
try: |
|
|
|
encrypted_key = fernet.encrypt(api_key.encode()) |
|
|
|
|
|
st.session_state.saved_api_key = api_key |
|
st.session_state.encrypted_key = encrypted_key.decode() |
|
|
|
|
|
if save_encrypted_key(encrypted_key.decode(), username): |
|
st.success("API key encrypted and saved! ✅") |
|
else: |
|
st.warning("API key encrypted but couldn't save to file! ⚠️") |
|
except Exception as e: |
|
st.error(f"Error saving API key: {str(e)}") |
|
|
|
|
|
elif user_password and not api_key: |
|
|
|
encrypted_key = load_encrypted_key(username) |
|
if encrypted_key: |
|
try: |
|
|
|
key = base64.urlsafe_b64encode(user_password.ljust(32)[:32].encode()) |
|
fernet = Fernet(key) |
|
|
|
|
|
decrypted_key = fernet.decrypt(encrypted_key.encode()).decode() |
|
|
|
|
|
api_key = decrypted_key |
|
st.session_state.saved_api_key = api_key |
|
st.success("API key loaded successfully! 🔑") |
|
except Exception as e: |
|
st.error("Failed to decrypt API key. Wrong password? 🔒") |
|
else: |
|
st.warning("No saved API key found. Please enter your API key first. 🔑") |
|
|
|
|
|
if st.button("🗑️ Clear Saved API Key"): |
|
deleted_files = False |
|
error_message = "" |
|
|
|
|
|
if username: |
|
filename = f"{username}_encrypted_api_key" |
|
if os.path.exists(filename): |
|
try: |
|
os.remove(filename) |
|
deleted_files = True |
|
st.success(f"Deleted key file for user: {username}") |
|
except Exception as e: |
|
error_message += f"Error clearing {filename}: {str(e)}\n" |
|
|
|
|
|
if os.path.exists(".encrypted_api_key"): |
|
try: |
|
os.remove(".encrypted_api_key") |
|
deleted_files = True |
|
st.success("Deleted default key file") |
|
except Exception as e: |
|
error_message += f"Error clearing default key file: {str(e)}\n" |
|
|
|
|
|
if "saved_api_key" in st.session_state: |
|
del st.session_state.saved_api_key |
|
if "encrypted_key" in st.session_state: |
|
del st.session_state.encrypted_key |
|
|
|
|
|
if deleted_files: |
|
st.info("Session cleared. Reloading page...") |
|
time.sleep(1) |
|
st.rerun() |
|
elif error_message: |
|
st.error(error_message) |
|
else: |
|
st.warning("No saved API keys found to delete.") |
|
|
|
st.session_state.selected_model = st.selectbox( |
|
"4. Choose LLM model 🧠", |
|
options=["gpt-4o-mini", "gpt-4o"], |
|
index=["gpt-4o-mini", "gpt-4o"].index(st.session_state.selected_model) |
|
) |
|
|
|
|
|
|
|
if "previous_model" not in st.session_state: |
|
st.session_state.previous_model = st.session_state.selected_model |
|
elif st.session_state.previous_model != st.session_state.selected_model: |
|
|
|
st.session_state.vector_store = None |
|
st.session_state.greeted = False |
|
st.session_state.messages = [] |
|
st.session_state.memory = ChatMessageHistory() |
|
st.session_state.previous_model = st.session_state.selected_model |
|
st.info("Model changed! Please initialize again with the new model.") |
|
|
|
st.write("### Response Mode") |
|
col1, col2 = st.columns([1, 2]) |
|
with col1: |
|
mode_is_fast = st.toggle("Fast Mode", value=True) |
|
with col2: |
|
if mode_is_fast: |
|
st.caption("✨ Quick responses with good quality (recommended for most uses)") |
|
else: |
|
st.caption("🎯 Swarm mode, more refined responses (may take longer)") |
|
|
|
|
|
if api_key: |
|
os.environ["OPENAI_API_KEY"] = api_key |
|
|
|
|
|
if st.button("🚀 Initialize with Selected Model"): |
|
|
|
st.session_state.llm = ChatOpenAI( |
|
model_name=st.session_state.selected_model, |
|
openai_api_key=api_key, |
|
temperature=1.0 |
|
) |
|
|
|
if st.session_state.vector_store is None: |
|
embedding_status = st.empty() |
|
embedding_status.info("🔄 Processing and embedding your RAG data... This might take a moment! ⏳") |
|
embeddings = OpenAIEmbeddings(model="text-embedding-3-large") |
|
|
|
|
|
all_docs = [] |
|
for filename in os.listdir("./class-data"): |
|
file_path = os.path.join("./class-data", filename) |
|
|
|
if filename.endswith('.pdf'): |
|
|
|
loader = PyPDFLoader(file_path) |
|
docs = loader.load() |
|
all_docs.extend(docs) |
|
elif filename.endswith(('.txt', '.py', '.ini')): |
|
|
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
text = f.read() |
|
|
|
all_docs.append(Document( |
|
page_content=text, |
|
metadata={"source": filename, "type": "code" if filename.endswith('.py') else "text"} |
|
)) |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) |
|
def sanitize(documents): |
|
for doc in documents: |
|
doc.page_content = doc.page_content.encode("utf-8", "ignore").decode("utf-8") |
|
return documents |
|
|
|
splits = text_splitter.split_documents(all_docs) |
|
splits = sanitize(splits) |
|
|
|
|
|
st.session_state.vector_store = FAISS.from_documents(splits, embedding=embeddings) |
|
embedding_status.empty() |
|
|
|
|
|
if not st.session_state.greeted: |
|
|
|
st.session_state.llm_initialized = True |
|
st.rerun() |
|
|
|
st.markdown("---") |
|
|
|
|
|
st.markdown("### 🔧 CLASS Setup") |
|
if st.checkbox("Check CLASS installation status"): |
|
try: |
|
|
|
result = subprocess.run( |
|
[sys.executable, "-c", "from classy import Class; print('CLASS successfully imported!')"], |
|
capture_output=True, |
|
text=True |
|
) |
|
|
|
if result.returncode == 0: |
|
st.success("✅ CLASS is already installed and ready to use!") |
|
else: |
|
st.error("❌ The 'classy' module is not installed. Please install CLASS using the button below.") |
|
if result.stderr: |
|
st.code(result.stderr, language="bash") |
|
except Exception as e: |
|
st.error(f"❌ Error checking CLASS installation: {str(e)}") |
|
|
|
|
|
st.text("If not installed, install CLASS to enable code execution and plotting") |
|
if st.button("🔄 Install CLASS"): |
|
|
|
status_placeholder = st.empty() |
|
status_placeholder.info("Installing CLASS... This could take a few minutes.") |
|
|
|
try: |
|
|
|
install_script_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'install_classy.sh') |
|
|
|
|
|
os.chmod(install_script_path, 0o755) |
|
|
|
|
|
process = subprocess.Popen( |
|
[install_script_path], |
|
stdout=subprocess.PIPE, |
|
stderr=subprocess.STDOUT, |
|
text=True, |
|
bufsize=1, |
|
shell=True, |
|
cwd=os.path.dirname(os.path.abspath(__file__)) |
|
) |
|
|
|
|
|
current_line_placeholder = st.empty() |
|
|
|
|
|
output_text = "" |
|
for line in iter(process.stdout.readline, ''): |
|
output_text += line |
|
|
|
if line.strip(): |
|
current_line_placeholder.info(f"Current: {line.strip()}") |
|
|
|
|
|
return_code = process.wait() |
|
|
|
|
|
current_line_placeholder.empty() |
|
|
|
|
|
if return_code == 0: |
|
status_placeholder.success("✅ CLASS installed successfully!") |
|
else: |
|
status_placeholder.error(f"❌ CLASS installation failed with return code: {return_code}") |
|
|
|
|
|
with st.expander("View Full Installation Log", expanded=False): |
|
st.code(output_text) |
|
|
|
except Exception as e: |
|
status_placeholder.error(f"Installation failed with exception: {str(e)}") |
|
st.exception(e) |
|
|
|
|
|
st.text("If CLASS is installed, test the environment") |
|
if st.button("🧪 Test CLASS"): |
|
|
|
status_placeholder = st.empty() |
|
status_placeholder.info("Testing CLASS environment... This could take a moment.") |
|
|
|
try: |
|
|
|
test_script_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'test_classy.py') |
|
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
|
process = subprocess.Popen( |
|
[sys.executable, test_script_path], |
|
stdout=subprocess.PIPE, |
|
stderr=subprocess.STDOUT, |
|
text=True, |
|
bufsize=1, |
|
cwd=temp_dir |
|
) |
|
|
|
|
|
current_line_placeholder = st.empty() |
|
|
|
|
|
output_text = "" |
|
for line in iter(process.stdout.readline, ''): |
|
output_text += line |
|
|
|
if line.strip(): |
|
current_line_placeholder.info(f"Current: {line.strip()}") |
|
|
|
|
|
return_code = process.wait() |
|
|
|
|
|
current_line_placeholder.empty() |
|
|
|
|
|
if return_code == 0: |
|
status_placeholder.success("✅ CLASS test completed successfully!") |
|
else: |
|
status_placeholder.error(f"❌ CLASS test failed with return code: {return_code}") |
|
|
|
|
|
|
|
if "ModuleNotFoundError" in output_text or "ImportError" in output_text: |
|
st.error("❌ Python module import error detected. Make sure CLASS is properly installed.") |
|
|
|
if "CosmoSevereError" in output_text or "CosmoComputationError" in output_text: |
|
st.error("❌ CLASS computation error detected.") |
|
|
|
|
|
with st.expander("View Full Test Log", expanded=False): |
|
st.code(output_text) |
|
|
|
plot_path = os.path.join(temp_dir, 'cmb_temperature_spectrum.png') |
|
if os.path.exists(plot_path): |
|
|
|
st.subheader("Generated CMB Power Spectrum") |
|
st.image(plot_path, use_container_width=True) |
|
else: |
|
st.warning("⚠️ No plot was generated") |
|
|
|
except Exception as e: |
|
status_placeholder.error(f"Test failed with exception: {str(e)}") |
|
st.exception(e) |
|
|
|
st.markdown("---") |
|
st.session_state.debug = st.checkbox("🔍 Show Debug Info") |
|
if st.button("🗑️ Reset Chat"): |
|
st.session_state.clear() |
|
st.rerun() |
|
|
|
if st.session_state.last_token_count > 0: |
|
st.markdown(f"🧮 **Last response token usage:** `{st.session_state.last_token_count}` tokens") |
|
|
|
|
|
if "generated_plots" in st.session_state and st.session_state.generated_plots: |
|
with st.expander("📊 Plot Gallery", expanded=False): |
|
st.write("All plots generated during this session:") |
|
|
|
for i, plot_path in enumerate(st.session_state.generated_plots): |
|
if os.path.exists(plot_path): |
|
st.image(plot_path, width=250, caption=os.path.basename(plot_path)) |
|
st.markdown("---") |
|
|
|
|
|
def build_messages(context, question, system): |
|
system_msg = SystemMessage(content=system) |
|
human_msg = HumanMessage(content=f"Context:\n{context}\n\nQuestion:\n{question}") |
|
return [system_msg] + st.session_state.memory.messages + [human_msg] |
|
|
|
def build_messages_rating(context, question, answer, system): |
|
system_msg = SystemMessage(content=system) |
|
human_msg = HumanMessage(content=f"Context:\n{context}\n\nQuestion:\n{question}\n\nAI Answer:\n{answer}") |
|
return [system_msg] + st.session_state.memory.messages + [human_msg] |
|
|
|
def build_messages_refinement(context, question, answer, feedback, system): |
|
system_msg = SystemMessage(content=system) |
|
human_msg = HumanMessage(content=f"Context:\n{context}\n\nQuestion:\n{question}\n\nAI Answer:\n{answer}\n\nReviewer Feedback:\n{feedback}") |
|
return [system_msg] + st.session_state.memory.messages + [human_msg] |
|
|
|
def format_memory_messages(memory_messages): |
|
formatted = "" |
|
for msg in memory_messages: |
|
role = msg.type.capitalize() |
|
content = msg.content |
|
formatted += f"{role}: {content}\n\n" |
|
return formatted.strip() |
|
|
|
|
|
def retrieve_context(question): |
|
docs = st.session_state.vector_store.similarity_search(question, k=4) |
|
return "\n\n".join([doc.page_content for doc in docs]) |
|
|
|
|
|
|
|
|
|
|
|
class PlotAwareExecutor(LocalCommandLineCodeExecutor): |
|
def __init__(self, **kwargs): |
|
import tempfile |
|
|
|
plots_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'plots') |
|
os.makedirs(plots_dir, exist_ok=True) |
|
|
|
|
|
temp_dir = tempfile.TemporaryDirectory() |
|
kwargs['work_dir'] = temp_dir.name |
|
super().__init__(**kwargs) |
|
self._temp_dir = temp_dir |
|
self._plots_dir = plots_dir |
|
|
|
@contextlib.contextmanager |
|
def _capture_output(self): |
|
old_out, old_err = sys.stdout, sys.stderr |
|
buf_out, buf_err = io.StringIO(), io.StringIO() |
|
sys.stdout, sys.stderr = buf_out, buf_err |
|
try: |
|
yield buf_out, buf_err |
|
finally: |
|
sys.stdout, sys.stderr = old_out, old_err |
|
|
|
def execute_code(self, code: str): |
|
|
|
match = re.search(r"```(?:python)?\n(.*?)```", code, re.DOTALL) |
|
cleaned = match.group(1) if match else code |
|
cleaned = cleaned.replace("plt.show()", "") |
|
|
|
|
|
timestamp = time.strftime("%Y-%m-%d-%H-%M-%S") |
|
plot_filename = f'plot_{timestamp}.png' |
|
plot_path = os.path.join(self._plots_dir, plot_filename) |
|
temp_plot_path = None |
|
|
|
for line in cleaned.split("\n"): |
|
if "plt.savefig" in line: |
|
temp_plot_path = os.path.join(self._temp_dir.name, f'temporary_{timestamp}.png') |
|
cleaned = cleaned.replace(line, f"plt.savefig('{temp_plot_path}', dpi=300)") |
|
break |
|
else: |
|
|
|
if "plt." in cleaned: |
|
temp_plot_path = os.path.join(self._temp_dir.name, f'temporary_{timestamp}.png') |
|
cleaned += f"\nplt.savefig('{temp_plot_path}')" |
|
|
|
|
|
temp_script_path = os.path.join(self._temp_dir.name, f'temp_script_{timestamp}.py') |
|
with open(temp_script_path, 'w') as f: |
|
f.write(cleaned) |
|
|
|
full_output = "" |
|
try: |
|
|
|
process = subprocess.Popen( |
|
[sys.executable, temp_script_path], |
|
stdout=subprocess.PIPE, |
|
stderr=subprocess.STDOUT, |
|
text=True, |
|
bufsize=1, |
|
cwd=self._temp_dir.name |
|
) |
|
stdout, _ = process.communicate() |
|
|
|
|
|
with self._capture_output() as (out_buf, err_buf): |
|
if stdout: |
|
out_buf.write(stdout) |
|
stdout_text = out_buf.getvalue() |
|
stderr_text = err_buf.getvalue() |
|
|
|
if stdout_text: |
|
full_output += f"STDOUT:\n{stdout_text}\n" |
|
if stderr_text: |
|
full_output += f"STDERR:\n{stderr_text}\n" |
|
|
|
|
|
if temp_plot_path and os.path.exists(temp_plot_path): |
|
import shutil |
|
shutil.copy2(temp_plot_path, plot_path) |
|
|
|
if "generated_plots" not in st.session_state: |
|
st.session_state.generated_plots = [] |
|
|
|
st.session_state.generated_plots.append(plot_path) |
|
|
|
except Exception: |
|
with self._capture_output() as (out_buf, err_buf): |
|
import traceback |
|
traceback.print_exc(file=sys.stderr) |
|
full_output += f"STDERR:\n{err_buf.getvalue()}\n" |
|
|
|
return full_output, plot_path |
|
|
|
|
|
executor = PlotAwareExecutor(timeout=10) |
|
|
|
|
|
initial_config = LLMConfig( |
|
api_type="openai", |
|
model=st.session_state.selected_model, |
|
temperature=0.2, |
|
api_key=api_key, |
|
) |
|
|
|
review_config = LLMConfig( |
|
api_type="openai", |
|
model=st.session_state.selected_model, |
|
temperature=0.7, |
|
api_key=api_key, |
|
response_format=Feedback |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
formatting_config = LLMConfig( |
|
api_type="openai", |
|
model=st.session_state.selected_model, |
|
temperature=0.3, |
|
api_key=api_key, |
|
) |
|
|
|
code_execution_config = LLMConfig( |
|
api_type="openai", |
|
model=st.session_state.selected_model, |
|
temperature=0.1, |
|
api_key=api_key, |
|
) |
|
|
|
|
|
initial_agent = ConversableAgent( |
|
name="initial_agent", |
|
system_message=f""" |
|
{Initial_Agent_Instructions}""", |
|
human_input_mode="NEVER", |
|
llm_config=initial_config |
|
) |
|
|
|
review_agent = ConversableAgent( |
|
name="review_agent", |
|
system_message=f"""{Review_Agent_Instructions}""", |
|
human_input_mode="NEVER", |
|
llm_config=review_config |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
formatting_agent = ConversableAgent( |
|
name="formatting_agent", |
|
system_message="""{Formatting_Agent_Instructions}""", |
|
human_input_mode="NEVER", |
|
llm_config=formatting_config |
|
) |
|
|
|
code_executor = ConversableAgent( |
|
name="code_executor", |
|
system_message="""{Code_Execution_Agent_Instructions}""", |
|
human_input_mode="NEVER", |
|
llm_config=code_execution_config, |
|
code_execution_config={"executor": executor}, |
|
max_consecutive_auto_reply=50 |
|
) |
|
|
|
def call_ai(context, user_input): |
|
if mode_is_fast: |
|
messages = build_messages(context, user_input, Initial_Agent_Instructions) |
|
response = st.session_state.llm.invoke(messages) |
|
return Response(content=response.content) |
|
else: |
|
|
|
st.markdown("Thinking (Swarm Mode)... ") |
|
|
|
|
|
conversation_history = format_memory_messages(st.session_state.memory.messages) |
|
|
|
|
|
st.markdown("Generating initial draft...") |
|
chat_result_1 = initial_agent.initiate_chat( |
|
recipient=initial_agent, |
|
message=f"Conversation history:\n{conversation_history}\n\nContext from documents: {context}\n\nUser question: {user_input}", |
|
max_turns=1, |
|
summary_method="last_msg" |
|
) |
|
draft_answer = chat_result_1.summary |
|
if st.session_state.debug: |
|
st.session_state.debug_messages.append(("Initial Draft", draft_answer)) |
|
|
|
|
|
st.markdown("Reviewing draft...") |
|
chat_result_2 = review_agent.initiate_chat( |
|
recipient=review_agent, |
|
message=f"Conversation history:\n{conversation_history}\n\nPlease review this draft answer:\n{draft_answer}", |
|
max_turns=1, |
|
summary_method="last_msg" |
|
) |
|
review_feedback = chat_result_2.summary |
|
if st.session_state.debug: |
|
st.session_state.debug_messages.append(("Review Feedback", review_feedback)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.markdown("Formatting final answer...") |
|
chat_result_4 = formatting_agent.initiate_chat( |
|
recipient=formatting_agent, |
|
message=f"""Please format this answer while preserving any code blocks: |
|
{draft_answer}""", |
|
max_turns=1, |
|
summary_method="last_msg" |
|
) |
|
formatted_answer = chat_result_4.summary |
|
if st.session_state.debug: |
|
st.session_state.debug_messages.append(("Formatted Answer", formatted_answer)) |
|
|
|
|
|
if "```python" in formatted_answer: |
|
|
|
formatted_answer += "\n\n> 💡 **Note**: This answer contains code. If you want to execute it, type 'execute!' in the chat." |
|
return Response(content=formatted_answer) |
|
else: |
|
return Response(content=formatted_answer) |
|
|
|
|
|
|
|
user_input = st.chat_input("Type your prompt here...") |
|
|
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
|
|
if "PLOT_PATH:" in message["content"]: |
|
|
|
parts = message["content"].split("PLOT_PATH:") |
|
|
|
st.markdown(parts[0]) |
|
|
|
for plot_info in parts[1:]: |
|
plot_path = plot_info.split('\n')[0].strip() |
|
if os.path.exists(plot_path): |
|
st.image(plot_path, width=700) |
|
else: |
|
st.markdown(message["content"]) |
|
|
|
|
|
if user_input: |
|
|
|
st.session_state.messages.append({"role": "user", "content": user_input}) |
|
with st.chat_message("user"): |
|
st.markdown(user_input) |
|
|
|
st.session_state.memory.add_user_message(user_input) |
|
context = retrieve_context(user_input) |
|
|
|
|
|
try: |
|
import tiktoken |
|
enc = tiktoken.encoding_for_model("gpt-4") |
|
st.session_state.last_token_count = len(enc.encode(user_input)) |
|
except: |
|
st.session_state.last_token_count = 0 |
|
|
|
|
|
with st.chat_message("assistant"): |
|
stream_box = st.empty() |
|
stream_handler = StreamHandler(stream_box) |
|
|
|
|
|
st.session_state.llm = ChatOpenAI( |
|
model_name=st.session_state.selected_model, |
|
streaming=True, |
|
callbacks=[stream_handler], |
|
openai_api_key=api_key, |
|
temperature=0.2 |
|
) |
|
|
|
|
|
if user_input.strip().lower() == "execute!": |
|
|
|
last_assistant_message = None |
|
for message in reversed(st.session_state.messages): |
|
if message["role"] == "assistant" and "```" in message["content"]: |
|
last_assistant_message = message["content"] |
|
break |
|
|
|
if last_assistant_message: |
|
st.markdown("Executing code...") |
|
st.info("🚀 Executing cleaned code...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
execution_output, plot_path = executor.execute_code(last_assistant_message) |
|
st.subheader("Execution Output") |
|
st.text(execution_output) |
|
|
|
if os.path.exists(plot_path): |
|
st.success("✅ Plot generated successfully!") |
|
|
|
|
|
st.image(plot_path, width=700) |
|
else: |
|
st.warning("⚠️ No plot was generated") |
|
|
|
|
|
max_iterations = 3 |
|
current_iteration = 0 |
|
has_errors = any(error_indicator in execution_output for error_indicator in ["Traceback", "Error:", "Exception:", "TypeError:", "ValueError:", "NameError:", "SyntaxError:", "Error in Class"]) |
|
|
|
while has_errors and current_iteration < max_iterations: |
|
current_iteration += 1 |
|
st.error(f"Previous error: {execution_output}") |
|
st.info(f"🔧 Fixing errors (attempt {current_iteration}/{max_iterations})...") |
|
|
|
|
|
review_message = f""" |
|
Previous answer had errors during execution: |
|
{execution_output} |
|
|
|
Please review and suggest fixes for this answer. IMPORTANT: Preserve all code blocks exactly as they are, only fix actual errors: |
|
{last_assistant_message} |
|
""" |
|
chat_result_2 = review_agent.initiate_chat( |
|
recipient=review_agent, |
|
message=review_message, |
|
max_turns=1, |
|
summary_method="last_msg" |
|
) |
|
review_feedback = chat_result_2.summary |
|
if st.session_state.debug: |
|
st.session_state.debug_messages.append(("Error Review Feedback", review_feedback)) |
|
|
|
|
|
chat_result_3 = initial_agent.initiate_chat( |
|
recipient=initial_agent, |
|
message=f"""Original answer: {last_assistant_message} |
|
Review feedback with error fixes: {review_feedback} |
|
IMPORTANT: Only fix actual errors in the code blocks. Preserve all working code exactly as it is.""", |
|
max_turns=1, |
|
summary_method="last_msg" |
|
) |
|
corrected_answer = chat_result_3.summary |
|
if st.session_state.debug: |
|
st.session_state.debug_messages.append(("Corrected Answer", corrected_answer)) |
|
|
|
|
|
chat_result_4 = formatting_agent.initiate_chat( |
|
recipient=formatting_agent, |
|
message=f"""Please format this corrected answer while preserving all code blocks: |
|
{corrected_answer} |
|
""", |
|
max_turns=1, |
|
summary_method="last_msg" |
|
) |
|
formatted_answer = chat_result_4.summary |
|
if st.session_state.debug: |
|
st.session_state.debug_messages.append(("Formatted Corrected Answer", formatted_answer)) |
|
|
|
|
|
st.info("🚀 Executing corrected code...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
execution_output, plot_path = executor.execute_code(formatted_answer) |
|
st.subheader("Execution Output") |
|
st.text(execution_output) |
|
|
|
if os.path.exists(plot_path): |
|
st.success("✅ Plot generated successfully!") |
|
|
|
st.image(plot_path, width=700) |
|
else: |
|
st.warning("⚠️ No plot was generated") |
|
|
|
if st.session_state.debug: |
|
st.session_state.debug_messages.append(("Execution Output", execution_output)) |
|
|
|
|
|
if not has_errors or current_iteration == max_iterations: |
|
|
|
final_answer = formatted_answer if formatted_answer else last_assistant_message |
|
response_text = f"Execution completed successfully:\n{execution_output}\n\nThe following code was executed:\n```python\n{final_answer}\n```" |
|
|
|
|
|
if os.path.exists(plot_path): |
|
response_text += f"\n\nPLOT_PATH:{plot_path}\n" |
|
|
|
if current_iteration > 0: |
|
response_text = f"After {current_iteration} correction attempts: " + response_text |
|
|
|
|
|
response = Response(content=response_text) |
|
|
|
|
|
last_assistant_message = formatted_answer |
|
has_errors = any(error_indicator in execution_output for error_indicator in ["Traceback", "Error:", "Exception:", "TypeError:", "ValueError:", "NameError:", "SyntaxError:", "Error in Class"]) |
|
|
|
if has_errors: |
|
st.markdown("> ⚠️ **Note**: Some errors could not be fixed after multiple attempts. You can request changes by describing them in the chat.") |
|
st.markdown(f"> ❌ Last execution message:\n{execution_output}") |
|
response = Response(content=f"Execution completed with errors:\n{execution_output}") |
|
else: |
|
|
|
if any(error_indicator in execution_output for error_indicator in ["Traceback", "Error:", "Exception:", "TypeError:", "ValueError:", "NameError:", "SyntaxError:"]): |
|
st.markdown("> ⚠️ **Note**: Code execution completed but with errors. You can request changes by describing them in the chat.") |
|
st.markdown(f"> ❌ Execution message:\n{execution_output}") |
|
response = Response(content=f"Execution completed with errors:\n{execution_output}") |
|
else: |
|
st.markdown(f"> ✅ Code executed successfully. Last execution message:\n{execution_output}") |
|
|
|
|
|
with st.expander("View Successfully Executed Code", expanded=False): |
|
st.markdown(last_assistant_message) |
|
|
|
|
|
response_text = f"Execution completed successfully:\n{execution_output}\n\nThe following code was executed:\n```python\n{last_assistant_message}\n```" |
|
|
|
|
|
if os.path.exists(plot_path): |
|
response_text += f"\n\nPLOT_PATH:{plot_path}\n" |
|
|
|
response = Response(content=response_text) |
|
else: |
|
response = Response(content="No code found to execute in the previous messages.") |
|
else: |
|
response = call_ai(context, user_input) |
|
if not mode_is_fast: |
|
st.markdown(response.content) |
|
|
|
st.session_state.memory.add_ai_message(response.content) |
|
st.session_state.messages.append({"role": "assistant", "content": response.content}) |
|
|
|
|
|
|
|
if "llm_initialized" in st.session_state and st.session_state.llm_initialized and not st.session_state.greeted: |
|
|
|
with st.chat_message("assistant"): |
|
|
|
welcome_container = st.empty() |
|
|
|
|
|
welcome_stream_handler = StreamHandler(welcome_container) |
|
|
|
|
|
streaming_llm = ChatOpenAI( |
|
model_name=st.session_state.selected_model, |
|
streaming=True, |
|
callbacks=[welcome_stream_handler], |
|
openai_api_key=api_key, |
|
temperature=0.2 |
|
) |
|
|
|
|
|
greeting = streaming_llm.invoke([ |
|
SystemMessage(content=Initial_Agent_Instructions), |
|
HumanMessage(content="Please greet the user and briefly explain what you can do as the CLASS code assistant.") |
|
]) |
|
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": greeting.content}) |
|
st.session_state.memory.add_ai_message(greeting.content) |
|
st.session_state.greeted = True |
|
|
|
|
|
if st.session_state.debug: |
|
with st.sidebar.expander("🛠️ Debug Information", expanded=True): |
|
|
|
debug_container = st.container() |
|
with debug_container: |
|
st.markdown("### Debug Messages") |
|
|
|
|
|
for title, message in st.session_state.debug_messages: |
|
st.markdown(f"### {title}") |
|
st.markdown(message) |
|
st.markdown("---") |
|
|
|
with st.sidebar.expander("🛠️ Context Used"): |
|
if "context" in locals(): |
|
st.markdown(context) |
|
else: |
|
st.markdown("No context retrieved yet.") |