jiangyzy commited on
Commit
595554b
1 Parent(s): 0c288b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -129,11 +129,12 @@ def prepare_data(device, input_image, x0, y0, x1, y1, polar, azimuth, text):
129
  return batch
130
 
131
 
132
- @spaces.GPU(enable_queue=True)
133
  def run_generation(sampler, model, device, input_image, x0, y0, x1, y1, polar, azimuth, text, seed):
134
  seed_everything(seed)
135
  batch = prepare_data(device, input_image, x0, y0, x1, y1, polar, azimuth, text)
136
  model = model.to(device)
 
137
 
138
  c = model.get_learned_conditioning(batch["image_cond"])
139
  c = torch.cat([c, batch["T"]], dim=-1)
@@ -194,13 +195,13 @@ def main(args):
194
  config = OmegaConf.load("configs/config_customnet.yaml")
195
  model = instantiate_from_config(config.model)
196
 
197
- model_path='./customnet_v1.pt?download=true'
198
- if not os.path.exists(model_path):
199
- os.system(f'wget https://huggingface.co/TencentARC/CustomNet/resolve/main/customnet_v1.pt?download=true -P .')
200
 
201
- ckpt = torch.load(model_path, map_location="cpu")
202
- model.load_state_dict(ckpt)
203
- del ckpt
204
  model = model.to(device)
205
  sampler = DDIMSampler(model, device=device)
206
 
 
129
  return batch
130
 
131
 
132
+ @spaces.GPU(enable_queue=True, duration=180)
133
  def run_generation(sampler, model, device, input_image, x0, y0, x1, y1, polar, azimuth, text, seed):
134
  seed_everything(seed)
135
  batch = prepare_data(device, input_image, x0, y0, x1, y1, polar, azimuth, text)
136
  model = model.to(device)
137
+ sampler = DDIMSampler(model, device=device)
138
 
139
  c = model.get_learned_conditioning(batch["image_cond"])
140
  c = torch.cat([c, batch["T"]], dim=-1)
 
195
  config = OmegaConf.load("configs/config_customnet.yaml")
196
  model = instantiate_from_config(config.model)
197
 
198
+ # model_path='./customnet_v1.pt?download=true'
199
+ # if not os.path.exists(model_path):
200
+ # os.system(f'wget https://huggingface.co/TencentARC/CustomNet/resolve/main/customnet_v1.pt?download=true -P .')
201
 
202
+ # ckpt = torch.load(model_path, map_location="cpu")
203
+ # model.load_state_dict(ckpt)
204
+ # del ckpt
205
  model = model.to(device)
206
  sampler = DDIMSampler(model, device=device)
207