mboss commited on
Commit
6d7ec2d
·
1 Parent(s): 3438c46
Files changed (2) hide show
  1. app.py +2 -2
  2. src/color_matcher.py +1 -2
app.py CHANGED
@@ -9,8 +9,8 @@ with gr.Blocks() as demo:
9
  """
10
  # ReSWD
11
 
12
- <a href="https://reservoirswd.github.io/"><img src="https://img.shields.io/badge/Project%20Page-5CE1BC.svg"></a> <br>
13
- <a href="https://reservoirswd.github.io/static/paper.pdf"><img src="https://img.shields.io/badge/Arxiv-2408.00653-B31B1B.svg"></a> <br>
14
  <a href="https://github.com/Stability-AI/ReSWD"><img src="https://img.shields.io/badge/git-%23F05032?logo=git&logoColor=white"></a>
15
  <br>
16
 
 
9
  """
10
  # ReSWD
11
 
12
+ <a href="https://reservoirswd.github.io/"><img src="https://img.shields.io/badge/Project%20Page-5CE1BC.svg"></a>
13
+ <a href="https://reservoirswd.github.io/static/paper.pdf"><img src="https://img.shields.io/badge/Arxiv-2408.00653-B31B1B.svg"></a>
14
  <a href="https://github.com/Stability-AI/ReSWD"><img src="https://img.shields.io/badge/git-%23F05032?logo=git&logoColor=white"></a>
15
  <br>
16
 
src/color_matcher.py CHANGED
@@ -85,10 +85,9 @@ def train(
85
  source_img = source_img.cuda()
86
 
87
  batch_size = source_img.shape[0]
88
- cdl = CDL(batch_size)
89
 
90
  optim = torch.optim.Adam(cdl.parameters(), lr=lr)
91
- cdl, optim = cdl.cuda(), optim.cuda()
92
 
93
  lossses = []
94
  for i in tqdm(range(num_steps), disable=silent):
 
85
  source_img = source_img.cuda()
86
 
87
  batch_size = source_img.shape[0]
88
+ cdl = CDL(batch_size).cuda()
89
 
90
  optim = torch.optim.Adam(cdl.parameters(), lr=lr)
 
91
 
92
  lossses = []
93
  for i in tqdm(range(num_steps), disable=silent):