Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,6 +2,7 @@ import torch
|
|
| 2 |
import torch.nn as nn
|
| 3 |
from torchvision import models,transforms
|
| 4 |
from PIL import Image
|
|
|
|
| 5 |
import gradio as gr
|
| 6 |
from torchvision.transforms import transforms
|
| 7 |
|
|
@@ -13,16 +14,52 @@ t=transforms.Compose([ transforms.ToTensor(),
|
|
| 13 |
transforms.RandomHorizontalFlip(0.5),
|
| 14 |
transforms.RandomRotation(10),
|
| 15 |
])
|
| 16 |
-
class_name=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
model=torch.load("model.pth")
|
| 19 |
print(model)
|
|
|
|
| 20 |
def predict(image):
|
|
|
|
| 21 |
image=t(image).unsqueeze(0)
|
| 22 |
with torch.no_grad():
|
| 23 |
output=model(image)
|
| 24 |
_,predicted=torch.max(output,1)
|
| 25 |
-
|
|
|
|
| 26 |
return predicted_class
|
| 27 |
|
| 28 |
interface=gr.Interface(
|
|
|
|
| 2 |
import torch.nn as nn
|
| 3 |
from torchvision import models,transforms
|
| 4 |
from PIL import Image
|
| 5 |
+
import torch.nn.functional as f
|
| 6 |
import gradio as gr
|
| 7 |
from torchvision.transforms import transforms
|
| 8 |
|
|
|
|
| 14 |
transforms.RandomHorizontalFlip(0.5),
|
| 15 |
transforms.RandomRotation(10),
|
| 16 |
])
|
| 17 |
+
class_name=["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship","truck"]
|
| 18 |
+
|
| 19 |
+
class CIFAR_Module(nn.Module):
|
| 20 |
+
def __init__(self,in_channel):
|
| 21 |
+
self.in_channel=in_channel
|
| 22 |
+
super(CIFAR_Module,self).__init__()
|
| 23 |
+
self.con1=nn.Conv2d(in_channel,6*in_channel,5)
|
| 24 |
+
self.pool1=nn.MaxPool2d(5,stride=2)
|
| 25 |
+
self.con2=nn.Conv2d(6*in_channel,16*in_channel,5)
|
| 26 |
+
self.pool2=nn.MaxPool2d(5,stride=2)
|
| 27 |
+
self.flat=nn.Flatten()
|
| 28 |
+
self.fc1=nn.Linear(192,100*in_channel)
|
| 29 |
+
self.fc2=nn.Linear(100*in_channel,40*in_channel)
|
| 30 |
+
self.fc3=nn.Linear(40*in_channel,10)
|
| 31 |
+
def forward(self,x):
|
| 32 |
+
x=self.con1(x)
|
| 33 |
+
x=f.relu(x)
|
| 34 |
+
x=self.pool1(x)
|
| 35 |
+
x=f.relu(x)
|
| 36 |
+
x=self.con2(x)
|
| 37 |
+
x=f.relu(x)
|
| 38 |
+
x=self.pool2(x)
|
| 39 |
+
x=self.flat(x)
|
| 40 |
+
x=self.fc1(x)
|
| 41 |
+
x=f.relu(x)
|
| 42 |
+
x=self.fc2(x)
|
| 43 |
+
x=f.relu(x)
|
| 44 |
+
x=self.fc3(x)
|
| 45 |
+
return x
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
model=CIFAR_Module(3)
|
| 49 |
+
model.load_state_dict(torch.load("model.pth",weights_only=True))
|
| 50 |
+
model.eval()
|
| 51 |
+
|
| 52 |
|
|
|
|
| 53 |
print(model)
|
| 54 |
+
|
| 55 |
def predict(image):
|
| 56 |
+
image=image.resize((32,32))
|
| 57 |
image=t(image).unsqueeze(0)
|
| 58 |
with torch.no_grad():
|
| 59 |
output=model(image)
|
| 60 |
_,predicted=torch.max(output,1)
|
| 61 |
+
print(output)
|
| 62 |
+
predicted_class=class_name[predicted.item()-1]
|
| 63 |
return predicted_class
|
| 64 |
|
| 65 |
interface=gr.Interface(
|