File size: 3,208 Bytes
8797a7e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import torch
from torch import Tensor, LongTensor
from transformers import T5ForConditionalGeneration, T5Config
from transformers import TextIteratorStreamer
from transformers.generation.configuration_utils import GenerationConfig
class TextToTextModel(T5ForConditionalGeneration):
def __init__(self, config: T5Config) -> None:
'''
TextToTextModel继承T5ForConditionalGeneration
'''
super().__init__(config)
@torch.no_grad()
def my_generate(self,
input_ids: LongTensor,
attention_mask: LongTensor,
max_seq_len: int=256,
search_type: str='beam',
streamer: TextIteratorStreamer=None,
) -> Tensor:
'''
自定义gennerate方法方便调用、测试
search_type: ['greedy', 'beam', 'sampling', 'contrastive', ]
- *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and
`do_sample=False`
- *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0.`
and `top_k>1`
- *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and
`do_sample=True`
- *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and
`do_sample=False`
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if
`num_beams>1` and `do_sample=True`
'''
generation_config = GenerationConfig()
generation_config.remove_invalid_values = True
generation_config.eos_token_id = 1
generation_config.pad_token_id = 0
generation_config.decoder_start_token_id = self.config.decoder_start_token_id
generation_config.max_new_tokens = max_seq_len
# generation_config.repetition_penalty = 1.1 # 重复词惩罚
if search_type == 'greedy':
generation_config.num_beams = 1
generation_config.do_sample = False
elif search_type == 'beam':
generation_config.top_k = 50
generation_config.num_beams = 5
generation_config.do_sample = True
generation_config.top_p = 0.95
generation_config.no_repeat_ngram_size = 4
generation_config.length_penalty = -2.0
generation_config.early_stopping = True
elif search_type == 'sampling':
generation_config.num_beams = 1
generation_config.do_sample = True
generation_config.top_k = 50
generation_config.temperature = 0.98 # 越低概率越趋向于均匀分布
generation_config.top_p = 0.80
generation_config.no_repeat_ngram_size = 4
elif search_type == 'contrastive':
generation_config.penalty_alpha = 0.5
generation_config.top_k = 50
result = self.generate(
inputs=input_ids,
attention_mask=attention_mask,
generation_config=generation_config,
streamer=streamer,
)
return result
|