Spaces:
Sleeping
Sleeping
from mixer import Mixer | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import torch | |
class T5Mixer(Mixer): | |
def __init__(self) -> None: | |
super().__init__() | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
"llm-book/t5-base-long-livedoor-news-corpus") | |
self.model = AutoModelForSeq2SeqLM.from_pretrained( | |
"llm-book/t5-base-long-livedoor-news-corpus") | |
self.tokenid_to_tokentext = {i: t for t, | |
i in self.tokenizer.get_vocab().items()} | |
def get_encoder_state(self, sentence: str): | |
inputs = self.tokenizer(sentence, return_tensors="pt") | |
eo = self.model.encoder.forward(**inputs) | |
es = eo["last_hidden_state"] | |
return es | |
def get_mixed_encode_state( | |
self, sentence_A: str, sentence_B: str, A_ratio: float = 0.5): | |
es_A = self.get_encoder_state(sentence_A) | |
es_B = self.get_encoder_state(sentence_B) | |
n_tokens_A = es_A.size(1) | |
n_tokens_B = es_B.size(1) | |
if n_tokens_A >= n_tokens_B: | |
es = es_A.clone().detach()*A_ratio | |
es[:, :n_tokens_B, :] += es_B*(1.-A_ratio) | |
else: | |
es = es_B.clone().detach()*(1.-A_ratio) | |
es[:, :n_tokens_A, :] += es_A*A_ratio | |
return es | |
def mix_sentences(self, sentence_A: str, sentence_B: str, A_ratio: float, max_n_tokens: int = 140): | |
es = self.get_mixed_encode_state(sentence_A, sentence_B, A_ratio) | |
to = torch.tensor([[self.tokenizer.pad_token_id]]) | |
for i in range(max_n_tokens): | |
od = self.model.decoder.forward( | |
input_ids=to, | |
encoder_hidden_states=es+torch.randn_like(es)*noise_rate | |
) | |
sd = od.last_hidden_state | |
l = self.model.lm_head(sd[0, -1, :]) | |
t_next = l.argmax() | |
ttext_next = self.tokenid_to_tokentext[int(t_next)] | |
to = torch.cat((to, t_next[None, None]), dim=-1) | |
if t_next == self.tokenizer.eos_token_id: | |
break | |
sentence = self.tokenizer.batch_decode(to)[0] | |
return sentence | |