silentchen commited on
Commit
d5eda1c
1 Parent(s): e731c6b

update space

Browse files
Files changed (1) hide show
  1. app.py +16 -15
app.py CHANGED
@@ -257,19 +257,20 @@ def main():
257
  elif 'cyber' in instruction:
258
  e_type = 'cyber'
259
 
260
- model = load_model('text300M', device=device)
261
- with torch.no_grad():
262
- new_proj = nn.Linear(1024 * 2, 1024, device=device, dtype=model.wrapped.input_proj.weight.dtype)
263
- new_proj.weight = nn.Parameter(torch.zeros_like(new_proj.weight))
264
- new_proj.weight[:, :1024].copy_(model.wrapped.input_proj.weight) #
265
- new_proj.bias = nn.Parameter(torch.zeros_like(new_proj.bias))
266
- new_proj.bias[:1024].copy_(model.wrapped.input_proj.bias)
267
- model.wrapped.input_proj = new_proj
268
-
269
- ckp = torch.load(
270
- hf_hub_download(repo_id='silentchen/Shap_Editor', subfolder='single', filename='{}.pt'.format(e_type)),
271
- map_location='cpu')
272
- model.load_state_dict(ckp['model'])
 
273
 
274
  noise_initial = initial_noise[e_type].to(device)
275
  noise_start_t_e_type = noise_start_t[e_type]
@@ -324,8 +325,8 @@ def main():
324
  return mesh_path[0], mesh_path[1], mesh_path[2], mesh_path[3], state
325
 
326
 
327
- del models
328
- models = None
329
  with Blocks(
330
  css=css,
331
  analytics_enabled=False,
 
257
  elif 'cyber' in instruction:
258
  e_type = 'cyber'
259
 
260
+ model = models[e_type]
261
+ # model = load_model('text300M', device=device)
262
+ # with torch.no_grad():
263
+ # new_proj = nn.Linear(1024 * 2, 1024, device=device, dtype=model.wrapped.input_proj.weight.dtype)
264
+ # new_proj.weight = nn.Parameter(torch.zeros_like(new_proj.weight))
265
+ # new_proj.weight[:, :1024].copy_(model.wrapped.input_proj.weight) #
266
+ # new_proj.bias = nn.Parameter(torch.zeros_like(new_proj.bias))
267
+ # new_proj.bias[:1024].copy_(model.wrapped.input_proj.bias)
268
+ # model.wrapped.input_proj = new_proj
269
+ #
270
+ # ckp = torch.load(
271
+ # hf_hub_download(repo_id='silentchen/Shap_Editor', subfolder='single', filename='{}.pt'.format(e_type)),
272
+ # map_location='cpu')
273
+ # model.load_state_dict(ckp['model'])
274
 
275
  noise_initial = initial_noise[e_type].to(device)
276
  noise_start_t_e_type = noise_start_t[e_type]
 
325
  return mesh_path[0], mesh_path[1], mesh_path[2], mesh_path[3], state
326
 
327
 
328
+ # del models
329
+ # models = None
330
  with Blocks(
331
  css=css,
332
  analytics_enabled=False,