mehdidc commited on
Commit
85a9a0c
1 Parent(s): 957ae0d
Files changed (2) hide show
  1. app.py +3 -2
  2. cli.py +3 -1
app.py CHANGED
@@ -10,7 +10,7 @@ models = {
10
  "Deep ConvAE": torch.load("deep_convae.th", map_location="cpu"),
11
  "Dense K-Sparse": torch.load("fc_sparse.th", map_location="cpu"),
12
  }
13
- def gen(md, model_name, seed, nb_iter, nb_samples, width, height, nb_active, only_last, black_bg):
14
  torch.manual_seed(int(seed))
15
  bs = 64
16
  model = models[model_name]
@@ -22,6 +22,7 @@ def gen(md, model_name, seed, nb_iter, nb_samples, width, height, nb_active, onl
22
  nb_examples=int(nb_samples),
23
  w=int(width), h=int(height), c=1,
24
  batch_size=bs,
 
25
  )
26
  if not black_bg:
27
  samples = 1 - samples
@@ -48,7 +49,7 @@ iface = gr.Interface(
48
  fn=gen,
49
  inputs=[
50
  gr.Markdown(text),
51
- gr.Dropdown(list(models.keys()), value="Deep ConvAE"), gr.Number(value=0), gr.Number(value=25), gr.Number(value=1), gr.Number(value=28), gr.Number(value=28),gr.Slider(minimum=0,maximum=800, value=800, step=1), gr.Checkbox(value=False, label="Only show last iteration"), gr.Checkbox(value=True, label="Black background")
52
  ],
53
  outputs="image"
54
  )
 
10
  "Deep ConvAE": torch.load("deep_convae.th", map_location="cpu"),
11
  "Dense K-Sparse": torch.load("fc_sparse.th", map_location="cpu"),
12
  }
13
+ def gen(md, model_name, seed, nb_iter, nb_samples, width, height, nb_active, only_last, black_bg, binarize, binarize_threshold):
14
  torch.manual_seed(int(seed))
15
  bs = 64
16
  model = models[model_name]
 
22
  nb_examples=int(nb_samples),
23
  w=int(width), h=int(height), c=1,
24
  batch_size=bs,
25
+ binarize_threshold=binarize_threshold if binarize else None,
26
  )
27
  if not black_bg:
28
  samples = 1 - samples
 
49
  fn=gen,
50
  inputs=[
51
  gr.Markdown(text),
52
+ gr.Dropdown(list(models.keys()), value="Deep ConvAE"), gr.Number(value=0), gr.Number(value=25), gr.Number(value=1), gr.Number(value=28), gr.Number(value=28),gr.Slider(minimum=0,maximum=800, value=800, step=1), gr.Checkbox(value=False, label="Only show last iteration"), gr.Checkbox(value=True, label="Black background"), gr.Checkbox(value=False, label="binarize"), gr.Number(value=0.5)
53
  ],
54
  outputs="image"
55
  )
cli.py CHANGED
@@ -91,7 +91,7 @@ def save_weights(m, folder='.'):
91
  imsave('{}/feat.png'.format(folder), gr)
92
 
93
  @torch.no_grad()
94
- def iterative_refinement(ae, nb_examples=1, nb_iter=10, w=28, h=28, c=1, batch_size=None):
95
  if batch_size is None:
96
  batch_size = nb_examples
97
  x = torch.rand(nb_iter, nb_examples, c, w, h)
@@ -99,6 +99,8 @@ def iterative_refinement(ae, nb_examples=1, nb_iter=10, w=28, h=28, c=1, batch_s
99
  for j in range(0, nb_examples, batch_size):
100
  oldv = x[i-1][j:j + batch_size].to(device)
101
  newv = ae(oldv)
 
 
102
  newv = newv.data.cpu()
103
  x[i][j:j + batch_size] = newv
104
  return x
 
91
  imsave('{}/feat.png'.format(folder), gr)
92
 
93
  @torch.no_grad()
94
+ def iterative_refinement(ae, nb_examples=1, nb_iter=10, w=28, h=28, c=1, batch_size=None, binarize_threshold=None):
95
  if batch_size is None:
96
  batch_size = nb_examples
97
  x = torch.rand(nb_iter, nb_examples, c, w, h)
 
99
  for j in range(0, nb_examples, batch_size):
100
  oldv = x[i-1][j:j + batch_size].to(device)
101
  newv = ae(oldv)
102
+ if binarize_threshold is not None:
103
+ newv = (newv>binarize_threshold).float()
104
  newv = newv.data.cpu()
105
  x[i][j:j + batch_size] = newv
106
  return x