File size: 45,618 Bytes
8343c13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
# This script requires Streamlit and LangChain
# Install it with: pip install streamlit openai langchain langchain-openai langchain-community

import streamlit as st
import time
import json
import os
import base64
import getpass
from cryptography.fernet import Fernet
from langchain_openai import ChatOpenAI
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.documents import Document

from langchain.callbacks.base import BaseCallbackHandler

from pydantic import BaseModel, Field
from typing import Annotated


from autogen import ConversableAgent, LLMConfig, UpdateSystemMessage
import tempfile
from autogen.coding import LocalCommandLineCodeExecutor, CodeBlock
import matplotlib
matplotlib.use('Agg')  # Set the backend to Agg before importing pyplot
import matplotlib.pyplot as plt
import io
from PIL import Image
import re
import subprocess
import sys
from typing import Tuple
import contextlib  # for contextlib.contextmanager

# --- Helper Functions ---
def save_encrypted_key(encrypted_key, username):
    """Save encrypted key to file with username prefix"""
    try:
        filename = f"{username}_encrypted_api_key" if username else ".encrypted_api_key"
        with open(filename, "w") as f:
            f.write(encrypted_key)
        return True
    except Exception as e:
        return False

def load_encrypted_key(username):
    """Load encrypted key from file with username prefix"""
    try:
        filename = f"{username}_encrypted_api_key" if username else ".encrypted_api_key"
        with open(filename, "r") as f:
            return f.read()
    except FileNotFoundError:
        return None

def read_keys_from_file(file_path):
    with open(file_path, 'r') as file:
        return json.load(file)

def read_prompt_from_file(path):
    with open(path, 'r') as f:
        return f.read()
    
class Response:
    def __init__(self, content):
        self.content = content


class Feedback(BaseModel):
    grade: Annotated[int, Field(description="Score from 1 to 10")]
    improvement_instructions: Annotated[str, Field(description="Advice on how to improve the reply")]

class StreamHandler(BaseCallbackHandler):
    def __init__(self, container):
        self.container = container
        self.text = ""

    def on_llm_new_token(self, token: str, **kwargs):
        self.text += token
        self.container.markdown(self.text + "โ–Œ")

# --- Streamlit Page Config ---
st.set_page_config(
    page_title="CLAPP Agent",
    page_icon="๐Ÿค–",
    layout="wide",
    initial_sidebar_state="auto"
)

st.markdown("# CLAPP: CLASS LLM Agent for Pair Programming")
col1, col2, col3 = st.columns([1, 2, 1])
with col2:
    st.image("images/CLAPP.png", width=400)


# New prompts for the swarm
Initial_Agent_Instructions = read_prompt_from_file("prompts/class_instructions.txt") # Reuse or adapt class_instructions
Review_Agent_Instructions = read_prompt_from_file("prompts/review_instructions.txt") # Adapt rating_instructions
#Typo_Agent_Instructions = read_prompt_from_file("prompts/typo_instructions.txt")   # New prompt file
Formatting_Agent_Instructions = read_prompt_from_file("prompts/formatting_instructions.txt") # New prompt file
Code_Execution_Agent_Instructions = read_prompt_from_file("prompts/codeexecutor_instructions.txt") # New prompt file

# --- Initialize Session State ---
def init_session():
    if "messages" not in st.session_state:
        st.session_state.messages = []
    if "debug" not in st.session_state:
        st.session_state.debug = False
    if "llm" not in st.session_state:
        st.session_state.llm = None
    if "llmBG" not in st.session_state:
        st.session_state.llmBG = None
    if "memory" not in st.session_state:
        st.session_state.memory = ChatMessageHistory()
    if "vector_store" not in st.session_state:
        st.session_state.vector_store = None
    if "last_token_count" not in st.session_state:
        st.session_state.last_token_count = 0
    if "selected_model" not in st.session_state:
        st.session_state.selected_model = "gpt-4o-mini"
    if "greeted" not in st.session_state:
        st.session_state.greeted = False
    if "debug_messages" not in st.session_state:
        st.session_state.debug_messages = []


init_session()



# --- Sidebar Configuration ---
with st.sidebar:
    st.header("๐Ÿ” API & Assistants")
    api_key = st.text_input("1. OpenAI API Key", type="password")
    username = st.text_input("2. Username (for saving your API key)", placeholder="Enter your username")
    user_password = st.text_input("3. Password to encrypt/decrypt API key", type="password")
    
    # When both API key and password are provided
    if api_key and user_password:
        # Create encryption key from password
        key = base64.urlsafe_b64encode(user_password.ljust(32)[:32].encode())
        fernet = Fernet(key)
        
        # If this is a new API key, encrypt and save it
        if "saved_api_key" not in st.session_state or api_key != st.session_state.saved_api_key:
            try:
                # Encrypt the API key
                encrypted_key = fernet.encrypt(api_key.encode())
                
                # Save to session state and file
                st.session_state.saved_api_key = api_key
                st.session_state.encrypted_key = encrypted_key.decode()
                
                # Save to file
                if save_encrypted_key(encrypted_key.decode(), username):
                    st.success("API key encrypted and saved! โœ…")
                else:
                    st.warning("API key encrypted but couldn't save to file! โš ๏ธ")
            except Exception as e:
                st.error(f"Error saving API key: {str(e)}")
    
    # Try to load saved API key if password is provided
    elif user_password and not api_key:
        # Try to load from file first
        encrypted_key = load_encrypted_key(username)
        if encrypted_key:
            try:
                # Recreate encryption key
                key = base64.urlsafe_b64encode(user_password.ljust(32)[:32].encode())
                fernet = Fernet(key)
                
                # Decrypt the saved key
                decrypted_key = fernet.decrypt(encrypted_key.encode()).decode()
                
                # Set the API key
                api_key = decrypted_key
                st.session_state.saved_api_key = api_key
                st.success("API key loaded successfully! ๐Ÿ”‘")
            except Exception as e:
                st.error("Failed to decrypt API key. Wrong password? ๐Ÿ”’")
        else:
            st.warning("No saved API key found. Please enter your API key first. ๐Ÿ”‘")

    # Add clear saved key button
    if st.button("๐Ÿ—‘๏ธ Clear Saved API Key"):
        deleted_files = False
        error_message = ""
        
        # Try to delete username-specific file if it exists
        if username:
            filename = f"{username}_encrypted_api_key"
            if os.path.exists(filename):
                try:
                    os.remove(filename)
                    deleted_files = True
                    st.success(f"Deleted key file for user: {username}")
                except Exception as e:
                    error_message += f"Error clearing {filename}: {str(e)}\n"
        
        # Also try to delete the default file if it exists
        if os.path.exists(".encrypted_api_key"):
            try:
                os.remove(".encrypted_api_key")
                deleted_files = True
                st.success("Deleted default key file")
            except Exception as e:
                error_message += f"Error clearing default key file: {str(e)}\n"
        
        # Clean up session state
        if "saved_api_key" in st.session_state:
            del st.session_state.saved_api_key
        if "encrypted_key" in st.session_state:
            del st.session_state.encrypted_key
        
        # Show appropriate message
        if deleted_files:
            st.info("Session cleared. Reloading page...")
            time.sleep(1)  # Brief pause so user can see the message
            st.rerun()
        elif error_message:
            st.error(error_message)
        else:
            st.warning("No saved API keys found to delete.")

    st.session_state.selected_model = st.selectbox(
        "4. Choose LLM model ๐Ÿง ",
        options=["gpt-4o-mini", "gpt-4o"],
        index=["gpt-4o-mini", "gpt-4o"].index(st.session_state.selected_model)
    )


    # Check if model has changed
    if "previous_model" not in st.session_state:
        st.session_state.previous_model = st.session_state.selected_model
    elif st.session_state.previous_model != st.session_state.selected_model:
        # Reset relevant state variables when model changes
        st.session_state.vector_store = None
        st.session_state.greeted = False
        st.session_state.messages = []
        st.session_state.memory = ChatMessageHistory()
        st.session_state.previous_model = st.session_state.selected_model
        st.info("Model changed! Please initialize again with the new model.")

    st.write("### Response Mode")
    col1, col2 = st.columns([1, 2])
    with col1:
        mode_is_fast = st.toggle("Fast Mode", value=True)
    with col2:
        if mode_is_fast:
            st.caption("โœจ Quick responses with good quality (recommended for most uses)")
        else:
            st.caption("๐ŸŽฏ Swarm mode, more refined responses (may take longer)")
    

    if api_key:
        os.environ["OPENAI_API_KEY"] = api_key
        
        # Initialize only after model is selected
        if st.button("๐Ÿš€ Initialize with Selected Model"):
            # First initialization without streaming
            st.session_state.llm = ChatOpenAI(
                    model_name=st.session_state.selected_model,
                    openai_api_key=api_key,
                    temperature=1.0
            )

            if st.session_state.vector_store is None:
                embedding_status = st.empty()
                embedding_status.info("๐Ÿ”„ Processing and embedding your RAG data... This might take a moment! โณ")
                embeddings = OpenAIEmbeddings(model="text-embedding-3-large")
                
                # Get all files from class-data directory
                all_docs = []
                for filename in os.listdir("./class-data"):
                    file_path = os.path.join("./class-data", filename)
                    
                    if filename.endswith('.pdf'):
                        # Handle PDF files
                        loader = PyPDFLoader(file_path)
                        docs = loader.load()
                        all_docs.extend(docs)
                    elif filename.endswith(('.txt', '.py', '.ini')):  # Added .py extension
                        # Handle text and Python files
                        with open(file_path, 'r', encoding='utf-8') as f:
                            text = f.read()
                            # Create a document with metadata
                            all_docs.append(Document(
                                page_content=text,
                                metadata={"source": filename, "type": "code" if filename.endswith('.py') else "text"}
                            ))

                # Split and process all documents
                text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
                def sanitize(documents):
                    for doc in documents:
                        doc.page_content = doc.page_content.encode("utf-8", "ignore").decode("utf-8")
                    return documents
                    
                splits = text_splitter.split_documents(all_docs)
                splits = sanitize(splits)
                
                # Create vector store from all documents
                st.session_state.vector_store = FAISS.from_documents(splits, embedding=embeddings)
                embedding_status.empty()  # Clear the loading message

            # Initialize but don't generate welcome message yet
            if not st.session_state.greeted:
                # Just set the initialized flag, we'll generate the welcome message later
                st.session_state.llm_initialized = True
                st.rerun()  # Refresh the page to show the initialized state

    st.markdown("---")  # Add a separator for better visual organization
    
    # Check if CLASS is already installed
    st.markdown("### ๐Ÿ”ง CLASS Setup")
    if st.checkbox("Check CLASS installation status"):
        try:
            # Use sys.executable to run a simple test to see if classy can be imported
            result = subprocess.run(
                [sys.executable, "-c", "from classy import Class; print('CLASS successfully imported!')"],
                capture_output=True,
                text=True
            )
            
            if result.returncode == 0:
                st.success("โœ… CLASS is already installed and ready to use!")
            else:
                st.error("โŒ The 'classy' module is not installed. Please install CLASS using the button below.")
                if result.stderr:
                    st.code(result.stderr, language="bash")
        except Exception as e:
            st.error(f"โŒ Error checking CLASS installation: {str(e)}")
    
    # Add CLASS installation and testing buttons
    st.text("If not installed, install CLASS to enable code execution and plotting")
    if st.button("๐Ÿ”„ Install CLASS"):
        # Show simple initial message
        status_placeholder = st.empty()
        status_placeholder.info("Installing CLASS... This could take a few minutes.")
        
        try:
            # Get the path to install_classy.sh
            install_script_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'install_classy.sh')
            
            # Make the script executable
            os.chmod(install_script_path, 0o755)
            
            # Run the installation script with shell=True to ensure proper execution
            process = subprocess.Popen(
                [install_script_path],
                stdout=subprocess.PIPE,
                stderr=subprocess.STDOUT,
                text=True,
                bufsize=1,
                shell=True,
                cwd=os.path.dirname(os.path.abspath(__file__))
            )
            
            # Create a placeholder for the current line
            current_line_placeholder = st.empty()
            
            # Collect output in the background while showing just the last line
            output_text = ""
            for line in iter(process.stdout.readline, ''):
                output_text += line
                # Update the placeholder with just the current line (real-time feedback)
                if line.strip():  # Only update for non-empty lines
                    current_line_placeholder.info(f"Current: {line.strip()}")
            
            # Get the final return code
            return_code = process.wait()
            
            # Clear the current line placeholder when done
            current_line_placeholder.empty()
            
            # Update status based on result
            if return_code == 0:
                status_placeholder.success("โœ… CLASS installed successfully!")
            else:
                status_placeholder.error(f"โŒ CLASS installation failed with return code: {return_code}")
                
            # Display the full output in an expander (not expanded by default)
            with st.expander("View Full Installation Log", expanded=False):
                st.code(output_text)
                
        except Exception as e:
            status_placeholder.error(f"Installation failed with exception: {str(e)}")
            st.exception(e)  # Show the full exception for debugging

    # Add test environment button
    st.text("If CLASS is installed, test the environment")
    if st.button("๐Ÿงช Test CLASS"):
        # Show simple initial message
        status_placeholder = st.empty()
        status_placeholder.info("Testing CLASS environment... This could take a moment.")
        
        try:
            # Get the path to test_classy.py
            test_script_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'test_classy.py')
            
            # Create a temporary directory for the test
            with tempfile.TemporaryDirectory() as temp_dir:
                # Run the test script with streaming output
                process = subprocess.Popen(
                    [sys.executable, test_script_path],
                    stdout=subprocess.PIPE,
                    stderr=subprocess.STDOUT,
                    text=True,
                    bufsize=1,
                    cwd=temp_dir
                )
                
                # Create a placeholder for the current line
                current_line_placeholder = st.empty()
                
                # Collect output in the background while showing just the last line
                output_text = ""
                for line in iter(process.stdout.readline, ''):
                    output_text += line
                    # Update the placeholder with just the current line (real-time feedback)
                    if line.strip():  # Only update for non-empty lines
                        current_line_placeholder.info(f"Current: {line.strip()}")
                
                # Get the final return code
                return_code = process.wait()
                
                # Clear the current line placeholder when done
                current_line_placeholder.empty()
                
                # Update status based on result
                if return_code == 0:
                    status_placeholder.success("โœ… CLASS test completed successfully!")
                else:
                    status_placeholder.error(f"โŒ CLASS test failed with return code: {return_code}")
                
                
                # Check for common errors
                if "ModuleNotFoundError" in output_text or "ImportError" in output_text:
                    st.error("โŒ Python module import error detected. Make sure CLASS is properly installed.")
                
                if "CosmoSevereError" in output_text or "CosmoComputationError" in output_text:
                    st.error("โŒ CLASS computation error detected.")
                
                # Display the full output in an expander (not expanded by default)
                with st.expander("View Full Test Log", expanded=False):
                    st.code(output_text)
                    # Check if the plot was generated
                    plot_path = os.path.join(temp_dir, 'cmb_temperature_spectrum.png')
                    if os.path.exists(plot_path):
                        # Show the plot if it was generated
                        st.subheader("Generated CMB Power Spectrum")
                        st.image(plot_path, use_container_width=True)
                    else:
                        st.warning("โš ๏ธ No plot was generated")
                    
        except Exception as e:
            status_placeholder.error(f"Test failed with exception: {str(e)}")
            st.exception(e)  # Show the full exception for debugging
    
    st.markdown("---")  # Add a separator for better visual organization
    st.session_state.debug = st.checkbox("๐Ÿ” Show Debug Info")
    if st.button("๐Ÿ—‘๏ธ Reset Chat"):
        st.session_state.clear()
        st.rerun()

    if st.session_state.last_token_count > 0:
        st.markdown(f"๐Ÿงฎ **Last response token usage:** `{st.session_state.last_token_count}` tokens")

    # --- Display all saved plots in sidebar ---
    if "generated_plots" in st.session_state and st.session_state.generated_plots:
        with st.expander("๐Ÿ“Š Plot Gallery", expanded=False):
            st.write("All plots generated during this session:")
            # Use a single column layout for the sidebar
            for i, plot_path in enumerate(st.session_state.generated_plots):
                if os.path.exists(plot_path):
                    st.image(plot_path, width=250, caption=os.path.basename(plot_path))
                    st.markdown("---")  # Add separator between plots

# --- Retrieval + Prompt Construction ---
def build_messages(context, question, system):
    system_msg = SystemMessage(content=system)
    human_msg = HumanMessage(content=f"Context:\n{context}\n\nQuestion:\n{question}")
    return [system_msg] + st.session_state.memory.messages + [human_msg]

def build_messages_rating(context, question, answer, system):
    system_msg = SystemMessage(content=system)
    human_msg = HumanMessage(content=f"Context:\n{context}\n\nQuestion:\n{question}\n\nAI Answer:\n{answer}")
    return [system_msg] + st.session_state.memory.messages + [human_msg]

def build_messages_refinement(context, question, answer, feedback, system):
    system_msg = SystemMessage(content=system)
    human_msg = HumanMessage(content=f"Context:\n{context}\n\nQuestion:\n{question}\n\nAI Answer:\n{answer}\n\nReviewer Feedback:\n{feedback}")
    return [system_msg] + st.session_state.memory.messages + [human_msg]

def format_memory_messages(memory_messages):
    formatted = ""
    for msg in memory_messages:
        role = msg.type.capitalize()  # 'human' -> 'Human'
        content = msg.content
        formatted += f"{role}: {content}\n\n"
    return formatted.strip()


def retrieve_context(question):
    docs = st.session_state.vector_store.similarity_search(question, k=4)
    return "\n\n".join([doc.page_content for doc in docs])


# Set up code execution environment
#temp_dir = tempfile.TemporaryDirectory()

class PlotAwareExecutor(LocalCommandLineCodeExecutor):
    def __init__(self, **kwargs):
        import tempfile
        # Create a persistent plots directory if it doesn't exist
        plots_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'plots')
        os.makedirs(plots_dir, exist_ok=True)
        
        # Still use a temp dir for code execution
        temp_dir = tempfile.TemporaryDirectory()
        kwargs['work_dir'] = temp_dir.name
        super().__init__(**kwargs)
        self._temp_dir = temp_dir
        self._plots_dir = plots_dir

    @contextlib.contextmanager
    def _capture_output(self):
        old_out, old_err = sys.stdout, sys.stderr
        buf_out, buf_err = io.StringIO(), io.StringIO()
        sys.stdout, sys.stderr = buf_out, buf_err
        try:
            yield buf_out, buf_err
        finally:
            sys.stdout, sys.stderr = old_out, old_err

    def execute_code(self, code: str):
        # 1) Extract code from markdown
        match = re.search(r"```(?:python)?\n(.*?)```", code, re.DOTALL)
        cleaned = match.group(1) if match else code
        cleaned = cleaned.replace("plt.show()", "")
        
        # Add timestamp for saving figures only if there's plt usage in the code
        timestamp = time.strftime("%Y-%m-%d-%H-%M-%S")
        plot_filename = f'plot_{timestamp}.png'
        plot_path = os.path.join(self._plots_dir, plot_filename)
        temp_plot_path = None
        
        for line in cleaned.split("\n"):
            if "plt.savefig" in line: 
                temp_plot_path = os.path.join(self._temp_dir.name, f'temporary_{timestamp}.png')
                cleaned = cleaned.replace(line, f"plt.savefig('{temp_plot_path}', dpi=300)")
                break
        else:
            # If there's a plot but no save, auto-insert save
            if "plt." in cleaned:
                temp_plot_path = os.path.join(self._temp_dir.name, f'temporary_{timestamp}.png')
                cleaned += f"\nplt.savefig('{temp_plot_path}')"

        # Create a temporary Python file to execute
        temp_script_path = os.path.join(self._temp_dir.name, f'temp_script_{timestamp}.py')
        with open(temp_script_path, 'w') as f:
            f.write(cleaned)
        
        full_output = ""
        try:
            # 2) Capture stdout using subprocess
            process = subprocess.Popen(
                [sys.executable, temp_script_path],
                stdout=subprocess.PIPE,
                stderr=subprocess.STDOUT,
                text=True,
                bufsize=1, 
                cwd=self._temp_dir.name
            )
            stdout, _ = process.communicate()

            # 3) Format the output
            with self._capture_output() as (out_buf, err_buf):
                if stdout:
                    out_buf.write(stdout)
                stdout_text = out_buf.getvalue()
                stderr_text = err_buf.getvalue()

            if stdout_text:
                full_output += f"STDOUT:\n{stdout_text}\n"
            if stderr_text:
                full_output += f"STDERR:\n{stderr_text}\n"
                
            # Copy plot from temp to persistent location if it exists
            if temp_plot_path and os.path.exists(temp_plot_path):
                import shutil
                shutil.copy2(temp_plot_path, plot_path)
                # Initialize the plots list if it doesn't exist
                if "generated_plots" not in st.session_state:
                    st.session_state.generated_plots = []
                # Add the persistent plot path to session state
                st.session_state.generated_plots.append(plot_path)

        except Exception:
            with self._capture_output() as (out_buf, err_buf):
                import traceback
                traceback.print_exc(file=sys.stderr)
                full_output += f"STDERR:\n{err_buf.getvalue()}\n"

        return full_output, plot_path

# Example instantiation:
executor = PlotAwareExecutor(timeout=10)

# Global agent configurations
initial_config = LLMConfig(
    api_type="openai", 
    model=st.session_state.selected_model,
    temperature=0.2,  # Low temperature for consistent initial responses
    api_key=api_key,
)

review_config = LLMConfig(
    api_type="openai", 
    model=st.session_state.selected_model, 
    temperature=0.7,  # Higher temperature for creative reviews
    api_key=api_key,
    response_format=Feedback
)

# typo_config = LLMConfig(
#     api_type="openai", 
#     model=st.session_state.selected_model, 
#     temperature=0.1,  # Very low temperature for precise code corrections
#     api_key=api_key,
# )

formatting_config = LLMConfig(
    api_type="openai", 
    model=st.session_state.selected_model, 
    temperature=0.3,  # Moderate temperature for formatting
    api_key=api_key,
)

code_execution_config = LLMConfig(
    api_type="openai", 
    model=st.session_state.selected_model, 
    temperature=0.1,  # Very low temperature for code execution
    api_key=api_key,
)

# Global agent instances with updated system messages
initial_agent = ConversableAgent(
    name="initial_agent",
    system_message=f"""
{Initial_Agent_Instructions}""",
    human_input_mode="NEVER",
    llm_config=initial_config
)

review_agent = ConversableAgent(
    name="review_agent",
    system_message=f"""{Review_Agent_Instructions}""",
    human_input_mode="NEVER",
    llm_config=review_config
)

# typo_agent = ConversableAgent(
#     name="typo_agent",
#     system_message=f"""You are the typo and code correction agent. Your task is to:
# 1. Fix any typos or grammatical errors
# 2. Correct any code issues
# 3. Ensure proper formatting
# 4. Maintain the original meaning while improving clarity
# 5. Verify plots are saved to disk (not using show())
# 6. PRESERVE all code blocks exactly as they are unless there are actual errors
# 7. If no changes are needed, keep the original code blocks unchanged

# # {Typo_Agent_Instructions}""",
# #     human_input_mode="NEVER",
# #     llm_config=typo_config
# # )

formatting_agent = ConversableAgent(
    name="formatting_agent",
    system_message="""{Formatting_Agent_Instructions}""",
    human_input_mode="NEVER",
    llm_config=formatting_config
)

code_executor = ConversableAgent(
    name="code_executor",
    system_message="""{Code_Execution_Agent_Instructions}""",
    human_input_mode="NEVER",
    llm_config=code_execution_config,
    code_execution_config={"executor": executor},
    max_consecutive_auto_reply=50
)

def call_ai(context, user_input):
    if mode_is_fast:
        messages = build_messages(context, user_input, Initial_Agent_Instructions)
        response = st.session_state.llm.invoke(messages)
        return Response(content=response.content)
    else:
        # New Swarm Workflow for detailed mode
        st.markdown("Thinking (Swarm Mode)... ")

        # Format the conversation history for context
        conversation_history = format_memory_messages(st.session_state.memory.messages)

        # 1. Initial Agent generates the draft
        st.markdown("Generating initial draft...")
        chat_result_1 = initial_agent.initiate_chat(
            recipient=initial_agent,
            message=f"Conversation history:\n{conversation_history}\n\nContext from documents: {context}\n\nUser question: {user_input}",
            max_turns=1,
            summary_method="last_msg"
        )
        draft_answer = chat_result_1.summary
        if st.session_state.debug:
            st.session_state.debug_messages.append(("Initial Draft", draft_answer))

        # 2. Review Agent critiques the draft
        st.markdown("Reviewing draft...")
        chat_result_2 = review_agent.initiate_chat(
            recipient=review_agent,
            message=f"Conversation history:\n{conversation_history}\n\nPlease review this draft answer:\n{draft_answer}",
            max_turns=1,
            summary_method="last_msg"
        )
        review_feedback = chat_result_2.summary
        if st.session_state.debug:
            st.session_state.debug_messages.append(("Review Feedback", review_feedback))

        # # 3. Typo Agent corrects the draft
        # st.markdown("Checking for typos...")
        # chat_result_3 = typo_agent.initiate_chat(
        #     recipient=typo_agent,
        #     message=f"Original draft: {draft_answer}\n\nReview feedback: {review_feedback}",
        #     max_turns=1,
        #     summary_method="last_msg"
        # )
        # typo_corrected_answer = chat_result_3.summary
        # if st.session_state.debug: st.text(f"Typo-Corrected Answer:\n{typo_corrected_answer}")

        # 4. Formatting Agent formats the final answer
        st.markdown("Formatting final answer...")
        chat_result_4 = formatting_agent.initiate_chat(
            recipient=formatting_agent,
            message=f"""Please format this answer while preserving any code blocks:
                {draft_answer}""",
            max_turns=1,
            summary_method="last_msg"
        )
        formatted_answer = chat_result_4.summary
        if st.session_state.debug:
            st.session_state.debug_messages.append(("Formatted Answer", formatted_answer))

        # Check if the answer contains code
        if "```python" in formatted_answer:
            # Add a note about code execution
            formatted_answer += "\n\n> ๐Ÿ’ก **Note**: This answer contains code. If you want to execute it, type 'execute!' in the chat."
            return Response(content=formatted_answer)
        else:
            return Response(content=formatted_answer)


# --- Chat Input ---
user_input = st.chat_input("Type your prompt here...")

# --- Display Full Chat History ---
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        # Check if this message contains a plot path marker
        if "PLOT_PATH:" in message["content"]:
            # Split content into text and plot path
            parts = message["content"].split("PLOT_PATH:")
            # Display the text part
            st.markdown(parts[0])
            # Display each plot path
            for plot_info in parts[1:]:
                plot_path = plot_info.split('\n')[0].strip()
                if os.path.exists(plot_path):
                    st.image(plot_path, width=700)
        else:
            st.markdown(message["content"])

# --- Process New Prompt ---
if user_input:
    # Show user input immediately
    st.session_state.messages.append({"role": "user", "content": user_input})
    with st.chat_message("user"):
        st.markdown(user_input)

    st.session_state.memory.add_user_message(user_input)
    context = retrieve_context(user_input)
    
    # Count prompt tokens using tiktoken if needed
    try:
        import tiktoken
        enc = tiktoken.encoding_for_model("gpt-4")
        st.session_state.last_token_count = len(enc.encode(user_input))
    except:
        st.session_state.last_token_count = 0

    # Stream assistant response
    with st.chat_message("assistant"):
        stream_box = st.empty()
        stream_handler = StreamHandler(stream_box)

        # Second initialization with streaming
        st.session_state.llm = ChatOpenAI(
                model_name=st.session_state.selected_model,
                streaming=True,
                callbacks=[stream_handler],
                openai_api_key=api_key,
                temperature=0.2
        )

        # Check if this is an execution request
        if user_input.strip().lower() == "execute!":
            # Find the last assistant message containing code
            last_assistant_message = None
            for message in reversed(st.session_state.messages):
                if message["role"] == "assistant" and "```" in message["content"]:
                    last_assistant_message = message["content"]
                    break
            
            if last_assistant_message:
                st.markdown("Executing code...")
                st.info("๐Ÿš€ Executing cleaned code...")
                #chat_result = code_executor.initiate_chat(
                #    recipient=code_executor,
                #    message=f"Please execute this code:\n{last_assistant_message}",
                #    max_turns=1,
                #    summary_method="last_msg"
                #)
                #execution_output = chat_result.summary
                execution_output, plot_path = executor.execute_code(last_assistant_message)
                st.subheader("Execution Output")
                st.text(execution_output)  # now contains both STDOUT and STDERR
                
                if os.path.exists(plot_path):
                    st.success("โœ… Plot generated successfully!")
                    # Display the plot
                    #st.image(plot_path, use_container_width=True)
                    st.image(plot_path, width=700)
                else:
                    st.warning("โš ๏ธ No plot was generated")
                
                # Check for errors and iterate if needed
                max_iterations = 3  # Maximum number of iterations to prevent infinite loops
                current_iteration = 0
                has_errors = any(error_indicator in execution_output for error_indicator in ["Traceback", "Error:", "Exception:", "TypeError:", "ValueError:", "NameError:", "SyntaxError:", "Error in Class"])

                while has_errors and current_iteration < max_iterations:
                    current_iteration += 1
                    st.error(f"Previous error: {execution_output}")  # Show the actual error message
                    st.info(f"๐Ÿ”ง Fixing errors (attempt {current_iteration}/{max_iterations})...")

                    # Get new review with error information
                    review_message = f"""
                    Previous answer had errors during execution:
                    {execution_output}

                    Please review and suggest fixes for this answer. IMPORTANT: Preserve all code blocks exactly as they are, only fix actual errors:
                    {last_assistant_message}
                    """
                    chat_result_2 = review_agent.initiate_chat(
                        recipient=review_agent,
                        message=review_message,
                        max_turns=1,
                        summary_method="last_msg"
                    )
                    review_feedback = chat_result_2.summary
                    if st.session_state.debug:
                        st.session_state.debug_messages.append(("Error Review Feedback", review_feedback))

                    # Get corrected version
                    chat_result_3 = initial_agent.initiate_chat(
                        recipient=initial_agent,
                        message=f"""Original answer: {last_assistant_message}
                        Review feedback with error fixes: {review_feedback}
                        IMPORTANT: Only fix actual errors in the code blocks. Preserve all working code exactly as it is.""",
                        max_turns=1,
                        summary_method="last_msg"
                    )
                    corrected_answer = chat_result_3.summary
                    if st.session_state.debug:
                        st.session_state.debug_messages.append(("Corrected Answer", corrected_answer))

                    # Format the corrected answer
                    chat_result_4 = formatting_agent.initiate_chat(
                        recipient=formatting_agent,
                        message=f"""Please format this corrected answer while preserving all code blocks:
                        {corrected_answer}
                        """,
                        max_turns=1,
                        summary_method="last_msg"
                    )
                    formatted_answer = chat_result_4.summary
                    if st.session_state.debug:
                        st.session_state.debug_messages.append(("Formatted Corrected Answer", formatted_answer))

                    # Execute the corrected code
                    st.info("๐Ÿš€ Executing corrected code...")
                    #chat_result = code_executor.initiate_chat(
                    #    recipient=code_executor,
                    #    message=f"Please execute this corrected code:\n{formatted_answer}",
                    #    max_turns=1,
                    #    summary_method="last_msg"
                    #)
                    #execution_output = chat_result.summary
                    execution_output, plot_path = executor.execute_code(formatted_answer)
                    st.subheader("Execution Output")
                    st.text(execution_output)  # now contains both STDOUT and STDERR
                    
                    if os.path.exists(plot_path):
                        st.success("โœ… Plot generated successfully!")
                        # Display the plot
                        st.image(plot_path, width=700)
                    else:
                        st.warning("โš ๏ธ No plot was generated")
                    
                    if st.session_state.debug:
                        st.session_state.debug_messages.append(("Execution Output", execution_output))
                    
                    # If we've reached the end of iterations and we're successful
                    if not has_errors or current_iteration == max_iterations:
                        # Add successful execution to the conversation with plot
                        final_answer = formatted_answer if formatted_answer else last_assistant_message
                        response_text = f"Execution completed successfully:\n{execution_output}\n\nThe following code was executed:\n```python\n{final_answer}\n```"
                        
                        # Add plot path marker for rendering in the conversation
                        if os.path.exists(plot_path):
                            response_text += f"\n\nPLOT_PATH:{plot_path}\n"
                            
                        if current_iteration > 0:
                            response_text = f"After {current_iteration} correction attempts: " + response_text
                        
                        # Set the response variable with our constructed text that includes plot
                        response = Response(content=response_text)
                    
                    # Update last_assistant_message with the formatted answer for next iteration
                    last_assistant_message = formatted_answer
                    has_errors = any(error_indicator in execution_output for error_indicator in ["Traceback", "Error:", "Exception:", "TypeError:", "ValueError:", "NameError:", "SyntaxError:", "Error in Class"])

                if has_errors:
                    st.markdown("> โš ๏ธ **Note**: Some errors could not be fixed after multiple attempts. You can request changes by describing them in the chat.")
                    st.markdown(f"> โŒ Last execution message:\n{execution_output}")
                    response = Response(content=f"Execution completed with errors:\n{execution_output}")
                else:
                    # Check for common error indicators in the output
                    if any(error_indicator in execution_output for error_indicator in ["Traceback", "Error:", "Exception:", "TypeError:", "ValueError:", "NameError:", "SyntaxError:"]):
                        st.markdown("> โš ๏ธ **Note**: Code execution completed but with errors. You can request changes by describing them in the chat.")
                        st.markdown(f"> โŒ Execution message:\n{execution_output}")
                        response = Response(content=f"Execution completed with errors:\n{execution_output}")
                    else:
                        st.markdown(f"> โœ… Code executed successfully. Last execution message:\n{execution_output}")
                        
                        # Display the final code that was successfully executed
                        with st.expander("View Successfully Executed Code", expanded=False):
                            st.markdown(last_assistant_message)
                            
                        # Create a response message that includes the plot path
                        response_text = f"Execution completed successfully:\n{execution_output}\n\nThe following code was executed:\n```python\n{last_assistant_message}\n```"
                        
                        # Add plot path marker for rendering in the conversation
                        if os.path.exists(plot_path):
                            response_text += f"\n\nPLOT_PATH:{plot_path}\n"
                            
                        response = Response(content=response_text)
            else:
                response = Response(content="No code found to execute in the previous messages.")
        else:
            response = call_ai(context, user_input)
            if not mode_is_fast:
                st.markdown(response.content)

        st.session_state.memory.add_ai_message(response.content)
        st.session_state.messages.append({"role": "assistant", "content": response.content})

# --- Display Welcome Message (outside of sidebar) ---
# This ensures the welcome message appears in the main content area
if "llm_initialized" in st.session_state and st.session_state.llm_initialized and not st.session_state.greeted:
    # Create a chat message container for the welcome message
    with st.chat_message("assistant"):
        # Create empty container for streaming
        welcome_container = st.empty()
        
        # Set up the streaming handler
        welcome_stream_handler = StreamHandler(welcome_container)
        
        # Initialize streaming LLM
        streaming_llm = ChatOpenAI(
            model_name=st.session_state.selected_model,
            streaming=True,
            callbacks=[welcome_stream_handler],
            openai_api_key=api_key,
            temperature=0.2
        )
        
        # Generate the streaming welcome message
        greeting = streaming_llm.invoke([
            SystemMessage(content=Initial_Agent_Instructions),
            HumanMessage(content="Please greet the user and briefly explain what you can do as the CLASS code assistant.")
        ])
        
        # Save the completed message to history
        st.session_state.messages.append({"role": "assistant", "content": greeting.content})
        st.session_state.memory.add_ai_message(greeting.content)
        st.session_state.greeted = True

# --- Debug Info ---
if st.session_state.debug:
    with st.sidebar.expander("๐Ÿ› ๏ธ Debug Information", expanded=True):
        # Create a container for debug messages
        debug_container = st.container()
        with debug_container:
            st.markdown("### Debug Messages")
            
            # Display all debug messages in a scrollable container
            for title, message in st.session_state.debug_messages:
                st.markdown(f"### {title}")
                st.markdown(message)
                st.markdown("---")
    
    with st.sidebar.expander("๐Ÿ› ๏ธ Context Used"):
        if "context" in locals():
            st.markdown(context)
        else:
            st.markdown("No context retrieved yet.")