In [1]:
!pip install transformers
!pip install datasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

In [3]:
# Read the pokedex we scraped in web_scrape.ipynb into a DataFrame
pkmn = pd.read_csv("pokemon.csv")
pkmn.rename(columns={"Unnamed: 0": "wiki_index"}, inplace=True)
pkmn = pkmn[pkmn.primary_type != "Bird"] # MissingNo is special, but not special enough to break the rules.

In [4]:
# Fixing Inference.

lil = pkmn[['primary_type', 'Notes']].copy()

In [5]:
from datasets.dataset_dict import DatasetDict
from datasets import Dataset
import datasets

In [6]:
lil['primary_type'] = lil['primary_type'].astype('category') 
lil['label'] = lil['primary_type'].cat.codes
df = lil[['label', 'Notes']].copy()
df = df.rename(columns={'Notes': 'text'})

In [7]:
id2label = {k: v for k, v in enumerate(lil['primary_type'].cat.categories)}
label2id = {v: k for k, v in enumerate(lil['primary_type'].cat.categories)}

In [8]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = AutoModelForSequenceClassification.from_pretrained(
    "distilbert-base-uncased", num_labels=18, id2label=id2label, label2id=label2id
)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_projector.weight', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.weight', 'pre_clas

In [10]:
train_df = df.sample(frac=0.7)
test_df = df.drop(train_df.index, inplace=False)

train_dataset = Dataset.from_dict(train_df)
test_dataset = Dataset.from_dict(test_df)
my_dataset_dict = DatasetDict({"train":train_dataset,"test":test_dataset})

my_dataset_dict

DatasetDict({
    train: Dataset({
        features: ['label', 'text'],
        num_rows: 654
    })
    test: Dataset({
        features: ['label', 'text'],
        num_rows: 281
    })
})

In [11]:
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

dataset = my_dataset_dict
tokenized_datasets = dataset.map(tokenize_function, batched=True)

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [12]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['label', 'text', 'input_ids', 'attention_mask'],
        num_rows: 654
    })
    test: Dataset({
        features: ['label', 'text', 'input_ids', 'attention_mask'],
        num_rows: 281
    })
})

In [13]:
small_train_dataset = tokenized_datasets["train"]
small_eval_dataset = tokenized_datasets["test"]

In [14]:
from transformers import TrainingArguments

training_args = TrainingArguments(output_dir="test_trainer")

In [15]:
import numpy as np
from datasets import load_metric

metric = load_metric("accuracy")


In [16]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [17]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch")

In [18]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
)

In [19]:
trainer.train()

The following columns in the training set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `DistilBertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 654
  Num Epochs = 3
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 246


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,2.471577,0.274021
2,No log,2.191889,0.437722
3,No log,2.077948,0.47331


The following columns in the evaluation set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `DistilBertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 281
  Batch size = 8
The following columns in the evaluation set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `DistilBertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 281
  Batch size = 8
The following columns in the evaluation set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `DistilBertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 281


TrainOutput(global_step=246, training_loss=2.312268264894563, metrics={'train_runtime': 106.6333, 'train_samples_per_second': 18.4, 'train_steps_per_second': 2.307, 'total_flos': 259975195619328.0, 'train_loss': 2.312268264894563, 'epoch': 3.0})

In [20]:
model.save_pretrained("./model")

Configuration saved in ./config.json
Model weights saved in ./pytorch_model.bin


In [21]:
model2 = AutoModelForSequenceClassification.from_pretrained('./model')

loading configuration file ./config.json
Model config DistilBertConfig {
  "_name_or_path": ".",
  "activation": "gelu",
  "architectures": [
    "DistilBertForSequenceClassification"
  ],
  "attention_dropout": 0.1,
  "dim": 768,
  "dropout": 0.1,
  "hidden_dim": 3072,
  "id2label": {
    "0": "Bug",
    "1": "Dark",
    "2": "Dragon",
    "3": "Electric",
    "4": "Fairy",
    "5": "Fighting",
    "6": "Fire",
    "7": "Flying",
    "8": "Ghost",
    "9": "Grass",
    "10": "Ground",
    "11": "Ice",
    "12": "Normal",
    "13": "Poison",
    "14": "Psychic",
    "15": "Rock",
    "16": "Steel",
    "17": "Water"
  },
  "initializer_range": 0.02,
  "label2id": {
    "Bug": 0,
    "Dark": 1,
    "Dragon": 2,
    "Electric": 3,
    "Fairy": 4,
    "Fighting": 5,
    "Fire": 6,
    "Flying": 7,
    "Ghost": 8,
    "Grass": 9,
    "Ground": 10,
    "Ice": 11,
    "Normal": 12,
    "Poison": 13,
    "Psychic": 14,
    "Rock": 15,
    "Steel": 16,
    "Water": 17
  },
  "max_position_embe

In [22]:
from transformers import pipeline

classifier = pipeline(task="text-classification", tokenizer=tokenizer, model=model2.to('cpu'))

In [41]:
classifier('This pokemon climbs buildings at night.')

[{'label': 'Bug', 'score': 0.17771221697330475}]

In [36]:
classifier('This pokemon climbs buildings at night. They frequent midnight pool parties')

[{'label': 'Water', 'score': 0.4050225019454956}]

In [37]:
classifier('This pokemon climbs buildings at night. They frequent midnight garden parties')

[{'label': 'Grass', 'score': 0.38808730244636536}]

In [38]:
classifier('This pokemon climbs buildings at night. They frequent midnight flame-throwing parties')

[{'label': 'Fire', 'score': 0.22531799972057343}]