Retrieva, Inc. org

SDPA attention の追加

下記のようにすることで Attention 部分の処理が torch.matmul から torch の sdpa に変更されます(指定しない場合は eager)

model = AutoModel.from_pretrained("retrieva-jp/bert-1.3b", trust_remote_code=True, attn_implementation="sdpa")

SDPA Attention の検証結果

  • SDPA Attention を利用した場合と、これまでの Attention(eager)を利用した場合で出力が大きく変わらないことを検証済み

スクリーンショット 2024-07-08 14.40.33.png

  • SDPA Attention を有効にすることで、秒間あたりのトークン処理数などが改善されることを確認済み

image.png

Retrieva, Inc. org

内部でも SDPA の有無で出力が変更しないことを確認できたためマージします

Katsumata420 changed pull request status to merged

Sign up or log in to comment