Kazeemkz's picture
Upload folder using huggingface_hub
c65c91e verified
import torch
import torch.nn as nn
from efficientnet_pytorch import EfficientNet
import gradio as gr
# Define the custom model architecture
class CustomModel(nn.Module):
def __init__(self):
super(CustomModel, self).__init__()
self.fc = nn.Linear(6, 50176)
self.fc_bn = nn.BatchNorm1d(50176)
self.pretrained_model = EfficientNet.from_pretrained('efficientnet-b0', num_classes=32)
self.classification_head = nn.Sequential(
nn.Linear(32, 1),
nn.Sigmoid()
)
def forward(self, x):
x = self.fc(x)
x = self.fc_bn(x)
x = x.view(-1, 224, 224)
x = torch.stack([x] * 3, dim=1)
x = self.pretrained_model(x)
x = self.classification_head(x)
return x
# Load the trained model
model = CustomModel()
model.load_state_dict(torch.load('best_model_efficientnet_b0.pth'))
model.eval()
# Load the validation dataset
#val_dataset = CustomDataset('outside.csv')
#val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
# Function to make prediction
def predict(feature1, feature2, feature3, feature4, feature5, feature6):
features = torch.tensor([[feature1, feature2, feature3, feature4, feature5, feature6]], dtype=torch.float32)
output = model(features)
prediction = output.round().item()
return "Kidney Stone Detected" if prediction == 1 else "No Stone Detected"
light_blue = "#ADD8E6"
# Create a Gradio interface
inputs = [
gr.Slider(minimum=0.8, maximum=1.5, label="gravity: Specific Gravity"), # Using gr.Slider for each feature
gr.Slider(minimum=3, maximum=8, label="ph: pH (Potential of Hydrogen)"),
gr.Slider(minimum=200, maximum=1200, label="osmo: Osmolality"),
gr.Slider(minimum=5, maximum=30, label="cond: Conductivity"),
gr.Slider(minimum=50, maximum=700, label="urea: Urea"),
gr.Slider(minimum=0, maximum=20, label="calc: Calcium")
]
output = gr.Label() # Output label for the prediction
interface = gr.Interface(predict, inputs, output, title="Kidney Stone Detection NOTE- FOR RESEARCH PURPOSE ONLY-",
description="Enter the values for each feature",
css=f".gradio-container {{ background-color: {light_blue} }}" # Inline CSS injection
) # Customize interface details
interface.launch() # Launch the Gradio interface