--- library_name: transformers base_model: - google/vit-base-patch16-224 --- # Model Card for Pokémon Type Classification This model leverages a Vision Transformer (ViT) to classify Pokémon images into 18 different types. It was developed as part of the CS 310 Final Project and trained on a Pokémon image dataset. ## Model Details - **Developer:** Xianglu (Steven) Zhu - **Purpose:** Pokémon type classification - **Model Type:** Vision Transformer (ViT) for image classification ## Getting Started Here’s how you can use the model for classification: ```python import torch from PIL import Image import torchvision.transforms as transforms from transformers import ViTForImageClassification, ViTFeatureExtractor # Load the pretrained model and feature extractor hf_model = ViTForImageClassification.from_pretrained("NP-NP/pokemon_model") hf_feature_extractor = ViTFeatureExtractor.from_pretrained("NP-NP/pokemon_model") # Define preprocessing transformations transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=hf_feature_extractor.image_mean, std=hf_feature_extractor.image_std) ]) # Mapping of labels to indices and vice versa labels_dict = { 'Grass': 0, 'Fire': 1, 'Water': 2, 'Bug': 3, 'Normal': 4, 'Poison': 5, 'Electric': 6, 'Ground': 7, 'Fairy': 8, 'Fighting': 9, 'Psychic': 10, 'Rock': 11, 'Ghost': 12, 'Ice': 13, 'Dragon': 14, 'Dark': 15, 'Steel': 16, 'Flying': 17 } idx_to_label = {v: k for k, v in labels_dict.items()} # Load and preprocess the image image_path = "cute-pikachu-flowers-pokemon-desktop-wallpaper.jpg" image = Image.open(image_path).convert("RGB") input_tensor = transform(image).unsqueeze(0) # shape: (1, 3, 224, 224) # Make a prediction hf_model.eval() with torch.no_grad(): outputs = hf_model(input_tensor) logits = outputs.logits predicted_class_idx = torch.argmax(logits, dim=1).item() predicted_class = idx_to_label[predicted_class_idx] print("Predicted Pokémon type:", predicted_class) ```