Nunzio commited on
Commit
c8e426d
·
1 Parent(s): 84603b7
Files changed (2) hide show
  1. app.py +7 -12
  2. utils/imageHandling.py +16 -1
app.py CHANGED
@@ -2,7 +2,7 @@ import os, torch
2
  import gradio as gr
3
  from PIL import Image
4
 
5
- from utils.imageHandling import hfImageToTensor, preprocessing, postprocessing
6
  from model.modelLoading import loadModel
7
 
8
 
@@ -71,22 +71,17 @@ with gr.Blocks(title="Semantic Segmentation Predictors") as demo:
71
  with gr.Row():
72
  gr.Markdown("## Preloaded GTA V images to be used for testing the model")
73
  with gr.Row():
74
- gta_gallery = gr.Gallery(
75
- value=sorted([Image.open(os.path.join(gta_image_dir, f)).convert("RGB") for f in os.listdir(gta_image_dir) if f.endswith(".png")]),
76
- label="GTA V Examples",
77
- show_label=False,
78
- columns=5,
79
- rows=1,
80
- height=200,
81
- type="pil"
82
  )
83
 
84
  with gr.Row():
85
  gr.Markdown("## Preloaded Cityscapes images to be used for testing the model")
86
  with gr.Row():
87
- city_gallery = gr.Gallery(value=sorted([Image.open(os.path.join(city_image_dir, f)).convert("RGB") for f in os.listdir(city_image_dir) if f.endswith(".png")]),
88
- label="Cityscapes Examples", show_label=False, columns=5, rows=1,
89
- height=256, type="pil"
90
  )
91
 
92
  submit_btn.click(
 
2
  import gradio as gr
3
  from PIL import Image
4
 
5
+ from utils.imageHandling import hfImageToTensor, preprocessing, postprocessing, loadPreloadedImages
6
  from model.modelLoading import loadModel
7
 
8
 
 
71
  with gr.Row():
72
  gr.Markdown("## Preloaded GTA V images to be used for testing the model")
73
  with gr.Row():
74
+ gta_gallery = gr.Gallery(value=loadPreloadedImages(gta_image_dir),
75
+ label="GTA V Examples", show_label=False, columns=5,
76
+ rows=1, height=200, type="pil"
 
 
 
 
 
77
  )
78
 
79
  with gr.Row():
80
  gr.Markdown("## Preloaded Cityscapes images to be used for testing the model")
81
  with gr.Row():
82
+ city_gallery = gr.Gallery(value=loadPreloadedImages(city_image_dir),
83
+ label="Cityscapes Examples", show_label=False, columns=5,
84
+ rows=1, height=256, type="pil"
85
  )
86
 
87
  submit_btn.click(
utils/imageHandling.py CHANGED
@@ -1,4 +1,5 @@
1
- import torch, torchvision
 
2
 
3
  # %% image loading
4
  def hfImageToTensor(image, width:int=1024, height:int=512)->torch.Tensor:
@@ -81,3 +82,17 @@ def postprocessing(pred: torch.Tensor) -> torch.Tensor:
81
  torch.Tensor: Processed tensor of shape (3, H, W) for visualization.
82
  """
83
  return torchvision.transforms.functional.to_pil_image(print_mask(pred.squeeze(0).cpu().to(torch.uint8)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, torchvision, os
2
+ from PIL import Image
3
 
4
  # %% image loading
5
  def hfImageToTensor(image, width:int=1024, height:int=512)->torch.Tensor:
 
82
  torch.Tensor: Processed tensor of shape (3, H, W) for visualization.
83
  """
84
  return torchvision.transforms.functional.to_pil_image(print_mask(pred.squeeze(0).cpu().to(torch.uint8)))
85
+
86
+
87
+ # %% preloaded images
88
+ def loadPreloadedImages(image_dir: str) -> list[Image.Image]:
89
+ """
90
+ Load preloaded images from a directory.
91
+
92
+ Args:
93
+ image_dir (str): Path to the directory containing images.
94
+
95
+ Returns:
96
+ list[Image.Image]: List of loaded images.
97
+ """
98
+ return list(map(lambda f: Image.open(os.path.join(image_dir, f)).convert("RGB"), sorted([f for f in os.listdir(image_dir) if f.endswith(".png")])))