import argparse import gradio as gr import numpy as np import torch from PIL import Image import constants import utils from ldm.util import instantiate_from_config from omegaconf import OmegaConf from zipfile import ZipFile import os import requests import shutil def download_model(url): os.makedirs("models", exist_ok=True) local_filename = url.split('/')[-1] with requests.get(url, stream=True) as r: with open(os.path.join("models", local_filename), 'wb') as file: shutil.copyfileobj(r.raw, file) with ZipFile("models/gqa_inpaint.zip", 'r') as zObject: zObject.extractall(path="models/") os.remove("models/gqa_inpaint.zip") MODEL = None def inference(image: np.ndarray, instruction: str, center_crop: bool): if not instruction.lower().startswith("remove the"): raise gr.Error("Instruction should start with 'Remove the' !") image = Image.fromarray(image) cropped_image, image = utils.preprocess_image(image, center_crop=center_crop) output_image = MODEL.inpaint(image, instruction, num_steps=100, device="cuda", return_pil=True, seed=0) return cropped_image, output_image if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--config", type=str, default="configs/latent-diffusion/gqa-inpaint-ldm-vq-f8-256x256.yaml", help="Path of the model config file", ) parser.add_argument( "--checkpoint", type=str, default="models/gqa_inpaint/ldm/model.ckpt", help="Path of the model checkpoint file", ) args = parser.parse_args() print("## Downloading the model file") download_model("https://huggingface.co/abyildirim/inst-inpaint-models/resolve/main/gqa_inpaint.zip") print("## Download is completed") print("## Running the demo") parsed_config = OmegaConf.load(args.config) MODEL = instantiate_from_config(parsed_config["model"]) model_state_dict = torch.load(args.checkpoint, map_location="cpu")["state_dict"] MODEL.load_state_dict(model_state_dict) MODEL.eval() MODEL.to("cuda") sample_image, sample_instruction, sample_step = constants.EXAMPLES[3] gr.Interface( fn=inference, inputs=[ gr.Image(type="numpy", value=sample_image, label="Source Image").style( height=256 ), gr.Textbox( label="Instruction", lines=1, value=sample_instruction, ), gr.Checkbox(value=True, label="Center Crop", interactive=False), ], outputs=[ gr.Image(type="pil", label="Cropped Image").style(height=256), gr.Image(type="pil", label="Output Image").style(height=256), ], allow_flagging="never", examples=constants.EXAMPLES, cache_examples=True, title=constants.TITLE, description=constants.DESCRIPTION, ).launch()