mawady commited on
Commit
1b63bdd
1 Parent(s): 09cd20b

format code

Browse files
Files changed (1) hide show
  1. app.py +42 -34
app.py CHANGED
@@ -12,15 +12,17 @@ import gdown
12
  import urllib.request
13
  import gradio as gr
14
 
15
- #url = 'https://drive.google.com/uc?id=1VMLpE5ojF9fq0GtBKaqcMVWUIfJUfKbc'
16
  path_class_names = "./class_names_restnet_catsVSdogs.pkl"
17
- #gdown.download(url, path_class_names, quiet=False, use_cookies=False)
18
 
19
- #url = 'https://drive.google.com/uc?id=1jorQB1mpPCLH097M8paxut3v5XwVlKqp'
20
  path_model = "./model_state_restnet_catsVSdogs.pth"
21
- #gdown.download(url, path_model, quiet=False, use_cookies=False)
22
 
23
- url = "https://upload.wikimedia.org/wikipedia/commons/3/38/Adorable-animal-cat-20787.jpg"
 
 
24
  path_input = "./cat.jpg"
25
  urllib.request.urlretrieve(url, filename=path_input)
26
 
@@ -29,12 +31,14 @@ url = "https://upload.wikimedia.org/wikipedia/commons/4/43/Cute_dog.jpg"
29
  path_input = "./dog.jpg"
30
  urllib.request.urlretrieve(url, filename=path_input)
31
 
32
- data_transforms_val = transforms.Compose([
 
33
  transforms.Resize(256),
34
  transforms.CenterCrop(224),
35
  transforms.ToTensor(),
36
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
37
- ])
 
38
  class_names = pickle.load(open(path_class_names, "rb"))
39
 
40
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -43,45 +47,49 @@ model_ft = models.resnet18(pretrained=True)
43
  num_ftrs = model_ft.fc.in_features
44
  model_ft.fc = nn.Linear(num_ftrs, len(class_names))
45
  model_ft = model_ft.to(device)
46
- model_ft.load_state_dict(copy.deepcopy(torch.load(path_model,device)))
 
47
 
48
  def do_inference(img):
49
- img_t = data_transforms_val(img)
50
- batch_t = torch.unsqueeze(img_t, 0)
51
- model_ft.eval()
52
- # We don't need gradients for test, so wrap in
53
- # no_grad to save memory
54
- with torch.no_grad():
55
- batch_t = batch_t.to(device)
56
- # forward propagation
57
- output = model_ft( batch_t)
58
- # get prediction
59
- probs = torch.nn.functional.softmax(output, dim=1)
60
- output = torch.argsort(probs, dim=1, descending=True).cpu().numpy()[0].astype(int)
61
- probs = probs.cpu().numpy()[0]
62
- probs = probs[output]
63
- labels = np.array(class_names)[output]
64
- return {labels[i]: round(float(probs[i]),2) for i in range(len(labels))}
 
 
65
 
66
- im = gr.inputs.Image(shape=(512, 512), image_mode='RGB',
67
- invert_colors=False, source="upload",
68
- type="pil")
 
69
  title = "CatsVsDogs Classifier"
70
  description = "Playground: Inferernce of Object Classification (Binary) using ResNet18 model and CatsVsDogs dataset. Libraries: PyTorch, Gradio."
71
- examples = [['./cat.jpg'],['./dog.jpg']]
72
- article="<p style='text-align: center'><a href='https://github.com/mawady/colab-recipes-cv' target='_blank'>Colab Recipes for Computer Vision - Dr. Mohamed Elawady</a></p>"
73
  iface = gr.Interface(
74
- do_inference,
75
- im,
76
  gr.outputs.Label(num_top_classes=2),
77
  live=False,
78
  interpretation=None,
79
  title=title,
80
  description=description,
81
  article=article,
82
- examples=examples
83
  )
84
 
85
- #iface.test_launch()
86
 
87
  iface.launch()
 
12
  import urllib.request
13
  import gradio as gr
14
 
15
+ # url = 'https://drive.google.com/uc?id=1VMLpE5ojF9fq0GtBKaqcMVWUIfJUfKbc'
16
  path_class_names = "./class_names_restnet_catsVSdogs.pkl"
17
+ # gdown.download(url, path_class_names, quiet=False, use_cookies=False)
18
 
19
+ # url = 'https://drive.google.com/uc?id=1jorQB1mpPCLH097M8paxut3v5XwVlKqp'
20
  path_model = "./model_state_restnet_catsVSdogs.pth"
21
+ # gdown.download(url, path_model, quiet=False, use_cookies=False)
22
 
23
+ url = (
24
+ "https://upload.wikimedia.org/wikipedia/commons/3/38/Adorable-animal-cat-20787.jpg"
25
+ )
26
  path_input = "./cat.jpg"
27
  urllib.request.urlretrieve(url, filename=path_input)
28
 
 
31
  path_input = "./dog.jpg"
32
  urllib.request.urlretrieve(url, filename=path_input)
33
 
34
+ data_transforms_val = transforms.Compose(
35
+ [
36
  transforms.Resize(256),
37
  transforms.CenterCrop(224),
38
  transforms.ToTensor(),
39
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
40
+ ]
41
+ )
42
  class_names = pickle.load(open(path_class_names, "rb"))
43
 
44
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
47
  num_ftrs = model_ft.fc.in_features
48
  model_ft.fc = nn.Linear(num_ftrs, len(class_names))
49
  model_ft = model_ft.to(device)
50
+ model_ft.load_state_dict(copy.deepcopy(torch.load(path_model, device)))
51
+
52
 
53
  def do_inference(img):
54
+ img_t = data_transforms_val(img)
55
+ batch_t = torch.unsqueeze(img_t, 0)
56
+ model_ft.eval()
57
+ # We don't need gradients for test, so wrap in
58
+ # no_grad to save memory
59
+ with torch.no_grad():
60
+ batch_t = batch_t.to(device)
61
+ # forward propagation
62
+ output = model_ft(batch_t)
63
+ # get prediction
64
+ probs = torch.nn.functional.softmax(output, dim=1)
65
+ output = (
66
+ torch.argsort(probs, dim=1, descending=True).cpu().numpy()[0].astype(int)
67
+ )
68
+ probs = probs.cpu().numpy()[0]
69
+ probs = probs[output]
70
+ labels = np.array(class_names)[output]
71
+ return {labels[i]: round(float(probs[i]), 2) for i in range(len(labels))}
72
 
73
+
74
+ im = gr.inputs.Image(
75
+ shape=(512, 512), image_mode="RGB", invert_colors=False, source="upload", type="pil"
76
+ )
77
  title = "CatsVsDogs Classifier"
78
  description = "Playground: Inferernce of Object Classification (Binary) using ResNet18 model and CatsVsDogs dataset. Libraries: PyTorch, Gradio."
79
+ examples = [["./cat.jpg"], ["./dog.jpg"]]
80
+ article = "<p style='text-align: center'><a href='https://github.com/mawady/colab-recipes-cv' target='_blank'>Colab Recipes for Computer Vision - Dr. Mohamed Elawady</a></p>"
81
  iface = gr.Interface(
82
+ do_inference,
83
+ im,
84
  gr.outputs.Label(num_top_classes=2),
85
  live=False,
86
  interpretation=None,
87
  title=title,
88
  description=description,
89
  article=article,
90
+ examples=examples,
91
  )
92
 
93
+ # iface.test_launch()
94
 
95
  iface.launch()