matikosowy commited on
Commit
084852b
·
1 Parent(s): 36c8c04

torch model

Browse files
Files changed (2) hide show
  1. app.py +39 -12
  2. model.pkl → porsche_model.pth +2 -2
app.py CHANGED
@@ -1,18 +1,45 @@
1
  import gradio as gr
2
- from fastai.vision.all import *
 
 
 
 
3
 
4
- learn = load_learner('model.pkl')
5
- categories = ['911', 'Cayenne', 'Cayman', 'Macan', 'Panamera', 'Taycan']
6
 
7
- def classify_image(img):
8
- img = PILImage.create(img)
9
- dl = learn.dls.test_dl([img])
10
- probs, _ = learn.tta(dl=dl)
11
- return dict(zip(categories, map(float, probs[0])))
 
12
 
 
 
 
 
 
13
 
14
- image = gr.Image(width=224, height=224)
15
- label = gr.Label()
 
 
 
 
 
16
 
17
- intf = gr.Interface(fn=classify_image, inputs=image, outputs=label)
18
- intf.launch(inline=False, share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from torchvision import transforms
5
+ import torchvision.models as models
6
+ import torch.nn as nn
7
 
8
+ # Define the class names
9
+ class_names = ['911', 'cayenne', 'cayman', 'macan', 'panamera', 'taycan']
10
 
11
+ # Instantiate the model and load state_dict
12
+ model_ft = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
13
+ for param in model_ft.parameters():
14
+ param.requires_grad = False
15
+ for param in model_ft.layer4.parameters():
16
+ param.requires_grad = True
17
 
18
+ num_ftrs = model_ft.fc.in_features
19
+ model_ft.fc = nn.Linear(num_ftrs, len(class_names))
20
+ model_ft = model_ft.to('cuda' if torch.cuda.is_available() else 'cpu')
21
+ model_ft.load_state_dict(torch.load('model_ft.pth'))
22
+ model_ft.eval()
23
 
24
+ # Define preprocessing transforms
25
+ preprocess = transforms.Compose([
26
+ transforms.Resize(256),
27
+ transforms.CenterCrop(224),
28
+ transforms.ToTensor(),
29
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
30
+ ])
31
 
32
+ # Define the prediction function
33
+ def predict(image):
34
+ image = preprocess(image).unsqueeze(0).to(model_ft.device) # Add batch dimension and move to device
35
+ with torch.no_grad():
36
+ outputs = model_ft(image)
37
+ _, predicted = torch.max(outputs, 1)
38
+ return class_names[predicted.item()]
39
+
40
+ # Create Gradio interface
41
+ iface = gr.Interface(fn=predict,
42
+ inputs=gr.inputs.Image(type="pil"),
43
+ outputs="text")
44
+
45
+ iface.launch()
model.pkl → porsche_model.pth RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6918bdfe8f5c78dd1e0427173cb35f60e26d9aff6ebcb3bd502fcaaa2ada966b
3
- size 87485098
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9506b43b2d24c94d97a2e4767cb1b05f84132373579fc7bdd8fb775bbc9dc9c5
3
+ size 85292402