import torch from diffusers import StableDiffusionImg2ImgPipeline, \ StableDiffusionPipeline def check_cuda_device(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') return device def get_the_model(device=None): model_id = "stabilityai/stable-diffusion-2" pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) if device: pipe.to(device) else: device = check_cuda_device() pipe.to(device) return pipe def get_image_to_image_model(path=None, device=None): model_id = "stabilityai/stable-diffusion-2" if path: pipe = StableDiffusionImg2ImgPipeline.from_pretrained( path, torch_dtype=torch.float16) else: pipe = StableDiffusionImg2ImgPipeline.from_pretrained( model_id, torch_dtype=torch.float16) if device: if device == "cuda" or device == "cpu": pipe.to(device) else: device = check_cuda_device() pipe.to(device) return pipe def gen_initial_img(int_prompt): model = get_the_model(None) image = model(int_prompt, num_inference_steps=100).images[0] return image def generate_story(int_prompt, steps, iterations=100): image_dic = {} init_img = gen_initial_img(int_prompt) img2img_model = get_image_to_image_model() img = init_img for idx, step in enumerate(steps): print(f"step: {idx}") print(step) image = img2img_model(prompt=step, image=img, strength=0.75, guidance_scale=7.5, num_inference_steps=iterations).images[0] image_dic[idx] = { "image": image, "prompt": step } img = image return init_img, image_dic