Knowledge Distillation for Computer Vision
知識の蒸留は、より大規模で複雑なモデル (教師) からより小規模で単純なモデル (生徒) に知識を伝達するために使用される手法です。あるモデルから別のモデルに知識を抽出するには、特定のタスク (この場合は画像分類) でトレーニングされた事前トレーニング済み教師モデルを取得し、画像分類でトレーニングされる生徒モデルをランダムに初期化します。次に、学生モデルをトレーニングして、その出力と教師の出力の差を最小限に抑え、動作を模倣します。これは Distilling the Knowledge in a Neural Network by Hinton et al で最初に導入されました。このガイドでは、タスク固有の知識の蒸留を行います。これには Beans データセット を使用します。
このガイドでは、微調整された ViT モデル (教師モデル) を抽出して MobileNet (学生モデル) 🤗 Transformers の Trainer API を使用します。
蒸留とプロセスの評価に必要なライブラリをインストールしましょう。
pip install transformers datasets accelerate tensorboard evaluate --upgrade
この例では、教師モデルとしてmerve/beans-vit-224
モデルを使用しています。これは、Bean データセットに基づいて微調整されたgoogle/vit-base-patch16-224-in21k
に基づく画像分類モデルです。このモデルをランダムに初期化された MobileNetV2 に抽出します。
次に、データセットをロードします。
from datasets import load_dataset
dataset = load_dataset("beans")
この場合、同じ解像度で同じ出力が返されるため、どちらのモデルの画像プロセッサも使用できます。 dataset
のmap()
メソッドを使用して、データセットのすべての分割に前処理を適用します。
from transformers import AutoImageProcessor
teacher_processor = AutoImageProcessor.from_pretrained("merve/beans-vit-224")
def process(examples):
processed_inputs = teacher_processor(examples["image"])
return processed_inputs
processed_datasets = dataset.map(process, batched=True)
基本的に、我々は生徒モデル(ランダムに初期化されたMobileNet)が教師モデル(微調整されたビジョン変換器)を模倣することを望む。これを実現するために、まず教師と生徒からロジット出力を得る。次に、それぞれのソフトターゲットの重要度を制御するパラメータtemperature
で分割する。lambda
と呼ばれるパラメータは蒸留ロスの重要度を量る。この例では、temperature=5
、lambda=0.5
とする。生徒と教師の間の発散を計算するために、Kullback-Leibler発散損失を使用します。2つのデータPとQが与えられたとき、KLダイバージェンスはQを使ってPを表現するためにどれだけの余分な情報が必要かを説明します。もし2つが同じであれば、QからPを説明するために必要な他の情報はないので、それらのKLダイバージェンスはゼロになります。
from transformers import TrainingArguments, Trainer
import torch
import torch.nn as nn
import torch.nn.functional as F
class ImageDistilTrainer(Trainer):
def __init__(self, *args, teacher_model=None, **kwargs):
super().__init__(*args, **kwargs)
self.teacher = teacher_model
self.student = student_model
self.loss_function = nn.KLDivLoss(reduction="batchmean")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.teacher.to(device)
self.teacher.eval()
self.temperature = temperature
self.lambda_param = lambda_param
def compute_loss(self, student, inputs, return_outputs=False):
student_output = self.student(**inputs)
with torch.no_grad():
teacher_output = self.teacher(**inputs)
# Compute soft targets for teacher and student
soft_teacher = F.softmax(teacher_output.logits / self.temperature, dim=-1)
soft_student = F.log_softmax(student_output.logits / self.temperature, dim=-1)
# Compute the loss
distillation_loss = self.loss_function(soft_student, soft_teacher) * (self.temperature ** 2)
# Compute the true label loss
student_target_loss = student_output.loss
# Calculate final loss
loss = (1. - self.lambda_param) * student_target_loss + self.lambda_param * distillation_loss
return (loss, student_output) if return_outputs else loss
次に、Hugging Face Hub にログインして、trainer
を通じてモデルを Hugging Face Hub にプッシュできるようにします。
from huggingface_hub import notebook_login
notebook_login()
教師モデルと生徒モデルであるTrainingArguments
を設定しましょう。
from transformers import AutoModelForImageClassification, MobileNetV2Config, MobileNetV2ForImageClassification
training_args = TrainingArguments(
output_dir="my-awesome-model",
num_train_epochs=30,
fp16=True,
logging_dir=f"{repo_name}/logs",
logging_strategy="epoch",
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="accuracy",
report_to="tensorboard",
push_to_hub=True,
hub_strategy="every_save",
hub_model_id=repo_name,
)
num_labels = len(processed_datasets["train"].features["labels"].names)
# initialize models
teacher_model = AutoModelForImageClassification.from_pretrained(
"merve/beans-vit-224",
num_labels=num_labels,
ignore_mismatched_sizes=True
)
# training MobileNetV2 from scratch
student_config = MobileNetV2Config()
student_config.num_labels = num_labels
student_model = MobileNetV2ForImageClassification(student_config)
compute_metrics
関数を使用して、テスト セットでモデルを評価できます。この関数は、トレーニング プロセス中にモデルのaccuracy
とf1
を計算するために使用されます。
import evaluate
import numpy as np
accuracy = evaluate.load("accuracy")
def compute_metrics(eval_pred):
predictions, labels = eval_pred
acc = accuracy.compute(references=labels, predictions=np.argmax(predictions, axis=1))
return {"accuracy": acc["accuracy"]}
定義したトレーニング引数を使用してTrainer
を初期化しましょう。データ照合装置も初期化します。
from transformers import DefaultDataCollator
data_collator = DefaultDataCollator()
trainer = ImageDistilTrainer(
student_model=student_model,
teacher_model=teacher_model,
training_args=training_args,
train_dataset=processed_datasets["train"],
eval_dataset=processed_datasets["validation"],
data_collator=data_collator,
tokenizer=teacher_extractor,
compute_metrics=compute_metrics,
temperature=5,
lambda_param=0.5
)
これでモデルをトレーニングできるようになりました。
trainer.train()
テスト セットでモデルを評価できます。
trainer.evaluate(processed_datasets["test"])
テスト セットでは、モデルの精度は 72% に達します。蒸留効率の健全性チェックを行うために、同じハイパーパラメータを使用して Bean データセットで MobileNet を最初からトレーニングし、テスト セットで 63% の精度を観察しました。読者の皆様には、さまざまな事前トレーニング済み教師モデル、学生アーキテクチャ、蒸留パラメータを試していただき、その結果を報告していただくようお勧めします。抽出されたモデルのトレーニング ログとチェックポイントは このリポジトリ にあり、最初からトレーニングされた MobileNetV2 はこの リポジトリ。
< > Update on GitHub