Noah-Wang commited on
Commit
6737fe8
1 Parent(s): 1284b56

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +149 -15
handler.py CHANGED
@@ -1,19 +1,80 @@
1
- from typing import Dict, List, Any
2
  import timm
3
  import torch
 
4
  from timm.utils import ParseKwargs
5
  from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  class EndpointHandler():
8
  def __init__(self, path=""):
9
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
- # self.aiGeneratorModel = timm.create_model('eva02_base_patch14_448.mim_in22k_ft_in22k_in1k', num_classes=9, in_chans=3, checkpoint_path=path + '/AIModelDetector.pth-6ff3631e.pth')
11
- aiArtModel = timm.create_model('eva02_base_patch14_448.mim_in22k_ft_in22k_in1k', num_classes=3, in_chans=3, checkpoint_path=path + '/AIArtDetector.pth-af59f7fa.pth')
12
- # aiGeneratorModel = aiGeneratorModel.to(self.device)
13
- aiArtModel = aiArtModel.to(self.device)
14
- # aiGeneratorModel.eval()
15
- aiArtModel.eval()
 
 
 
 
 
 
 
 
 
 
 
 
16
 
 
17
  self.transform = timm.data.create_transform(input_size=(3, 448, 448),
18
  is_training=False,
19
  use_prefetcher=False,
@@ -36,22 +97,95 @@ class EndpointHandler():
36
  crop_mode='squash',
37
  tf_preprocessing=False,
38
  separate=False)
 
 
 
 
39
 
40
- def __call__(self, data):
 
 
 
 
41
  """
42
  data args:
43
  inputs: Dict[str, Any]
44
  Return:
45
  A :obj:`list` | `dict`: will be serialized and returned
46
  """
47
- # get inputs
48
- image = data.pop("inputs", data)
49
-
50
- image_tensor1 = self.transform(image).to(self.device)
51
- with torch.no_grad():
52
- output1 = self.aiArtModel(image_tensor1.unsqueeze(0))
53
 
54
- return output1
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
 
 
 
 
 
 
 
 
 
 
 
57
 
 
 
 
1
+
2
  import timm
3
  import torch
4
+ from PIL import Image
5
  from timm.utils import ParseKwargs
6
  from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
7
 
8
+ ###
9
+
10
+ import os
11
+ import time
12
+ from contextlib import suppress
13
+ from functools import partial
14
+
15
+ import numpy as np
16
+ import pandas as pd
17
+ import torch
18
+
19
+ from timm.data import create_dataset, create_loader, resolve_data_config, ImageNetInfo, infer_imagenet_subset
20
+ from timm.layers import apply_test_time_pool
21
+ from timm.models import create_model
22
+ from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser, ParseKwargs
23
+
24
+ try:
25
+ from apex import amp
26
+ has_apex = True
27
+ except ImportError:
28
+ has_apex = False
29
+
30
+ has_native_amp = False
31
+ try:
32
+ if getattr(torch.cuda.amp, 'autocast') is not None:
33
+ has_native_amp = True
34
+ except AttributeError:
35
+ pass
36
+
37
+ # try:
38
+ # from functorch.compile import memory_efficient_fusion
39
+ # has_functorch = True
40
+ # except ImportError as e:
41
+ # has_functorch = False
42
+
43
+ has_compile = hasattr(torch, 'compile')
44
+
45
+ import PIL
46
+ import requests
47
+ import io
48
+ import base64
49
+
50
+
51
+
52
+ # ImageFile.LOAD_TRUNCATED_IMAGES = True
53
+ ###
54
+
55
  class EndpointHandler():
56
  def __init__(self, path=""):
57
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
58
+
59
+ if torch.cuda.is_available():
60
+ torch.backends.cuda.matmul.allow_tf32 = True
61
+ torch.backends.cudnn.benchmark = True
62
+
63
+ # May sacrifice a bit of accuracy, depending on our needs
64
+ assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
65
+ amp_dtype = torch.float16
66
+ amp_autocast = partial(torch.autocast, device_type=self.device.type, dtype=amp_dtype)
67
+
68
+ # data_config = resolve_data_config(vars(args), model=model)
69
+
70
+ # self.aiGeneratorModel = timm.create_model('eva02_base_patch14_448.mim_in22k_ft_in22k_in1k', num_classes=9, in_chans=3, checkpoint_path=path + 'AIModelDetector.pth-6ff3631e.pth')
71
+ self.aiArtModel = timm.create_model('eva02_base_patch14_448.mim_in22k_ft_in22k_in1k', num_classes=3, in_chans=3, checkpoint_path=path + 'AIArtDetector.pth-af59f7fa.pth')
72
+ # self.aiGeneratorModel = self.aiGeneratorModel.to(self.device)
73
+ self.aiArtModel = self.aiArtModel.to(self.device)
74
+ # self.aiGeneratorModel.eval()
75
+ self.aiArtModel.eval()
76
 
77
+
78
  self.transform = timm.data.create_transform(input_size=(3, 448, 448),
79
  is_training=False,
80
  use_prefetcher=False,
 
97
  crop_mode='squash',
98
  tf_preprocessing=False,
99
  separate=False)
100
+
101
+ # assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
102
+ # torch._dynamo.reset()
103
+ # model = torch.compile(model, backend=args.torchcompile)
104
 
105
+ self.supported_formats = ["JPEG", "PNG", "BMP", "TIFF", "WEBP", "RAW"] #GIF requires its own special implementation to get its frames
106
+ print("initialized handler.py successfully")
107
+ # self.label_map = {0: 'Dall-E 2', 1: 'DiscoDiff', 2: 'Midjourney', 3: 'NightCafe', 4: 'NovelAI', 5: 'Stable Diffusion', 6: 'StarryAI', 7: 'WomboDream', 8: 'Artbreeder'}
108
+
109
+ def __call__(self, data):
110
  """
111
  data args:
112
  inputs: Dict[str, Any]
113
  Return:
114
  A :obj:`list` | `dict`: will be serialized and returned
115
  """
116
+ inputs = data.pop("inputs")
117
+ if len(inputs) > 50:
118
+ return {'error': 'Exceeds max limit of images (50)'}
 
 
 
119
 
120
+ image_paths = inputs #['https://google_image.png', '']
121
+ batch_size = 1 # Set your desired batch size
122
 
123
+ results = {}
124
+ for i in range(0, len(image_paths), batch_size): # For each batch
125
+
126
+ batch_paths = image_paths[i:i+batch_size]
127
+ validUrls = []
128
+ batch_images = []
129
+
130
+ for j, src in enumerate(batch_paths): # Get all valid images open and inputted in batch_images
131
+ try:
132
+ # Image.open(batch_paths[j]).load() # Tests if image is okay to run inference on.
133
+ pos = src.find("base64")
134
+ if pos != -1:
135
+ # Assuming base64_str is the string value without 'data:image/jpeg;base64,'
136
+ new = Image.open(io.BytesIO(base64.decodebytes(bytes(src[pos+7:], "utf-8")))).convert("RGB")
137
+ # new.load() Necessary? Does this catch any edge cases? Without this, we don't actually load the image pixels.
138
+ batch_images.append(new)
139
+ validUrls.append(src)
140
+ else:
141
+ try:
142
+ # r = requests.get(src, stream=True)
143
+ # r.raw.decode_content = True
144
+ # new = Image.open(r.raw).convert("RGB")
145
+ # new = Image.open(urlopen(src))
146
+ headers = {
147
+ 'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/115.0.0.0 Safari/537.36'
148
+ }
149
+
150
+ r = requests.get(src, headers=headers)
151
+ new = Image.open(io.BytesIO(r.content)).convert("RGB")
152
+ # new.load()
153
+ batch_images.append(new)
154
+ validUrls.append(src)
155
+ except Exception as e:
156
+ results[src] = {'error': 'Failed to process image'}
157
+ # invalid_indices.append(j)
158
+ continue
159
+ # batch_images.append(batch_paths[j])
160
+
161
+ except Exception as e:
162
+ results[src] = {'error': 'Failed to process image w/ base64 in url'}
163
+ continue
164
+
165
+ # width, height = new.size
166
+
167
+ # if (width < 250 or height < 250) and len(request.data['srcs']) == 1:
168
+ # res['error'] = 'Please use a higher quality image'
169
+ # return JsonResponse(res, safe=False, status=status.HTTP_400_BAD_REQUEST)
170
+
171
+ batch_tensors = torch.stack([self.transform(img).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) for img in batch_images])
172
+ # batch_tensors = torch.unsqueeze(batch_tensors, 0)
173
+
174
+
175
+ # batch_images = [Image.open(path) for path in batch_paths]
176
+ # batch_tensors = torch.stack([preprocess(img) for img in batch_images])
177
 
178
+ with torch.no_grad():
179
+ output1 = self.aiGeneratorModel(batch_tensors)
180
+ for k, tensor in enumerate(output1):
181
+ output = tensor.softmax(-1)
182
+ output, indice = output.topk(9)
183
+ labels = [self.label_map[x] for x in indice.cpu().numpy().tolist()]
184
+ probabilities = [round(i * 100, 2) for i in output.cpu().numpy().tolist()]
185
+ single_res = {'prob': probabilities, 'indices': labels}
186
+ results[validUrls[k]] = single_res
187
+
188
+ return results
189
 
190
+ # handler = EndpointHandler()
191
+ # handler.__call__({'inputs': ['']})