Nunzio commited on
Commit
ed3e09d
·
1 Parent(s): 6a0b93e

added code

Browse files
Files changed (1) hide show
  1. app.py +32 -4
app.py CHANGED
@@ -1,11 +1,39 @@
1
- import os, torch, torchvision
2
- import torchvision.transforms.functional
3
  from model.BiSeNet.build_bisenet import BiSeNet
4
  import gradio as gr
5
  from utils.imageHandling import hfImageToTensor, preprocessing
6
 
7
 
 
 
 
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  # %% prediction on an image
@@ -22,9 +50,9 @@ def predict(inputImage: torch.Tensor, model: BiSeNet) -> torch.Tensor:
22
  prediction (torch.Tensor): The predicted segmentation mask.
23
  """
24
  with torch.no_grad():
25
- output = model(preprocessing(inputImage))
26
  output = output[0] if isinstance(output, (tuple, list)) else output
27
- return output[0].argmax(dim=0, keepdim=True)
28
 
29
 
30
 
 
1
+ import os, torch
 
2
  from model.BiSeNet.build_bisenet import BiSeNet
3
  import gradio as gr
4
  from utils.imageHandling import hfImageToTensor, preprocessing
5
 
6
 
7
+ def run_prediction(image: gr.Image, selected_model: str)-> tuple[torch.Tensor]:
8
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
9
+ image = hfImageToTensor(image, width=1024, height=512)
10
+ return image, predict(image, loadModel(selected_model, device))
11
 
12
+ # Gradio UI
13
+ with gr.Blocks(title="🔀 BiSeNet | BiSeNetV2 Predictor") as demo:
14
+ gr.Markdown("## 🧠 Image Segmentation with BiSeNet and BiSeNetV2")
15
+ gr.Markdown("Upload an image and choose your preferred model for segmentation.")
16
+
17
+ with gr.Row():
18
+ with gr.Column():
19
+ model_selector = gr.Radio(
20
+ choices=["BiSeNet", "BiSeNetV2"],
21
+ value="BiSeNet",
22
+ label="Select model"
23
+ )
24
+ image_input = gr.Image(type="pil", label="Upload image")
25
+ submit_btn = gr.Button("🧪 Run prediction")
26
+ with gr.Column():
27
+ original_display = gr.Image(label="Original image")
28
+ result_display = gr.Image(label="Model prediction")
29
+
30
+ submit_btn.click(
31
+ fn=run_prediction,
32
+ inputs=[image_input, model_selector],
33
+ outputs=[original_display, result_display]
34
+ )
35
+
36
+ demo.launch()
37
 
38
 
39
  # %% prediction on an image
 
50
  prediction (torch.Tensor): The predicted segmentation mask.
51
  """
52
  with torch.no_grad():
53
+ output = model(preprocessing(inputImage.clone()).to(model.device))
54
  output = output[0] if isinstance(output, (tuple, list)) else output
55
+ return output[0].argmax(dim=0, keepdim=True).cpu()
56
 
57
 
58