--- license: apache-2.0 --- This is a simple AI image detection model utilizing visual transformers trained on the CIFake dataset. Example usage: ```python import torch from PIL import Image from torchvision import transforms from transformers import ViTForImageClassification, ViTImageProcessor # Load the trained model model_path = 'trained_modelBEST.pth' model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224') model.classifier = torch.nn.Linear(model.classifier.in_features, 2) model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) model.eval() # Define the image preprocessing pipeline preprocess = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) def predict(image_path, model, preprocess): # Load and preprocess the image image = Image.open(image_path).convert('RGB') inputs = preprocess(image).unsqueeze(0) # Perform inference with torch.no_grad(): outputs = model(inputs).logits predicted_label = torch.argmax(outputs).item() # Map the predicted label to the corresponding class label_map = {0: 'FAKE', 1: 'REAL'} predicted_class = label_map[predicted_label] return predicted_class # Example usage image_paths = [ 'path/to/real/image.jpg', 'path/to/fake/image.jpg', 'path/to/reddit/image.jpg' ] for image_path in image_paths: predicted_class = predict(image_path, model, preprocess) print(f'Predicted class: {predicted_class}', image_path) ```