|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
|
|
from saicinpainting.evaluation.utils import move_to_device |
|
from saicinpainting.evaluation.refinement import refine_predict |
|
from saicinpainting.evaluation.data import pad_img_to_modulo |
|
from saicinpainting.training.trainers import load_checkpoint |
|
|
|
import numpy as np |
|
import torch |
|
import yaml |
|
from omegaconf import OmegaConf |
|
from torch.utils.data._utils.collate import default_collate |
|
import os |
|
|
|
import requests |
|
import zipfile |
|
import os |
|
|
|
|
|
url = "https://huggingface.co/smartywu/big-lama/resolve/main/big-lama.zip" |
|
|
|
|
|
local_filename = "big-lama.zip" |
|
|
|
|
|
extract_dir = "big-lama" |
|
|
|
|
|
if os.path.exists(extract_dir): |
|
print(f"The directory '{extract_dir}' already exists. Skipping download and extraction.") |
|
else: |
|
|
|
if not os.path.exists(local_filename): |
|
|
|
with requests.get(url, stream=True) as response: |
|
response.raise_for_status() |
|
with open(local_filename, 'wb') as f: |
|
for chunk in response.iter_content(chunk_size=8192): |
|
f.write(chunk) |
|
print(f"Downloaded '{local_filename}' successfully.") |
|
else: |
|
print(f"The file '{local_filename}' already exists. Skipping download.") |
|
|
|
|
|
with zipfile.ZipFile(local_filename, 'r') as zip_ref: |
|
zip_ref.extractall() |
|
print(f"Extracted '{local_filename}' into '{extract_dir}' successfully.") |
|
|
|
|
|
os.remove(local_filename) |
|
print(f"Removed '{local_filename}' after extraction.") |
|
|
|
generator = torch.Generator(device="cpu").manual_seed(42) |
|
|
|
size = (1024, 1024) |
|
|
|
|
|
def image_preprocess(image: Image, mode="RGB", return_orig=False): |
|
img = np.array(image.convert(mode)) |
|
if img.ndim == 3: |
|
img = np.transpose(img, (2, 0, 1)) |
|
out_img = img.astype("float32") / 255 |
|
if return_orig: |
|
return out_img, img |
|
else: |
|
return out_img |
|
|
|
|
|
def infer(image): |
|
source = image["background"].convert("RGB").resize(size) |
|
|
|
mask = image["layers"][0] |
|
|
|
mask = mask.point(lambda p: p > 0 and 255).split()[3] |
|
mask.convert("RGB") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = torch.device("cpu") |
|
|
|
predict_config_path = "./configs/prediction/default.yaml" |
|
|
|
with open(predict_config_path, "r") as f: |
|
predict_config = OmegaConf.create(yaml.safe_load(f)) |
|
|
|
train_config_path = os.path.join(predict_config.model.path, "config.yaml") |
|
with open(train_config_path, "r") as f: |
|
train_config = OmegaConf.create(yaml.safe_load(f)) |
|
|
|
train_config.training_model.predict_only = True |
|
train_config.visualizer.kind = "noop" |
|
|
|
checkpoint_path = os.path.join( |
|
predict_config.model.path, "models", predict_config.model.checkpoint |
|
) |
|
|
|
model = load_checkpoint( |
|
train_config, checkpoint_path, strict=False, map_location="cpu" |
|
) |
|
model.freeze() |
|
if not predict_config.get("refine", False): |
|
model.to(device) |
|
|
|
img = image_preprocess(source, mode="RGB") |
|
mask = image_preprocess(mask, mode="L") |
|
|
|
result = dict(image=img, mask=mask[None, ...]) |
|
|
|
if ( |
|
predict_config.dataset.pad_out_to_modulo is not None |
|
and predict_config.dataset.pad_out_to_modulo > 1 |
|
): |
|
result["unpad_to_size"] = result["image"].shape[1:] |
|
result["image"] = pad_img_to_modulo( |
|
result["image"], predict_config.dataset.pad_out_to_modulo |
|
) |
|
result["mask"] = pad_img_to_modulo( |
|
result["mask"], predict_config.dataset.pad_out_to_modulo |
|
) |
|
|
|
batch = default_collate([result]) |
|
if predict_config.get("refine", False): |
|
assert "unpad_to_size" in batch, "Unpadded size is required for the refinement" |
|
|
|
|
|
cur_res = refine_predict(batch, model, **predict_config.refiner) |
|
cur_res = cur_res[0].permute(1, 2, 0).detach().cpu().numpy() |
|
else: |
|
with torch.no_grad(): |
|
batch = move_to_device(batch, device) |
|
batch["mask"] = (batch["mask"] > 0) * 1 |
|
batch = model(batch) |
|
cur_res = ( |
|
batch[predict_config.out_key][0].permute(1, 2, 0).detach().cpu().numpy() |
|
) |
|
unpad_to_size = batch.get("unpad_to_size", None) |
|
if unpad_to_size is not None: |
|
orig_height, orig_width = unpad_to_size |
|
cur_res = cur_res[:orig_height, :orig_width] |
|
|
|
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8") |
|
|
|
yield cur_res |
|
|
|
|
|
def clear_result(): |
|
return gr.update(value=None) |
|
|
|
|
|
css = """.main-div div{display:inline-flex;align-items:center;gap:.8rem;font-size:1.75rem}.main-div div h1{font-weight:900;margin-bottom:7px}.main-div p{margin-bottom:10px;font-size:94%}a{text-decoration:underline}.tabs{margin-top:0;margin-bottom:0}#gallery{min-height:20rem} |
|
""" |
|
prefix = "" |
|
|
|
title = f""" |
|
<div class="main-div"> |
|
<div> |
|
<h1>Lama model</h1> |
|
</div> |
|
Running on {"<b>GPU 🔥</b>" if torch.cuda.is_available() else "<b>CPU 🥶</b>"} <br><br> |
|
<a style="display:inline-block" href="https://huggingface.co/spaces/akhaliq/small-stable-diffusion-v0?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a> |
|
</div> |
|
""" |
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.HTML(title) |
|
with gr.Row(): |
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Row(): |
|
with gr.Column(): |
|
run_button = gr.Button("Generate") |
|
with gr.Row(): |
|
input_image = gr.ImageMask( |
|
type="pil", |
|
label="Input Image", |
|
crop_size=(1024, 1024), |
|
layers=False, |
|
height=1024, |
|
width=1024 |
|
) |
|
|
|
result = gr.Image( |
|
interactive=False, |
|
label="Generated Image", |
|
) |
|
use_as_input_button = gr.Button("Use as Input Image", visible=False) |
|
|
|
def use_output_as_input(output_image): |
|
return gr.update(value=output_image) |
|
|
|
use_as_input_button.click( |
|
fn=use_output_as_input, inputs=[result], outputs=[input_image] |
|
) |
|
|
|
run_button.click( |
|
fn=clear_result, |
|
inputs=None, |
|
outputs=result, |
|
).then( |
|
fn=lambda: gr.update(visible=False), |
|
inputs=None, |
|
outputs=use_as_input_button, |
|
).then( |
|
fn=infer, |
|
inputs=[input_image], |
|
outputs=result, |
|
).then( |
|
fn=lambda: gr.update(visible=True), |
|
inputs=None, |
|
outputs=use_as_input_button, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo.launch() |
|
|