Spaces:
Running
Running
VinayHajare
commited on
Commit
•
4618d22
1
Parent(s):
24026e0
Update inference.py
Browse files- inference.py +7 -6
inference.py
CHANGED
@@ -37,15 +37,16 @@ checkpoint_cc12m = torch.load(cc12m_model_path, map_location=torch.device(device
|
|
37 |
|
38 |
# Create a new Generator model and initialize it with the pre-trained weights
|
39 |
netG = NetG(64, 100, 512, 256, 3, False, clip_model).to(device)
|
40 |
-
|
41 |
-
|
|
|
42 |
|
43 |
# Function to generate images from text
|
44 |
def generate_image_from_text(caption, model, batch_size=4):
|
45 |
if model == "CUB":
|
46 |
-
generator =
|
47 |
else:
|
48 |
-
generator =
|
49 |
|
50 |
# Create the noise tensor
|
51 |
noise = torch.randn((batch_size, 100)).to(device)
|
@@ -82,9 +83,9 @@ def generate_image_from_text(caption, model, batch_size=4):
|
|
82 |
# Function to generate images from text
|
83 |
def generate_image_from_text_with_persistent_storage(caption, model, batch_size=4):
|
84 |
if model == "CUB":
|
85 |
-
generator =
|
86 |
else:
|
87 |
-
generator =
|
88 |
|
89 |
# Create the noise tensor
|
90 |
noise = torch.randn((batch_size, 100)).to(device)
|
|
|
37 |
|
38 |
# Create a new Generator model and initialize it with the pre-trained weights
|
39 |
netG = NetG(64, 100, 512, 256, 3, False, clip_model).to(device)
|
40 |
+
netG1 = NetG(64, 100, 512, 256, 3, False, clip_model).to(device)
|
41 |
+
cub = load_model_weights(netG, checkpoint_cub['model']['netG'], multi_gpus=False)
|
42 |
+
cc12m = load_model_weights(netG1, checkpoint_cc12m['model']['netG'], multi_gpus=False)
|
43 |
|
44 |
# Function to generate images from text
|
45 |
def generate_image_from_text(caption, model, batch_size=4):
|
46 |
if model == "CUB":
|
47 |
+
generator = cub
|
48 |
else:
|
49 |
+
generator = cc12m
|
50 |
|
51 |
# Create the noise tensor
|
52 |
noise = torch.randn((batch_size, 100)).to(device)
|
|
|
83 |
# Function to generate images from text
|
84 |
def generate_image_from_text_with_persistent_storage(caption, model, batch_size=4):
|
85 |
if model == "CUB":
|
86 |
+
generator = cub
|
87 |
else:
|
88 |
+
generator = cc12m
|
89 |
|
90 |
# Create the noise tensor
|
91 |
noise = torch.randn((batch_size, 100)).to(device)
|