FredZhang7 commited on
Commit
42aff3a
·
1 Parent(s): 2277b9e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +19 -35
README.md CHANGED
@@ -8,7 +8,9 @@ pipeline_tag: image-classification
8
  ---
9
  # Google Safesearch Mini Model Card
10
 
11
- Initially, the training data consisted of 278,000 images, and the model achieved 99% training and test acc. Now, this model is trained on 2,220,000+ images scraped from Google Images, Reddit, Imgur, and Github.
 
 
12
  The InceptionV3 and Xception models have been fine-tuned to predict the likelihood of an image falling into one of three categories: nsfw_gore, nsfw_suggestive, and safe.
13
 
14
  After 20 epochs on PyTorch, the finetuned InceptionV3 model achieves 94% acc on both training and test data. After 3.3 epochs on Keras, the finetuned Xception model scores 94% acc on training set and 92% on test set.
@@ -23,64 +25,46 @@ The PyTorch model runs much slower with transformers, so downloading it external
23
  pip install --upgrade torchvision
24
  ```
25
  ```python
26
- import torch, os
 
27
  from PIL import Image
28
- import warnings
29
- warnings.filterwarnings("ignore")
30
 
31
  PATH_TO_IMAGE = 'https://images.unsplash.com/photo-1594568284297-7c64464062b1'
32
  USE_CUDA = False
33
 
 
34
  def download_model():
35
  print("Downloading google_safesearch_mini.bin...")
36
- import urllib.request
37
- url = "https://huggingface.co/FredZhang7/google-safesearch-mini/resolve/main/pytorch_model.bin"
38
- urllib.request.urlretrieve(url, "google_safesearch_mini.bin")
39
 
40
- def run():
41
  if not os.path.exists("google_safesearch_mini.bin"):
42
  download_model()
43
  model = torch.jit.load('./google_safesearch_mini.bin')
44
- if PATH_TO_IMAGE.startswith('http://') or PATH_TO_IMAGE.startswith('https://'):
45
- import requests
46
- from io import BytesIO
47
- response = requests.get(PATH_TO_IMAGE)
48
- img = Image.open(BytesIO(response.content)).convert('RGB')
49
- else:
50
- img = Image.open(PATH_TO_IMAGE).convert('RGB')
51
- from torchvision import transforms
52
- transform = transforms.Compose([
53
- transforms.Resize(299),
54
- transforms.ToTensor(),
55
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
56
- ])
57
- img = transform(img)
58
- img = img.unsqueeze(0)
59
  if USE_CUDA:
60
- img = img.cuda()
61
- model = model.cuda()
62
  else:
63
- img = img.cpu()
64
- model = model.cpu()
65
  model.eval()
66
  with torch.no_grad():
67
  out, _ = model(img)
68
  _, predicted = torch.max(out.data, 1)
69
- classes = {
70
- 0: 'nsfw_gore',
71
- 1: 'nsfw_suggestive',
72
- 2: 'safe'
73
- }
74
  # account for edge cases
75
- if predicted[0] != 2 and abs(out[0][2] - out[0][predicted[0]]) > 0.22:
76
  img = Image.new('RGB', image.size, color = (0, 255, 255))
77
  print("\033[93m" + "safe" + "\033[0m")
78
  else:
79
  print('\n\033[1;31m' + classes[predicted.item()] + '\033[0m' if predicted.item() != 2 else '\033[1;32m' + classes[predicted.item()] + '\033[0m\n')
80
 
81
  if __name__ == '__main__':
82
- torch.multiprocessing.freeze_support()
83
- run()
84
  ```
85
  Output Example:
86
  ![prediction](./output_example.png)
 
8
  ---
9
  # Google Safesearch Mini Model Card
10
 
11
+ <a href="https://huggingface.co/FredZhang7/google-safesearch-mini-v2"> <font size="4"> <bold> Version 2 is here! </bold> </font> </a>
12
+
13
+ This model is trained on 2,220,000+ images scraped from Google Images, Reddit, Imgur, and Github.
14
  The InceptionV3 and Xception models have been fine-tuned to predict the likelihood of an image falling into one of three categories: nsfw_gore, nsfw_suggestive, and safe.
15
 
16
  After 20 epochs on PyTorch, the finetuned InceptionV3 model achieves 94% acc on both training and test data. After 3.3 epochs on Keras, the finetuned Xception model scores 94% acc on training set and 92% on test set.
 
25
  pip install --upgrade torchvision
26
  ```
27
  ```python
28
+ import torch, os, warnings, requests
29
+ from io import BytesIO
30
  from PIL import Image
31
+ from urllib.request import urlretrieve
32
+ from torchvision import transforms
33
 
34
  PATH_TO_IMAGE = 'https://images.unsplash.com/photo-1594568284297-7c64464062b1'
35
  USE_CUDA = False
36
 
37
+ warnings.filterwarnings("ignore")
38
  def download_model():
39
  print("Downloading google_safesearch_mini.bin...")
40
+ urlretrieve("https://huggingface.co/FredZhang7/google-safesearch-mini/resolve/main/pytorch_model.bin", "google_safesearch_mini.bin")
 
 
41
 
42
+ def eval():
43
  if not os.path.exists("google_safesearch_mini.bin"):
44
  download_model()
45
  model = torch.jit.load('./google_safesearch_mini.bin')
46
+ img = Image.open(PATH_TO_IMAGE).convert('RGB') if not (PATH_TO_IMAGE.startswith('http://') or PATH_TO_IMAGE.startswith('https://')) else Image.open(BytesIO(requests.get(PATH_TO_IMAGE).content)).convert('RGB')
47
+ transform = transforms.Compose([transforms.Resize(299), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
48
+ img = transform(img).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
49
  if USE_CUDA:
50
+ img, model = img.cuda(), model.cuda()
 
51
  else:
52
+ img, model = img.cpu(), model.cpu()
 
53
  model.eval()
54
  with torch.no_grad():
55
  out, _ = model(img)
56
  _, predicted = torch.max(out.data, 1)
57
+ classes = {0: 'nsfw_gore', 1: 'nsfw_suggestive', 2: 'safe'}
58
+
 
 
 
59
  # account for edge cases
60
+ if predicted[0] != 2 and abs(out[0][2] - out[0][predicted[0]]) > 0.20:
61
  img = Image.new('RGB', image.size, color = (0, 255, 255))
62
  print("\033[93m" + "safe" + "\033[0m")
63
  else:
64
  print('\n\033[1;31m' + classes[predicted.item()] + '\033[0m' if predicted.item() != 2 else '\033[1;32m' + classes[predicted.item()] + '\033[0m\n')
65
 
66
  if __name__ == '__main__':
67
+ eval()
 
68
  ```
69
  Output Example:
70
  ![prediction](./output_example.png)