Lama / app.py
Dnau15's picture
file path fixing
1235c7c
raw
history blame
7.56 kB
import gradio as gr
import torch
import numpy as np
from PIL import Image
from saicinpainting.evaluation.utils import move_to_device
from saicinpainting.evaluation.refinement import refine_predict
from saicinpainting.evaluation.data import pad_img_to_modulo
from saicinpainting.training.trainers import load_checkpoint
import numpy as np
import torch
import yaml
from omegaconf import OmegaConf
from torch.utils.data._utils.collate import default_collate
import os
#from gradio_imageslider import ImageSlider
import requests
import zipfile
import os
# URL of the file to download
url = "https://huggingface.co/smartywu/big-lama/resolve/main/big-lama.zip"
# Local filename to save the downloaded file
local_filename = "big-lama.zip"
# Directory to extract the files into
extract_dir = "big-lama"
# Check if the extracted directory already exists
if os.path.exists(extract_dir):
print(f"The directory '{extract_dir}' already exists. Skipping download and extraction.")
else:
# Check if the zip file already exists
if not os.path.exists(local_filename):
# Download the file
with requests.get(url, stream=True) as response:
response.raise_for_status()
with open(local_filename, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"Downloaded '{local_filename}' successfully.")
else:
print(f"The file '{local_filename}' already exists. Skipping download.")
# Unzip the file
with zipfile.ZipFile(local_filename, 'r') as zip_ref:
zip_ref.extractall()
print(f"Extracted '{local_filename}' into '{extract_dir}' successfully.")
# Optionally, remove the zip file after extraction
os.remove(local_filename)
print(f"Removed '{local_filename}' after extraction.")
generator = torch.Generator(device="cpu").manual_seed(42)
size = (1024, 1024)
def image_preprocess(image: Image, mode="RGB", return_orig=False):
img = np.array(image.convert(mode))
if img.ndim == 3:
img = np.transpose(img, (2, 0, 1))
out_img = img.astype("float32") / 255
if return_orig:
return out_img, img
else:
return out_img
def infer(image):
source = image["background"].convert("RGB").resize(size)
mask = image["layers"][0]
mask = mask.point(lambda p: p > 0 and 255).split()[3]
mask.convert("RGB")
# binary_mask = mask.point(lambda p: 255 if p > 0 else 0)
# inverted_mask = ImageChops.invert(binary_mask)
# alpha_image = Image.new("RGB", source.size, (0, 0, 0))
# cnet_image = Image.composite(source, alpha_image, inverted_mask)
device = torch.device("cpu")
predict_config_path = "./configs/prediction/default.yaml"
with open(predict_config_path, "r") as f:
predict_config = OmegaConf.create(yaml.safe_load(f))
train_config_path = os.path.join(predict_config.model.path, "config.yaml")
with open(train_config_path, "r") as f:
train_config = OmegaConf.create(yaml.safe_load(f))
train_config.training_model.predict_only = True
train_config.visualizer.kind = "noop"
checkpoint_path = os.path.join(
predict_config.model.path, "models", predict_config.model.checkpoint
)
model = load_checkpoint(
train_config, checkpoint_path, strict=False, map_location="cpu"
)
model.freeze()
if not predict_config.get("refine", False):
model.to(device)
img = image_preprocess(source, mode="RGB")
mask = image_preprocess(mask, mode="L")
result = dict(image=img, mask=mask[None, ...])
if (
predict_config.dataset.pad_out_to_modulo is not None
and predict_config.dataset.pad_out_to_modulo > 1
):
result["unpad_to_size"] = result["image"].shape[1:]
result["image"] = pad_img_to_modulo(
result["image"], predict_config.dataset.pad_out_to_modulo
)
result["mask"] = pad_img_to_modulo(
result["mask"], predict_config.dataset.pad_out_to_modulo
)
batch = default_collate([result])
if predict_config.get("refine", False):
assert "unpad_to_size" in batch, "Unpadded size is required for the refinement"
# image unpadding is taken care of in the refiner, so that output image
# is same size as the input image
cur_res = refine_predict(batch, model, **predict_config.refiner)
cur_res = cur_res[0].permute(1, 2, 0).detach().cpu().numpy()
else:
with torch.no_grad():
batch = move_to_device(batch, device)
batch["mask"] = (batch["mask"] > 0) * 1
batch = model(batch)
cur_res = (
batch[predict_config.out_key][0].permute(1, 2, 0).detach().cpu().numpy()
)
unpad_to_size = batch.get("unpad_to_size", None)
if unpad_to_size is not None:
orig_height, orig_width = unpad_to_size
cur_res = cur_res[:orig_height, :orig_width]
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
yield cur_res
def clear_result():
return gr.update(value=None)
css = """.main-div div{display:inline-flex;align-items:center;gap:.8rem;font-size:1.75rem}.main-div div h1{font-weight:900;margin-bottom:7px}.main-div p{margin-bottom:10px;font-size:94%}a{text-decoration:underline}.tabs{margin-top:0;margin-bottom:0}#gallery{min-height:20rem}
"""
prefix = ""
title = f"""
<div class="main-div">
<div>
<h1>Lama model</h1>
</div>
Running on {"<b>GPU 🔥</b>" if torch.cuda.is_available() else "<b>CPU 🥶</b>"} <br><br>
<a style="display:inline-block" href="https://huggingface.co/spaces/akhaliq/small-stable-diffusion-v0?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
</div>
"""
with gr.Blocks(css=css) as demo:
gr.HTML(title)
with gr.Row():
with gr.Row():
with gr.Column():
with gr.Row():
with gr.Column():
run_button = gr.Button("Generate")
with gr.Row():
input_image = gr.ImageMask(
type="pil",
label="Input Image",
crop_size=(1024, 1024),
layers=False,
height=1024,
width=1024
)
result = gr.Image(
interactive=False,
label="Generated Image",
)
use_as_input_button = gr.Button("Use as Input Image", visible=False)
def use_output_as_input(output_image):
return gr.update(value=output_image)
use_as_input_button.click(
fn=use_output_as_input, inputs=[result], outputs=[input_image]
)
run_button.click(
fn=clear_result,
inputs=None,
outputs=result,
).then(
fn=lambda: gr.update(visible=False),
inputs=None,
outputs=use_as_input_button,
).then(
fn=infer,
inputs=[input_image],
outputs=result,
).then(
fn=lambda: gr.update(visible=True),
inputs=None,
outputs=use_as_input_button,
)
# prompt.submit(
# fn=clear_result,
# inputs=None,
# outputs=result,
# ).then(
# fn=lambda: gr.update(visible=False),
# inputs=None,
# outputs=use_as_input_button,
# ).then(
# fn=infer,
# inputs=[prompt, input_image],
# outputs=result,
# ).then(
# fn=lambda: gr.update(visible=True),
# inputs=None,
# outputs=use_as_input_button,
# )
demo.launch()