mischeiwiller commited on
Commit
9acfc48
β€’
1 Parent(s): f23b303

Update kornia_aug.py

Browse files
Files changed (1) hide show
  1. kornia_aug.py +8 -27
kornia_aug.py CHANGED
@@ -1,17 +1,16 @@
1
  import streamlit as st
2
  import kornia
3
- from torch import nn
4
  import torch
 
5
  from torchvision.transforms import functional as F
6
  from torchvision.utils import make_grid
7
  from streamlit_ace import st_ace
8
  from PIL import Image
9
 
10
- IS_LOCAL = False #Change this
11
 
12
- @st.cache(suppress_st_warning=True)
13
  def set_transform(content):
14
- # st.write("set transform")
15
  try:
16
  transform = eval(content, {"kornia": kornia, "nn": nn}, None)
17
  except Exception as e:
@@ -32,9 +31,8 @@ scaler = int(im.height / 2)
32
  st.sidebar.image(im, caption="Input Image", width=256)
33
  image = F.pil_to_tensor(im).float() / 255
34
 
35
-
36
  # batch size is just for show
37
- batch_size = st.sidebar.slider("batch_size", min_value=4, max_value=16,value=8)
38
  gpu = st.sidebar.checkbox("Use GPU!", value=True)
39
  if not gpu:
40
  st.sidebar.markdown("With Kornia you do ops on the GPU!")
@@ -48,7 +46,7 @@ else:
48
  device = torch.device("cpu")
49
  else:
50
  st.sidebar.markdown("Running on GPU~")
51
- device = torch.device("cuda:0")
52
 
53
  predefined_transforms = [
54
  """
@@ -75,7 +73,7 @@ nn.Sequential(
75
  kornia.augmentation.RandomHorizontalFlip(p=0.7),
76
  kornia.augmentation.RandomGrayscale(p=0.5),
77
  )
78
- """,
79
  ]
80
 
81
  selected_transform = st.selectbox(
@@ -96,36 +94,22 @@ content = st_ace(
96
  readonly=readonly,
97
  )
98
  if content:
99
- # st.write(content)
100
  transform = set_transform(content)
101
 
102
- # st.write(transform)
103
-
104
- # with st.echo():
105
- # transform = nn.Sequential(
106
- # K.RandomAffine(360),
107
- # K.ColorJitter(0.2, 0.3, 0.2, 0.3)
108
- # )
109
-
110
  process = st.button("Next Batch")
111
 
112
  # Fake dataloader
113
  image_batch = torch.stack(batch_size * [image])
114
 
115
-
116
- image_batch.to(device)
117
  transformeds = None
118
  try:
119
  transformeds = transform(image_batch)
120
  except Exception as e:
121
  st.write(f"There was an error: {e}")
122
-
123
-
124
-
125
 
126
  cols = st.columns(4)
127
 
128
- # st.image(F.to_pil_image(make_grid(transformeds)))
129
  if transformeds is not None:
130
  for i, x in enumerate(transformeds):
131
  i = i % 4
@@ -136,7 +120,4 @@ st.markdown(
136
  )
137
  st.markdown(
138
  "Kornia can do a lot more than augmentations~ [Check it out](https://kornia.readthedocs.io/en/latest/get-started/introduction.html)"
139
- )
140
- # if process:
141
- # pass
142
-
 
1
  import streamlit as st
2
  import kornia
 
3
  import torch
4
+ from torch import nn
5
  from torchvision.transforms import functional as F
6
  from torchvision.utils import make_grid
7
  from streamlit_ace import st_ace
8
  from PIL import Image
9
 
10
+ IS_LOCAL = False # Change this
11
 
12
+ @st.cache_data
13
  def set_transform(content):
 
14
  try:
15
  transform = eval(content, {"kornia": kornia, "nn": nn}, None)
16
  except Exception as e:
 
31
  st.sidebar.image(im, caption="Input Image", width=256)
32
  image = F.pil_to_tensor(im).float() / 255
33
 
 
34
  # batch size is just for show
35
+ batch_size = st.sidebar.slider("batch_size", min_value=4, max_value=16, value=8)
36
  gpu = st.sidebar.checkbox("Use GPU!", value=True)
37
  if not gpu:
38
  st.sidebar.markdown("With Kornia you do ops on the GPU!")
 
46
  device = torch.device("cpu")
47
  else:
48
  st.sidebar.markdown("Running on GPU~")
49
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
50
 
51
  predefined_transforms = [
52
  """
 
73
  kornia.augmentation.RandomHorizontalFlip(p=0.7),
74
  kornia.augmentation.RandomGrayscale(p=0.5),
75
  )
76
+ """
77
  ]
78
 
79
  selected_transform = st.selectbox(
 
94
  readonly=readonly,
95
  )
96
  if content:
 
97
  transform = set_transform(content)
98
 
 
 
 
 
 
 
 
 
99
  process = st.button("Next Batch")
100
 
101
  # Fake dataloader
102
  image_batch = torch.stack(batch_size * [image])
103
 
104
+ image_batch = image_batch.to(device)
 
105
  transformeds = None
106
  try:
107
  transformeds = transform(image_batch)
108
  except Exception as e:
109
  st.write(f"There was an error: {e}")
 
 
 
110
 
111
  cols = st.columns(4)
112
 
 
113
  if transformeds is not None:
114
  for i, x in enumerate(transformeds):
115
  i = i % 4
 
120
  )
121
  st.markdown(
122
  "Kornia can do a lot more than augmentations~ [Check it out](https://kornia.readthedocs.io/en/latest/get-started/introduction.html)"
123
+ )