import gradio as gr import requests import torch import torch.nn as nn import torchvision import torchvision.transforms as T from transformers import AutoFeatureExtractor, ResNetForImageClassification import timm feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-101") model = ResNetForImageClassification.from_pretrained("microsoft/resnet-101") model.eval() import os def print_bn(): bn_data = [] for m in model.modules(): if(type(m) is nn.BatchNorm2d): # print(m.momentum) bn_data.extend(m.running_mean.data.numpy().tolist()) bn_data.extend(m.running_var.data.numpy().tolist()) bn_data.append(m.momentum) print(len(bn_data)) return bn_data def update_bn(image): cursor_im = 0 image = T.Resize((90,90))(image) image = image.reshape(-1) for m in model.modules(): if(type(m) is nn.BatchNorm2d): if(cursor_im < image.shape[0]): M = m.running_mean.data.shape[0] if(cursor_im+M < image.shape[0]): m.running_mean.data = image[cursor_im:cursor_im+M] cursor_im += M print(cursor_im,':',cursor_im+M) else: m.running_mean.data[:image.shape[0]-cursor_im] = image[cursor_im:] break return def greet(image): if(image is None): bn_data = print_bn() return ','.join([f'{x:.2f}' for x in bn_data]) else: print(type(image)) image = torch.tensor(image).float() print(image.min(), image.max()) image = image/255.0 image = image.unsqueeze(0) image = torch.permute(image, [0,3,1,2]) update_bn(image) print(image.shape) out = model(pixel_values=image) return "Hello world!" image = gr.inputs.Image(label="Upload a photo for beauty", shape=(224,224)) out_image = gr.inputs.Image(label='Yes, it becomes better.') iface = gr.Interface(fn=greet, inputs=image, outputs='text') iface.launch()