imvladikon commited on
Commit
0f20ddf
โ€ข
1 Parent(s): d5f0fa9

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +60 -1
README.md CHANGED
@@ -1,6 +1,65 @@
1
  ---
2
  language:
3
  - he
 
4
  ---
5
 
6
- Experiments with encoder-decoder model, where encoder is [alephbert-base](https://huggingface.co/onlplab/alephbert-base) and [decoder is pruned T5-base model](https://huggingface.co/imvladikon/het5-base)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  language:
3
  - he
4
+ pipeline_tag: text-generation
5
  ---
6
 
7
+ ### Description
8
+ Experiments with encoder-decoder model, where encoder is [alephbert-base](https://huggingface.co/onlplab/alephbert-base) and [decoder is pruned mT5-base model](https://huggingface.co/imvladikon/het5-base)
9
+ Could be useful for generation hard-negative samples for pair-text classification
10
+
11
+
12
+ ### Usage
13
+
14
+ ```bash
15
+ git clone https://huggingface.co/imvladikon/alephbert-encoder-t5-decoder
16
+ ```
17
+
18
+ ```python
19
+ import torch
20
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModel
21
+ from transformers.modeling_outputs import BaseModelOutput
22
+ from datasets import load_dataset
23
+
24
+ enc_checkpoint = "./alephbert-encoder-t5-decoder/encoder"
25
+ enc_tokenizer = AutoTokenizer.from_pretrained(enc_checkpoint)
26
+ encoder = AutoModel.from_pretrained(enc_checkpoint).cuda()
27
+
28
+ dec_checkpoint = "./alephbert-encoder-t5-decoder/decoder"
29
+ dec_tokenizer = AutoTokenizer.from_pretrained(dec_checkpoint)
30
+ decoder = AutoModelForSeq2SeqLM.from_pretrained(dec_checkpoint).cuda()
31
+
32
+
33
+ def encode(texts):
34
+ encoded_input = enc_tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors='pt')
35
+ with torch.no_grad():
36
+ model_output = encoder(**encoded_input.to(encoder.device))
37
+ embeddings = model_output.pooler_output
38
+ embeddings = torch.nn.functional.normalize(embeddings)
39
+ return embeddings
40
+
41
+
42
+ def decode(embeddings, max_length=256, repetition_penalty=3.0, **kwargs):
43
+ out = decoder.generate(
44
+ encoder_outputs=BaseModelOutput(last_hidden_state=embeddings.unsqueeze(1)),
45
+ max_length=max_length,
46
+ repetition_penalty=repetition_penalty,
47
+ )
48
+ return [dec_tokenizer.decode(tokens, skip_special_tokens=True) for tokens in out]
49
+
50
+
51
+ encoder.eval()
52
+
53
+ text = """
54
+ ืžื—ืจ ื™ื•ืกื™ืฃ ืœื”ื™ื•ืช ืžืขื•ื ืŸ ื—ืœืงื™ืช ื•ื‘ืžื”ืœืš ื”ื™ื•ื ื™ืชื—ื–ืงื• ื”ืจื•ื—ื•ืช ื‘ื“ืจื•ื ื”ืืจืฅ ื•ื™ื™ืชื›ืŸ ืื•ื‘ืš ื‘ืื–ื•ืจ.
55
+ """.strip()
56
+ batch = [text]
57
+ embeddings = encode(batch)
58
+ decoder.eval()
59
+ out = decoder.generate(encoder_outputs=BaseModelOutput(last_hidden_state=embeddings.unsqueeze(1)), max_length=512, repetition_penalty=3.0)
60
+
61
+ for t, o in zip(batch, out):
62
+ print(t)
63
+ print(dec_tokenizer.decode(o, skip_special_tokens=True))
64
+ print('-----------')
65
+ ```