Gaurav
Update app.py
7d844e3
raw history blame
No virus
2.51 kB
import gradio as gr
import torch
import random
import numpy as np
from scipy.spatial import Delaunay
from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation\
device = torch.device("cpu")
model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-tiny-ade").to(device)
model.eval()
preprocessor = MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-tiny-ade")
def visualize_instance_seg_mask(mask):
print(mask)
print(mask.shape)
image = np.zeros((mask.shape[0], mask.shape[1], 3))
labels = np.unique(mask)
print("================unquie labels")
wall=[]
floor=[]
window=[]
other=[]
label2color = {label: (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels}
for i in range(image.shape[0]):
for j in range(image.shape[1]):
if mask[i, j]==0:
wall.append([i,j])
elif mask[i, j]==3:
floor.append([i,j])
elif mask[i, j]==8:
window.append([i,j])
else:
other.append([i,j])
image[i, j, :] = label2color[mask[i, j]]
window_vertices = np.array([[x, -y,0] for x, y in floor])
unique_vertices, indices = np.unique(window_vertices, axis=0, return_inverse=True)
# Perform Delaunay triangulation
tri = Delaunay(unique_vertices[:, :2]) # Triangulate only based on x and y coordinates
# Extract indices
indices = tri.simplices
print(window_vertices)
print(indices)
#print(vertices)
image = image / 255
return image
def query_image(img):
target_size = (img.shape[0], img.shape[1])
inputs = preprocessor(images=img, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
outputs.class_queries_logits = outputs.class_queries_logits.cpu()
outputs.masks_queries_logits = outputs.masks_queries_logits.cpu()
results = preprocessor.post_process_segmentation(outputs=outputs, target_size=target_size)[0].cpu().detach()
results = torch.argmax(results, dim=0).numpy()
results = visualize_instance_seg_mask(results)
return results
demo = gr.Interface(
query_image,
inputs=[gr.Image()],
outputs="image",
title="Image Segmentation Demo",
description = "Please upload an image to see segmentation capabilities of this model",
examples=[["work2.jpg"]]
)
demo.launch(debug=True)