mischeiwiller commited on
Commit
4f2e28e
1 Parent(s): be91971

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -15
app.py CHANGED
@@ -22,24 +22,25 @@ def preprocess_image(img):
22
 
23
  print(f"After conversion to tensor - shape: {img.shape}")
24
 
25
- # Ensure 3D tensor (C, H, W)
26
  if img.ndim == 2:
27
- img = img.unsqueeze(0)
28
- elif img.ndim == 3 and img.shape[0] not in [1, 3]:
29
- img = img.permute(2, 0, 1)
 
 
 
 
 
 
30
 
31
- print(f"After ensuring 3D - shape: {img.shape}")
32
 
33
  # Ensure 3 channel image
34
- if img.shape[0] == 1:
35
- img = img.expand(3, -1, -1)
36
- elif img.shape[0] > 3:
37
- img = img[:3] # Take only the first 3 channels if more than 3
38
-
39
- print(f"After ensuring 3 channels - shape: {img.shape}")
40
-
41
- # Add batch dimension
42
- img = img.unsqueeze(0)
43
 
44
  print(f"Final tensor shape: {img.shape}")
45
  return img
@@ -70,4 +71,4 @@ with gr.Blocks(theme='huggingface') as demo_app:
70
  gr.Examples(examples=examples, inputs=[input_image1, input_image2])
71
 
72
  if __name__ == "__main__":
73
- demo_app.launch()
 
22
 
23
  print(f"After conversion to tensor - shape: {img.shape}")
24
 
25
+ # Ensure 4D tensor (B, C, H, W)
26
  if img.ndim == 2:
27
+ img = img.unsqueeze(0).unsqueeze(0)
28
+ elif img.ndim == 3:
29
+ if img.shape[0] in [1, 3]:
30
+ img = img.unsqueeze(0)
31
+ else:
32
+ img = img.unsqueeze(1)
33
+ elif img.ndim == 4:
34
+ if img.shape[1] not in [1, 3]:
35
+ img = img.permute(0, 3, 1, 2)
36
 
37
+ print(f"After ensuring 4D - shape: {img.shape}")
38
 
39
  # Ensure 3 channel image
40
+ if img.shape[1] == 1:
41
+ img = img.repeat(1, 3, 1, 1)
42
+ elif img.shape[1] > 3:
43
+ img = img[:, :3] # Take only the first 3 channels if more than 3
 
 
 
 
 
44
 
45
  print(f"Final tensor shape: {img.shape}")
46
  return img
 
71
  gr.Examples(examples=examples, inputs=[input_image1, input_image2])
72
 
73
  if __name__ == "__main__":
74
+ demo_app.launch(share=True)