minima2 / app.py
mgupta70's picture
itr1
28c37eb
import gradio as gr
import torch
from torch import nn
from torchvision.transforms import ToTensor
import torchvision.transforms as transforms
import os
import numpy as np
from PIL import Image
import torch.optim as optim
import torchvision.transforms.functional as F
from torch.nn.functional import pairwise_distance
class Network(nn.Module):
def __init__(self, emb_dim=128):
super(Network, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, 16, 5),
nn.PReLU(),
nn.MaxPool2d(2, stride=2),
nn.Dropout(0.3),
nn.Conv2d(16, 32, 5),
nn.PReLU(),
nn.MaxPool2d(2, stride=2),
nn.Dropout(0.3),
nn.Conv2d(32, 64, 5),
nn.PReLU(),
nn.MaxPool2d(2, stride=2),
nn.Dropout(0.3),
nn.Conv2d(64, 32, 5),
nn.PReLU(),
nn.MaxPool2d(2, stride=2),
nn.Dropout(0.3)
)
self.fc = nn.Sequential(
nn.Linear(32*10*10, 512),
nn.PReLU(),
nn.Linear(512, emb_dim)
)
def forward(self, x):
x = self.conv(x)
x = x.view(-1, 32*10*10)
x = self.fc(x)
# x = nn.functional.normalize(x)
return x
embedding_dims = 256
model = Network(embedding_dims)
optimizer = optim.Adam(model.parameters(), lr=0.0002)
model.load_state_dict(torch.load('triplet_30_v1.pth'))
model.eval()
def dist_pred(img1,img2,img3):
img1 = img1.resize(1,3,224,224)
img2 = img2.resize(1,3,224,224)
img3 = img3.resize(1,3,224,224)
score1 = pairwise_distance(img1,img2)
score2 = pairwise_distance(img2,img3)
return score1, score2
image1 = gr.inputs.Image(shape=(224,224))
image2 = gr.inputs.Image(shape=(224,224))
image3 = gr.inputs.Image(shape=(224,224))
label = gr.outputs.Label()
examples = [0]
intf = gr.Interface(fn=dist_pred, inputs = image1, image2,image3, outputs = label, examples = examples)
intf.launch(inline = False)