binarize
Browse files
app.py
CHANGED
@@ -10,7 +10,7 @@ models = {
|
|
10 |
"Deep ConvAE": torch.load("deep_convae.th", map_location="cpu"),
|
11 |
"Dense K-Sparse": torch.load("fc_sparse.th", map_location="cpu"),
|
12 |
}
|
13 |
-
def gen(md, model_name, seed, nb_iter, nb_samples, width, height, nb_active, only_last, black_bg):
|
14 |
torch.manual_seed(int(seed))
|
15 |
bs = 64
|
16 |
model = models[model_name]
|
@@ -22,6 +22,7 @@ def gen(md, model_name, seed, nb_iter, nb_samples, width, height, nb_active, onl
|
|
22 |
nb_examples=int(nb_samples),
|
23 |
w=int(width), h=int(height), c=1,
|
24 |
batch_size=bs,
|
|
|
25 |
)
|
26 |
if not black_bg:
|
27 |
samples = 1 - samples
|
@@ -48,7 +49,7 @@ iface = gr.Interface(
|
|
48 |
fn=gen,
|
49 |
inputs=[
|
50 |
gr.Markdown(text),
|
51 |
-
gr.Dropdown(list(models.keys()), value="Deep ConvAE"), gr.Number(value=0), gr.Number(value=25), gr.Number(value=1), gr.Number(value=28), gr.Number(value=28),gr.Slider(minimum=0,maximum=800, value=800, step=1), gr.Checkbox(value=False, label="Only show last iteration"), gr.Checkbox(value=True, label="Black background")
|
52 |
],
|
53 |
outputs="image"
|
54 |
)
|
|
|
10 |
"Deep ConvAE": torch.load("deep_convae.th", map_location="cpu"),
|
11 |
"Dense K-Sparse": torch.load("fc_sparse.th", map_location="cpu"),
|
12 |
}
|
13 |
+
def gen(md, model_name, seed, nb_iter, nb_samples, width, height, nb_active, only_last, black_bg, binarize, binarize_threshold):
|
14 |
torch.manual_seed(int(seed))
|
15 |
bs = 64
|
16 |
model = models[model_name]
|
|
|
22 |
nb_examples=int(nb_samples),
|
23 |
w=int(width), h=int(height), c=1,
|
24 |
batch_size=bs,
|
25 |
+
binarize_threshold=binarize_threshold if binarize else None,
|
26 |
)
|
27 |
if not black_bg:
|
28 |
samples = 1 - samples
|
|
|
49 |
fn=gen,
|
50 |
inputs=[
|
51 |
gr.Markdown(text),
|
52 |
+
gr.Dropdown(list(models.keys()), value="Deep ConvAE"), gr.Number(value=0), gr.Number(value=25), gr.Number(value=1), gr.Number(value=28), gr.Number(value=28),gr.Slider(minimum=0,maximum=800, value=800, step=1), gr.Checkbox(value=False, label="Only show last iteration"), gr.Checkbox(value=True, label="Black background"), gr.Checkbox(value=False, label="binarize"), gr.Number(value=0.5)
|
53 |
],
|
54 |
outputs="image"
|
55 |
)
|
cli.py
CHANGED
@@ -91,7 +91,7 @@ def save_weights(m, folder='.'):
|
|
91 |
imsave('{}/feat.png'.format(folder), gr)
|
92 |
|
93 |
@torch.no_grad()
|
94 |
-
def iterative_refinement(ae, nb_examples=1, nb_iter=10, w=28, h=28, c=1, batch_size=None):
|
95 |
if batch_size is None:
|
96 |
batch_size = nb_examples
|
97 |
x = torch.rand(nb_iter, nb_examples, c, w, h)
|
@@ -99,6 +99,8 @@ def iterative_refinement(ae, nb_examples=1, nb_iter=10, w=28, h=28, c=1, batch_s
|
|
99 |
for j in range(0, nb_examples, batch_size):
|
100 |
oldv = x[i-1][j:j + batch_size].to(device)
|
101 |
newv = ae(oldv)
|
|
|
|
|
102 |
newv = newv.data.cpu()
|
103 |
x[i][j:j + batch_size] = newv
|
104 |
return x
|
|
|
91 |
imsave('{}/feat.png'.format(folder), gr)
|
92 |
|
93 |
@torch.no_grad()
|
94 |
+
def iterative_refinement(ae, nb_examples=1, nb_iter=10, w=28, h=28, c=1, batch_size=None, binarize_threshold=None):
|
95 |
if batch_size is None:
|
96 |
batch_size = nb_examples
|
97 |
x = torch.rand(nb_iter, nb_examples, c, w, h)
|
|
|
99 |
for j in range(0, nb_examples, batch_size):
|
100 |
oldv = x[i-1][j:j + batch_size].to(device)
|
101 |
newv = ae(oldv)
|
102 |
+
if binarize_threshold is not None:
|
103 |
+
newv = (newv>binarize_threshold).float()
|
104 |
newv = newv.data.cpu()
|
105 |
x[i][j:j + batch_size] = newv
|
106 |
return x
|