Jyo-K commited on
Commit
d1c266e
·
verified ·
1 Parent(s): 8aac779

Upload 7 files

Browse files
Files changed (7) hide show
  1. app.py +181 -0
  2. graph.py +78 -0
  3. llms.py +76 -0
  4. merge.py +133 -0
  5. nodes.py +144 -0
  6. requirements.txt +17 -0
  7. state.py +20 -0
app.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+
4
+ load_dotenv()
5
+
6
+ import gradio as gr
7
+ from PIL import Image
8
+ from langchain_core.messages import HumanMessage, AIMessage
9
+
10
+ from graph import build_graph
11
+ from state import WorkflowState
12
+
13
+ diagnosis_graph = None
14
+ reply_graph = None
15
+ initialization_error = None
16
+
17
+ try:
18
+ diagnosis_graph, reply_graph = build_graph()
19
+ print("Gradio app and graphs initialized successfully.")
20
+ except Exception as e:
21
+ initialization_error = f"CRITICAL ERROR: Could not build graphs. {e}. Check API keys."
22
+ print(initialization_error)
23
+
24
+
25
+ def convert_history_to_langchain(chat_history):
26
+ """Converts Gradio history to Langchain message list."""
27
+ messages = []
28
+ for user_msg, ai_msg in chat_history:
29
+ if user_msg is not None:
30
+ if isinstance(user_msg, (dict, tuple)):
31
+ messages.append(HumanMessage(content="User uploaded an image."))
32
+ else:
33
+ messages.append(HumanMessage(content=user_msg))
34
+ if ai_msg is not None:
35
+ messages.append(AIMessage(content=ai_msg))
36
+ return messages
37
+
38
+ def reset_state_on_start():
39
+ """Fresh empty state for every new diagnosis."""
40
+ return WorkflowState(
41
+ image=None,
42
+ chat_history=[],
43
+ disease_prediction="",
44
+ symptoms_to_check=[],
45
+ symptoms_confirmed=[],
46
+ current_symptom_index=0,
47
+ treatment_info="",
48
+ final_diagnosis=""
49
+ )
50
+
51
+
52
+ def chat_fn(message: str, chat_history: list, agent_state: dict, img_upload: Image):
53
+ """
54
+ Handles user input and manages the agent's workflow state.
55
+ """
56
+ if initialization_error:
57
+ chat_history.append((message, initialization_error))
58
+ yield chat_history, {}, gr.update(value=None, interactive=True), gr.update(value="", interactive=True)
59
+ return
60
+
61
+ if not agent_state or agent_state.get("final_diagnosis"):
62
+ current_state = reset_state_on_start()
63
+ else:
64
+ current_state = agent_state
65
+
66
+ chat_history = chat_history or []
67
+
68
+
69
+ is_new_diagnosis = False
70
+ if img_upload and (message.lower().strip() == "start" or message == ""):
71
+ print("--- Running NEW diagnosis flow ---")
72
+ is_new_diagnosis = True
73
+ current_state = reset_state_on_start()
74
+ current_state["image"] = img_upload
75
+ graph_to_run = diagnosis_graph
76
+ chat_history.append([(img_upload,), None])
77
+
78
+ elif current_state.get("symptoms_to_check") and not current_state.get("final_diagnosis"):
79
+ print("--- Running REPLY symptom loop flow ---")
80
+ graph_to_run = reply_graph
81
+ chat_history.append([message, None])
82
+
83
+
84
+ else:
85
+
86
+ if message:
87
+ chat_history.append([message, None])
88
+ chat_history[-1][1] = "Hello! Please upload an image, then click 'Start Diagnosis'."
89
+
90
+ yield chat_history, agent_state, gr.update(value=img_upload, interactive=True), gr.update(value="", interactive=True)
91
+ return
92
+
93
+ current_state["chat_history"] = convert_history_to_langchain(chat_history)
94
+
95
+ try:
96
+ final_state = {}
97
+ for step in graph_to_run.stream(current_state, {"recursion_limit": 100}):
98
+
99
+ final_state = list(step.values())[0]
100
+
101
+ ai_response = final_state['chat_history'][-1].content
102
+ chat_history[-1][1] = ai_response
103
+
104
+ if final_state.get("final_diagnosis"):
105
+ print("--- Agent Flow ENDED ---")
106
+ yield chat_history, {}, gr.update(value=None, interactive=True), gr.update(value="", interactive=True)
107
+ else:
108
+ yield chat_history, final_state, gr.update(value=img_upload, interactive=False), gr.update(value="", interactive=True)
109
+
110
+ except Exception as e:
111
+ print(f"--- Graph Runtime Error --- \n{e}")
112
+ error_msg = f"A runtime error occurred: {e}. Please check the console."
113
+ chat_history[-1][1] = error_msg
114
+ yield chat_history, {}, gr.update(value=None, interactive=True), gr.update(value="", interactive=True)
115
+
116
+ def clear_all():
117
+ """Clears chat, state, and image."""
118
+ return [], {}, None, ""
119
+
120
+ with gr.Blocks(theme=gr.themes.Soft(), title="Agentic Skin AI") as demo:
121
+ gr.Markdown("# 🩺 Multimodal Agentic Skin Disease AI")
122
+ gr.Markdown(
123
+ "**Disclaimer:** This is a demo project and NOT a medical device. "
124
+ "Consult a real doctor for any medical concerns."
125
+ )
126
+ agent_state = gr.State({})
127
+
128
+ with gr.Row():
129
+ with gr.Column(scale=1):
130
+ gr.Markdown("### 1. Image Input")
131
+ img_upload = gr.Image(type="pil", label="Upload Skin Image", interactive=True)
132
+
133
+ btn_start = gr.Button("Start Diagnosis", variant="primary")
134
+ btn_clear = gr.Button("Clear All & Start New")
135
+
136
+ gr.Markdown(
137
+ "**Instructions:**\n"
138
+ "1. Upload an image.\n"
139
+ "2. Click **Start Diagnosis**.\n"
140
+ "3. Answer the agent's questions in the textbox."
141
+ )
142
+
143
+ with gr.Column(scale=2):
144
+ gr.Markdown("### 2. Agent Conversation")
145
+ chatbot = gr.Chatbot(label="Agent Conversation", height=500, bubble_full_width=False, avatar_images=None)
146
+ txt_msg = gr.Textbox(
147
+ label="Your message (Yes / No / etc.)",
148
+ placeholder="Answer the agent's questions here...",
149
+ interactive=True
150
+ )
151
+
152
+ txt_msg.submit(
153
+ fn=chat_fn,
154
+ inputs=[txt_msg, chatbot, agent_state, img_upload],
155
+ outputs=[chatbot, agent_state, img_upload, txt_msg]
156
+ )
157
+
158
+ btn_start.click(
159
+ fn=chat_fn,
160
+ inputs=[gr.Textbox(value="start", visible=False), chatbot, agent_state, img_upload],
161
+ outputs=[chatbot, agent_state, img_upload, txt_msg]
162
+ )
163
+
164
+ btn_clear.click(
165
+ fn=clear_all,
166
+ inputs=None,
167
+ outputs=[chatbot, agent_state, img_upload, txt_msg]
168
+ )
169
+
170
+ img_upload.upload(
171
+ fn=lambda: ([], {}, "Click 'Start Diagnosis' to begin."),
172
+ inputs=None,
173
+ outputs=[chatbot, agent_state, txt_msg]
174
+ )
175
+
176
+ if __name__ == "__main__":
177
+ if initialization_error:
178
+ print("\n\n*** CANNOT LAUNCH APP: Agent failed to initialize. ***")
179
+ print(f"*** ERROR: {initialization_error} ***")
180
+ else:
181
+ demo.launch(debug=True)
graph.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langgraph.graph import StateGraph, END
2
+ from state import WorkflowState
3
+ from nodes import (
4
+ node_analyze_image,
5
+ node_fetch_symptoms,
6
+ node_ask_symptom_question,
7
+ node_process_user_response,
8
+ node_generate_final_response,
9
+ router_check_image_analysis,
10
+ router_should_ask_symptoms,
11
+ router_should_continue_asking
12
+ )
13
+
14
+ def build_graph():
15
+ """
16
+ Builds and compiles the two agentic workflows:
17
+ 1. The main diagnosis graph.
18
+ 2. The reply-handling graph.
19
+ """
20
+
21
+ workflow = StateGraph(WorkflowState)
22
+
23
+ workflow.add_node("analyze_image", node_analyze_image)
24
+ workflow.add_node("fetch_symptoms", node_fetch_symptoms)
25
+ workflow.add_node("ask_symptom_question", node_ask_symptom_question)
26
+ workflow.add_node("generate_final_response", node_generate_final_response)
27
+
28
+ workflow.set_entry_point("analyze_image")
29
+
30
+ workflow.add_conditional_edges(
31
+ "analyze_image",
32
+ router_check_image_analysis,
33
+ {
34
+ "fetch_symptoms": "fetch_symptoms",
35
+ "end_error": END
36
+ }
37
+ )
38
+
39
+ workflow.add_conditional_edges(
40
+ "fetch_symptoms",
41
+ router_should_ask_symptoms,
42
+ {
43
+ "ask_symptom_question": "ask_symptom_question",
44
+ "generate_final_response": "generate_final_response"
45
+ }
46
+ )
47
+
48
+ workflow.add_edge("ask_symptom_question", END)
49
+ workflow.add_edge("generate_final_response", END)
50
+
51
+ diagnosis_graph = workflow.compile()
52
+
53
+
54
+ reply_workflow = StateGraph(WorkflowState)
55
+
56
+ reply_workflow.add_node("process_user_response", node_process_user_response)
57
+ reply_workflow.add_node("ask_symptom_question", node_ask_symptom_question)
58
+ reply_workflow.add_node("generate_final_response", node_generate_final_response)
59
+
60
+ reply_workflow.set_entry_point("process_user_response")
61
+
62
+ reply_workflow.add_conditional_edges(
63
+ "process_user_response",
64
+ router_should_continue_asking,
65
+ {
66
+ "ask_symptom_question": "ask_symptom_question",
67
+ "generate_final_response": "generate_final_response"
68
+ }
69
+ )
70
+
71
+ reply_workflow.add_edge("ask_symptom_question", END)
72
+ reply_workflow.add_edge("generate_final_response", END)
73
+
74
+ reply_graph = reply_workflow.compile()
75
+
76
+ print("--- LangGraph Compiled ---")
77
+
78
+ return diagnosis_graph, reply_graph
llms.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from langchain_huggingface import HuggingFaceEndpoint
3
+ from langchain_core.prompts import ChatPromptTemplate
4
+ from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
5
+ from langchain_core.runnables import RunnableLambda
6
+
7
+ LLM_REPO_ID = os.environ.get("LLM_REPO_ID", "Jyo-K/Fine-Tuned-Qwen2.5_1B")
8
+ HF_API_KEY = os.environ.get("HF_API_KEY")
9
+
10
+ _llm_instance = None
11
+
12
+ def get_llm():
13
+ """Lazily initializes and returns the LLM instance."""
14
+ global _llm_instance
15
+ if _llm_instance is None:
16
+ if not HF_API_KEY:
17
+ raise ValueError("HF_TOKEN environment variable not set. Cannot initialize LLM.")
18
+ _llm_instance = HuggingFaceEndpoint(
19
+ #model= "Jyo-K/Fine-Tuned-Qwen2.5_1B",
20
+ repo_id=LLM_REPO_ID,
21
+ huggingfacehub_api_token=HF_API_KEY,
22
+ temperature=0.1,
23
+ max_new_tokens=256,
24
+ top_k=50,
25
+ top_p=0.95
26
+ )
27
+ print("--- LLM Initialized ---")
28
+ return _llm_instance
29
+
30
+ classifier_prompt = ChatPromptTemplate.from_messages([
31
+ ("system", (
32
+ "You are a helpful classification assistant. "
33
+ "Your task is to classify the user's last response as 'yes', 'no', or 'unclear' "
34
+ "based on their message. "
35
+ "User's previous message: '{last_human_message}'"
36
+ "\nRespond ONLY with a single JSON object in the format: "
37
+ "{{\"classification\": \"yes\"}} or {{\"classification\": \"no\"}} or {{\"classification\": \"unclear\"}}"
38
+ ))
39
+ ])
40
+
41
+
42
+ symptom_classifier_chain = classifier_prompt | get_llm() | JsonOutputParser()
43
+
44
+ question_prompt = ChatPromptTemplate.from_messages([
45
+ ("system", (
46
+ "You are a friendly medical assistant bot. Ask the user if they are experiencing the "
47
+ "following symptom. Be clear and concise. Do not add any extra greeting or sign-off. "
48
+ "Symptom: '{symptom}'"
49
+ "\nExample: Are you experiencing any itchiness or a rash?"
50
+ ))
51
+ ])
52
+
53
+ question_generation_chain = question_prompt | get_llm() | StrOutputParser()
54
+
55
+
56
+ summary_prompt = ChatPromptTemplate.from_messages([
57
+ ("system", (
58
+ "You are a helpful medical assistant providing a summary. "
59
+ "Based on the initial image analysis and confirmed symptoms, generate a summary. "
60
+ "DO NOT provide a definitive diagnosis. "
61
+ "Structure your response clearly: "
62
+ "1. Start by stating the potential condition identified from the image."
63
+ "2. List the symptoms the user confirmed."
64
+ "3. Provide the general treatment information found for this condition."
65
+ "4. **ALWAYS** include the provided disclaimer at the very end."
66
+ "\n---"
67
+ "Initial Image Prediction: {disease}"
68
+ "Confirmed Symptoms: {symptoms}"
69
+ "Potential Treatment Information: {treatment}"
70
+ "Disclaimer: {disclaimer}"
71
+ "\n---"
72
+ "Generate your summary now."
73
+ ))
74
+ ])
75
+
76
+ summary_generation_chain = summary_prompt | get_llm() | StrOutputParser()
merge.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LoRA Adapter Merge Script for M1 Mac
3
+ Merges LoRA adapters with the base Qwen2.5-1.5B model
4
+ """
5
+
6
+ import torch
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+ from peft import PeftModel
9
+ import os
10
+ import argparse
11
+
12
+ def merge_lora_weights(
13
+ base_model_name: str = "unsloth/Qwen2.5-1.5B",
14
+ lora_adapter_path: str = "fine_tuned_model", #fine_tuned_model
15
+ output_path: str = "merged_model",
16
+ device: str = "mps" # Use Metal Performance Shaders for M1
17
+ ):
18
+ """
19
+ Merge LoRA adapters with base model
20
+
21
+ Args:
22
+ base_model_name: HuggingFace model name or path to base model
23
+ lora_adapter_path: Path to your LoRA adapter files
24
+ output_path: Where to save the merged model
25
+ device: Device to use ('mps' for M1/M2 Mac, 'cpu' for compatibility)
26
+ """
27
+
28
+ print(f"🚀 Starting LoRA merge process...")
29
+ print(f"📦 Base model: {base_model_name}")
30
+ print(f"🔧 LoRA adapters: {lora_adapter_path}")
31
+ print(f"💾 Output path: {output_path}")
32
+
33
+ # Check if MPS is available (M1/M2 Mac)
34
+ if device == "mps" and not torch.backends.mps.is_available():
35
+ print("⚠️ MPS not available, falling back to CPU")
36
+ device = "cpu"
37
+
38
+ try:
39
+ # Step 1: Load the base model
40
+ print("\n📥 Loading base model...")
41
+ base_model = AutoModelForCausalLM.from_pretrained(
42
+ base_model_name,
43
+ torch_dtype=torch.float16 if device == "mps" else torch.float32,
44
+ device_map={"": device},
45
+ low_cpu_mem_usage=True
46
+ )
47
+ print("✅ Base model loaded successfully")
48
+
49
+ # Step 2: Load the tokenizer
50
+ print("\n📥 Loading tokenizer...")
51
+ tokenizer = AutoTokenizer.from_pretrained(lora_adapter_path)
52
+ print("✅ Tokenizer loaded successfully")
53
+
54
+ # Step 3: Load LoRA adapters
55
+ print("\n📥 Loading LoRA adapters...")
56
+ model_with_lora = PeftModel.from_pretrained(
57
+ base_model,
58
+ lora_adapter_path,
59
+ device_map={"": device}
60
+ )
61
+ print("✅ LoRA adapters loaded successfully")
62
+
63
+ # Step 4: Merge weights
64
+ print("\n🔄 Merging LoRA weights into base model...")
65
+ merged_model = model_with_lora.merge_and_unload()
66
+ print("✅ Weights merged successfully")
67
+
68
+ # Step 5: Save the merged model
69
+ print(f"\n💾 Saving merged model to {output_path}...")
70
+ os.makedirs(output_path, exist_ok=True)
71
+
72
+ merged_model.save_pretrained(
73
+ output_path,
74
+ safe_serialization=True,
75
+ max_shard_size="2GB"
76
+ )
77
+ tokenizer.save_pretrained(output_path)
78
+
79
+ print("✅ Merged model saved successfully!")
80
+ print(f"\n🎉 Complete! Your merged model is ready at: {output_path}")
81
+ print(f"📊 Model size: ~3GB")
82
+
83
+ # Optional: Print model info
84
+ print("\n📋 Model Information:")
85
+ print(f" - Architecture: {merged_model.config.architectures}")
86
+ print(f" - Parameters: {sum(p.numel() for p in merged_model.parameters()):,}")
87
+ print(f" - Vocab size: {merged_model.config.vocab_size}")
88
+
89
+ return merged_model, tokenizer
90
+
91
+ except Exception as e:
92
+ print(f"\n❌ Error during merge process: {str(e)}")
93
+ raise
94
+
95
+ def main():
96
+ parser = argparse.ArgumentParser(description="Merge LoRA adapters with base model")
97
+ parser.add_argument(
98
+ "--base-model",
99
+ type=str,
100
+ default="unsloth/Qwen2.5-1.5B",
101
+ help="Base model name or path"
102
+ )
103
+ parser.add_argument(
104
+ "--lora-path",
105
+ type=str,
106
+ default="./fine_tuned_model",
107
+ help="Path to LoRA adapter files"
108
+ )
109
+ parser.add_argument(
110
+ "--output-path",
111
+ type=str,
112
+ default="./merged_model",
113
+ help="Output path for merged model"
114
+ )
115
+ parser.add_argument(
116
+ "--device",
117
+ type=str,
118
+ default="mps",
119
+ choices=["mps", "cpu"],
120
+ help="Device to use (mps for M1/M2, cpu for compatibility)"
121
+ )
122
+
123
+ args = parser.parse_args()
124
+
125
+ merge_lora_weights(
126
+ base_model_name=args.base_model,
127
+ lora_adapter_path=args.lora_path,
128
+ output_path=args.output_path,
129
+ device=args.device
130
+ )
131
+
132
+ if __name__ == "__main__":
133
+ main()
nodes.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.messages import AIMessage, HumanMessage
2
+ from state import WorkflowState
3
+ from tools import tool_analyze_skin_image, tool_fetch_disease_info
4
+ from llms import (
5
+ symptom_classifier_chain,
6
+ question_generation_chain,
7
+ summary_generation_chain
8
+ )
9
+
10
+
11
+ def node_analyze_image(state: WorkflowState) -> WorkflowState:
12
+ """Analyzes the user's image and updates the state."""
13
+ print("--- Node: Analyzing Image ---")
14
+ image = state.get("image")
15
+ if not image:
16
+ state['chat_history'].append(AIMessage(
17
+ content="Please upload an image first."
18
+ ))
19
+ return state
20
+
21
+ prediction_result = tool_analyze_skin_image.invoke({"image": image})
22
+
23
+ if "Error:" in prediction_result:
24
+ state['chat_history'].append(AIMessage(content=prediction_result))
25
+ state['final_diagnosis'] = "Error"
26
+ return state
27
+
28
+ state['disease_prediction'] = prediction_result
29
+ return state
30
+
31
+ def node_fetch_symptoms(state: WorkflowState) -> WorkflowState:
32
+ """Fetches symptoms for the predicted disease."""
33
+ print(f"--- Node: Fetching Symptoms for {state['disease_prediction']} ---")
34
+ disease = state['disease_prediction']
35
+
36
+ info = tool_fetch_disease_info.invoke({"disease_name": disease})
37
+
38
+ if "error" in info:
39
+ state['chat_history'].append(AIMessage(content=info['error']))
40
+ state['final_diagnosis'] = "Error"
41
+ return state
42
+
43
+ state['symptoms_to_check'] = info.get("symptoms", [])
44
+ state['treatment_info'] = info.get("treatment", "No treatment info available.")
45
+ state['current_symptom_index'] = 0
46
+ state['symptoms_confirmed'] = []
47
+
48
+ if not state['symptoms_to_check']:
49
+ print("No symptoms found to check. Proceeding to final response.")
50
+
51
+ return state
52
+
53
+ def node_ask_symptom_question(state: WorkflowState) -> WorkflowState:
54
+ """Asks the user the next symptom question."""
55
+ print(f"--- Node: Asking Symptom Question {state['current_symptom_index']} ---")
56
+ symptoms = state['symptoms_to_check']
57
+ index = state['current_symptom_index']
58
+
59
+ symptom = symptoms[index]
60
+
61
+ question = question_generation_chain.invoke({"symptom": symptom})
62
+
63
+ state['chat_history'].append(AIMessage(content=question))
64
+ state['current_symptom_index'] = index + 1
65
+ return state
66
+
67
+ def node_process_user_response(state: WorkflowState) -> WorkflowState:
68
+ """Processes the user's 'yes' or 'no' response to a symptom question."""
69
+ print("--- Node: Processing User Response ---")
70
+ last_human_message = state['chat_history'][-1].content
71
+
72
+ index = state['current_symptom_index']
73
+ last_asked_symptom = state['symptoms_to_check'][index - 1]
74
+
75
+ try:
76
+ classification = symptom_classifier_chain.invoke(
77
+ {"last_human_message": last_human_message}
78
+ )
79
+
80
+ if classification.get("classification") == "yes":
81
+ print(f"User confirmed symptom: {last_asked_symptom}")
82
+ state['symptoms_confirmed'].append(last_asked_symptom)
83
+ else:
84
+ print(f"User denied symptom: {last_asked_symptom}")
85
+
86
+ except Exception as e:
87
+ print(f"Error classifying user response: {e}. Assuming 'unclear'.")
88
+
89
+ return state
90
+
91
+ def node_generate_final_response(state: WorkflowState) -> WorkflowState:
92
+ """Generates the final summary and disclaimer for the user."""
93
+ print("--- Node: Generating Final Response ---")
94
+
95
+ disclaimer = (
96
+ "\n\n**DISCLAIMER:**\n"
97
+ "I am just a dumb agent, not a medical professional. "
98
+ "This is a side project for learning purposes. "
99
+ "Please **DO NOT** take this information for face value. "
100
+ "Consult a real doctor or dermatologist for any medical concerns."
101
+ )
102
+
103
+ summary = summary_generation_chain.invoke({
104
+ "disease": state['disease_prediction'],
105
+ "symptoms": ", ".join(state['symptoms_confirmed']) or "None confirmed",
106
+ "treatment": state['treatment_info'],
107
+ "disclaimer": disclaimer
108
+ })
109
+
110
+ state['chat_history'].append(AIMessage(content=summary))
111
+ state['final_diagnosis'] = "Complete"
112
+ return state
113
+
114
+
115
+ def router_should_ask_symptoms(state: WorkflowState) -> str:
116
+ """
117
+ Checks if there are symptoms to ask about.
118
+ If yes -> ask_symptom_question
119
+ If no -> generate_final_response
120
+ """
121
+ if state.get("symptoms_to_check"):
122
+ return "ask_symptom_question"
123
+ else:
124
+ return "generate_final_response"
125
+
126
+ def router_should_continue_asking(state: WorkflowState) -> str:
127
+ """
128
+ Checks if we have more symptoms to ask about after a user's response.
129
+ If yes -> ask_symptom_question
130
+ If no -> generate_final_response
131
+ """
132
+ if state['current_symptom_index'] < len(state['symptoms_to_check']):
133
+ return "ask_symptom_question"
134
+ else:
135
+ return "generate_final_response"
136
+
137
+ def router_check_image_analysis(state: WorkflowState) -> str:
138
+ """
139
+ Checks if the image analysis was successful.
140
+ """
141
+ if state.get("final_diagnosis") == "Error":
142
+ return "end_error"
143
+ else:
144
+ return "fetch_symptoms"
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pydantic
2
+ langchain-pinecone
3
+ langchain
4
+ langchain-community
5
+ ipykernel
6
+ langchain-core
7
+ langchain-text-splitters
8
+ langgraph
9
+ sentence-transformers
10
+ python-dotenv
11
+ gradio
12
+ langchain-huggingface
13
+ pinecone-client
14
+ requests
15
+ pillow
16
+ pypdf
17
+ pinecone-client
state.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TypedDict, List, Optional
2
+ from langchain_core.messages import BaseMessage
3
+ from PIL.Image import Image
4
+
5
+ class WorkflowState(TypedDict):
6
+ """
7
+ Represents the state of our agent's workflow.
8
+ This dictionary is passed between nodes, allowing them to share information.
9
+ """
10
+ image: Optional[Image]
11
+ chat_history: List[BaseMessage]
12
+
13
+ disease_prediction: str
14
+ symptoms_to_check: List[str]
15
+ treatment_info: str
16
+
17
+ symptoms_confirmed: List[str]
18
+ current_symptom_index: int
19
+
20
+ final_diagnosis: str