import PIL | |
from BeamDiffusionModel.models.diffusionModel.StableDiffusion import StableDiffusion | |
from BeamDiffusionModel.models.diffusionModel.configs.config_loader import CONFIG | |
from BeamDiffusionModel.models.clip.clip import Clip | |
import torch | |
import json | |
sd = StableDiffusion() | |
clip = Clip() | |
def read_json(path): | |
with open(path, 'r') as f: | |
data = json.load(f) | |
return data | |
def get_img(path): | |
img = PIL.Image.open(path) | |
return img | |
def clip_score(step, imgs_path): | |
imgs = [] | |
if isinstance(imgs_path, list): | |
for img_path in imgs_path: | |
img = get_img(img_path) | |
imgs.append(img) | |
else: | |
img = get_img(imgs_path) | |
imgs.append(img) | |
return clip.similarity(step, imgs) | |
def gen_img(caption, latent= None, seed=None): | |
system = "cuda" if CONFIG.get("diffusion_model", {}).get("use_cuda", True) and torch.cuda.is_available() else "cpu" | |
if seed: | |
generator = torch.Generator(system).manual_seed(seed) | |
else: | |
generator = None | |
img, latents = sd.generate_image(caption, generator=generator, latent=latent) | |
return latents, img | |