yyk19 commited on
Commit
5b5da1b
·
1 Parent(s): f769af2

update candidate checkpoints.

Browse files
app.py CHANGED
@@ -74,7 +74,7 @@ def process_multi_wrapper_only_show_rendered(rendered_txt_0, rendered_txt_1, ren
74
  shared_eta, shared_a_prompt, shared_n_prompt,
75
  only_show_rendered_image=True)
76
 
77
- def load_ckpt(model_ckpt = "LAION-Glyph-10M"):
78
  global render_tool, model
79
  if torch.cuda.is_available():
80
  for i in range(5):
@@ -84,8 +84,11 @@ def load_ckpt(model_ckpt = "LAION-Glyph-10M"):
84
 
85
  if model_ckpt == "LAION-Glyph-1M":
86
  model = load_model_ckpt(model, "laion1M_model_wo_ema.ckpt")
87
- elif model_ckpt == "LAION-Glyph-10M":
88
- model = load_model_ckpt(model, "model_wo_ema.ckpt")
 
 
 
89
  render_tool = Render_Text(model)
90
  output_str = f"already change the model checkpoint to {model_ckpt}"
91
  print(output_str)
@@ -96,7 +99,8 @@ def load_ckpt(model_ckpt = "LAION-Glyph-10M"):
96
  print("empty the cuda cache")
97
 
98
  cfg = OmegaConf.load("config.yaml")
99
- model = load_model_from_config(cfg, "model_wo_ema.ckpt", verbose=True)
 
100
  # model = load_model_from_config(cfg, "model_states.pt", verbose=True)
101
  # model = load_model_from_config(cfg, "model.ckpt", verbose=True)
102
  # ddim_sampler = DDIMSampler(model)
@@ -148,7 +152,7 @@ with block:
148
  with gr.Accordion("Model Options", open=False):
149
  with gr.Row():
150
  # model_ckpt = gr.inputs.Dropdown(["LAION-Glyph-10M", "Textcaps5K-10"], label="Checkpoint", default = "LAION-Glyph-10M")
151
- model_ckpt = gr.inputs.Dropdown(["LAION-Glyph-10M", "LAION-Glyph-1M"], label="Checkpoint", default = "LAION-Glyph-10M")
152
  load_button = gr.Button(value = "Load Checkpoint")
153
 
154
  with gr.Accordion("Shared Advanced Options", open=False):
 
74
  shared_eta, shared_a_prompt, shared_n_prompt,
75
  only_show_rendered_image=True)
76
 
77
+ def load_ckpt(model_ckpt = "LAION-Glyph-10M-Epoch-5"):
78
  global render_tool, model
79
  if torch.cuda.is_available():
80
  for i in range(5):
 
84
 
85
  if model_ckpt == "LAION-Glyph-1M":
86
  model = load_model_ckpt(model, "laion1M_model_wo_ema.ckpt")
87
+ elif model_ckpt == "LAION-Glyph-10M-Epoch-5":
88
+ model = load_model_ckpt(model, "laion10M_epoch_5_model_wo_ema.ckpt")
89
+ elif model_ckpt == "LAION-Glyph-10M-Epoch-6":
90
+ model = load_model_ckpt(model, "laion10M_epoch_6_model_wo_ema.ckpt")
91
+
92
  render_tool = Render_Text(model)
93
  output_str = f"already change the model checkpoint to {model_ckpt}"
94
  print(output_str)
 
99
  print("empty the cuda cache")
100
 
101
  cfg = OmegaConf.load("config.yaml")
102
+ model = load_model_from_config(cfg, "laion10M_epoch_6_model_wo_ema.ckpt", verbose=True)
103
+ # model = load_model_from_config(cfg, "model_wo_ema.ckpt", verbose=True)
104
  # model = load_model_from_config(cfg, "model_states.pt", verbose=True)
105
  # model = load_model_from_config(cfg, "model.ckpt", verbose=True)
106
  # ddim_sampler = DDIMSampler(model)
 
152
  with gr.Accordion("Model Options", open=False):
153
  with gr.Row():
154
  # model_ckpt = gr.inputs.Dropdown(["LAION-Glyph-10M", "Textcaps5K-10"], label="Checkpoint", default = "LAION-Glyph-10M")
155
+ model_ckpt = gr.inputs.Dropdown(["LAION-Glyph-10M-Epoch-6", "LAION-Glyph-10M-Epoch-5", "LAION-Glyph-1M"], label="Checkpoint", default = "LAION-Glyph-10M-Epoch-6")
156
  load_button = gr.Button(value = "Load Checkpoint")
157
 
158
  with gr.Accordion("Shared Advanced Options", open=False):
model_wo_ema.ckpt → laion10M_epoch_5_model_wo_ema.ckpt RENAMED
File without changes
laion10M_epoch_6_model_wo_ema.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f012c42f53f80965257c4dfcff53caeea193c5e135b5fd21f953f2e2df59406c
3
+ size 6671914001
transfer.py CHANGED
@@ -1,8 +1,11 @@
1
  from omegaconf import OmegaConf
2
  from scripts.rendertext_tool import Render_Text, load_model_from_config
3
  import torch
4
- cfg = OmegaConf.load("config_cuda_ema.yaml")
5
- model = load_model_from_config(cfg, "model_states.pt", verbose=True)
 
 
 
6
 
7
  from pytorch_lightning.callbacks import ModelCheckpoint
8
  with model.ema_scope("store ema weights"):
 
1
  from omegaconf import OmegaConf
2
  from scripts.rendertext_tool import Render_Text, load_model_from_config
3
  import torch
4
+
5
+ cfg = OmegaConf.load("config_ema.yaml")
6
+ # model = load_model_from_config(cfg, "model_states.pt", verbose=True)
7
+ model = load_model_from_config(cfg, "mp_rank_00_model_states.pt", verbose=True)
8
+
9
 
10
  from pytorch_lightning.callbacks import ModelCheckpoint
11
  with model.ema_scope("store ema weights"):