NP-NP commited on
Commit
a721254
·
verified ·
1 Parent(s): cb87e23

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +16 -19
README.md CHANGED
@@ -1,60 +1,56 @@
1
  ---
2
  library_name: transformers
3
  base_model:
4
- - google/vit-base-patch16-224
5
  ---
6
 
7
- # Model Card for Model ID
8
-
9
- <!-- Provide a quick summary of what the model is/does. -->
10
 
 
11
 
 
12
 
13
  ## Model Details
14
- VIT model used for Pokemon type classification, also used in my CS 310 Final Project
15
-
16
- ### Model Description
17
-
18
- <!-- Provide a longer summary of what this model is. -->
19
 
20
- This is the model card of a 🤗 transformers model that has been trained on Pokemon image dataset to classify the all 18 types of the Pokemon.
 
 
21
 
22
- - **Developed by:** Xianglu(Steven) Zhu
23
- - **Funded by [optional]:** [More Information Needed]
24
- - **Shared by [optional]:** [More Information Needed]
25
- - **Model type:** Vision Transformer for Regression
26
 
27
- ## How to Get Started with the Model
28
-
29
- Use the code below to get started with the model
30
 
 
31
  import torch
32
  from PIL import Image
33
  import torchvision.transforms as transforms
34
  from transformers import ViTForImageClassification, ViTFeatureExtractor
35
 
 
36
  hf_model = ViTForImageClassification.from_pretrained("NP-NP/pokemon_model")
37
  hf_feature_extractor = ViTFeatureExtractor.from_pretrained("NP-NP/pokemon_model")
38
 
 
39
  transform = transforms.Compose([
40
  transforms.Resize((224, 224)),
41
  transforms.ToTensor(),
42
  transforms.Normalize(mean=hf_feature_extractor.image_mean, std=hf_feature_extractor.image_std)
43
  ])
44
 
 
45
  labels_dict = {
46
  'Grass': 0, 'Fire': 1, 'Water': 2, 'Bug': 3, 'Normal': 4, 'Poison': 5, 'Electric': 6,
47
  'Ground': 7, 'Fairy': 8, 'Fighting': 9, 'Psychic': 10, 'Rock': 11, 'Ghost': 12,
48
  'Ice': 13, 'Dragon': 14, 'Dark': 15, 'Steel': 16, 'Flying': 17
49
  }
50
-
51
  idx_to_label = {v: k for k, v in labels_dict.items()}
52
 
 
53
  image_path = "cute-pikachu-flowers-pokemon-desktop-wallpaper.jpg"
54
  image = Image.open(image_path).convert("RGB")
55
-
56
  input_tensor = transform(image).unsqueeze(0) # shape: (1, 3, 224, 224)
57
 
 
58
  hf_model.eval()
59
  with torch.no_grad():
60
  outputs = hf_model(input_tensor)
@@ -63,6 +59,7 @@ with torch.no_grad():
63
 
64
  predicted_class = idx_to_label[predicted_class_idx]
65
  print("Predicted Pokémon type:", predicted_class)
 
66
 
67
 
68
 
 
1
  ---
2
  library_name: transformers
3
  base_model:
4
+ - google/vit-base-patch16-224
5
  ---
6
 
7
+ # Model Card for Pokémon Type Classification
 
 
8
 
9
+ This model leverages a Vision Transformer (ViT) to classify Pokémon images into 18 different types.
10
 
11
+ It was developed as part of the CS 310 Final Project and trained on a Pokémon image dataset.
12
 
13
  ## Model Details
 
 
 
 
 
14
 
15
+ - **Developer:** Xianglu (Steven) Zhu
16
+ - **Purpose:** Pokémon type classification
17
+ - **Model Type:** Vision Transformer (ViT) for image classification
18
 
19
+ ## Getting Started
 
 
 
20
 
21
+ Here’s how you can use the model for classification:
 
 
22
 
23
+ ```python
24
  import torch
25
  from PIL import Image
26
  import torchvision.transforms as transforms
27
  from transformers import ViTForImageClassification, ViTFeatureExtractor
28
 
29
+ # Load the pretrained model and feature extractor
30
  hf_model = ViTForImageClassification.from_pretrained("NP-NP/pokemon_model")
31
  hf_feature_extractor = ViTFeatureExtractor.from_pretrained("NP-NP/pokemon_model")
32
 
33
+ # Define preprocessing transformations
34
  transform = transforms.Compose([
35
  transforms.Resize((224, 224)),
36
  transforms.ToTensor(),
37
  transforms.Normalize(mean=hf_feature_extractor.image_mean, std=hf_feature_extractor.image_std)
38
  ])
39
 
40
+ # Mapping of labels to indices and vice versa
41
  labels_dict = {
42
  'Grass': 0, 'Fire': 1, 'Water': 2, 'Bug': 3, 'Normal': 4, 'Poison': 5, 'Electric': 6,
43
  'Ground': 7, 'Fairy': 8, 'Fighting': 9, 'Psychic': 10, 'Rock': 11, 'Ghost': 12,
44
  'Ice': 13, 'Dragon': 14, 'Dark': 15, 'Steel': 16, 'Flying': 17
45
  }
 
46
  idx_to_label = {v: k for k, v in labels_dict.items()}
47
 
48
+ # Load and preprocess the image
49
  image_path = "cute-pikachu-flowers-pokemon-desktop-wallpaper.jpg"
50
  image = Image.open(image_path).convert("RGB")
 
51
  input_tensor = transform(image).unsqueeze(0) # shape: (1, 3, 224, 224)
52
 
53
+ # Make a prediction
54
  hf_model.eval()
55
  with torch.no_grad():
56
  outputs = hf_model(input_tensor)
 
59
 
60
  predicted_class = idx_to_label[predicted_class_idx]
61
  print("Predicted Pokémon type:", predicted_class)
62
+ ```
63
 
64
 
65