from langchain_core.messages import AIMessage, HumanMessage from state import WorkflowState from tools import tool_analyze_skin_image, tool_fetch_disease_info from llms import ( symptom_classifier_chain, question_generation_chain, summary_generation_chain ) def node_analyze_image(state: WorkflowState) -> WorkflowState: """Analyzes the user's image and updates the state.""" print("--- Node: Analyzing Image ---") image = state.get("image") if not image: state['chat_history'].append(AIMessage( content="Please upload an image first." )) return state prediction_result = tool_analyze_skin_image.invoke({"image": image}) if "Error:" in prediction_result: state['chat_history'].append(AIMessage(content=prediction_result)) state['final_diagnosis'] = "Error" return state state['disease_prediction'] = prediction_result return state def node_fetch_symptoms(state: WorkflowState) -> WorkflowState: """Fetches symptoms for the predicted disease.""" print(f"--- Node: Fetching Symptoms for {state['disease_prediction']} ---") disease = state['disease_prediction'] info = tool_fetch_disease_info.invoke({"disease_name": disease}) if "error" in info: state['chat_history'].append(AIMessage(content=info['error'])) state['final_diagnosis'] = "Error" return state state['symptoms_to_check'] = info.get("symptoms", []) state['treatment_info'] = info.get("treatment", "No treatment info available.") state['current_symptom_index'] = 0 state['symptoms_confirmed'] = [] if not state['symptoms_to_check']: print("No symptoms found to check. Proceeding to final response.") return state def node_ask_symptom_question(state: WorkflowState) -> WorkflowState: """Asks the user the next symptom question.""" print(f"--- Node: Asking Symptom Question {state['current_symptom_index']} ---") symptoms = state['symptoms_to_check'] index = state['current_symptom_index'] symptom = symptoms[index] question = question_generation_chain.invoke({"symptom": symptom}) state['chat_history'].append(AIMessage(content=question)) state['current_symptom_index'] = index + 1 return state def node_process_user_response(state: WorkflowState) -> WorkflowState: """Processes the user's 'yes' or 'no' response to a symptom question.""" print("--- Node: Processing User Response ---") last_human_message = state['chat_history'][-1].content index = state['current_symptom_index'] last_asked_symptom = state['symptoms_to_check'][index - 1] try: classification = symptom_classifier_chain.invoke( {"last_human_message": last_human_message} ) if classification.get("classification") == "yes": print(f"User confirmed symptom: {last_asked_symptom}") state['symptoms_confirmed'].append(last_asked_symptom) else: print(f"User denied symptom: {last_asked_symptom}") except Exception as e: print(f"Error classifying user response: {e}. Assuming 'unclear'.") return state def node_generate_final_response(state: WorkflowState) -> WorkflowState: """Generates the final summary and disclaimer for the user.""" print("--- Node: Generating Final Response ---") disclaimer = ( "\n\n**DISCLAIMER:**\n" "I am just a dumb agent, not a medical professional. " "This is a side project for learning purposes. " "Please **DO NOT** take this information for face value. " "Consult a real doctor or dermatologist for any medical concerns." ) summary = summary_generation_chain.invoke({ "disease": state['disease_prediction'], "symptoms": ", ".join(state['symptoms_confirmed']) or "None confirmed", "treatment": state['treatment_info'], "disclaimer": disclaimer }) state['chat_history'].append(AIMessage(content=summary)) state['final_diagnosis'] = "Complete" return state def router_should_ask_symptoms(state: WorkflowState) -> str: """ Checks if there are symptoms to ask about. If yes -> ask_symptom_question If no -> generate_final_response """ if state.get("symptoms_to_check"): return "ask_symptom_question" else: return "generate_final_response" def router_should_continue_asking(state: WorkflowState) -> str: """ Checks if we have more symptoms to ask about after a user's response. If yes -> ask_symptom_question If no -> generate_final_response """ if state['current_symptom_index'] < len(state['symptoms_to_check']): return "ask_symptom_question" else: return "generate_final_response" def router_check_image_analysis(state: WorkflowState) -> str: """ Checks if the image analysis was successful. """ if state.get("final_diagnosis") == "Error": return "end_error" else: return "fetch_symptoms"