mmenendezg's picture
Add files for the gradio app
c5c5181
raw history blame
No virus
1.62 kB
import numpy as np
import torch
from PIL import Image
from transformers import AutoImageProcessor
from models.mobilevit import MobileVIT
# Checkpoint of the model used in the projec
MODEL_CHECKPOINT = "mmenendezg/mobilevit-fluorescent-neuronal-cells"
# Define the accelerator
if torch.backends.mps.is_available():
DEVICE = torch.device("mps:0")
ACCELERATOR = "mps"
elif torch.cuda.is_available():
DEVICE = torch.device("cuda")
ACCELERATOR = "gpu"
else:
DEVICE = torch.device("cpu")
ACCELERATOR = "cpu"
def single_prediction(image):
# Instantiate the model from the checkpoint and using the hparams file
mobilevit_model = MobileVIT()
mobilevit_model.to(DEVICE)
# Instantiate the image_processor
image_processor = AutoImageProcessor.from_pretrained(
MODEL_CHECKPOINT, do_reduce_labels=False
)
# Load the image
image = image.convert("RGB")
# Convert the image to numpy array
np_image = np.asarray(image, dtype=np.uint8)
# Preprocess the image and move the image to the GPU Device
processed_image = image_processor(images=np_image, return_tensors="pt")
processed_image.to(DEVICE)
# Make the prediction and resize the predicted mask
logits = mobilevit_model.model(pixel_values=processed_image["pixel_values"])
post_processed_image = image_processor.post_process_semantic_segmentation(
outputs=logits, target_sizes=[(np_image.shape[0], np_image.shape[1])]
)
# Process the mask
mask = post_processed_image[0].data.cpu().numpy().astype(np.uint8) * 255
mask = Image.fromarray(mask)
return mask