Spaces:
Sleeping
Sleeping
File size: 747 Bytes
1975737 46985f9 1975737 46985f9 1975737 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
import torch
import gc
import utils
import model
import config
import image_generator as generator
def predict(prompt, pipe, loss_function=None):
latents = []
for seed_number, sd_concept in zip(config.SEEDS, config.STABLE_DIFUSION_CONCEPTS):
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
prompt = [f'{prompt} {sd_concept}']
latent = generator.generate_images(pipe=pipe, seed_number=seed_number, prompt=prompt, loss_function=loss_function)
latents.append(latent)
latents = torch.vstack(latents)
images = utils.convert_latents_to_pil_images(pipe=pipe, latents=latents)
grid = utils.populate_image_grid(images, 1, len(latents))
return grid
|