|
import torch |
|
import torch.nn as nn |
|
from transformers import ViTImageProcessor, ViTModel, BertTokenizerFast, BertModel |
|
from PIL import Image |
|
import gradio as gr |
|
|
|
|
|
class VisionLanguageModel(nn.Module): |
|
def __init__(self): |
|
super(VisionLanguageModel, self).__init__() |
|
self.vision_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k') |
|
self.language_model = BertModel.from_pretrained('bert-base-uncased') |
|
self.classifier = nn.Linear( |
|
self.vision_model.config.hidden_size + self.language_model.config.hidden_size, |
|
2 |
|
) |
|
|
|
def forward(self, input_ids, attention_mask, pixel_values): |
|
vision_outputs = self.vision_model(pixel_values=pixel_values) |
|
vision_pooled_output = vision_outputs.pooler_output |
|
|
|
language_outputs = self.language_model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask |
|
) |
|
language_pooled_output = language_outputs.pooler_output |
|
|
|
combined_features = torch.cat( |
|
(vision_pooled_output, language_pooled_output), |
|
dim=1 |
|
) |
|
|
|
logits = self.classifier(combined_features) |
|
return logits |
|
|
|
model = VisionLanguageModel() |
|
model.load_state_dict(torch.load('best_model.pth', map_location=torch.device('cpu'), weights_only=True)) |
|
model.eval() |
|
|
|
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') |
|
feature_extractor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k') |
|
|
|
def predict(image, text_input): |
|
image = feature_extractor(images=image, return_tensors="pt").pixel_values |
|
encoding = tokenizer( |
|
text_input, |
|
add_special_tokens=True, |
|
max_length=256, |
|
padding='max_length', |
|
truncation=True, |
|
return_tensors='pt' |
|
) |
|
with torch.no_grad(): |
|
outputs = model( |
|
input_ids=encoding['input_ids'], |
|
attention_mask=encoding['attention_mask'], |
|
pixel_values=image |
|
) |
|
_, prediction = torch.max(outputs, dim=1) |
|
return prediction.item() |
|
|
|
|
|
with gr.Blocks(css=""" |
|
body { |
|
color: black; |
|
} |
|
.benign, .malignant { |
|
background-color: white; |
|
border: 1px solid lightgray; |
|
padding: 10px; |
|
border-radius: 5px; |
|
color: black; |
|
} |
|
.benign.correct, .malignant.correct { |
|
background-color: lightgreen; |
|
color: black; |
|
} |
|
""") as demo: |
|
gr.Markdown( |
|
""" |
|
# 🩺 SKIN LESION CLASSIFICATION |
|
Upload an image of a skin lesion and provide clinical details to get a prediction of benign or malignant. |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
image_input = gr.Image(type="pil", label="Upload Skin Lesion Image") |
|
text_input = gr.Textbox(label="Clinical Information (e.g., patient age, symptoms)") |
|
|
|
with gr.Column(scale=1): |
|
gr.Markdown("## PREDICTION RESULTS") |
|
benign_output = gr.HTML("<div class='benign'>Benign</div>") |
|
malignant_output = gr.HTML("<div class='malignant'>Malignant</div>") |
|
|
|
def display_prediction(image, text_input): |
|
prediction = predict(image, text_input) |
|
benign_html = "<div class='benign{}'>Benign</div>".format(" correct" if prediction == 0 else "") |
|
malignant_html = "<div class='malignant{}'>Malignant</div>".format(" correct" if prediction == 1 else "") |
|
return benign_html, malignant_html |
|
|
|
|
|
submit_btn = gr.Button("Get Prediction") |
|
submit_btn.click(display_prediction, inputs=[image_input, text_input], outputs=[benign_output, malignant_output]) |
|
|
|
demo.launch() |
|
|
|
|
|
|
|
|
|
|