Add automatic download of imagenet_classes.txt
Browse files- app.py +18 -1
- requirements.txt +2 -1
app.py
CHANGED
|
@@ -1,7 +1,9 @@
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import glob
|
| 3 |
import time
|
| 4 |
import random
|
|
|
|
| 5 |
|
| 6 |
# Import necessary libraries
|
| 7 |
from torchvision import models, transforms
|
|
@@ -20,8 +22,23 @@ transform = transforms.Compose([
|
|
| 20 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 21 |
])
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
# Load class labels
|
| 24 |
-
with open('imagenet_classes.txt') as f:
|
| 25 |
labels = [line.strip() for line in f.readlines()]
|
| 26 |
|
| 27 |
def classify_image(image):
|
|
|
|
| 1 |
+
import os
|
| 2 |
import gradio as gr
|
| 3 |
import glob
|
| 4 |
import time
|
| 5 |
import random
|
| 6 |
+
import requests
|
| 7 |
|
| 8 |
# Import necessary libraries
|
| 9 |
from torchvision import models, transforms
|
|
|
|
| 22 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 23 |
])
|
| 24 |
|
| 25 |
+
# Function to download imagenet_classes.txt
|
| 26 |
+
def download_imagenet_classes():
|
| 27 |
+
url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
|
| 28 |
+
response = requests.get(url)
|
| 29 |
+
if response.status_code == 200:
|
| 30 |
+
with open("imagenet_classes.txt", "wb") as f:
|
| 31 |
+
f.write(response.content)
|
| 32 |
+
print("imagenet_classes.txt downloaded successfully.")
|
| 33 |
+
else:
|
| 34 |
+
print("Failed to download imagenet_classes.txt")
|
| 35 |
+
|
| 36 |
+
# Check if imagenet_classes.txt exists, if not, download it
|
| 37 |
+
if not os.path.exists("imagenet_classes.txt"):
|
| 38 |
+
download_imagenet_classes()
|
| 39 |
+
|
| 40 |
# Load class labels
|
| 41 |
+
with open('imagenet_classes.txt', 'r') as f:
|
| 42 |
labels = [line.strip() for line in f.readlines()]
|
| 43 |
|
| 44 |
def classify_image(image):
|
requirements.txt
CHANGED
|
@@ -2,4 +2,5 @@ gradio==3.23.0
|
|
| 2 |
Pillow==9.5.0
|
| 3 |
torch==1.13.1
|
| 4 |
torchvision==0.14.1
|
| 5 |
-
urllib3<1.27,>=1.25.4
|
|
|
|
|
|
| 2 |
Pillow==9.5.0
|
| 3 |
torch==1.13.1
|
| 4 |
torchvision==0.14.1
|
| 5 |
+
urllib3<1.27,>=1.25.4
|
| 6 |
+
requests==2.31.0
|