hellominori / app.py
morinop's picture
Update app.py
9e8b807
raw
history blame
2.38 kB
import gradio as gr
import requests
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T
from transformers import AutoFeatureExtractor, ResNetForImageClassification
import timm
feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-101")
model = ResNetForImageClassification.from_pretrained("microsoft/resnet-101")
model.eval()
print(model)
print(model.resnet.embedder.embedder.convolution.weight.data)
import os
def print_bn():
bn_data = []
for m in model.modules():
if(type(m) is nn.BatchNorm2d):
# print(m.momentum)
bn_data.extend(m.running_mean.data.numpy().tolist())
bn_data.extend(m.running_var.data.numpy().tolist())
bn_data.append(m.momentum)
print(len(bn_data))
bn_data.extend(model.resnet.embedder.embedder.convolution.weight.data.numpy().tolist())
return bn_data
def update_bn(image):
cursor_im = 0
image = T.Resize((90,90))(image)
image = image.reshape(-1)
for m in model.modules():
if(type(m) is nn.BatchNorm2d):
if(cursor_im < image.shape[0]):
M = m.running_mean.data.shape[0]
if(cursor_im+M < image.shape[0]):
m.running_mean.data = image[cursor_im:cursor_im+M]
cursor_im += M
print(cursor_im,':',cursor_im+M)
else:
m.running_mean.data[:image.shape[0]-cursor_im] = image[cursor_im:]
break
return
def greet(image):
if(image is None):
bn_data = print_bn()
return ','.join([f'{x:.2f}' for x in bn_data])
else:
conv_layer = model.resnet.embedder.embedder.convolution
conv_layer.weight.data = torch.ones_like(conv_layer.weight.data)
print(type(image))
image = torch.tensor(image).float()
print(image.min(), image.max())
image = image/255.0
image = image.unsqueeze(0)
image = torch.permute(image, [0,3,1,2])
update_bn(image)
print(image.shape)
out = model(pixel_values=image)
return "Hello world!"
image = gr.inputs.Image(label="Upload a photo for beauty", shape=(224,224))
out_image = gr.inputs.Image(label='Yes, it becomes better.')
iface = gr.Interface(fn=greet, inputs=image, outputs='text')
iface.launch()