yyk19 commited on
Commit
ebf2390
·
1 Parent(s): ba96fba

remove ema parts in checkpoints.

Browse files
Files changed (3) hide show
  1. app.py +3 -3
  2. model_states.pt → model_wo_ema.ckpt +2 -2
  3. transfer.py +8 -2
app.py CHANGED
@@ -66,11 +66,11 @@ def process_multi_wrapper_only_show_rendered(rendered_txt_0, rendered_txt_1, ren
66
  shared_eta, shared_a_prompt, shared_n_prompt,
67
  only_show_rendered_image=True)
68
 
69
- # cfg = OmegaConf.load("config.yaml")
70
- # model = load_model_from_config(cfg, "model_states.pt", verbose=True)
71
 
72
  cfg = OmegaConf.load("config.yaml")
73
- model = load_model_from_config(cfg, "model.ckpt", verbose=True)
 
 
74
 
75
  ddim_sampler = DDIMSampler(model)
76
  render_tool = Render_Text(model)
 
66
  shared_eta, shared_a_prompt, shared_n_prompt,
67
  only_show_rendered_image=True)
68
 
 
 
69
 
70
  cfg = OmegaConf.load("config.yaml")
71
+ model = load_model_from_config(cfg, "model_wo_ema.ckpt", verbose=True)
72
+ # model = load_model_from_config(cfg, "model_states.pt", verbose=True)
73
+ # model = load_model_from_config(cfg, "model.ckpt", verbose=True)
74
 
75
  ddim_sampler = DDIMSampler(model)
76
  render_tool = Render_Text(model)
model_states.pt → model_wo_ema.ckpt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2b56f1251182afabc8d5291e07c3a3aaf21d85d36445b25d0057fc1960d63de5
3
- size 9880058178
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b86b22188bf580e80773a5ae101bf9787eb258349f3f1acf0ae50fd10cb3fec
3
+ size 6671922039
transfer.py CHANGED
@@ -6,9 +6,15 @@ model = load_model_from_config(cfg, "model_states.pt", verbose=True)
6
 
7
  from pytorch_lightning.callbacks import ModelCheckpoint
8
  with model.ema_scope("store ema weights"):
 
 
 
 
 
 
9
  file_content = {
10
- 'state_dict': model.state_dict()
11
  }
12
- torch.save(file_content, "model.ckpt")
13
  print("has stored the transfered ckpt.")
14
  print("trial ends!")
 
6
 
7
  from pytorch_lightning.callbacks import ModelCheckpoint
8
  with model.ema_scope("store ema weights"):
9
+ model_sd = model.state_dict()
10
+ store_sd = {}
11
+ for key in model_sd:
12
+ if "ema" in key:
13
+ continue
14
+ store_sd[key] = model_sd[key]
15
  file_content = {
16
+ 'state_dict': store_sd
17
  }
18
+ torch.save(file_content, "model_wo_ema.ckpt")
19
  print("has stored the transfered ckpt.")
20
  print("trial ends!")