VinayHajare commited on
Commit
4618d22
1 Parent(s): 24026e0

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +7 -6
inference.py CHANGED
@@ -37,15 +37,16 @@ checkpoint_cc12m = torch.load(cc12m_model_path, map_location=torch.device(device
37
 
38
  # Create a new Generator model and initialize it with the pre-trained weights
39
  netG = NetG(64, 100, 512, 256, 3, False, clip_model).to(device)
40
- #cub = load_model_weights(netG, checkpoint_cub['model']['netG'], multi_gpus=False)
41
- #cc12m = load_model_weights(netG, checkpoint_cc12m['model']['netG'], multi_gpus=False)
 
42
 
43
  # Function to generate images from text
44
  def generate_image_from_text(caption, model, batch_size=4):
45
  if model == "CUB":
46
- generator = load_model_weights(netG, checkpoint_cub['model']['netG'], multi_gpus=False)
47
  else:
48
- generator = load_model_weights(netG, checkpoint_cc12m['model']['netG'], multi_gpus=False)
49
 
50
  # Create the noise tensor
51
  noise = torch.randn((batch_size, 100)).to(device)
@@ -82,9 +83,9 @@ def generate_image_from_text(caption, model, batch_size=4):
82
  # Function to generate images from text
83
  def generate_image_from_text_with_persistent_storage(caption, model, batch_size=4):
84
  if model == "CUB":
85
- generator = load_model_weights(netG, checkpoint_cub['model']['netG'], multi_gpus=False)
86
  else:
87
- generator = load_model_weights(netG, checkpoint_cc12m['model']['netG'], multi_gpus=False)
88
 
89
  # Create the noise tensor
90
  noise = torch.randn((batch_size, 100)).to(device)
 
37
 
38
  # Create a new Generator model and initialize it with the pre-trained weights
39
  netG = NetG(64, 100, 512, 256, 3, False, clip_model).to(device)
40
+ netG1 = NetG(64, 100, 512, 256, 3, False, clip_model).to(device)
41
+ cub = load_model_weights(netG, checkpoint_cub['model']['netG'], multi_gpus=False)
42
+ cc12m = load_model_weights(netG1, checkpoint_cc12m['model']['netG'], multi_gpus=False)
43
 
44
  # Function to generate images from text
45
  def generate_image_from_text(caption, model, batch_size=4):
46
  if model == "CUB":
47
+ generator = cub
48
  else:
49
+ generator = cc12m
50
 
51
  # Create the noise tensor
52
  noise = torch.randn((batch_size, 100)).to(device)
 
83
  # Function to generate images from text
84
  def generate_image_from_text_with_persistent_storage(caption, model, batch_size=4):
85
  if model == "CUB":
86
+ generator = cub
87
  else:
88
+ generator = cc12m
89
 
90
  # Create the noise tensor
91
  noise = torch.randn((batch_size, 100)).to(device)