File size: 3,885 Bytes
479af2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a240000
 
 
 
 
 
 
 
 
 
 
 
 
479af2c
 
 
 
 
 
a240000
 
 
 
479af2c
331c179
479af2c
 
 
 
 
 
 
 
b34e4c5
479af2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331c179
479af2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import gradio as gr
from options.test_options import TestOptions
from data import create_dataset
from models import create_model
from PIL import Image
import torchvision.transforms as transforms
import torch
import sys
import matplotlib.pyplot as plt
"python test.py --model test --name selfie2anime --dataroot selfie2anime/testB --num_test 100 --model_suffix '_B' --no_dropout"


title = "MASFNet: Multi-scale Adaptive Sampling Fusion Network for Object Detection in Adverse Weather"
description = ""
article = ""

def reset_interface():
    return gr.update(value=None), gr.update(visible=False)

def resize_image(img):
    # 将图片调整为256x256分辨率
    return img.resize((256, 256), Image.BICUBIC)

def check_resolution(img):
    # 获取图片分辨率
    width, height = img.size
    # 检查分辨率是否符合要求
    if (width == 256 and height == 256) or (width == 64 and height == 64):
        return True
    else:
        return False

def inference(img):
    try:
        # Debugging: Check if image is correctly received
        if img is None:
            print("No image received!")
            return None

        if check_resolution(img)==False:
            img = resize_image(img)
        
        import sys
        sys.argv = ['--model',  '--dataroot', './data/', '--num_test', '1',  '--no_dropout']

        # Load options and set them up
        opt = TestOptions().parse() 
        opt.num_threads = 0   
        opt.batch_size = 1   
        opt.serial_batches = True 
        opt.no_flip = True    
        opt.display_id = -1   
        opt.name = ''
        opt.model_suffix = '_B'
        opt.num_test = 1
        opt.no_dropout = True

        # Create model and set it up
        dataset = create_dataset(opt)
        model = create_model(opt)     
        model.setup(opt)     
        if opt.eval:
            model.eval()
        # Convert PIL image to tensor
        img_tensor = transforms.ToTensor()(img.convert('RGB')).unsqueeze(0)
        img_tensor = img_tensor.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))  # Move to GPU if available
        
        # Prepare data for the model
        data = {'A':img_tensor,'A_paths':'./data/'}
        model.set_input(data)
        model.test()
        
        # Get the output visuals
        img_out = model.get_current_visuals()
        output_img_tensor = img_out.get('fake')
        print(f'type of output_img_tensor: {type(img_out)}')
        if output_img_tensor is None:
            print("No output from model!")
            return None
        
        if isinstance(output_img_tensor, torch.Tensor):
            # 将张量转换回PIL图像
            output_img = output_img_tensor.squeeze(0).cpu().detach().numpy().transpose(1, 2, 0)
            output_img = (output_img * 0.5 + 0.5) * 255  # 假设输出在[-1, 1]之间标准化
            output_img = output_img.astype('uint8')
            output_img = Image.fromarray(output_img)
            print(f'type if output_img_tensor: {type(output_img_tensor)}')
            return output_img
        else:
            print(f"意外的输出类型: {type(output_img_tensor)}")
            return None
    
    except Exception as e:
        print(f"Error during inference: {e}")
        return None

example_images = [
    "img/1.png"
]

with gr.Blocks() as demo:
    gr.Markdown(f"### {title}")
    gr.Markdown(description)
    
    with gr.Row():
        with gr.Column():
            img_input = gr.Image(type="pil", label="Upload an Image")
            submit_btn = gr.Button("Submit...")
        with gr.Column():
            output = gr.Image(type="pil", label="Prediction Result")

    submit_btn.click(fn=inference, inputs=img_input, outputs=output)
    demo.load(reset_interface, None, output)
    gr.Examples(
        examples=example_images,
        inputs=img_input,
    )
    
demo.launch()