Update app.py
Browse files
app.py
CHANGED
@@ -34,7 +34,6 @@ from data import get_dataset
|
|
34 |
import torchvision.transforms as transforms
|
35 |
|
36 |
import gradio as gr
|
37 |
-
import streamlit as st
|
38 |
|
39 |
model_name = "convnext_xlarge_in22k"
|
40 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
@@ -332,57 +331,42 @@ def load_model():
|
|
332 |
"""
|
333 |
lseg_model, lseg_transform = load_model()
|
334 |
|
335 |
-
|
336 |
-
|
337 |
-
input_labels = st.text_input("Input labels", value="dog, grass, other")
|
338 |
-
gr.outputs.Label(type="confidences",num_top_classes=5)
|
339 |
-
st.write("The labels are", input_labels)
|
340 |
|
341 |
-
|
342 |
-
pimage = lseg_transform(np.array(image)).unsqueeze(0)
|
343 |
|
344 |
-
labels = []
|
345 |
-
for label in input_labels.split(","):
|
346 |
-
|
347 |
-
|
348 |
-
with torch.no_grad():
|
349 |
-
outputs = lseg_model.parallel_forward(pimage, labels)
|
350 |
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
355 |
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
plt.
|
367 |
-
plt.
|
368 |
-
|
369 |
-
|
370 |
-
plt.subplot(122)
|
371 |
-
plt.imshow(seg)
|
372 |
-
plt.legend(handles=patches, loc='upper right', bbox_to_anchor=(1.3, 1), prop={'size': 5})
|
373 |
-
plt.axis('off')
|
374 |
-
|
375 |
-
plt.tight_layout()
|
376 |
-
|
377 |
-
#st.image([image,seg], width=700, caption=["Input image", "Segmentation"])
|
378 |
-
st.pyplot(fig)
|
379 |
-
|
380 |
-
title = "LSeg"
|
381 |
-
|
382 |
-
description = "Gradio demo for LSeg for semantic segmentation. To use it, simply upload your image, or click one of the examples to load them, then add any label set"
|
383 |
-
|
384 |
-
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2201.03546' target='_blank'>Language-driven Semantic Segmentation</a> | <a href='hhttps://github.com/isl-org/lang-seg' target='_blank'>Github Repo</a></p>"
|
385 |
|
386 |
-
|
387 |
|
388 |
-
gr.Interface(inference,
|
|
|
34 |
import torchvision.transforms as transforms
|
35 |
|
36 |
import gradio as gr
|
|
|
37 |
|
38 |
model_name = "convnext_xlarge_in22k"
|
39 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
331 |
"""
|
332 |
lseg_model, lseg_transform = load_model()
|
333 |
|
334 |
+
def inference(image,text):
|
335 |
+
input_labels = text
|
|
|
|
|
|
|
336 |
|
337 |
+
pimage = lseg_transform(np.array(image)).unsqueeze(0)
|
|
|
338 |
|
339 |
+
labels = []
|
340 |
+
for label in input_labels.split(","):
|
341 |
+
labels.append(label.strip())
|
|
|
|
|
|
|
342 |
|
343 |
+
with torch.no_grad():
|
344 |
+
outputs = lseg_model.parallel_forward(pimage, labels)
|
345 |
+
|
346 |
+
predicts = [
|
347 |
+
torch.max(output, 1)[1].cpu().numpy()
|
348 |
+
for output in outputs
|
349 |
+
]
|
350 |
+
|
351 |
+
image = pimage[0].permute(1,2,0)
|
352 |
+
image = image * 0.5 + 0.5
|
353 |
+
image = Image.fromarray(np.uint8(255*image)).convert("RGBA")
|
354 |
|
355 |
+
pred = predicts[0]
|
356 |
+
new_palette = get_new_pallete(len(labels))
|
357 |
+
mask, patches = get_new_mask_pallete(pred, new_palette, out_label_flag=True, labels=labels)
|
358 |
+
seg = mask.convert("RGBA")
|
359 |
+
|
360 |
+
fig = plt.figure()
|
361 |
+
plt.subplot(121)
|
362 |
+
plt.axis('off')
|
363 |
+
|
364 |
+
plt.subplot(122)
|
365 |
+
plt.legend(handles=patches, loc='upper right', bbox_to_anchor=(1.3, 1), prop={'size': 5})
|
366 |
+
plt.axis('off')
|
367 |
+
|
368 |
+
plt.tight_layout()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
369 |
|
370 |
+
return plt
|
371 |
|
372 |
+
gr.Interface(inference,["image","text"],"plot").launch(debug=True)
|