not-lain commited on
Commit
fbd8e59
1 Parent(s): b05deda

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +55 -0
README.md CHANGED
@@ -1,3 +1,58 @@
1
  ---
2
  license: mit
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
  ---
4
+ # Disclamer
5
+ I do not own, distribute, or take credits for this model, all copyrights belong to [Instadeep](https://huggingface.co/InstaDeepAI) under the [MIT licence](https://github.com/instadeepai/tunbert/)
6
+
7
+
8
+ # how to load the model
9
+ ```python
10
+ !git clone https://huggingface.co/not-lain/TunBERT
11
+ !pip install transformers
12
+ import torch.nn as nn
13
+ import torch
14
+ from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModelForSequenceClassification, PreTrainedModel,AutoConfig, BertModel
15
+ from transformers.modeling_outputs import SequenceClassifierOutput
16
+ config = AutoConfig.from_pretrained("not-lain/TunBERT")
17
+ class classifier(nn.Module):
18
+ def __init__(self,config):
19
+ super().__init__()
20
+
21
+ self.layer0 = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size, bias=True)
22
+ self.layer1 = nn.Linear(in_features=config.hidden_size, out_features=config.type_vocab_size, bias=True)
23
+ def forward(self,tensor):
24
+ out1 = self.layer0(tensor)
25
+ return self.layer1(out1)
26
+
27
+
28
+ class TunBERT(PreTrainedModel):
29
+ def __init__(self, config):
30
+ super().__init__(config)
31
+ self.BertModel = BertModel(config)
32
+ self.dropout = nn.Dropout(p=0.1, inplace=False)
33
+ self.classifier = classifier(config)
34
+
35
+ def forward(self,input_ids=None,token_type_ids=None,attention_mask=None,labels=None) :
36
+ outputs = self.BertModel(input_ids,token_type_ids,attention_mask)
37
+ sequence_output = self.dropout(outputs.last_hidden_state)
38
+ logits = self.classifier(sequence_output)
39
+ loss =None
40
+ if labels is not None :
41
+ loss_func = nn.CrossentropyLoss()
42
+ loss = loss_func(logits.view(-1,self.config.type_vocab_size),labels.view(-1))
43
+ return SequenceClassifierOutput(loss = loss, logits= logits, hidden_states=outputs.last_hidden_state,attentions=outputs.attentions)
44
+
45
+
46
+ tunbert = TunBERT(config)
47
+ tunbert.load_state_dict(torch.load("/content/TunBERT/pytorch_model.bin"))
48
+
49
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
50
+ ```
51
+
52
+
53
+ # how to use the model
54
+ ```python
55
+ text = "[insert text here]"
56
+ inputs = tokenizer(text,return_tensors='pt')
57
+ output = model(**inputs)
58
+ ```