File size: 13,171 Bytes
32d54de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8ac349
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
import os
import streamlit as st
from langchain_community.graphs import Neo4jGraph
import pandas as pd
import json
import time

from ki_gen.planner import build_planner_graph
# Update import path if init_app moved or args changed
from ki_gen.utils import init_app, memory, ConfigSchema, State # Import necessary types
from ki_gen.prompts import get_initial_prompt

from neo4j import GraphDatabase

# Set page config
st.set_page_config(page_title="Key Issue Generator", layout="wide")

# Neo4j Database Configuration
NEO4J_URI = "neo4j+s://4985272f.databases.neo4j.io"
NEO4J_USERNAME = "neo4j"
NEO4J_PASSWORD = os.getenv("neo4j_password")

# API Keys for LLM services
OPENAI_API_KEY = os.getenv("openai_api_key")
# GROQ_API_KEY is removed as we switch to Gemini
# GROQ_API_KEY = os.getenv("groq_api_key") 
# Ensure Gemini API key is available in the environment
GEMINI_API_KEY = os.getenv("gemini_api_key") 
LANGSMITH_API_KEY = os.getenv("langsmith_api_key")

def verify_neo4j_connectivity():
    """Verify connection to Neo4j database"""
    try:
        # Ensure driver closes properly
        driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
        driver.verify_connectivity()
        driver.close() # Explicitly close the driver
        return True # Return simple boolean
    except Exception as e:
        return f"Error: {str(e)}"

# Update load_config defaults
def load_config() -> ConfigSchema: # Add type hint
    """Load configuration with custom parameters"""
    # Custom configuration based on provided parameters
    # Update default models to gemini-2.0-flash
    custom_config = {
        "main_llm": "gemini-2.0-flash", 
        "plan_method": "generation",
        "use_detailed_query": False,
        "cypher_gen_method": "guided",
        "validate_cypher": False,
        "summarize_model": "gemini-2.0-flash", 
        "eval_method": "binary",
        "eval_threshold": 0.7,
        "max_docs": 15,
        "compression_method": "llm_lingua",
        "compress_rate": 0.33,
        "force_tokens": ["."],  # Converting to list format as expected by the application
        "eval_model": "gemini-2.0-flash", 
        "thread_id": "3" # Consider making thread_id dynamic or user-specific
    }
    
    # Add Neo4j graph object to config
    neo_graph = None # Initialize to None
    try:
        # Check connectivity before creating graph object potentially?
        if verify_neo4j_connectivity() is True:
             neo_graph = Neo4jGraph(
                 url=NEO4J_URI, 
                 username=NEO4J_USERNAME, 
                 password=NEO4J_PASSWORD
             )
             custom_config["graph"] = neo_graph
        else:
             st.error(f"Neo4j connection issue: {verify_neo4j_connectivity()}")
             # Return None or raise error if graph is essential
             return None

    except Exception as e:
        st.error(f"Error creating Neo4jGraph object: {e}")
        return None
        
    # Return wrapped in 'configurable' key as expected by LangGraph
    return {"configurable": custom_config} 


def generate_key_issues(user_query):
    """Main function to generate key issues from Neo4j data"""
    # Initialize application with API keys (remove groq_key)
    init_app(
        openai_key=OPENAI_API_KEY,
        # groq_key=GROQ_API_KEY, # Remove Groq key
        langsmith_key=LANGSMITH_API_KEY
    )
    
    # Load configuration with custom parameters
    config = load_config()
    if not config or "configurable" not in config or not config["configurable"].get("graph"):
        st.error("Failed to load configuration or connect to Neo4j. Cannot proceed.")
        return None, []
    
    # Create status containers
    plan_status = st.empty()
    plan_display = st.empty()
    retrieval_status = st.empty()
    processing_status = st.empty()
    
    # Build planner graph
    plan_status.info("Building planner graph...")
    # Pass the full config dictionary to build_planner_graph
    graph = build_planner_graph(memory, config) 
    
    # Execute initial prompt generation
    plan_status.info(f"Generating plan for query: {user_query}")
    
    messages_content = []
    initial_prompt_data = get_initial_prompt(config, user_query)
    
    # Stream initial plan generation
    try:
        for event in graph.stream(initial_prompt_data, config, stream_mode="values"):
            if "messages" in event and event["messages"]:
                event["messages"][-1].pretty_print()
                messages_content.append(event["messages"][-1].content)
             # Add checks for specific nodes if needed for status updates
            # if "__start__" in event: # Example check
            #     plan_status.info("Starting plan generation...")

    except Exception as e:
         st.error(f"Error during initial graph stream: {e}")
         return None, []

    # Get the state with the generated plan (after initial stream/interrupt)
    try:
         # Ensure thread_id matches what's used internally if applicable
        state = graph.get_state(config) 
        # Check if 'store_plan' exists and is a list
        stored_plan = state.values.get('store_plan', [])
        if isinstance(stored_plan, list) and stored_plan:
             steps = [i for i in range(1, len(stored_plan)+1)]
             plan_df = pd.DataFrame({'Plan steps': steps, 'Description': stored_plan})
             plan_status.success("Plan generation complete!")
             plan_display.dataframe(plan_df, use_container_width=True)
        else:
             plan_status.warning("Plan not found or empty in graph state after generation.")
             plan_display.empty() # Clear display if no plan

    except Exception as e:
         st.error(f"Error getting graph state or displaying plan: {e}")
         return None, []


    # Continue with plan execution for document retrieval
    # This part assumes the graph will continue after the first interrupt
    retrieval_status.info("Retrieving documents...")
    try:
         # Stream from the current state (None indicates continue)
        for event in graph.stream(None, config, stream_mode="values"):
             if "messages" in event and event["messages"]:
                 event["messages"][-1].pretty_print()
                 messages_content.append(event["messages"][-1].content)
            # Add checks for nodes like 'human_validation' if needed for status

    except Exception as e:
        st.error(f"Error during document retrieval stream: {e}")
        return None, []


    # Get updated state after document retrieval interrupt
    try:
        snapshot = graph.get_state(config)
        valid_docs_retrieved = snapshot.values.get('valid_docs', [])
        doc_count = len(valid_docs_retrieved) if isinstance(valid_docs_retrieved, list) else 0
        retrieval_status.success(f"Retrieved {doc_count} documents")
        
        # --- Human Validation / Processing Steps ---
        # This section needs interaction logic if manual validation is desired.
        # For now, setting default processing steps and marking as validated.
        processing_status.info("Processing documents...")
        process_steps = ["summarize"]  # Default: just summarize
        
        # Update state to indicate human validation is complete and specify processing steps
        # This should happen *before* the next stream call that triggers processing
        graph.update_state(config, {'human_validated': True, 'process_steps': process_steps}) 
        
    except Exception as e:
         st.error(f"Error getting state after retrieval or setting up processing: {e}")
         return None, []


    # Continue execution with document processing
    try:
        for event in graph.stream(None, config, stream_mode="values"):
            if "messages" in event and event["messages"]:
                event["messages"][-1].pretty_print()
                messages_content.append(event["messages"][-1].content)
             # Check for the end node or final chatbot node if needed

    except Exception as e:
         st.error(f"Error during document processing stream: {e}")
         return None, []

    # Get final state after processing
    try:
        final_snapshot = graph.get_state(config)
        processing_status.success("Document processing complete!")
        
        # Extract final result and documents
        final_result = None
        valid_docs_final = []
        if "messages" in final_snapshot.values and final_snapshot.values["messages"]:
             # Assume the last message contains the final result
            final_result = final_snapshot.values["messages"][-1].content 
        
        # Get the final state of valid_docs (might be processed summaries)
        valid_docs_final = final_snapshot.values.get('valid_docs', []) 
        if not isinstance(valid_docs_final, list): # Ensure it's a list
             valid_docs_final = []

        return final_result, valid_docs_final
        
    except Exception as e:
         st.error(f"Error getting final state or extracting results: {e}")
         return None, []

# App header
st.title("Key Issue Generator")
st.write("Generate key issues from a Neo4j knowledge graph using advanced language models.")

# Check database connectivity
connectivity_status = verify_neo4j_connectivity()
st.sidebar.header("Database Status")
# Use boolean check
if connectivity_status is True: 
    st.sidebar.success("Connected to Neo4j database")
else:
    # Display the error message returned
    st.sidebar.error(f"Database connection issue: {connectivity_status}") 

# User input section
st.header("Enter Your Query")
user_query = st.text_area("What would you like to explore?", 
                         "What are the main challenges in AI adoption for healthcare systems?", 
                         height=100)

# Process button
if st.button("Generate Key Issues", type="primary"):
    # Update API key check for Gemini
    if not OPENAI_API_KEY or not GEMINI_API_KEY or not LANGSMITH_API_KEY or not NEO4J_PASSWORD: 
        st.error("Required API keys (OpenAI, Gemini, Langsmith) or database credentials are missing. Please check your environment variables.")
    elif connectivity_status is not True: # Check DB connection again before starting
         st.error(f"Cannot start: Neo4j connection issue: {connectivity_status}")
    else:
        with st.spinner("Processing your query..."):
            start_time = time.time()
            # Call the main generation function
            final_result, valid_docs = generate_key_issues(user_query) 
            end_time = time.time()
            
            if final_result is not None: # Check if result is not None (indicating success)
                # Display execution time
                st.sidebar.info(f"Total execution time: {round(end_time - start_time, 2)} seconds")
                
                # Display final result
                st.header("Generated Key Issues")
                st.markdown(final_result)
                
                # Option to download results
                st.download_button(
                    label="Download Results",
                    data=final_result, # Ensure final_result is string data
                    file_name="key_issues_results.txt",
                    mime="text/plain"
                )
                
                # Display retrieved/processed documents in expandable section
                if valid_docs:
                    with st.expander("View Processed Documents"): # Update title
                        for i, doc in enumerate(valid_docs):
                            st.markdown(f"### Document {i+1}")
                            # Handle doc format (could be string summary or original dict)
                            if isinstance(doc, dict):
                                for key in doc:
                                    st.markdown(f"**{key}**: {doc[key]}")
                            elif isinstance(doc, str):
                                 st.markdown(doc) # Display string directly if it's a summary
                            else:
                                 st.markdown(str(doc)) # Fallback for other types
                            st.divider()
            else:
                # Error messages are now shown within generate_key_issues
                # st.error("An error occurred during processing. Please check the logs or console output for details.")
                # Adding a placeholder here in case specific errors weren't caught
                if final_result is None: # Check explicit None return
                     st.error("Processing failed. Please check the console/logs for errors.")


# Help information in sidebar
with st.sidebar:
    st.header("About")
    st.info("""
    This application uses advanced language models (like Google Gemini) to analyze a Neo4j knowledge graph 
    and generate key issues based on your query. The process involves:
    
    1. Creating a plan based on your query
    2. Retrieving relevant documents from the database
    3. Processing and summarizing the information
    4. Generating a comprehensive response
    """)