FredZhang7
commited on
Commit
•
dc05e3b
1
Parent(s):
b5475b8
Delete Model.py
Browse files
Model.py
DELETED
@@ -1,117 +0,0 @@
|
|
1 |
-
from transformers import PreTrainedModel
|
2 |
-
import torch
|
3 |
-
import os
|
4 |
-
|
5 |
-
class InceptionV3ModelForImageClassification(PreTrainedModel):
|
6 |
-
def __init__(self, config):
|
7 |
-
super().__init__(config)
|
8 |
-
|
9 |
-
model_path = "pytorch_model.bin"
|
10 |
-
|
11 |
-
if self.config.model_name == "google-safesearch-mini":
|
12 |
-
if not os.path.exists(model_path):
|
13 |
-
import requests
|
14 |
-
url = "https://huggingface.co/FredZhang7/google-safesearch-mini/resolve/main/pytorch_model.bin"
|
15 |
-
r = requests.get(url, allow_redirects=True)
|
16 |
-
open(model_path, 'wb').write(r.content)
|
17 |
-
self.model = torch.jit.load(model_path)
|
18 |
-
else:
|
19 |
-
raise ValueError(f"Model {self.config.model_name} not found.")
|
20 |
-
|
21 |
-
def forward(self, input_ids):
|
22 |
-
return self.model(input_ids), None if self.config.model_name == "inception_v3" else self.model(input_ids)
|
23 |
-
|
24 |
-
def freeze(self):
|
25 |
-
for param in self.model.parameters():
|
26 |
-
param.requires_grad = False
|
27 |
-
|
28 |
-
def unfreeze(self):
|
29 |
-
for param in self.model.parameters():
|
30 |
-
param.requires_grad = True
|
31 |
-
|
32 |
-
def train(self, mode=True):
|
33 |
-
super().train(mode)
|
34 |
-
self.model.train(mode)
|
35 |
-
|
36 |
-
def eval(self):
|
37 |
-
return self.train(False)
|
38 |
-
|
39 |
-
def to(self, device):
|
40 |
-
self.model.to(device)
|
41 |
-
return self
|
42 |
-
|
43 |
-
def cuda(self, device=None):
|
44 |
-
return self.to("cuda")
|
45 |
-
|
46 |
-
def cpu(self):
|
47 |
-
return self.to("cpu")
|
48 |
-
|
49 |
-
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
50 |
-
return self.model.state_dict(destination, prefix, keep_vars)
|
51 |
-
|
52 |
-
def load_state_dict(self, state_dict, strict=True):
|
53 |
-
return self.model.load_state_dict(state_dict, strict)
|
54 |
-
|
55 |
-
def parameters(self, recurse=True):
|
56 |
-
return self.model.parameters(recurse)
|
57 |
-
|
58 |
-
def named_parameters(self, prefix='', recurse=True):
|
59 |
-
return self.model.named_parameters(prefix, recurse)
|
60 |
-
|
61 |
-
def children(self):
|
62 |
-
return self.model.children()
|
63 |
-
|
64 |
-
def named_children(self):
|
65 |
-
return self.model.named_children()
|
66 |
-
|
67 |
-
def modules(self):
|
68 |
-
return self.model.modules()
|
69 |
-
|
70 |
-
def named_modules(self, memo=None, prefix=''):
|
71 |
-
return self.model.named_modules(memo, prefix)
|
72 |
-
|
73 |
-
def zero_grad(self, set_to_none=False):
|
74 |
-
return self.model.zero_grad(set_to_none)
|
75 |
-
|
76 |
-
def share_memory(self):
|
77 |
-
return self.model.share_memory()
|
78 |
-
|
79 |
-
def transform(self, image):
|
80 |
-
from torchvision import transforms
|
81 |
-
transform = transforms.Compose([
|
82 |
-
transforms.Resize(299),
|
83 |
-
transforms.ToTensor(),
|
84 |
-
transforms.Normalize(mean=self.config.mean, std=self.config.std)
|
85 |
-
])
|
86 |
-
image = transform(image)
|
87 |
-
return image
|
88 |
-
|
89 |
-
def open_image(self, path):
|
90 |
-
from PIL import Image
|
91 |
-
path = 'https://images.unsplash.com/photo-1594568284297-7c64464062b1'
|
92 |
-
if path.startswith('http://') or path.startswith('https://'):
|
93 |
-
import requests
|
94 |
-
from io import BytesIO
|
95 |
-
response = requests.get(path)
|
96 |
-
image = Image.open(BytesIO(response.content)).convert('RGB')
|
97 |
-
else:
|
98 |
-
image = Image.open(path).convert('RGB')
|
99 |
-
return image
|
100 |
-
|
101 |
-
def predict(self, path, device="cuda", print_tensor=True):
|
102 |
-
image = self.open_image(path)
|
103 |
-
image = self.transform(image)
|
104 |
-
image = image.unsqueeze(0)
|
105 |
-
self.eval()
|
106 |
-
if device == "cuda":
|
107 |
-
image = image.cuda()
|
108 |
-
self.cuda()
|
109 |
-
else:
|
110 |
-
image = image.cpu()
|
111 |
-
self.cpu()
|
112 |
-
with torch.no_grad():
|
113 |
-
out, aux = self(image)
|
114 |
-
if print_tensor:
|
115 |
-
print(out)
|
116 |
-
_, predicted = torch.max(out.logits, 1)
|
117 |
-
return self.config.classes[str(predicted.item())]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|