Agent_Derma_Doc / nodes.py
Jyo-K's picture
Upload 7 files
d1c266e verified
from langchain_core.messages import AIMessage, HumanMessage
from state import WorkflowState
from tools import tool_analyze_skin_image, tool_fetch_disease_info
from llms import (
symptom_classifier_chain,
question_generation_chain,
summary_generation_chain
)
def node_analyze_image(state: WorkflowState) -> WorkflowState:
"""Analyzes the user's image and updates the state."""
print("--- Node: Analyzing Image ---")
image = state.get("image")
if not image:
state['chat_history'].append(AIMessage(
content="Please upload an image first."
))
return state
prediction_result = tool_analyze_skin_image.invoke({"image": image})
if "Error:" in prediction_result:
state['chat_history'].append(AIMessage(content=prediction_result))
state['final_diagnosis'] = "Error"
return state
state['disease_prediction'] = prediction_result
return state
def node_fetch_symptoms(state: WorkflowState) -> WorkflowState:
"""Fetches symptoms for the predicted disease."""
print(f"--- Node: Fetching Symptoms for {state['disease_prediction']} ---")
disease = state['disease_prediction']
info = tool_fetch_disease_info.invoke({"disease_name": disease})
if "error" in info:
state['chat_history'].append(AIMessage(content=info['error']))
state['final_diagnosis'] = "Error"
return state
state['symptoms_to_check'] = info.get("symptoms", [])
state['treatment_info'] = info.get("treatment", "No treatment info available.")
state['current_symptom_index'] = 0
state['symptoms_confirmed'] = []
if not state['symptoms_to_check']:
print("No symptoms found to check. Proceeding to final response.")
return state
def node_ask_symptom_question(state: WorkflowState) -> WorkflowState:
"""Asks the user the next symptom question."""
print(f"--- Node: Asking Symptom Question {state['current_symptom_index']} ---")
symptoms = state['symptoms_to_check']
index = state['current_symptom_index']
symptom = symptoms[index]
question = question_generation_chain.invoke({"symptom": symptom})
state['chat_history'].append(AIMessage(content=question))
state['current_symptom_index'] = index + 1
return state
def node_process_user_response(state: WorkflowState) -> WorkflowState:
"""Processes the user's 'yes' or 'no' response to a symptom question."""
print("--- Node: Processing User Response ---")
last_human_message = state['chat_history'][-1].content
index = state['current_symptom_index']
last_asked_symptom = state['symptoms_to_check'][index - 1]
try:
classification = symptom_classifier_chain.invoke(
{"last_human_message": last_human_message}
)
if classification.get("classification") == "yes":
print(f"User confirmed symptom: {last_asked_symptom}")
state['symptoms_confirmed'].append(last_asked_symptom)
else:
print(f"User denied symptom: {last_asked_symptom}")
except Exception as e:
print(f"Error classifying user response: {e}. Assuming 'unclear'.")
return state
def node_generate_final_response(state: WorkflowState) -> WorkflowState:
"""Generates the final summary and disclaimer for the user."""
print("--- Node: Generating Final Response ---")
disclaimer = (
"\n\n**DISCLAIMER:**\n"
"I am just a dumb agent, not a medical professional. "
"This is a side project for learning purposes. "
"Please **DO NOT** take this information for face value. "
"Consult a real doctor or dermatologist for any medical concerns."
)
summary = summary_generation_chain.invoke({
"disease": state['disease_prediction'],
"symptoms": ", ".join(state['symptoms_confirmed']) or "None confirmed",
"treatment": state['treatment_info'],
"disclaimer": disclaimer
})
state['chat_history'].append(AIMessage(content=summary))
state['final_diagnosis'] = "Complete"
return state
def router_should_ask_symptoms(state: WorkflowState) -> str:
"""
Checks if there are symptoms to ask about.
If yes -> ask_symptom_question
If no -> generate_final_response
"""
if state.get("symptoms_to_check"):
return "ask_symptom_question"
else:
return "generate_final_response"
def router_should_continue_asking(state: WorkflowState) -> str:
"""
Checks if we have more symptoms to ask about after a user's response.
If yes -> ask_symptom_question
If no -> generate_final_response
"""
if state['current_symptom_index'] < len(state['symptoms_to_check']):
return "ask_symptom_question"
else:
return "generate_final_response"
def router_check_image_analysis(state: WorkflowState) -> str:
"""
Checks if the image analysis was successful.
"""
if state.get("final_diagnosis") == "Error":
return "end_error"
else:
return "fetch_symptoms"