Spaces:
Sleeping
Sleeping
import gradio as gr | |
from options.test_options import TestOptions | |
from data import create_dataset | |
from models import create_model | |
from PIL import Image | |
import torchvision.transforms as transforms | |
import torch | |
import sys | |
import matplotlib.pyplot as plt | |
"python test.py --model test --name selfie2anime --dataroot selfie2anime/testB --num_test 100 --model_suffix '_B' --no_dropout" | |
title = "MASFNet: Multi-scale Adaptive Sampling Fusion Network for Object Detection in Adverse Weather" | |
description = "" | |
article = "" | |
def reset_interface(): | |
return gr.update(value=None), gr.update(visible=False) | |
def resize_image(img): | |
# 将图片调整为256x256分辨率 | |
return img.resize((256, 256), Image.BICUBIC) | |
def check_resolution(img): | |
# 获取图片分辨率 | |
width, height = img.size | |
# 检查分辨率是否符合要求 | |
if (width == 256 and height == 256) or (width == 64 and height == 64): | |
return True | |
else: | |
return False | |
def inference(img): | |
try: | |
# Debugging: Check if image is correctly received | |
if img is None: | |
print("No image received!") | |
return None | |
if check_resolution(img)==False: | |
img = resize_image(img) | |
import sys | |
sys.argv = ['--model', '--dataroot', './data/', '--num_test', '1', '--no_dropout'] | |
# Load options and set them up | |
opt = TestOptions().parse() | |
opt.num_threads = 0 | |
opt.batch_size = 1 | |
opt.serial_batches = True | |
opt.no_flip = True | |
opt.display_id = -1 | |
opt.name = '' | |
opt.model_suffix = '_B' | |
opt.num_test = 1 | |
opt.no_dropout = True | |
# Create model and set it up | |
dataset = create_dataset(opt) | |
model = create_model(opt) | |
model.setup(opt) | |
if opt.eval: | |
model.eval() | |
# Convert PIL image to tensor | |
img_tensor = transforms.ToTensor()(img.convert('RGB')).unsqueeze(0) | |
img_tensor = img_tensor.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) # Move to GPU if available | |
# Prepare data for the model | |
data = {'A':img_tensor,'A_paths':'./data/'} | |
model.set_input(data) | |
model.test() | |
# Get the output visuals | |
img_out = model.get_current_visuals() | |
output_img_tensor = img_out.get('fake') | |
print(f'type of output_img_tensor: {type(img_out)}') | |
if output_img_tensor is None: | |
print("No output from model!") | |
return None | |
if isinstance(output_img_tensor, torch.Tensor): | |
# 将张量转换回PIL图像 | |
output_img = output_img_tensor.squeeze(0).cpu().detach().numpy().transpose(1, 2, 0) | |
output_img = (output_img * 0.5 + 0.5) * 255 # 假设输出在[-1, 1]之间标准化 | |
output_img = output_img.astype('uint8') | |
output_img = Image.fromarray(output_img) | |
print(f'type if output_img_tensor: {type(output_img_tensor)}') | |
return output_img | |
else: | |
print(f"意外的输出类型: {type(output_img_tensor)}") | |
return None | |
except Exception as e: | |
print(f"Error during inference: {e}") | |
return None | |
example_images = [ | |
"img/1.png" | |
] | |
with gr.Blocks() as demo: | |
gr.Markdown(f"### {title}") | |
gr.Markdown(description) | |
with gr.Row(): | |
with gr.Column(): | |
img_input = gr.Image(type="pil", label="Upload an Image") | |
submit_btn = gr.Button("Submit...") | |
with gr.Column(): | |
output = gr.Image(type="pil", label="Prediction Result") | |
submit_btn.click(fn=inference, inputs=img_input, outputs=output) | |
demo.load(reset_interface, None, output) | |
gr.Examples( | |
examples=example_images, | |
inputs=img_input, | |
) | |
demo.launch() | |