File size: 1,028 Bytes
cfd9f7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import torch.nn as nn
from model import (
    SwitchTransformer,
    SwitchTransformerLayer,
    MultiHeadAttention,
    SwitchFeedForward,
    FeedForward,
)
from transformers import AutoTokenizer

device = 'cpu'

ff = FeedForward(768, 768*4)
attn = MultiHeadAttention(8, 768, 0.2)
st_ff = SwitchFeedForward(
            capacity_factor=1.25,
            drop_tokens=False,
            n_experts=4,
            expert=ff,
            d_model=768,
            is_scale_prob=True,
        )
st_layer = SwitchTransformerLayer(
            d_model=768, 
            attn=attn, 
            feed_forward=st_ff, 
            dropout_prob=0.2
        )
model = SwitchTransformer(
            layer=st_layer,
            n_layers=4,
            n_experts=4,
            device=device,
            load_balancing_loss_ceof=0.05,
        ).to(device)

model.load_state_dict(torch.load("switch_transformer.pt", map_location=torch.device('cpu')))
tokenizer = AutoTokenizer.from_pretrained("Kyrmasch/kaz-roberta-squad2-kaz")