|
--- |
|
license: apache-2.0 |
|
--- |
|
This is a simple AI image detection model utilizing visual transformers trained on the CIFake dataset. |
|
|
|
Example usage: |
|
```python |
|
import torch |
|
from PIL import Image |
|
from torchvision import transforms |
|
from transformers import ViTForImageClassification, ViTImageProcessor |
|
|
|
# Load the trained model |
|
model_path = 'vit_model.pth' |
|
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224') |
|
model.classifier = torch.nn.Linear(model.classifier.in_features, 2) |
|
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) |
|
model.eval() |
|
|
|
# Define the image preprocessing pipeline |
|
preprocess = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
]) |
|
|
|
def predict(image_path, model, preprocess): |
|
# Load and preprocess the image |
|
image = Image.open(image_path).convert('RGB') |
|
inputs = preprocess(image).unsqueeze(0) |
|
|
|
# Perform inference |
|
with torch.no_grad(): |
|
outputs = model(inputs).logits |
|
predicted_label = torch.argmax(outputs).item() |
|
|
|
# Map the predicted label to the corresponding class |
|
label_map = {0: 'FAKE', 1: 'REAL'} |
|
predicted_class = label_map[predicted_label] |
|
return predicted_class |
|
|
|
# Example usage |
|
image_paths = [ |
|
'path/to/image.jpg', |
|
'path/to/image.jpg', |
|
'path/to/image.jpg' |
|
] |
|
|
|
for image_path in image_paths: |
|
predicted_class = predict(image_path, model, preprocess) |
|
print(f'Predicted class: {predicted_class}', image_path) |
|
``` |
|
|