jiangyzy commited on
Commit
622dbe7
1 Parent(s): ba01e37

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -19
app.py CHANGED
@@ -64,6 +64,22 @@ If you have any questions, please feel free to reach me out at <b>yuanzy22@mails
64
  # prompt = None
65
  negtive_prompt = ""
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  def send_input_to_concat(input_image):
69
  W, H = input_image.size
@@ -130,10 +146,11 @@ def prepare_data(device, input_image, x0, y0, x1, y1, polar, azimuth, text):
130
 
131
 
132
  @spaces.GPU(enable_queue=True, duration=180)
133
- def run_generation(sampler, model, device, input_image, x0, y0, x1, y1, polar, azimuth, text, seed):
 
134
  seed_everything(seed)
135
  batch = prepare_data(device, input_image, x0, y0, x1, y1, polar, azimuth, text)
136
- model = model.to(device)
137
  sampler = DDIMSampler(model, device=device)
138
 
139
  c = model.get_learned_conditioning(batch["image_cond"])
@@ -189,22 +206,6 @@ def load_example(input_image, x0, y0, x1, y1, polar, azimuth, prompt):
189
 
190
  @torch.no_grad()
191
  def main(args):
192
- # load model
193
- device = torch.device("cuda")
194
- preprocess_model = load_preprocess_model()
195
- config = OmegaConf.load("configs/config_customnet.yaml")
196
- model = instantiate_from_config(config.model)
197
-
198
- model_path='./customnet_v1.pt?download=true'
199
- if not os.path.exists(model_path):
200
- os.system(f'wget https://huggingface.co/TencentARC/CustomNet/resolve/main/customnet_v1.pt?download=true -P .')
201
- ckpt = torch.load(model_path, map_location="cpu")
202
- model.load_state_dict(ckpt)
203
- del ckpt
204
-
205
- model = model.to(device)
206
- sampler = None
207
-
208
  # load demo
209
  demo = gr.Blocks()
210
  with demo:
@@ -279,7 +280,8 @@ def main(args):
279
  inputs=[x0, y0, x1, y1, input_image],
280
  outputs=[x0, y0, x1, y1, location_image])
281
 
282
- start.click(partial(run_generation, sampler, model, device),
 
283
  inputs=[input_image, x0, y0, x1, y1, polar, azimuth, prompt, seed],
284
  outputs=output_image)
285
 
 
64
  # prompt = None
65
  negtive_prompt = ""
66
 
67
+ # load model
68
+ device = torch.device("cuda")
69
+ preprocess_model = load_preprocess_model()
70
+ config = OmegaConf.load("configs/config_customnet.yaml")
71
+ model = instantiate_from_config(config.model)
72
+
73
+ model_path='./customnet_v1.pt?download=true'
74
+ if not os.path.exists(model_path):
75
+ os.system(f'wget https://huggingface.co/TencentARC/CustomNet/resolve/main/customnet_v1.pt?download=true -P .')
76
+ ckpt = torch.load(model_path, map_location="cpu")
77
+ model.load_state_dict(ckpt)
78
+ del ckpt
79
+
80
+ model = model.to(device)
81
+ sampler = None
82
+
83
 
84
  def send_input_to_concat(input_image):
85
  W, H = input_image.size
 
146
 
147
 
148
  @spaces.GPU(enable_queue=True, duration=180)
149
+ # def run_generation(sampler, model, device, input_image, x0, y0, x1, y1, polar, azimuth, text, seed):
150
+ def run_generation(sampler, input_image, x0, y0, x1, y1, polar, azimuth, text, seed):
151
  seed_everything(seed)
152
  batch = prepare_data(device, input_image, x0, y0, x1, y1, polar, azimuth, text)
153
+ # model = model.to(device)
154
  sampler = DDIMSampler(model, device=device)
155
 
156
  c = model.get_learned_conditioning(batch["image_cond"])
 
206
 
207
  @torch.no_grad()
208
  def main(args):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  # load demo
210
  demo = gr.Blocks()
211
  with demo:
 
280
  inputs=[x0, y0, x1, y1, input_image],
281
  outputs=[x0, y0, x1, y1, location_image])
282
 
283
+ # start.click(partial(run_generation, sampler, model, device),
284
+ start.click(partial(run_generation, sampler),
285
  inputs=[input_image, x0, y0, x1, y1, polar, azimuth, prompt, seed],
286
  outputs=output_image)
287