Spaces:
Sleeping
Sleeping
Upload 7 files
Browse files
app.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
|
| 4 |
+
load_dotenv()
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from langchain_core.messages import HumanMessage, AIMessage
|
| 9 |
+
|
| 10 |
+
from graph import build_graph
|
| 11 |
+
from state import WorkflowState
|
| 12 |
+
|
| 13 |
+
diagnosis_graph = None
|
| 14 |
+
reply_graph = None
|
| 15 |
+
initialization_error = None
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
diagnosis_graph, reply_graph = build_graph()
|
| 19 |
+
print("Gradio app and graphs initialized successfully.")
|
| 20 |
+
except Exception as e:
|
| 21 |
+
initialization_error = f"CRITICAL ERROR: Could not build graphs. {e}. Check API keys."
|
| 22 |
+
print(initialization_error)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def convert_history_to_langchain(chat_history):
|
| 26 |
+
"""Converts Gradio history to Langchain message list."""
|
| 27 |
+
messages = []
|
| 28 |
+
for user_msg, ai_msg in chat_history:
|
| 29 |
+
if user_msg is not None:
|
| 30 |
+
if isinstance(user_msg, (dict, tuple)):
|
| 31 |
+
messages.append(HumanMessage(content="User uploaded an image."))
|
| 32 |
+
else:
|
| 33 |
+
messages.append(HumanMessage(content=user_msg))
|
| 34 |
+
if ai_msg is not None:
|
| 35 |
+
messages.append(AIMessage(content=ai_msg))
|
| 36 |
+
return messages
|
| 37 |
+
|
| 38 |
+
def reset_state_on_start():
|
| 39 |
+
"""Fresh empty state for every new diagnosis."""
|
| 40 |
+
return WorkflowState(
|
| 41 |
+
image=None,
|
| 42 |
+
chat_history=[],
|
| 43 |
+
disease_prediction="",
|
| 44 |
+
symptoms_to_check=[],
|
| 45 |
+
symptoms_confirmed=[],
|
| 46 |
+
current_symptom_index=0,
|
| 47 |
+
treatment_info="",
|
| 48 |
+
final_diagnosis=""
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def chat_fn(message: str, chat_history: list, agent_state: dict, img_upload: Image):
|
| 53 |
+
"""
|
| 54 |
+
Handles user input and manages the agent's workflow state.
|
| 55 |
+
"""
|
| 56 |
+
if initialization_error:
|
| 57 |
+
chat_history.append((message, initialization_error))
|
| 58 |
+
yield chat_history, {}, gr.update(value=None, interactive=True), gr.update(value="", interactive=True)
|
| 59 |
+
return
|
| 60 |
+
|
| 61 |
+
if not agent_state or agent_state.get("final_diagnosis"):
|
| 62 |
+
current_state = reset_state_on_start()
|
| 63 |
+
else:
|
| 64 |
+
current_state = agent_state
|
| 65 |
+
|
| 66 |
+
chat_history = chat_history or []
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
is_new_diagnosis = False
|
| 70 |
+
if img_upload and (message.lower().strip() == "start" or message == ""):
|
| 71 |
+
print("--- Running NEW diagnosis flow ---")
|
| 72 |
+
is_new_diagnosis = True
|
| 73 |
+
current_state = reset_state_on_start()
|
| 74 |
+
current_state["image"] = img_upload
|
| 75 |
+
graph_to_run = diagnosis_graph
|
| 76 |
+
chat_history.append([(img_upload,), None])
|
| 77 |
+
|
| 78 |
+
elif current_state.get("symptoms_to_check") and not current_state.get("final_diagnosis"):
|
| 79 |
+
print("--- Running REPLY symptom loop flow ---")
|
| 80 |
+
graph_to_run = reply_graph
|
| 81 |
+
chat_history.append([message, None])
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
else:
|
| 85 |
+
|
| 86 |
+
if message:
|
| 87 |
+
chat_history.append([message, None])
|
| 88 |
+
chat_history[-1][1] = "Hello! Please upload an image, then click 'Start Diagnosis'."
|
| 89 |
+
|
| 90 |
+
yield chat_history, agent_state, gr.update(value=img_upload, interactive=True), gr.update(value="", interactive=True)
|
| 91 |
+
return
|
| 92 |
+
|
| 93 |
+
current_state["chat_history"] = convert_history_to_langchain(chat_history)
|
| 94 |
+
|
| 95 |
+
try:
|
| 96 |
+
final_state = {}
|
| 97 |
+
for step in graph_to_run.stream(current_state, {"recursion_limit": 100}):
|
| 98 |
+
|
| 99 |
+
final_state = list(step.values())[0]
|
| 100 |
+
|
| 101 |
+
ai_response = final_state['chat_history'][-1].content
|
| 102 |
+
chat_history[-1][1] = ai_response
|
| 103 |
+
|
| 104 |
+
if final_state.get("final_diagnosis"):
|
| 105 |
+
print("--- Agent Flow ENDED ---")
|
| 106 |
+
yield chat_history, {}, gr.update(value=None, interactive=True), gr.update(value="", interactive=True)
|
| 107 |
+
else:
|
| 108 |
+
yield chat_history, final_state, gr.update(value=img_upload, interactive=False), gr.update(value="", interactive=True)
|
| 109 |
+
|
| 110 |
+
except Exception as e:
|
| 111 |
+
print(f"--- Graph Runtime Error --- \n{e}")
|
| 112 |
+
error_msg = f"A runtime error occurred: {e}. Please check the console."
|
| 113 |
+
chat_history[-1][1] = error_msg
|
| 114 |
+
yield chat_history, {}, gr.update(value=None, interactive=True), gr.update(value="", interactive=True)
|
| 115 |
+
|
| 116 |
+
def clear_all():
|
| 117 |
+
"""Clears chat, state, and image."""
|
| 118 |
+
return [], {}, None, ""
|
| 119 |
+
|
| 120 |
+
with gr.Blocks(theme=gr.themes.Soft(), title="Agentic Skin AI") as demo:
|
| 121 |
+
gr.Markdown("# 🩺 Multimodal Agentic Skin Disease AI")
|
| 122 |
+
gr.Markdown(
|
| 123 |
+
"**Disclaimer:** This is a demo project and NOT a medical device. "
|
| 124 |
+
"Consult a real doctor for any medical concerns."
|
| 125 |
+
)
|
| 126 |
+
agent_state = gr.State({})
|
| 127 |
+
|
| 128 |
+
with gr.Row():
|
| 129 |
+
with gr.Column(scale=1):
|
| 130 |
+
gr.Markdown("### 1. Image Input")
|
| 131 |
+
img_upload = gr.Image(type="pil", label="Upload Skin Image", interactive=True)
|
| 132 |
+
|
| 133 |
+
btn_start = gr.Button("Start Diagnosis", variant="primary")
|
| 134 |
+
btn_clear = gr.Button("Clear All & Start New")
|
| 135 |
+
|
| 136 |
+
gr.Markdown(
|
| 137 |
+
"**Instructions:**\n"
|
| 138 |
+
"1. Upload an image.\n"
|
| 139 |
+
"2. Click **Start Diagnosis**.\n"
|
| 140 |
+
"3. Answer the agent's questions in the textbox."
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
with gr.Column(scale=2):
|
| 144 |
+
gr.Markdown("### 2. Agent Conversation")
|
| 145 |
+
chatbot = gr.Chatbot(label="Agent Conversation", height=500, bubble_full_width=False, avatar_images=None)
|
| 146 |
+
txt_msg = gr.Textbox(
|
| 147 |
+
label="Your message (Yes / No / etc.)",
|
| 148 |
+
placeholder="Answer the agent's questions here...",
|
| 149 |
+
interactive=True
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
txt_msg.submit(
|
| 153 |
+
fn=chat_fn,
|
| 154 |
+
inputs=[txt_msg, chatbot, agent_state, img_upload],
|
| 155 |
+
outputs=[chatbot, agent_state, img_upload, txt_msg]
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
btn_start.click(
|
| 159 |
+
fn=chat_fn,
|
| 160 |
+
inputs=[gr.Textbox(value="start", visible=False), chatbot, agent_state, img_upload],
|
| 161 |
+
outputs=[chatbot, agent_state, img_upload, txt_msg]
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
btn_clear.click(
|
| 165 |
+
fn=clear_all,
|
| 166 |
+
inputs=None,
|
| 167 |
+
outputs=[chatbot, agent_state, img_upload, txt_msg]
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
img_upload.upload(
|
| 171 |
+
fn=lambda: ([], {}, "Click 'Start Diagnosis' to begin."),
|
| 172 |
+
inputs=None,
|
| 173 |
+
outputs=[chatbot, agent_state, txt_msg]
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
if __name__ == "__main__":
|
| 177 |
+
if initialization_error:
|
| 178 |
+
print("\n\n*** CANNOT LAUNCH APP: Agent failed to initialize. ***")
|
| 179 |
+
print(f"*** ERROR: {initialization_error} ***")
|
| 180 |
+
else:
|
| 181 |
+
demo.launch(debug=True)
|
graph.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langgraph.graph import StateGraph, END
|
| 2 |
+
from state import WorkflowState
|
| 3 |
+
from nodes import (
|
| 4 |
+
node_analyze_image,
|
| 5 |
+
node_fetch_symptoms,
|
| 6 |
+
node_ask_symptom_question,
|
| 7 |
+
node_process_user_response,
|
| 8 |
+
node_generate_final_response,
|
| 9 |
+
router_check_image_analysis,
|
| 10 |
+
router_should_ask_symptoms,
|
| 11 |
+
router_should_continue_asking
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
def build_graph():
|
| 15 |
+
"""
|
| 16 |
+
Builds and compiles the two agentic workflows:
|
| 17 |
+
1. The main diagnosis graph.
|
| 18 |
+
2. The reply-handling graph.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
workflow = StateGraph(WorkflowState)
|
| 22 |
+
|
| 23 |
+
workflow.add_node("analyze_image", node_analyze_image)
|
| 24 |
+
workflow.add_node("fetch_symptoms", node_fetch_symptoms)
|
| 25 |
+
workflow.add_node("ask_symptom_question", node_ask_symptom_question)
|
| 26 |
+
workflow.add_node("generate_final_response", node_generate_final_response)
|
| 27 |
+
|
| 28 |
+
workflow.set_entry_point("analyze_image")
|
| 29 |
+
|
| 30 |
+
workflow.add_conditional_edges(
|
| 31 |
+
"analyze_image",
|
| 32 |
+
router_check_image_analysis,
|
| 33 |
+
{
|
| 34 |
+
"fetch_symptoms": "fetch_symptoms",
|
| 35 |
+
"end_error": END
|
| 36 |
+
}
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
workflow.add_conditional_edges(
|
| 40 |
+
"fetch_symptoms",
|
| 41 |
+
router_should_ask_symptoms,
|
| 42 |
+
{
|
| 43 |
+
"ask_symptom_question": "ask_symptom_question",
|
| 44 |
+
"generate_final_response": "generate_final_response"
|
| 45 |
+
}
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
workflow.add_edge("ask_symptom_question", END)
|
| 49 |
+
workflow.add_edge("generate_final_response", END)
|
| 50 |
+
|
| 51 |
+
diagnosis_graph = workflow.compile()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
reply_workflow = StateGraph(WorkflowState)
|
| 55 |
+
|
| 56 |
+
reply_workflow.add_node("process_user_response", node_process_user_response)
|
| 57 |
+
reply_workflow.add_node("ask_symptom_question", node_ask_symptom_question)
|
| 58 |
+
reply_workflow.add_node("generate_final_response", node_generate_final_response)
|
| 59 |
+
|
| 60 |
+
reply_workflow.set_entry_point("process_user_response")
|
| 61 |
+
|
| 62 |
+
reply_workflow.add_conditional_edges(
|
| 63 |
+
"process_user_response",
|
| 64 |
+
router_should_continue_asking,
|
| 65 |
+
{
|
| 66 |
+
"ask_symptom_question": "ask_symptom_question",
|
| 67 |
+
"generate_final_response": "generate_final_response"
|
| 68 |
+
}
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
reply_workflow.add_edge("ask_symptom_question", END)
|
| 72 |
+
reply_workflow.add_edge("generate_final_response", END)
|
| 73 |
+
|
| 74 |
+
reply_graph = reply_workflow.compile()
|
| 75 |
+
|
| 76 |
+
print("--- LangGraph Compiled ---")
|
| 77 |
+
|
| 78 |
+
return diagnosis_graph, reply_graph
|
llms.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from langchain_huggingface import HuggingFaceEndpoint
|
| 3 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 4 |
+
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
|
| 5 |
+
from langchain_core.runnables import RunnableLambda
|
| 6 |
+
|
| 7 |
+
LLM_REPO_ID = os.environ.get("LLM_REPO_ID", "Jyo-K/Fine-Tuned-Qwen2.5_1B")
|
| 8 |
+
HF_API_KEY = os.environ.get("HF_API_KEY")
|
| 9 |
+
|
| 10 |
+
_llm_instance = None
|
| 11 |
+
|
| 12 |
+
def get_llm():
|
| 13 |
+
"""Lazily initializes and returns the LLM instance."""
|
| 14 |
+
global _llm_instance
|
| 15 |
+
if _llm_instance is None:
|
| 16 |
+
if not HF_API_KEY:
|
| 17 |
+
raise ValueError("HF_TOKEN environment variable not set. Cannot initialize LLM.")
|
| 18 |
+
_llm_instance = HuggingFaceEndpoint(
|
| 19 |
+
#model= "Jyo-K/Fine-Tuned-Qwen2.5_1B",
|
| 20 |
+
repo_id=LLM_REPO_ID,
|
| 21 |
+
huggingfacehub_api_token=HF_API_KEY,
|
| 22 |
+
temperature=0.1,
|
| 23 |
+
max_new_tokens=256,
|
| 24 |
+
top_k=50,
|
| 25 |
+
top_p=0.95
|
| 26 |
+
)
|
| 27 |
+
print("--- LLM Initialized ---")
|
| 28 |
+
return _llm_instance
|
| 29 |
+
|
| 30 |
+
classifier_prompt = ChatPromptTemplate.from_messages([
|
| 31 |
+
("system", (
|
| 32 |
+
"You are a helpful classification assistant. "
|
| 33 |
+
"Your task is to classify the user's last response as 'yes', 'no', or 'unclear' "
|
| 34 |
+
"based on their message. "
|
| 35 |
+
"User's previous message: '{last_human_message}'"
|
| 36 |
+
"\nRespond ONLY with a single JSON object in the format: "
|
| 37 |
+
"{{\"classification\": \"yes\"}} or {{\"classification\": \"no\"}} or {{\"classification\": \"unclear\"}}"
|
| 38 |
+
))
|
| 39 |
+
])
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
symptom_classifier_chain = classifier_prompt | get_llm() | JsonOutputParser()
|
| 43 |
+
|
| 44 |
+
question_prompt = ChatPromptTemplate.from_messages([
|
| 45 |
+
("system", (
|
| 46 |
+
"You are a friendly medical assistant bot. Ask the user if they are experiencing the "
|
| 47 |
+
"following symptom. Be clear and concise. Do not add any extra greeting or sign-off. "
|
| 48 |
+
"Symptom: '{symptom}'"
|
| 49 |
+
"\nExample: Are you experiencing any itchiness or a rash?"
|
| 50 |
+
))
|
| 51 |
+
])
|
| 52 |
+
|
| 53 |
+
question_generation_chain = question_prompt | get_llm() | StrOutputParser()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
summary_prompt = ChatPromptTemplate.from_messages([
|
| 57 |
+
("system", (
|
| 58 |
+
"You are a helpful medical assistant providing a summary. "
|
| 59 |
+
"Based on the initial image analysis and confirmed symptoms, generate a summary. "
|
| 60 |
+
"DO NOT provide a definitive diagnosis. "
|
| 61 |
+
"Structure your response clearly: "
|
| 62 |
+
"1. Start by stating the potential condition identified from the image."
|
| 63 |
+
"2. List the symptoms the user confirmed."
|
| 64 |
+
"3. Provide the general treatment information found for this condition."
|
| 65 |
+
"4. **ALWAYS** include the provided disclaimer at the very end."
|
| 66 |
+
"\n---"
|
| 67 |
+
"Initial Image Prediction: {disease}"
|
| 68 |
+
"Confirmed Symptoms: {symptoms}"
|
| 69 |
+
"Potential Treatment Information: {treatment}"
|
| 70 |
+
"Disclaimer: {disclaimer}"
|
| 71 |
+
"\n---"
|
| 72 |
+
"Generate your summary now."
|
| 73 |
+
))
|
| 74 |
+
])
|
| 75 |
+
|
| 76 |
+
summary_generation_chain = summary_prompt | get_llm() | StrOutputParser()
|
merge.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LoRA Adapter Merge Script for M1 Mac
|
| 3 |
+
Merges LoRA adapters with the base Qwen2.5-1.5B model
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 8 |
+
from peft import PeftModel
|
| 9 |
+
import os
|
| 10 |
+
import argparse
|
| 11 |
+
|
| 12 |
+
def merge_lora_weights(
|
| 13 |
+
base_model_name: str = "unsloth/Qwen2.5-1.5B",
|
| 14 |
+
lora_adapter_path: str = "fine_tuned_model", #fine_tuned_model
|
| 15 |
+
output_path: str = "merged_model",
|
| 16 |
+
device: str = "mps" # Use Metal Performance Shaders for M1
|
| 17 |
+
):
|
| 18 |
+
"""
|
| 19 |
+
Merge LoRA adapters with base model
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
base_model_name: HuggingFace model name or path to base model
|
| 23 |
+
lora_adapter_path: Path to your LoRA adapter files
|
| 24 |
+
output_path: Where to save the merged model
|
| 25 |
+
device: Device to use ('mps' for M1/M2 Mac, 'cpu' for compatibility)
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
print(f"🚀 Starting LoRA merge process...")
|
| 29 |
+
print(f"📦 Base model: {base_model_name}")
|
| 30 |
+
print(f"🔧 LoRA adapters: {lora_adapter_path}")
|
| 31 |
+
print(f"💾 Output path: {output_path}")
|
| 32 |
+
|
| 33 |
+
# Check if MPS is available (M1/M2 Mac)
|
| 34 |
+
if device == "mps" and not torch.backends.mps.is_available():
|
| 35 |
+
print("⚠️ MPS not available, falling back to CPU")
|
| 36 |
+
device = "cpu"
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
# Step 1: Load the base model
|
| 40 |
+
print("\n📥 Loading base model...")
|
| 41 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 42 |
+
base_model_name,
|
| 43 |
+
torch_dtype=torch.float16 if device == "mps" else torch.float32,
|
| 44 |
+
device_map={"": device},
|
| 45 |
+
low_cpu_mem_usage=True
|
| 46 |
+
)
|
| 47 |
+
print("✅ Base model loaded successfully")
|
| 48 |
+
|
| 49 |
+
# Step 2: Load the tokenizer
|
| 50 |
+
print("\n📥 Loading tokenizer...")
|
| 51 |
+
tokenizer = AutoTokenizer.from_pretrained(lora_adapter_path)
|
| 52 |
+
print("✅ Tokenizer loaded successfully")
|
| 53 |
+
|
| 54 |
+
# Step 3: Load LoRA adapters
|
| 55 |
+
print("\n📥 Loading LoRA adapters...")
|
| 56 |
+
model_with_lora = PeftModel.from_pretrained(
|
| 57 |
+
base_model,
|
| 58 |
+
lora_adapter_path,
|
| 59 |
+
device_map={"": device}
|
| 60 |
+
)
|
| 61 |
+
print("✅ LoRA adapters loaded successfully")
|
| 62 |
+
|
| 63 |
+
# Step 4: Merge weights
|
| 64 |
+
print("\n🔄 Merging LoRA weights into base model...")
|
| 65 |
+
merged_model = model_with_lora.merge_and_unload()
|
| 66 |
+
print("✅ Weights merged successfully")
|
| 67 |
+
|
| 68 |
+
# Step 5: Save the merged model
|
| 69 |
+
print(f"\n💾 Saving merged model to {output_path}...")
|
| 70 |
+
os.makedirs(output_path, exist_ok=True)
|
| 71 |
+
|
| 72 |
+
merged_model.save_pretrained(
|
| 73 |
+
output_path,
|
| 74 |
+
safe_serialization=True,
|
| 75 |
+
max_shard_size="2GB"
|
| 76 |
+
)
|
| 77 |
+
tokenizer.save_pretrained(output_path)
|
| 78 |
+
|
| 79 |
+
print("✅ Merged model saved successfully!")
|
| 80 |
+
print(f"\n🎉 Complete! Your merged model is ready at: {output_path}")
|
| 81 |
+
print(f"📊 Model size: ~3GB")
|
| 82 |
+
|
| 83 |
+
# Optional: Print model info
|
| 84 |
+
print("\n📋 Model Information:")
|
| 85 |
+
print(f" - Architecture: {merged_model.config.architectures}")
|
| 86 |
+
print(f" - Parameters: {sum(p.numel() for p in merged_model.parameters()):,}")
|
| 87 |
+
print(f" - Vocab size: {merged_model.config.vocab_size}")
|
| 88 |
+
|
| 89 |
+
return merged_model, tokenizer
|
| 90 |
+
|
| 91 |
+
except Exception as e:
|
| 92 |
+
print(f"\n❌ Error during merge process: {str(e)}")
|
| 93 |
+
raise
|
| 94 |
+
|
| 95 |
+
def main():
|
| 96 |
+
parser = argparse.ArgumentParser(description="Merge LoRA adapters with base model")
|
| 97 |
+
parser.add_argument(
|
| 98 |
+
"--base-model",
|
| 99 |
+
type=str,
|
| 100 |
+
default="unsloth/Qwen2.5-1.5B",
|
| 101 |
+
help="Base model name or path"
|
| 102 |
+
)
|
| 103 |
+
parser.add_argument(
|
| 104 |
+
"--lora-path",
|
| 105 |
+
type=str,
|
| 106 |
+
default="./fine_tuned_model",
|
| 107 |
+
help="Path to LoRA adapter files"
|
| 108 |
+
)
|
| 109 |
+
parser.add_argument(
|
| 110 |
+
"--output-path",
|
| 111 |
+
type=str,
|
| 112 |
+
default="./merged_model",
|
| 113 |
+
help="Output path for merged model"
|
| 114 |
+
)
|
| 115 |
+
parser.add_argument(
|
| 116 |
+
"--device",
|
| 117 |
+
type=str,
|
| 118 |
+
default="mps",
|
| 119 |
+
choices=["mps", "cpu"],
|
| 120 |
+
help="Device to use (mps for M1/M2, cpu for compatibility)"
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
args = parser.parse_args()
|
| 124 |
+
|
| 125 |
+
merge_lora_weights(
|
| 126 |
+
base_model_name=args.base_model,
|
| 127 |
+
lora_adapter_path=args.lora_path,
|
| 128 |
+
output_path=args.output_path,
|
| 129 |
+
device=args.device
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
if __name__ == "__main__":
|
| 133 |
+
main()
|
nodes.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_core.messages import AIMessage, HumanMessage
|
| 2 |
+
from state import WorkflowState
|
| 3 |
+
from tools import tool_analyze_skin_image, tool_fetch_disease_info
|
| 4 |
+
from llms import (
|
| 5 |
+
symptom_classifier_chain,
|
| 6 |
+
question_generation_chain,
|
| 7 |
+
summary_generation_chain
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def node_analyze_image(state: WorkflowState) -> WorkflowState:
|
| 12 |
+
"""Analyzes the user's image and updates the state."""
|
| 13 |
+
print("--- Node: Analyzing Image ---")
|
| 14 |
+
image = state.get("image")
|
| 15 |
+
if not image:
|
| 16 |
+
state['chat_history'].append(AIMessage(
|
| 17 |
+
content="Please upload an image first."
|
| 18 |
+
))
|
| 19 |
+
return state
|
| 20 |
+
|
| 21 |
+
prediction_result = tool_analyze_skin_image.invoke({"image": image})
|
| 22 |
+
|
| 23 |
+
if "Error:" in prediction_result:
|
| 24 |
+
state['chat_history'].append(AIMessage(content=prediction_result))
|
| 25 |
+
state['final_diagnosis'] = "Error"
|
| 26 |
+
return state
|
| 27 |
+
|
| 28 |
+
state['disease_prediction'] = prediction_result
|
| 29 |
+
return state
|
| 30 |
+
|
| 31 |
+
def node_fetch_symptoms(state: WorkflowState) -> WorkflowState:
|
| 32 |
+
"""Fetches symptoms for the predicted disease."""
|
| 33 |
+
print(f"--- Node: Fetching Symptoms for {state['disease_prediction']} ---")
|
| 34 |
+
disease = state['disease_prediction']
|
| 35 |
+
|
| 36 |
+
info = tool_fetch_disease_info.invoke({"disease_name": disease})
|
| 37 |
+
|
| 38 |
+
if "error" in info:
|
| 39 |
+
state['chat_history'].append(AIMessage(content=info['error']))
|
| 40 |
+
state['final_diagnosis'] = "Error"
|
| 41 |
+
return state
|
| 42 |
+
|
| 43 |
+
state['symptoms_to_check'] = info.get("symptoms", [])
|
| 44 |
+
state['treatment_info'] = info.get("treatment", "No treatment info available.")
|
| 45 |
+
state['current_symptom_index'] = 0
|
| 46 |
+
state['symptoms_confirmed'] = []
|
| 47 |
+
|
| 48 |
+
if not state['symptoms_to_check']:
|
| 49 |
+
print("No symptoms found to check. Proceeding to final response.")
|
| 50 |
+
|
| 51 |
+
return state
|
| 52 |
+
|
| 53 |
+
def node_ask_symptom_question(state: WorkflowState) -> WorkflowState:
|
| 54 |
+
"""Asks the user the next symptom question."""
|
| 55 |
+
print(f"--- Node: Asking Symptom Question {state['current_symptom_index']} ---")
|
| 56 |
+
symptoms = state['symptoms_to_check']
|
| 57 |
+
index = state['current_symptom_index']
|
| 58 |
+
|
| 59 |
+
symptom = symptoms[index]
|
| 60 |
+
|
| 61 |
+
question = question_generation_chain.invoke({"symptom": symptom})
|
| 62 |
+
|
| 63 |
+
state['chat_history'].append(AIMessage(content=question))
|
| 64 |
+
state['current_symptom_index'] = index + 1
|
| 65 |
+
return state
|
| 66 |
+
|
| 67 |
+
def node_process_user_response(state: WorkflowState) -> WorkflowState:
|
| 68 |
+
"""Processes the user's 'yes' or 'no' response to a symptom question."""
|
| 69 |
+
print("--- Node: Processing User Response ---")
|
| 70 |
+
last_human_message = state['chat_history'][-1].content
|
| 71 |
+
|
| 72 |
+
index = state['current_symptom_index']
|
| 73 |
+
last_asked_symptom = state['symptoms_to_check'][index - 1]
|
| 74 |
+
|
| 75 |
+
try:
|
| 76 |
+
classification = symptom_classifier_chain.invoke(
|
| 77 |
+
{"last_human_message": last_human_message}
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
if classification.get("classification") == "yes":
|
| 81 |
+
print(f"User confirmed symptom: {last_asked_symptom}")
|
| 82 |
+
state['symptoms_confirmed'].append(last_asked_symptom)
|
| 83 |
+
else:
|
| 84 |
+
print(f"User denied symptom: {last_asked_symptom}")
|
| 85 |
+
|
| 86 |
+
except Exception as e:
|
| 87 |
+
print(f"Error classifying user response: {e}. Assuming 'unclear'.")
|
| 88 |
+
|
| 89 |
+
return state
|
| 90 |
+
|
| 91 |
+
def node_generate_final_response(state: WorkflowState) -> WorkflowState:
|
| 92 |
+
"""Generates the final summary and disclaimer for the user."""
|
| 93 |
+
print("--- Node: Generating Final Response ---")
|
| 94 |
+
|
| 95 |
+
disclaimer = (
|
| 96 |
+
"\n\n**DISCLAIMER:**\n"
|
| 97 |
+
"I am just a dumb agent, not a medical professional. "
|
| 98 |
+
"This is a side project for learning purposes. "
|
| 99 |
+
"Please **DO NOT** take this information for face value. "
|
| 100 |
+
"Consult a real doctor or dermatologist for any medical concerns."
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
summary = summary_generation_chain.invoke({
|
| 104 |
+
"disease": state['disease_prediction'],
|
| 105 |
+
"symptoms": ", ".join(state['symptoms_confirmed']) or "None confirmed",
|
| 106 |
+
"treatment": state['treatment_info'],
|
| 107 |
+
"disclaimer": disclaimer
|
| 108 |
+
})
|
| 109 |
+
|
| 110 |
+
state['chat_history'].append(AIMessage(content=summary))
|
| 111 |
+
state['final_diagnosis'] = "Complete"
|
| 112 |
+
return state
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def router_should_ask_symptoms(state: WorkflowState) -> str:
|
| 116 |
+
"""
|
| 117 |
+
Checks if there are symptoms to ask about.
|
| 118 |
+
If yes -> ask_symptom_question
|
| 119 |
+
If no -> generate_final_response
|
| 120 |
+
"""
|
| 121 |
+
if state.get("symptoms_to_check"):
|
| 122 |
+
return "ask_symptom_question"
|
| 123 |
+
else:
|
| 124 |
+
return "generate_final_response"
|
| 125 |
+
|
| 126 |
+
def router_should_continue_asking(state: WorkflowState) -> str:
|
| 127 |
+
"""
|
| 128 |
+
Checks if we have more symptoms to ask about after a user's response.
|
| 129 |
+
If yes -> ask_symptom_question
|
| 130 |
+
If no -> generate_final_response
|
| 131 |
+
"""
|
| 132 |
+
if state['current_symptom_index'] < len(state['symptoms_to_check']):
|
| 133 |
+
return "ask_symptom_question"
|
| 134 |
+
else:
|
| 135 |
+
return "generate_final_response"
|
| 136 |
+
|
| 137 |
+
def router_check_image_analysis(state: WorkflowState) -> str:
|
| 138 |
+
"""
|
| 139 |
+
Checks if the image analysis was successful.
|
| 140 |
+
"""
|
| 141 |
+
if state.get("final_diagnosis") == "Error":
|
| 142 |
+
return "end_error"
|
| 143 |
+
else:
|
| 144 |
+
return "fetch_symptoms"
|
requirements.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pydantic
|
| 2 |
+
langchain-pinecone
|
| 3 |
+
langchain
|
| 4 |
+
langchain-community
|
| 5 |
+
ipykernel
|
| 6 |
+
langchain-core
|
| 7 |
+
langchain-text-splitters
|
| 8 |
+
langgraph
|
| 9 |
+
sentence-transformers
|
| 10 |
+
python-dotenv
|
| 11 |
+
gradio
|
| 12 |
+
langchain-huggingface
|
| 13 |
+
pinecone-client
|
| 14 |
+
requests
|
| 15 |
+
pillow
|
| 16 |
+
pypdf
|
| 17 |
+
pinecone-client
|
state.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TypedDict, List, Optional
|
| 2 |
+
from langchain_core.messages import BaseMessage
|
| 3 |
+
from PIL.Image import Image
|
| 4 |
+
|
| 5 |
+
class WorkflowState(TypedDict):
|
| 6 |
+
"""
|
| 7 |
+
Represents the state of our agent's workflow.
|
| 8 |
+
This dictionary is passed between nodes, allowing them to share information.
|
| 9 |
+
"""
|
| 10 |
+
image: Optional[Image]
|
| 11 |
+
chat_history: List[BaseMessage]
|
| 12 |
+
|
| 13 |
+
disease_prediction: str
|
| 14 |
+
symptoms_to_check: List[str]
|
| 15 |
+
treatment_info: str
|
| 16 |
+
|
| 17 |
+
symptoms_confirmed: List[str]
|
| 18 |
+
current_symptom_index: int
|
| 19 |
+
|
| 20 |
+
final_diagnosis: str
|