RAM_plus_plus / app.py
Zilong-Zhang003
NameError
7318bea
raw
history blame
4.28 kB
import os
import io
import cv2
import gradio as gr
import numpy as np
import torch
import spaces
from PIL import Image
from functools import lru_cache
from huggingface_hub import hf_hub_download, snapshot_download
from torchvision.transforms.functional import normalize
import glob
import traceback
from restormerRFR_arch import RestormerRFR
from dino_feature_extractor import DinoFeatureModule
WEIGHT_REPO_ID = "233zzl/RAM_plus_plus"
WEIGHT_FILENAME = "7task/RestormerRFR.pth"
MODEL_NAME = "RestormerRFR"
def get_device():
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def warmup():
hf_hub_download(
repo_id=WEIGHT_REPO_ID,
filename=WEIGHT_FILENAME,
repo_type="model",
revision="main"
)
snapshot_download(
repo_id="facebook/dinov2-giant",
repo_type="model",
revision="main"
)
def build_model():
model = RestormerRFR(
inp_channels=3,
out_channels=3,
dim=48,
num_blocks=[4, 6, 6, 8],
num_refinement_blocks=4,
heads=[1, 2, 4, 8],
ffn_expansion_factor=2.66,
bias=False,
LayerNorm_type="WithBias",
finetune_type=None,
img_size=128,
)
return model
@lru_cache(maxsize=1)
def get_dino_extractor(device):
extractor = DinoFeatureModule().to(device).eval()
return extractor
@lru_cache(maxsize=1)
def get_model_and_device():
device = get_device()
model = build_model()
weight_path = hf_hub_download(
repo_id=WEIGHT_REPO_ID,
filename=WEIGHT_FILENAME,
)
ckpt = torch.load(weight_path, map_location="cpu")
keyname = "params" if "params" in ckpt else None
if keyname is not None:
model.load_state_dict(ckpt[keyname], strict=False)
else:
model.load_state_dict(ckpt, strict=False)
model.eval().to(device)
return model, device
@spaces.GPU(duration=240)
def restore_image(pil_img: Image.Image) -> Image.Image:
"""
输入一张图片,输出复原后的图片(与 RAM++ RestormerRFR + DINO 特征推理一致)
"""
try:
model, device = get_model_and_device()
dino_extractor = get_dino_extractor(device)
img_bgr = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR).astype(np.float32) / 255.0
img = torch.from_numpy(np.transpose(img_bgr[:, :, [2, 1, 0]], (2, 0, 1))).float() # (3,H,W), RGB
img = img.unsqueeze(0).to(device) # (1,3,H,W)
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
normalize(img, mean, std, inplace=True)
with torch.no_grad():
dino_features = dino_extractor(img)
output = model(img, dino_features)
output = normalize(output, -1 * mean / std, 1 / std)
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() # (3,H,W)
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # (H,W,RGB)
output = (output * 255.0).round().astype(np.uint8)
out_pil = Image.fromarray(output, mode="RGB")
return out_pil
except Exception as e:
raise gr.Error(f"{e}\n{traceback.format_exc()}")
DESCRIPTION = """
# RAM++: Robust Representation Learning via Adaptive Mask for All-in-One Image Restoration
"""
with gr.Blocks(title="RAM++ ZeroGPU Demo") as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column():
inp = gr.Image(type="pil", label="load picture(JPEG/PNG)")
btn = gr.Button("Run (ZeroGPU)")
with gr.Column():
out = gr.Image(type="pil", label="output")
ex_files = []
for ext in ("*.png", "*.jpg", "*.jpeg", "*.bmp"):
ex_files.extend(glob.glob(os.path.join("examples", ext)))
ex_files = sorted(ex_files)
if ex_files:
gr.Examples(examples=ex_files, inputs=inp, label="exampls)")
btn.click(restore_image, inputs=inp, outputs=out, api_name="run")
gr.Markdown("""
**Tips**
- If the queue is long or you hit the quota, please try again later, or upgrade to Pro for a higher ZeroGPU quota and priority.
""")
demo.load(fn=warmup, inputs=None, outputs=None)
if __name__ == "__main__":
demo.launch()