mridulk commited on
Commit
c81b721
·
verified ·
1 Parent(s): e9de0af

added the note for generation taking time

Browse files
Files changed (1) hide show
  1. app.py +11 -23
app.py CHANGED
@@ -19,7 +19,7 @@ from ldm.models.diffusion.plms import PLMSSampler
19
 
20
  def load_model_from_config(config, ckpt, verbose=False):
21
  print(f"Loading model from {ckpt}")
22
- pl_sd = torch.load(ckpt, map_location="cpu")
23
  # pl_sd = torch.load(ckpt)#, map_location="cpu")
24
  sd = pl_sd["state_dict"]
25
  model = instantiate_from_config(config.model)
@@ -31,7 +31,7 @@ def load_model_from_config(config, ckpt, verbose=False):
31
  print("unexpected keys:")
32
  print(u)
33
 
34
- # model.cuda()
35
  model.eval()
36
  return model
37
 
@@ -50,8 +50,8 @@ def masking_embed(embedding, levels=1):
50
  # LOAD MODEL GLOBALLY
51
  ckpt_path = './model_files/fishes/epoch=000119.ckpt'
52
  config_path = './model_files/fishes/2024-03-01T23-15-36-project.yaml'
53
- config = OmegaConf.load(config_path) # TODO: Optionally download from same location as ckpt and chnage this logic
54
- model = load_model_from_config(config, ckpt_path) # TODO: check path
55
 
56
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
57
  model = model.to(device)
@@ -78,12 +78,7 @@ def generate_image(fish_name, masking_level_input,
78
  return key
79
 
80
 
81
- if opt.plms:
82
- sampler = PLMSSampler(model)
83
- else:
84
- sampler = DDIMSampler(model)
85
-
86
-
87
 
88
  prompt = opt.prompt
89
  all_images = []
@@ -169,12 +164,6 @@ if __name__ == "__main__":
169
  help="number of ddim sampling steps",
170
  )
171
 
172
- parser.add_argument(
173
- "--plms",
174
- action='store_true',
175
- help="use plms sampling",
176
- )
177
-
178
  parser.add_argument(
179
  "--ddim_eta",
180
  type=float,
@@ -205,8 +194,6 @@ if __name__ == "__main__":
205
  opt = parser.parse_args()
206
 
207
  title = "🎞️ Phylo Diffusion - Generating Fish Images Tool"
208
- description = "Write the Species name to generate an image for.\n For Trait Masking: Specify the Level information as well"
209
-
210
 
211
  def load_example(prompt, level, option, components):
212
  components['prompt_input'].value = prompt
@@ -214,14 +201,16 @@ if __name__ == "__main__":
214
 
215
  def setup_interface():
216
  with gr.Blocks() as demo:
217
- gr.Markdown("# Phylo Diffusion - Generating Fish Images Tool")
 
218
  gr.Markdown("### Write the Species name to generate a fish image")
219
- gr.Markdown("### Select one of the experiments: Trait Masking or Trait Swapping")
 
220
 
221
  with gr.Row():
222
  with gr.Column():
223
- gr.Markdown("## Generate Images Based on Prompts")
224
- gr.Markdown("Enter a prompt to generate an image:")
225
  prompt_input = gr.Textbox(label="Species Name")
226
 
227
  # Radio button to select experiment type, with no default selection
@@ -248,7 +237,6 @@ if __name__ == "__main__":
248
  gr.Markdown("## Select an example:")
249
  examples = [
250
  ("Gambusia Affinis", "None", "", "Level 3"),
251
- ("Lepomis Auritus", "None", "", "Level 3"),
252
  ("Lepomis Auritus", "Level 3", "", "Level 3"),
253
  ("Noturus nocturnus", "None", "Notropis dorsalis", "Level 2")
254
  ]
 
19
 
20
  def load_model_from_config(config, ckpt, verbose=False):
21
  print(f"Loading model from {ckpt}")
22
+ pl_sd = torch.load(ckpt, map_location="cpu") # TODO: change for GPU resources
23
  # pl_sd = torch.load(ckpt)#, map_location="cpu")
24
  sd = pl_sd["state_dict"]
25
  model = instantiate_from_config(config.model)
 
31
  print("unexpected keys:")
32
  print(u)
33
 
34
+ # model.cuda() # TODO: change for GPU resources
35
  model.eval()
36
  return model
37
 
 
50
  # LOAD MODEL GLOBALLY
51
  ckpt_path = './model_files/fishes/epoch=000119.ckpt'
52
  config_path = './model_files/fishes/2024-03-01T23-15-36-project.yaml'
53
+ config = OmegaConf.load(config_path)
54
+ model = load_model_from_config(config, ckpt_path)
55
 
56
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
57
  model = model.to(device)
 
78
  return key
79
 
80
 
81
+ sampler = DDIMSampler(model)
 
 
 
 
 
82
 
83
  prompt = opt.prompt
84
  all_images = []
 
164
  help="number of ddim sampling steps",
165
  )
166
 
 
 
 
 
 
 
167
  parser.add_argument(
168
  "--ddim_eta",
169
  type=float,
 
194
  opt = parser.parse_args()
195
 
196
  title = "🎞️ Phylo Diffusion - Generating Fish Images Tool"
 
 
197
 
198
  def load_example(prompt, level, option, components):
199
  components['prompt_input'].value = prompt
 
201
 
202
  def setup_interface():
203
  with gr.Blocks() as demo:
204
+
205
+ gr.Markdown("# Phylo-Diffusion: Generating Fish Images Tool")
206
  gr.Markdown("### Write the Species name to generate a fish image")
207
+ gr.Markdown("### 1. Trait Masking: Specify the Level information to mask")
208
+ gr.Markdown("### 2. Trait Swapping: Specify the species name to swap trait with and at what level")
209
 
210
  with gr.Row():
211
  with gr.Column():
212
+ # gr.Markdown("## Generate Images Based on Prompts")
213
+ gr.Markdown("**NOTE:** The demo is currently running on free CPU resources provided by Hugging Face, so it may take up to 10 minutes to generate an image. We're working on securing additional resources to speed up the process. Thank you for your patience!")
214
  prompt_input = gr.Textbox(label="Species Name")
215
 
216
  # Radio button to select experiment type, with no default selection
 
237
  gr.Markdown("## Select an example:")
238
  examples = [
239
  ("Gambusia Affinis", "None", "", "Level 3"),
 
240
  ("Lepomis Auritus", "Level 3", "", "Level 3"),
241
  ("Noturus nocturnus", "None", "Notropis dorsalis", "Level 2")
242
  ]