fr-iqa / app.py
rizavelioglu's picture
Update app.py
879ac73 verified
import gradio as gr
from DISTS_pytorch import DISTS
from torchvision.io import read_image
import torch
import torchvision.transforms.v2 as transforms
import spaces
from metrics.DeepDC import DeepDC
from metrics.DeepWSD import DeepWSD
from metrics.ADISTS import ADISTS
from dreamsim import dreamsim
# pyiqa requires older version of packages, causing dependency issues during install. Therefore, we install it here.
# Specifically, it requires transformers=4.37.2.
try:
import pyiqa
except ImportError:
print("pyiqa not found. Installing...")
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "pyiqa==0.1.14.1", "--no-deps"])
import pyiqa
# Download models once at startup
_, _ = dreamsim(pretrained=True, device="cpu")
@spaces.GPU(duration=15)
class Evaluator:
def __init__(self, device):
self.device = device
self.transform = transforms.ToDtype(dtype=torch.float32, scale=True)
self.metrics = self._init_metrics()
def _init_metrics(self):
return {
"↓ MSE": torch.nn.functional.mse_loss,
"↓ L1": torch.nn.functional.l1_loss,
"↓ DISTS": DISTS().to(self.device),
"↓ ADISTS": ADISTS().to(self.device),
"↓ DeepDC": DeepDC().to(self.device),
"↓ DeepWSD": DeepWSD().to(self.device),
"↓ LPIPS": pyiqa.create_metric("lpips", device=self.device),
"↓ DreamSim": dreamsim(pretrained=True, device=self.device)[0],
"↑ PSNR": pyiqa.create_metric("psnr", device=self.device),
"↑ SSIM": pyiqa.create_metric("ssim", device=self.device),
"↑ MS-SSIM": pyiqa.create_metric("ms_ssim", device=self.device),
"↑ CW-SSIM": pyiqa.create_metric("cw_ssim", device=self.device),
"↑ FSIM": pyiqa.create_metric("fsim", device=self.device),
}
@torch.no_grad()
def evaluate(self, img_fname1, img_fname2):
img1 = self.transform(read_image(img_fname1)).unsqueeze(0).to(self.device)
img2 = self.transform(read_image(img_fname2)).unsqueeze(0).to(self.device)
# check images are the same size
if img1.shape != img2.shape:
return "Input images must have the same dimensions!"
return "\n".join(
f"{name:<10}: {float(metric(img1, img2).item()):3,.5f}"
for name, metric in self.metrics.items()
)
@spaces.GPU(duration=5)
def get_evaluator():
"""Returns a singleton Evaluator instance per worker/session."""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if not hasattr(get_evaluator, "evaluator"):
get_evaluator.evaluator = Evaluator(device)
return get_evaluator.evaluator
@spaces.GPU(duration=10)
def compute_similarity(img1_path, img2_path):
"""Main function for Gradio interface."""
if not img1_path or not img2_path:
return "Please upload both images!"
return get_evaluator().evaluate(img1_path, img2_path)
def create_interface():
examples = [
["examples/01_1.jpg", "examples/01_1.jpg"], # Add an extra example for identical images
["examples/01_1.jpg", "examples/noise.jpg"],
*[[f"examples/{i:02d}_1.jpg", f"examples/{i:02d}_2.jpg"] for i in range(1, 10)],
]
# Custom CSS
css = """
.center-header {
display: flex;
align-items: center;
justify-content: center;
margin: 0 0 10px 0;
}
.monospace-text {
font-family: 'Courier New', Courier, monospace;
}
.metrics-table {
width: 100%;
border-collapse: collapse;
}
.metrics-table td {
padding: 10px;
vertical-align: top;
}
"""
# Add UI elements
pyiqa_url = "https://github.com/chaofengc/IQA-PyTorch"
with gr.Blocks(title="FR-IQA", css=css) as demo:
gr.Markdown(f"""
<div class='center-header'><h1>Full-Reference Image Quality Assessment</h1></div>
Upload two images to compute various similarity metrics.<br>
**Note**: Images must have identical dimensions. Code will run much faster locally: due to ZeroGPU setup, metrics are re-initialized on every run..
""")
with gr.Row():
with gr.Column(scale=2):
img_fname1 = gr.Image(type="filepath", label="Image#1", height=512, width=512)
with gr.Column(scale=2):
img_fname2 = gr.Image(type="filepath", label="Image#2", height=512, width=512)
with gr.Column(scale=1):
metrics_output = gr.Textbox(label="Metrics Output", lines=22, elem_classes="monospace-text", show_copy_button=True)
with gr.Row():
submit_btn = gr.Button("Compute Metrics", variant="primary")
with gr.Row():
with gr.Column(scale=2):
gr.Examples(
examples=examples,
inputs=[img_fname1, img_fname2],
fn=compute_similarity,
outputs=metrics_output,
label="Example Pairs (all are 1024Γ—768)",
cache_examples=True,
cache_mode="lazy",
examples_per_page=6
)
with gr.Column(scale=2):
gr.Markdown(f"""
<div class='center-header'><h3>Acknowledgements</h3></div>
- Example images from [TryOffDiff](https://rizavelioglu.github.io/tryoffdiff) paper, which are sampled from VITON-HD dataset.
- Metrics (*score range is only rough estimation, actual score range may vary*):
<table class="metrics-table">
<tr>
<th>Metric</th>
<th>Score Range</th>
<th>Lower is better?</th>
<th>Source</th>
</tr>
<tr>
<td>MSE</td>
<td>[0, ∞)</td>
<td>Yes</td>
<td><a href="https://docs.pytorch.org/docs/stable/generated/torch.nn.MSELoss.html">torch</a></td>
</tr>
<tr>
<td>L1</td>
<td>[0, ∞)</td>
<td>Yes</td>
<td><a href="https://docs.pytorch.org/docs/stable/generated/torch.nn.L1Loss.html">torch</a></td>
</tr>
<tr>
<td>DISTS</td>
<td>[0, 1]</td>
<td>Yes</td>
<td><a href="https://github.com/dingkeyan93/DISTS">official</a></td>
</tr>
<tr>
<td>ADISTS</td>
<td>~[0, 1]</td>
<td>Yes</td>
<td><a href="https://github.com/dingkeyan93/A-DISTS">official</a></td>
</tr>
<tr>
<td>DeepDC</td>
<td>[0, 1]</td>
<td>Yes</td>
<td><a href="https://github.com/h4nwei/DeepDC">official</a></td>
</tr>
<tr>
<td>DeepWSD</td>
<td>[0, ∞)</td>
<td>Yes</td>
<td><a href="https://github.com/Buka-Xing/DeepWSD">official</a></td>
</tr>
<tr>
<td>LPIPS</td>
<td>[0, 1]</td>
<td>Yes</td>
<td><a href="{pyiqa_url}">pyiqa</a></td>
</tr>
<tr>
<td>DreamSim</td>
<td>[0, 1]</td>
<td>Yes</td>
<td><a href="https://github.com/ssundaram21/dreamsim">official</a></td>
</tr>
<tr>
<td>PSNR</td>
<td>[0, ∞)</td>
<td>No</td>
<td><a href="{pyiqa_url}">pyiqa</a></td>
</tr>
<tr>
<td>SSIM</td>
<td>[0, 1]</td>
<td>No</td>
<td><a href="{pyiqa_url}">pyiqa</a></td>
</tr>
<tr>
<td>MS-SSIM</td>
<td>[0, 1]</td>
<td>No</td>
<td><a href="{pyiqa_url}">pyiqa</a></td>
</tr>
<tr>
<td>CW-SSIM</td>
<td>[0, 1]</td>
<td>No</td>
<td><a href="{pyiqa_url}">pyiqa</a></td>
</tr>
<tr>
<td>FSIM</td>
<td>[0, 1]</td>
<td>No</td>
<td><a href="{pyiqa_url}">pyiqa</a></td>
</tr>
</table>
""")
submit_btn.click(
fn=compute_similarity,
inputs=[img_fname1, img_fname2],
outputs=[metrics_output]
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(share=False, ssr_mode=False)