File size: 751 Bytes
1975737
 
 
 
 
 
 
 
 
 
 
 
9eb0849
1975737
9eb0849
1975737
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
import gc

import utils
import model
import config
import image_generator as generator

def predict(prompt, pipe, loss_function=None):
    latents = []
    
    for seed_number, sd_concept in zip(config.SEEDS, config.STABLE_DIFUSION_CONCEPTS):
        # torch.cuda.empty_cache()
        gc.collect()
        # torch.cuda.empty_cache()
        
        prompt = [f'{prompt} {sd_concept}']
        latent = generator.generate_images(pipe=pipe, seed_number=seed_number, prompt=prompt, loss_function=loss_function)
        latents.append(latent)
    
    latents = torch.vstack(latents)
    images = utils.convert_latents_to_pil_images(pipe=pipe, latents=latents)
    grid = utils.populate_image_grid(images, 1, len(latents))
    return grid