jacoballessio commited on
Commit
457c10a
1 Parent(s): 3d1751b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +47 -0
README.md CHANGED
@@ -2,3 +2,50 @@
2
  license: apache-2.0
3
  ---
4
  This is a simple AI image detection model utilizing visual transformers trained on the CIFake dataset.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  license: apache-2.0
3
  ---
4
  This is a simple AI image detection model utilizing visual transformers trained on the CIFake dataset.
5
+
6
+ Example usage:
7
+ ```
8
+ import torch
9
+ from PIL import Image
10
+ from torchvision import transforms
11
+ from transformers import ViTForImageClassification, ViTImageProcessor
12
+
13
+ # Load the trained model
14
+ model_path = 'trained_modelBEST.pth'
15
+ model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
16
+ model.classifier = torch.nn.Linear(model.classifier.in_features, 2)
17
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
18
+ model.eval()
19
+
20
+ # Define the image preprocessing pipeline
21
+ preprocess = transforms.Compose([
22
+ transforms.Resize((224, 224)),
23
+ transforms.ToTensor(),
24
+ ])
25
+
26
+ def predict(image_path, model, preprocess):
27
+ # Load and preprocess the image
28
+ image = Image.open(image_path).convert('RGB')
29
+ inputs = preprocess(image).unsqueeze(0)
30
+
31
+ # Perform inference
32
+ with torch.no_grad():
33
+ outputs = model(inputs).logits
34
+ predicted_label = torch.argmax(outputs).item()
35
+
36
+ # Map the predicted label to the corresponding class
37
+ label_map = {0: 'FAKE', 1: 'REAL'}
38
+ predicted_class = label_map[predicted_label]
39
+ return predicted_class
40
+
41
+ # Example usage
42
+ image_paths = [
43
+ 'path/to/real/image.jpg',
44
+ 'path/to/fake/image.jpg',
45
+ 'path/to/reddit/image.jpg'
46
+ ]
47
+
48
+ for image_path in image_paths:
49
+ predicted_class = predict(image_path, model, preprocess)
50
+ print(f'Predicted class: {predicted_class}', image_path)
51
+ ```