File size: 2,559 Bytes
6588ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import numpy
from PIL import Image
from torchvision.transforms import ToTensor
from transformers import ViTModel, ViTFeatureExtractor
from transformers.modeling_outputs import SequenceClassifierOutput
import torch.nn as nn

feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-large-patch32-384')
class ViTForImageClassification(nn.Module):
    def __init__(self, num_labels=2):
      super(ViTForImageClassification, self).__init__()
      self.vit = ViTModel.from_pretrained('google/vit-large-patch32-384')
      self.dropout = nn.Dropout(0.1)
      self.classifier = nn.Linear(self.vit.config.hidden_size, num_labels)
      self.num_labels = num_labels

    def forward(self, pixel_values, labels=None):
      outputs = self.vit(pixel_values=pixel_values)
      output = self.dropout(outputs.last_hidden_state[:,0])
      logits = self.classifier(output)
      
      if labels is not None:
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
        )
      else:
        return logits
      
def preprocess_image(image, desired_size=384):
    im = image

    # Resize and pad the image
    old_size = im.size
    ratio = float(desired_size) / max(old_size)
    new_size = tuple([int(x*ratio) for x in old_size])
    im = im.resize(new_size)

    # Create a new image and paste the resized on it
    new_im = Image.new("RGB", (desired_size, desired_size), "white")
    new_im.paste(im, ((desired_size-new_size[0])//2, (desired_size-new_size[1])//2))
    return new_im

def predict_image(image, model, feature_extractor):
  # Ensure model is in eval mode
  model.eval()

  # Convert image to tensor
  transform = ToTensor()
  input_tensor = transform(image)
  input_tensor = torch.tensor(numpy.array(feature_extractor(input_tensor)['pixel_values']))

  # Move tensors to the right device
  input_tensor = input_tensor.cuda()

  # Forward pass of the image through the model
  output = model(input_tensor)

  # Convert model output to probabilities using softmax
  probabilities = torch.nn.functional.softmax(output, dim=1)

  return probabilities.cpu().detach().numpy()


model = ViTForImageClassification(num_labels=2)
model.load_state_dict(torch.load("./AID96k_E15_384.pth"))
model.cuda()
model.eval()
img = Image.open("test.png")
img = preprocess_image(img)
probs = predict_image(img, model, feature_extractor)
print(f"AI: {probs[0][0]}")
print(f"Human: {probs[0][1]}")