added progress bar
Browse files- __pycache__/normflows.cpython-310.pyc +0 -0
- app.py +6 -1
- normflows.py +1 -2
__pycache__/normflows.cpython-310.pyc
ADDED
Binary file (11.2 kB). View file
|
|
app.py
CHANGED
@@ -13,7 +13,12 @@ bw = st.number_input('Scale',value=3.05)
|
|
13 |
def compute():
|
14 |
api = nflow(dim=8,latent=16,dataset=uploaded_file)
|
15 |
api.compile(optim=torch.optim.ASGD,bw=bw,lr=0.0001,wd=None)
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
17 |
samples = np.array(api.model.sample(
|
18 |
torch.tensor(api.scaled).float()).detach())
|
19 |
|
|
|
13 |
def compute():
|
14 |
api = nflow(dim=8,latent=16,dataset=uploaded_file)
|
15 |
api.compile(optim=torch.optim.ASGD,bw=bw,lr=0.0001,wd=None)
|
16 |
+
|
17 |
+
my_bar = st.progress(0, text='Currently in progress')
|
18 |
+
|
19 |
+
for idx in api.train(iters=10000):
|
20 |
+
my_bar.progress(idx[0]/10000, text=str(idx[1]))
|
21 |
+
|
22 |
samples = np.array(api.model.sample(
|
23 |
torch.tensor(api.scaled).float()).detach())
|
24 |
|
normflows.py
CHANGED
@@ -341,8 +341,7 @@ class nflow():
|
|
341 |
|
342 |
if idx % 100 == 0:
|
343 |
print("Loss {}".format(loss.item()))
|
344 |
-
|
345 |
-
plt.plot(self.losses)
|
346 |
|
347 |
def performance(self):
|
348 |
"""
|
|
|
341 |
|
342 |
if idx % 100 == 0:
|
343 |
print("Loss {}".format(loss.item()))
|
344 |
+
yield idx,loss.item()
|
|
|
345 |
|
346 |
def performance(self):
|
347 |
"""
|