from __future__ import annotations import gc import pathlib import sys import gradio as gr import PIL.Image import numpy as np import torch from diffusers import StableDiffusionPipeline sys.path.insert(0, './ReVersion') # below are original import os # import argparse # import torch from PIL import Image # from diffusers import StableDiffusionPipeline # sys.path.insert(0, './ReVersion') # from templates.templates import inference_templates import math """ Inference script for generating batch results """ def make_image_grid(imgs, rows, cols): assert len(imgs) == rows*cols w, h = imgs[0].size grid = Image.new('RGB', size=(cols*w, rows*h)) grid_w, grid_h = grid.size for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h)) return grid def inference_fn( model_id, prompt, num_samples, guidance_scale, ): # create inference pipeline device = 'cuda' if torch.cuda.is_available() else 'cpu' pipe = StableDiffusionPipeline.from_pretrained(os.path.join('experiments', model_id),torch_dtype=torch.float16).to(device) # make directory to save images image_root_folder = os.path.join('experiments', model_id, 'inference') os.makedirs(image_root_folder, exist_ok = True) # if prompt is None and args.template_name is None: # raise ValueError("please input a single prompt through'--prompt' or select a batch of prompts using '--template_name'.") # single text prompt if prompt is not None: prompt_list = [prompt] else: prompt_list = [] # if args.template_name is not None: # # read the selected text prompts for generation # prompt_list.extend(inference_templates[args.template_name]) for prompt in prompt_list: # insert relation prompt # prompt = prompt.lower().replace("", "").format(placeholder_string) prompt = prompt.lower().replace("", "").format("") # make sub-folder image_folder = os.path.join(image_root_folder, prompt, 'samples') os.makedirs(image_folder, exist_ok = True) # batch generation images = pipe(prompt, num_inference_steps=50, guidance_scale=guidance_scale, num_images_per_prompt=num_samples).images # save generated images for idx, image in enumerate(images): image_name = f"{str(idx).zfill(4)}.png" image_path = os.path.join(image_folder, image_name) image.save(image_path) # save a grid of images image_grid = make_image_grid(images, rows=2, cols=math.ceil(num_samples/2)) image_grid_path = os.path.join(image_root_folder, prompt, f'{prompt}.png') return image_grid if __name__ == "__main__": inference_fn()