FredZhang7
commited on
Commit
·
42aff3a
1
Parent(s):
2277b9e
Update README.md
Browse files
README.md
CHANGED
@@ -8,7 +8,9 @@ pipeline_tag: image-classification
|
|
8 |
---
|
9 |
# Google Safesearch Mini Model Card
|
10 |
|
11 |
-
|
|
|
|
|
12 |
The InceptionV3 and Xception models have been fine-tuned to predict the likelihood of an image falling into one of three categories: nsfw_gore, nsfw_suggestive, and safe.
|
13 |
|
14 |
After 20 epochs on PyTorch, the finetuned InceptionV3 model achieves 94% acc on both training and test data. After 3.3 epochs on Keras, the finetuned Xception model scores 94% acc on training set and 92% on test set.
|
@@ -23,64 +25,46 @@ The PyTorch model runs much slower with transformers, so downloading it external
|
|
23 |
pip install --upgrade torchvision
|
24 |
```
|
25 |
```python
|
26 |
-
import torch, os
|
|
|
27 |
from PIL import Image
|
28 |
-
import
|
29 |
-
|
30 |
|
31 |
PATH_TO_IMAGE = 'https://images.unsplash.com/photo-1594568284297-7c64464062b1'
|
32 |
USE_CUDA = False
|
33 |
|
|
|
34 |
def download_model():
|
35 |
print("Downloading google_safesearch_mini.bin...")
|
36 |
-
|
37 |
-
url = "https://huggingface.co/FredZhang7/google-safesearch-mini/resolve/main/pytorch_model.bin"
|
38 |
-
urllib.request.urlretrieve(url, "google_safesearch_mini.bin")
|
39 |
|
40 |
-
def
|
41 |
if not os.path.exists("google_safesearch_mini.bin"):
|
42 |
download_model()
|
43 |
model = torch.jit.load('./google_safesearch_mini.bin')
|
44 |
-
if PATH_TO_IMAGE.startswith('http://') or PATH_TO_IMAGE.startswith('https://')
|
45 |
-
|
46 |
-
|
47 |
-
response = requests.get(PATH_TO_IMAGE)
|
48 |
-
img = Image.open(BytesIO(response.content)).convert('RGB')
|
49 |
-
else:
|
50 |
-
img = Image.open(PATH_TO_IMAGE).convert('RGB')
|
51 |
-
from torchvision import transforms
|
52 |
-
transform = transforms.Compose([
|
53 |
-
transforms.Resize(299),
|
54 |
-
transforms.ToTensor(),
|
55 |
-
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
56 |
-
])
|
57 |
-
img = transform(img)
|
58 |
-
img = img.unsqueeze(0)
|
59 |
if USE_CUDA:
|
60 |
-
img = img.cuda()
|
61 |
-
model = model.cuda()
|
62 |
else:
|
63 |
-
img = img.cpu()
|
64 |
-
model = model.cpu()
|
65 |
model.eval()
|
66 |
with torch.no_grad():
|
67 |
out, _ = model(img)
|
68 |
_, predicted = torch.max(out.data, 1)
|
69 |
-
classes = {
|
70 |
-
|
71 |
-
1: 'nsfw_suggestive',
|
72 |
-
2: 'safe'
|
73 |
-
}
|
74 |
# account for edge cases
|
75 |
-
if predicted[0] != 2 and abs(out[0][2] - out[0][predicted[0]]) > 0.
|
76 |
img = Image.new('RGB', image.size, color = (0, 255, 255))
|
77 |
print("\033[93m" + "safe" + "\033[0m")
|
78 |
else:
|
79 |
print('\n\033[1;31m' + classes[predicted.item()] + '\033[0m' if predicted.item() != 2 else '\033[1;32m' + classes[predicted.item()] + '\033[0m\n')
|
80 |
|
81 |
if __name__ == '__main__':
|
82 |
-
|
83 |
-
run()
|
84 |
```
|
85 |
Output Example:
|
86 |
![prediction](./output_example.png)
|
|
|
8 |
---
|
9 |
# Google Safesearch Mini Model Card
|
10 |
|
11 |
+
<a href="https://huggingface.co/FredZhang7/google-safesearch-mini-v2"> <font size="4"> <bold> Version 2 is here! </bold> </font> </a>
|
12 |
+
|
13 |
+
This model is trained on 2,220,000+ images scraped from Google Images, Reddit, Imgur, and Github.
|
14 |
The InceptionV3 and Xception models have been fine-tuned to predict the likelihood of an image falling into one of three categories: nsfw_gore, nsfw_suggestive, and safe.
|
15 |
|
16 |
After 20 epochs on PyTorch, the finetuned InceptionV3 model achieves 94% acc on both training and test data. After 3.3 epochs on Keras, the finetuned Xception model scores 94% acc on training set and 92% on test set.
|
|
|
25 |
pip install --upgrade torchvision
|
26 |
```
|
27 |
```python
|
28 |
+
import torch, os, warnings, requests
|
29 |
+
from io import BytesIO
|
30 |
from PIL import Image
|
31 |
+
from urllib.request import urlretrieve
|
32 |
+
from torchvision import transforms
|
33 |
|
34 |
PATH_TO_IMAGE = 'https://images.unsplash.com/photo-1594568284297-7c64464062b1'
|
35 |
USE_CUDA = False
|
36 |
|
37 |
+
warnings.filterwarnings("ignore")
|
38 |
def download_model():
|
39 |
print("Downloading google_safesearch_mini.bin...")
|
40 |
+
urlretrieve("https://huggingface.co/FredZhang7/google-safesearch-mini/resolve/main/pytorch_model.bin", "google_safesearch_mini.bin")
|
|
|
|
|
41 |
|
42 |
+
def eval():
|
43 |
if not os.path.exists("google_safesearch_mini.bin"):
|
44 |
download_model()
|
45 |
model = torch.jit.load('./google_safesearch_mini.bin')
|
46 |
+
img = Image.open(PATH_TO_IMAGE).convert('RGB') if not (PATH_TO_IMAGE.startswith('http://') or PATH_TO_IMAGE.startswith('https://')) else Image.open(BytesIO(requests.get(PATH_TO_IMAGE).content)).convert('RGB')
|
47 |
+
transform = transforms.Compose([transforms.Resize(299), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
48 |
+
img = transform(img).unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
if USE_CUDA:
|
50 |
+
img, model = img.cuda(), model.cuda()
|
|
|
51 |
else:
|
52 |
+
img, model = img.cpu(), model.cpu()
|
|
|
53 |
model.eval()
|
54 |
with torch.no_grad():
|
55 |
out, _ = model(img)
|
56 |
_, predicted = torch.max(out.data, 1)
|
57 |
+
classes = {0: 'nsfw_gore', 1: 'nsfw_suggestive', 2: 'safe'}
|
58 |
+
|
|
|
|
|
|
|
59 |
# account for edge cases
|
60 |
+
if predicted[0] != 2 and abs(out[0][2] - out[0][predicted[0]]) > 0.20:
|
61 |
img = Image.new('RGB', image.size, color = (0, 255, 255))
|
62 |
print("\033[93m" + "safe" + "\033[0m")
|
63 |
else:
|
64 |
print('\n\033[1;31m' + classes[predicted.item()] + '\033[0m' if predicted.item() != 2 else '\033[1;32m' + classes[predicted.item()] + '\033[0m\n')
|
65 |
|
66 |
if __name__ == '__main__':
|
67 |
+
eval()
|
|
|
68 |
```
|
69 |
Output Example:
|
70 |
![prediction](./output_example.png)
|