pavi156 commited on
Commit
6a922b2
1 Parent(s): fd475a5
Files changed (3) hide show
  1. app.py +49 -0
  2. cifar_net.pth +3 -0
  3. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ from PIL import Image
5
+
6
+ # Load the trained model
7
+ model_path = "cifar_net.pth"
8
+ model = torch.load(model_path, map_location=torch.device('cpu'))
9
+ model.eval()
10
+
11
+ # Define class labels for CIFAR-10
12
+ classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
13
+
14
+ def classify_image(image):
15
+ transform = transforms.Compose([
16
+ transforms.ToTensor(),
17
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
18
+ ])
19
+
20
+ # Preprocess the input image
21
+ image = transform(image).unsqueeze(0)
22
+
23
+ # Perform inference with the model
24
+ outputs = model(image)
25
+ _, predicted = torch.max(outputs, 1)
26
+ predicted_class = classes[predicted.item()]
27
+
28
+ return predicted_class
29
+
30
+ def classify_images(images):
31
+ return [classify_image(image) for image in images]
32
+
33
+ inputs_image = gr.inputs.Image(label="Input Image", type="pil")
34
+ outputs_image = gr.outputs.Label(label="Predicted Class")
35
+ interface_image = gr.Interface(
36
+ fn=classify_images,
37
+ inputs=inputs_image,
38
+ outputs=outputs_image,
39
+ title="CIFAR-10 Image Classifier",
40
+ description="Classify images into one of the CIFAR-10 classes.",
41
+ examples=[
42
+ ['image_0.jpg'],
43
+ ['image_1.jpg']
44
+ ],
45
+ allow_flagging=False
46
+ )
47
+
48
+ if __name__ == "__main__":
49
+ interface_image.launch()
cifar_net.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a39b10e26fc76a2d0f097dfd792ff7dac5a7c79ecbb1732017a726f3c1fafdc5
3
+ size 251167
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics requirements
2
+ # Usage: pip install -r requirements.txt
3
+
4
+ # Base ----------------------------------------
5
+ torch==1.9.0
6
+ torchvision==0.10.0
7
+ gradio==2.3.5
8
+ Pillow==8.2.0