FredZhang7 commited on
Commit
f75d25f
1 Parent(s): f6d506f

Upload model

Browse files
Files changed (4) hide show
  1. Config.py +28 -0
  2. Model.py +117 -0
  3. config.json +40 -0
  4. pytorch_model.bin +3 -0
Config.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+ import warnings
4
+ warnings.filterwarnings("ignore")
5
+
6
+ class InceptionV3Config(PretrainedConfig):
7
+ model_type = "inceptionv3"
8
+ def __init__(self, model_name: str = "inception_v3", num_classes: int = 3, input_size: List[int] = [3, 299, 299], interpolation: str = "bicubic", mean: List[float] = [0.5, 0.5, 0.5], std: List[float] = [0.5, 0.5, 0.5], classifier: str = "fc", has_aux: bool = True, label_offset: int = 1, classes: dict = { '0': 'nsfw_gore', '1': 'nsfw_suggestive', '2': 'safe' }, output_channels: int = 2048, use_jit=False, **kwargs):
9
+ self.model_name = model_name
10
+ self.num_classes = num_classes
11
+ self.input_size = input_size
12
+ self.interpolation = interpolation
13
+ self.mean = mean
14
+ self.std = std
15
+ self.classifier = classifier
16
+ self.has_aux = has_aux
17
+ self.label_offset = label_offset
18
+ self.classes = classes
19
+ self.output_channels = output_channels
20
+ self.use_jit = use_jit
21
+ super().__init__(**kwargs)
22
+
23
+ """
24
+
25
+ inceptionv3_config = InceptionV3Config()
26
+ inceptionv3_config.save_pretrained("inceptionv3_config")
27
+
28
+ """
Model.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ import torch
3
+ import os
4
+
5
+ class InceptionV3ModelForImageClassification(PreTrainedModel):
6
+ def __init__(self, config):
7
+ super().__init__(config)
8
+
9
+ model_path = "google-safesearch-mini.bin"
10
+
11
+ if self.config.model_name == "google-safesearch-mini":
12
+ if not os.path.exists(model_path):
13
+ import requests
14
+ url = "https://huggingface.co/FredZhang7/google-safesearch-mini/resolve/main/pytorch_model.bin"
15
+ r = requests.get(url, allow_redirects=True)
16
+ open(model_path, 'wb').write(r.content)
17
+ self.model = torch.jit.load(model_path)
18
+ else:
19
+ raise ValueError(f"Model {self.config.model_name} not found.")
20
+
21
+ def forward(self, input_ids):
22
+ return self.model(input_ids), None if self.config.model_name == "inception_v3" else self.model(input_ids)
23
+
24
+ def freeze(self):
25
+ for param in self.model.parameters():
26
+ param.requires_grad = False
27
+
28
+ def unfreeze(self):
29
+ for param in self.model.parameters():
30
+ param.requires_grad = True
31
+
32
+ def train(self, mode=True):
33
+ super().train(mode)
34
+ self.model.train(mode)
35
+
36
+ def eval(self):
37
+ return self.train(False)
38
+
39
+ def to(self, device):
40
+ self.model.to(device)
41
+ return self
42
+
43
+ def cuda(self, device=None):
44
+ return self.to("cuda")
45
+
46
+ def cpu(self):
47
+ return self.to("cpu")
48
+
49
+ def state_dict(self, destination=None, prefix='', keep_vars=False):
50
+ return self.model.state_dict(destination, prefix, keep_vars)
51
+
52
+ def load_state_dict(self, state_dict, strict=True):
53
+ return self.model.load_state_dict(state_dict, strict)
54
+
55
+ def parameters(self, recurse=True):
56
+ return self.model.parameters(recurse)
57
+
58
+ def named_parameters(self, prefix='', recurse=True):
59
+ return self.model.named_parameters(prefix, recurse)
60
+
61
+ def children(self):
62
+ return self.model.children()
63
+
64
+ def named_children(self):
65
+ return self.model.named_children()
66
+
67
+ def modules(self):
68
+ return self.model.modules()
69
+
70
+ def named_modules(self, memo=None, prefix=''):
71
+ return self.model.named_modules(memo, prefix)
72
+
73
+ def zero_grad(self, set_to_none=False):
74
+ return self.model.zero_grad(set_to_none)
75
+
76
+ def share_memory(self):
77
+ return self.model.share_memory()
78
+
79
+ def transform(self, image):
80
+ from torchvision import transforms
81
+ transform = transforms.Compose([
82
+ transforms.Resize(299),
83
+ transforms.ToTensor(),
84
+ transforms.Normalize(mean=self.config.mean, std=self.config.std)
85
+ ])
86
+ image = transform(image)
87
+ return image
88
+
89
+ def open_image(self, path):
90
+ from PIL import Image
91
+ path = 'https://images.unsplash.com/photo-1594568284297-7c64464062b1'
92
+ if path.startswith('http://') or path.startswith('https://'):
93
+ import requests
94
+ from io import BytesIO
95
+ response = requests.get(path)
96
+ image = Image.open(BytesIO(response.content)).convert('RGB')
97
+ else:
98
+ image = Image.open(path).convert('RGB')
99
+ return image
100
+
101
+ def predict(self, path, device="cuda", print_tensor=True):
102
+ image = self.open_image(path)
103
+ image = self.transform(image)
104
+ image = image.unsqueeze(0)
105
+
106
+ if device == "cuda":
107
+ image = image.cuda()
108
+ self.cuda()
109
+ else:
110
+ image = image.cpu()
111
+ self.cpu()
112
+ with torch.no_grad():
113
+ out, aux = self(image)
114
+ if print_tensor:
115
+ print(out)
116
+ _, predicted = torch.max(out.logits, 1)
117
+ return self.config.classes[str(predicted.item())]
config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "InceptionV3ModelForImageClassification"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "Config.InceptionV3Config",
7
+ "AutoModelForImageClassification": "Model.InceptionV3ModelForImageClassification"
8
+ },
9
+ "classes": {
10
+ "0": "nsfw_gore",
11
+ "1": "nsfw_suggestive",
12
+ "2": "safe"
13
+ },
14
+ "classifier": "fc",
15
+ "has_aux": true,
16
+ "input_size": [
17
+ 3,
18
+ 299,
19
+ 299
20
+ ],
21
+ "interpolation": "bicubic",
22
+ "label_offset": 1,
23
+ "mean": [
24
+ 0.5,
25
+ 0.5,
26
+ 0.5
27
+ ],
28
+ "model_name": "google-safesearch-mini",
29
+ "model_type": "inceptionv3",
30
+ "num_classes": 3,
31
+ "output_channels": 2048,
32
+ "std": [
33
+ 0.5,
34
+ 0.5,
35
+ 0.5
36
+ ],
37
+ "torch_dtype": "float32",
38
+ "transformers_version": "4.25.1",
39
+ "use_jit": true
40
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db510376e428b0d5f1472e4f56d31a4bfbee69b3e8a58c67a802098e00d42d12
3
+ size 100804217