Spaces:
Sleeping
Sleeping
import gradio as gr | |
import requests | |
import torch | |
import torch.nn as nn | |
import torchvision | |
import torchvision.transforms as T | |
import timm | |
model = timm.create_model("hf_hub:nateraw/resnet18-random", pretrained=True) | |
model.eval() | |
import os | |
def print_bn(): | |
bn_data = [] | |
for m in model.modules(): | |
if(type(m) is nn.BatchNorm2d): | |
# print(m.momentum) | |
bn_data.extend(m.running_mean.data.numpy().tolist()) | |
bn_data.extend(m.running_var.data.numpy().tolist()) | |
bn_data.append(m.momentum) | |
return bn_data | |
def update_bn(image): | |
cursor_im = 0 | |
image = image.reshape(-1) | |
image = T.Resize((40,40))(image) | |
for m in model.modules(): | |
if(type(m) is nn.BatchNorm2d): | |
if(cursor_im < image.shape[0]): | |
M = m.running_mean.data.shape[0] | |
if(cursor_im+M < image.shape[0]): | |
m.running_mean.data = image[cursor_im:cursor_im+M] | |
cursor_im += M # next | |
else: | |
m.running_mean.data[:image.shape[0]-cursor_im] = image[cursor_im:] | |
break # finish | |
return | |
def greet(image): | |
# url = f'https://huggingface.co/spaces?p=1&sort=modified&search=GPT' | |
# html = request_url(url) | |
# key = os.getenv("OPENAI_API_KEY") | |
# x = torch.ones([1,3,224,224]) | |
if(image is None): | |
bn_data = print_bn() | |
return ','.join([f'{x:.10f}' for x in bn_data]) | |
else: | |
print(type(image)) | |
image = torch.tensor(image).float() | |
print(image.min(), image.max()) | |
image = image/255.0 | |
image = image.unsqueeze(0) | |
print(image.shape) | |
image = torch.permute(image, [0,3,1,2]) | |
out = model(image) | |
update_bn(image) | |
# model.train() | |
return "Hello world!" | |
image = gr.inputs.Image(label="Upload a photo for beautify", shape=(224,224)) | |
iface = gr.Interface(fn=greet, inputs=image, outputs="text") | |
iface.launch() |