inst-inpaint / app.py
abyildirim's picture
model is moved to gpu
eaab7c8
raw history blame
No virus
2.98 kB
import argparse
import gradio as gr
import numpy as np
import torch
from PIL import Image
import constants
import utils
from ldm.util import instantiate_from_config
from omegaconf import OmegaConf
from zipfile import ZipFile
import os
import requests
import shutil
def download_model(url):
os.makedirs("models", exist_ok=True)
local_filename = url.split('/')[-1]
with requests.get(url, stream=True) as r:
with open(os.path.join("models", local_filename), 'wb') as file:
shutil.copyfileobj(r.raw, file)
with ZipFile("models/gqa_inpaint.zip", 'r') as zObject:
zObject.extractall(path="models/")
os.remove("models/gqa_inpaint.zip")
MODEL = None
def inference(image: np.ndarray, instruction: str, center_crop: bool):
if not instruction.lower().startswith("remove the"):
raise gr.Error("Instruction should start with 'Remove the' !")
image = Image.fromarray(image)
cropped_image, image = utils.preprocess_image(image, center_crop=center_crop)
output_image = MODEL.inpaint(image, instruction, num_steps=100, device="cuda", return_pil=True, seed=0)
return cropped_image, output_image
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
type=str,
default="configs/latent-diffusion/gqa-inpaint-ldm-vq-f8-256x256.yaml",
help="Path of the model config file",
)
parser.add_argument(
"--checkpoint",
type=str,
default="models/gqa_inpaint/ldm/model.ckpt",
help="Path of the model checkpoint file",
)
args = parser.parse_args()
print("## Downloading the model file")
download_model("https://huggingface.co/abyildirim/inst-inpaint-models/resolve/main/gqa_inpaint.zip")
print("## Download is completed")
print("## Running the demo")
parsed_config = OmegaConf.load(args.config)
MODEL = instantiate_from_config(parsed_config["model"])
model_state_dict = torch.load(args.checkpoint, map_location="cpu")["state_dict"]
MODEL.load_state_dict(model_state_dict)
MODEL.eval()
MODEL.to("cuda")
sample_image, sample_instruction, sample_step = constants.EXAMPLES[3]
gr.Interface(
fn=inference,
inputs=[
gr.Image(type="numpy", value=sample_image, label="Source Image").style(
height=256
),
gr.Textbox(
label="Instruction",
lines=1,
value=sample_instruction,
),
gr.Checkbox(value=True, label="Center Crop", interactive=False),
],
outputs=[
gr.Image(type="pil", label="Cropped Image").style(height=256),
gr.Image(type="pil", label="Output Image").style(height=256),
],
allow_flagging="never",
examples=constants.EXAMPLES,
cache_examples=True,
title=constants.TITLE,
description=constants.DESCRIPTION,
).launch()