DermBOT / gradio_app.py
KeerthiVM's picture
Update gradio_app.py
e40f9f1 verified
# dermbot_gradio_app.py
import gradio as gr
from PIL import Image
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import vit_b_16, vit_l_16, ViT_B_16_Weights, ViT_L_16_Weights
from huggingface_hub import hf_hub_download
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from qdrant_client import QdrantClient
from langchain_community.vectorstores import Qdrant
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_openai import ChatOpenAI
import os
import io
from fpdf import FPDF
# === Constants ===
multilabel_class_names = [
"Vesicle", "Papule", "Macule", "Plaque", "Abscess", "Pustule", "Bulla", "Patch",
"Nodule", "Ulcer", "Crust", "Erosion", "Excoriation", "Atrophy", "Exudate", "Purpura/Petechiae",
"Fissure", "Induration", "Xerosis", "Telangiectasia", "Scale", "Scar", "Friable", "Sclerosis",
"Pedunculated", "Exophytic/Fungating", "Warty/Papillomatous", "Dome-shaped", "Flat topped",
"Brown(Hyperpigmentation)", "Translucent", "White(Hypopigmentation)", "Purple", "Yellow",
"Black", "Erythema", "Comedo", "Lichenification", "Blue", "Umbilicated", "Poikiloderma",
"Salmon", "Wheal", "Acuminate", "Burrow", "Gray", "Pigmented", "Cyst"
]
multiclass_class_names = [
"systemic", "hair", "drug_reactions", "uriticaria", "acne", "light",
"autoimmune", "papulosquamous", "eczema", "skincancer",
"benign_tumors", "bacteria_parasetic_infections", "fungal_infections", "viral_skin_infections"
]
# === Models ===
class SkinViT(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.model = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
in_features = self.model.heads.head.in_features
self.model.heads.head = nn.Linear(in_features, num_classes)
def forward(self, x):
return self.model(x)
class DermNetViT(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.model = vit_l_16(weights=ViT_L_16_Weights.DEFAULT)
in_features = self.model.heads[0].in_features
self.model.heads = nn.Sequential(
nn.Linear(in_features, 1024),
nn.ReLU(),
nn.Linear(1024, num_classes)
)
def forward(self, x):
return self.model(x)
# === Load Model State Dicts ===
multilabel_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="skin_vit_fold10_sd.pth")
multiclass_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="best_dermnet_vit_sd.pth")
multilabel_model = SkinViT(num_classes=len(multilabel_class_names))
multiclass_model = DermNetViT(num_classes=len(multiclass_class_names))
multilabel_model.load_state_dict(torch.load(multilabel_model_path, map_location="cpu"))
multiclass_model.load_state_dict(torch.load(multiclass_model_path, map_location="cpu"))
multilabel_model.eval()
multiclass_model.eval()
# === RAG Setup ===
llm = ChatOpenAI(model="gpt-4o", temperature=0.2)
qdrant_client = QdrantClient(
url="https://2715ddd8-647f-40ee-bca4-9027d193e8aa.us-east-1-0.aws.cloud.qdrant.io",
api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.HXzezXdWMFeeR16F7zvqgjzsqrcm8hqa-StXdToFP9Q"
)
local_embedding = HuggingFaceEmbeddings(
model_name="Alibaba-NLP/gte-Qwen2-1.5B-instruct",
model_kwargs={"trust_remote_code": True, "device": "cpu"}
)
vector_store = Qdrant(
client=qdrant_client,
collection_name="ks_collection_1.5BE",
embeddings=local_embedding
)
retriever = vector_store.as_retriever()
AI_PROMPT_TEMPLATE = """You are an AI-assisted Dermatology Chatbot, specializing in diagnosing and educating users about skin diseases.
You provide accurate, compassionate, and detailed explanations while using correct medical terminology.
Guidelines:
1. Symptoms - Explain in simple terms with proper medical definitions.
2. Causes - Include genetic, environmental, and lifestyle-related risk factors.
3. Medications & Treatments - Provide common prescription and over-the-counter treatments.
4. Warnings & Emergencies - Always recommend consulting a licensed dermatologist.
5. Emergency Note - If symptoms worsen or include difficulty breathing, **advise calling 911 immediately.
Query: {question}
Relevant Information: {context}
Answer:
"""
prompt_template = PromptTemplate(template=AI_PROMPT_TEMPLATE, input_variables=["question", "context"])
rag_chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=retriever,
chain_type="stuff",
chain_type_kwargs={"prompt": prompt_template, "document_variable_name": "context"}
)
# === Inference ===
def run_diagnosis(image):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
input_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
probs_multi = torch.sigmoid(multilabel_model(input_tensor)).squeeze().numpy()
predicted_multi = [multilabel_class_names[i] for i, p in enumerate(probs_multi) if p > 0.5]
pred_idx = torch.argmax(multiclass_model(input_tensor), dim=1).item()
predicted_single = multiclass_class_names[pred_idx]
return predicted_multi, predicted_single
# === Chat Function ===
def chat_with_bot(image, history=[]):
predicted_multi, predicted_single = run_diagnosis(image)
query = f"What are my treatment options for {predicted_multi} and {predicted_single}?"
response = rag_chain.invoke(query)["result"]
history.append((f"User: {query}", f"AI: {response}"))
return response, history
# === Gradio App ===
with gr.Blocks() as demo:
gr.Markdown("# 🧬 DermBOT — Skin AI Assistant")
chatbot = gr.Chatbot()
img_input = gr.Image(type="pil")
output_text = gr.Textbox(label="DermBOT Response")
btn = gr.Button("Analyze & Diagnose")
state = gr.State([])
btn.click(fn=chat_with_bot, inputs=[img_input, state], outputs=[output_text, state])
if __name__ == "__main__":
demo.launch()