FredZhang7 commited on
Commit
dc05e3b
1 Parent(s): b5475b8

Delete Model.py

Browse files
Files changed (1) hide show
  1. Model.py +0 -117
Model.py DELETED
@@ -1,117 +0,0 @@
1
- from transformers import PreTrainedModel
2
- import torch
3
- import os
4
-
5
- class InceptionV3ModelForImageClassification(PreTrainedModel):
6
- def __init__(self, config):
7
- super().__init__(config)
8
-
9
- model_path = "pytorch_model.bin"
10
-
11
- if self.config.model_name == "google-safesearch-mini":
12
- if not os.path.exists(model_path):
13
- import requests
14
- url = "https://huggingface.co/FredZhang7/google-safesearch-mini/resolve/main/pytorch_model.bin"
15
- r = requests.get(url, allow_redirects=True)
16
- open(model_path, 'wb').write(r.content)
17
- self.model = torch.jit.load(model_path)
18
- else:
19
- raise ValueError(f"Model {self.config.model_name} not found.")
20
-
21
- def forward(self, input_ids):
22
- return self.model(input_ids), None if self.config.model_name == "inception_v3" else self.model(input_ids)
23
-
24
- def freeze(self):
25
- for param in self.model.parameters():
26
- param.requires_grad = False
27
-
28
- def unfreeze(self):
29
- for param in self.model.parameters():
30
- param.requires_grad = True
31
-
32
- def train(self, mode=True):
33
- super().train(mode)
34
- self.model.train(mode)
35
-
36
- def eval(self):
37
- return self.train(False)
38
-
39
- def to(self, device):
40
- self.model.to(device)
41
- return self
42
-
43
- def cuda(self, device=None):
44
- return self.to("cuda")
45
-
46
- def cpu(self):
47
- return self.to("cpu")
48
-
49
- def state_dict(self, destination=None, prefix='', keep_vars=False):
50
- return self.model.state_dict(destination, prefix, keep_vars)
51
-
52
- def load_state_dict(self, state_dict, strict=True):
53
- return self.model.load_state_dict(state_dict, strict)
54
-
55
- def parameters(self, recurse=True):
56
- return self.model.parameters(recurse)
57
-
58
- def named_parameters(self, prefix='', recurse=True):
59
- return self.model.named_parameters(prefix, recurse)
60
-
61
- def children(self):
62
- return self.model.children()
63
-
64
- def named_children(self):
65
- return self.model.named_children()
66
-
67
- def modules(self):
68
- return self.model.modules()
69
-
70
- def named_modules(self, memo=None, prefix=''):
71
- return self.model.named_modules(memo, prefix)
72
-
73
- def zero_grad(self, set_to_none=False):
74
- return self.model.zero_grad(set_to_none)
75
-
76
- def share_memory(self):
77
- return self.model.share_memory()
78
-
79
- def transform(self, image):
80
- from torchvision import transforms
81
- transform = transforms.Compose([
82
- transforms.Resize(299),
83
- transforms.ToTensor(),
84
- transforms.Normalize(mean=self.config.mean, std=self.config.std)
85
- ])
86
- image = transform(image)
87
- return image
88
-
89
- def open_image(self, path):
90
- from PIL import Image
91
- path = 'https://images.unsplash.com/photo-1594568284297-7c64464062b1'
92
- if path.startswith('http://') or path.startswith('https://'):
93
- import requests
94
- from io import BytesIO
95
- response = requests.get(path)
96
- image = Image.open(BytesIO(response.content)).convert('RGB')
97
- else:
98
- image = Image.open(path).convert('RGB')
99
- return image
100
-
101
- def predict(self, path, device="cuda", print_tensor=True):
102
- image = self.open_image(path)
103
- image = self.transform(image)
104
- image = image.unsqueeze(0)
105
- self.eval()
106
- if device == "cuda":
107
- image = image.cuda()
108
- self.cuda()
109
- else:
110
- image = image.cpu()
111
- self.cpu()
112
- with torch.no_grad():
113
- out, aux = self(image)
114
- if print_tensor:
115
- print(out)
116
- _, predicted = torch.max(out.logits, 1)
117
- return self.config.classes[str(predicted.item())]