|
from typing import Tuple, Optional
|
|
|
|
import torch
|
|
from transformers.generation.logits_process import TopKLogitsWarper, TopPLogitsWarper
|
|
|
|
|
|
def compute_rsa_probas(
|
|
logits: torch.Tensor, prior: torch.Tensor, rationality: float = 1.0
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
:param logits: (world_size, num_beam, vocab_size)
|
|
:param prior: (world_size, num_beam) for each beam the prior over the objects
|
|
:param rationality: rationality parameter, the higher the more rational ie the more the speaker will try to adapt
|
|
to the listener
|
|
:return: S1, L1: (world_size, num_beam, vocab_size).
|
|
S1[o, b, w] is the (log)probability of the word w given the object o and the current partial summary for the beam b
|
|
L1[o, b, w] is the (log)probability of the object o given the word w and the current partial summary for the beam b
|
|
"""
|
|
|
|
prod = logits + prior[..., None]
|
|
|
|
L0 = torch.nan_to_num(torch.log_softmax(prod, dim=0), nan=-float("inf"))
|
|
|
|
prod_s = logits + L0 * rationality
|
|
|
|
S1 = torch.log_softmax(prod_s, dim=-1)
|
|
S1 = torch.nan_to_num(S1, nan=-float("inf"))
|
|
|
|
prod_l = logits + L0
|
|
L1 = torch.log_softmax(prod_l, dim=0)
|
|
L1 = torch.nan_to_num(L1, nan=-float("inf"))
|
|
|
|
return S1, L1
|
|
|
|
|
|
def sample_from_probs(
|
|
logits: torch.Tensor, num_beams: torch.Tensor, do_sample: bool, K: int = 10
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""
|
|
|
|
:param logits: (num_beams, vocab_size) log proba for the next token only for the wanted object
|
|
:param num_beams: number of beam to sample. (Can be different from the shape of logits since some beams might have
|
|
finished earlier)
|
|
:param do_sample: sample or use argmax
|
|
:param K: number of samples to draw per beam to create the new population
|
|
:return: idx_beam, idx_token, tokens_scores, the indices of the sampled tokens and their scores
|
|
"""
|
|
|
|
vocab_size = logits.shape[-1]
|
|
if do_sample:
|
|
|
|
logits = logits.view(num_beams * logits.shape[-1])
|
|
probs = torch.softmax(logits, dim=-1)
|
|
samples = torch.multinomial(probs, num_samples=K * num_beams)
|
|
|
|
|
|
idx_beam, idx_token = samples // vocab_size, samples % vocab_size
|
|
|
|
logits = logits.view(num_beams * vocab_size)
|
|
|
|
tokens_scores = logits.gather(dim=-1, index=samples).squeeze(-1)
|
|
|
|
return idx_beam, idx_token, tokens_scores
|
|
|
|
else:
|
|
|
|
num_beams = logits.shape[0]
|
|
vocab_size = logits.shape[-1]
|
|
|
|
logits = logits.view(num_beams * vocab_size)
|
|
scores, samples = logits.topk(2 * num_beams, dim=-1)
|
|
|
|
idx_beam, idx_token = samples // vocab_size, samples % vocab_size
|
|
|
|
tokens_scores = scores.squeeze(-1)
|
|
|
|
return idx_beam, idx_token, tokens_scores
|
|
|
|
|
|
|
|
class RSAContextualDecoding:
|
|
def __init__(self, model, tokenizer, device):
|
|
"""
|
|
|
|
:param model:
|
|
:param tokenizer:
|
|
:param device:
|
|
"""
|
|
|
|
self.model = model.to(device)
|
|
self.tokenizer = tokenizer
|
|
self.device = device
|
|
|
|
def fwd_pass(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
decoder_input_ids: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
decoder_attention_mask: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Make a forward pass through the model to get the logits for the next tokens
|
|
:param input_ids: (world_size, num_beams, input_length)
|
|
:param decoder_input_ids: (world_size, num_beams, partial_target_length)
|
|
:param attention_mask: (world_size, num_beams, input_length)
|
|
:param decoder_attention_mask: (world_size, num_beams, partial_target_length)
|
|
:return: logits: (world_size, num_beams, vocab_size)
|
|
"""
|
|
with torch.no_grad():
|
|
world_size, num_beams = input_ids.shape[0], decoder_input_ids.shape[1]
|
|
|
|
input_ids = input_ids.view(world_size * num_beams, input_ids.shape[2]).to(self.device)
|
|
attention_mask = attention_mask.view(
|
|
world_size * num_beams, attention_mask.shape[2]
|
|
).to(self.device)
|
|
|
|
decoder_input_ids = decoder_input_ids.view(
|
|
world_size * num_beams, decoder_input_ids.shape[2]
|
|
).to(self.device)
|
|
|
|
decoder_attention_mask = decoder_attention_mask.view(
|
|
world_size * num_beams, decoder_attention_mask.shape[2]
|
|
).to(self.device)
|
|
|
|
outputs = self.model(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
decoder_input_ids=decoder_input_ids,
|
|
decoder_attention_mask=decoder_attention_mask,
|
|
)
|
|
logits = outputs.logits[..., -1, :]
|
|
|
|
logits = logits.view(self.world_size, num_beams, logits.shape[-1])
|
|
|
|
|
|
|
|
return logits
|
|
|
|
def duplicate_and_align_input_ids(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
input_ids_mask: torch.Tensor,
|
|
decoder_input_ids: torch.Tensor,
|
|
decoder_input_ids_mask: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Duplicate the input_ids and decoder_input_ids to have all pairs of input_ids[i] and decoder_input_ids[j]
|
|
It uses torch.repeat and torch.repeat_interleave to do get something like:
|
|
a 1
|
|
a 2
|
|
a 3
|
|
b 1
|
|
b 2
|
|
b 3
|
|
...
|
|
:param input_ids: (world_size, input_length)
|
|
:param decoder_input_ids: (num_beam, partial_target_length)
|
|
:return: input_ids: (world_size, num_beam, input_length)
|
|
decoder_input_ids: (world_size, num_beam, partial_target_length)
|
|
aligned such that all pairs of input_ids[i] and decoder_input_ids[j] are present
|
|
"""
|
|
|
|
num_beams = decoder_input_ids.shape[0]
|
|
|
|
input_ids = input_ids.unsqueeze(1).repeat(1, num_beams, 1)
|
|
input_ids_mask = input_ids_mask.unsqueeze(1).repeat(1, num_beams, 1)
|
|
|
|
|
|
decoder_input_ids = decoder_input_ids.repeat_interleave(self.world_size, dim=0)
|
|
decoder_input_ids_mask = decoder_input_ids_mask.repeat_interleave(
|
|
self.world_size, dim=0
|
|
)
|
|
|
|
decoder_input_ids = decoder_input_ids.view(self.world_size, num_beams, -1)
|
|
decoder_input_ids_mask = decoder_input_ids_mask.view(
|
|
self.world_size, num_beams, -1
|
|
)
|
|
|
|
|
|
|
|
|
|
return input_ids, input_ids_mask, decoder_input_ids, decoder_input_ids_mask
|
|
|
|
def compute_rsa_probas(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
decoder_input_ids: torch.Tensor,
|
|
decoder_attention_mask: torch.Tensor,
|
|
do_sample: bool = True,
|
|
top_p: Optional[float] = None,
|
|
top_k: Optional[int] = None,
|
|
temperature: float = 1.0,
|
|
rationality: float = 8.0,
|
|
process_logits_before_rsa: bool = True,
|
|
beam_scores: torch.Tensor = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
|
|
:param input_ids: input_ids to the encoder/decoder model = source texts
|
|
:param attention_mask: attention_mask to the encoder/decoder model
|
|
:param decoder_input_ids: decoder ids / partial summaries
|
|
:param decoder_attention_mask: attention mask for the decoder
|
|
:param do_sample: are we planning on sampling the tokens or using argmax (to apply or not the logits processor)
|
|
:param top_p: parameters for the logits processor top p
|
|
:param top_k: parameters for the logits processor top k
|
|
:param temperature: sampling temperature
|
|
:param rationality: how rational is the speaker (higher means more rational)
|
|
:param process_logits_before_rsa: should we apply the logits processor before or after the RSA computation
|
|
:param beam_scores: (world_size, num_beams) the scores of the beams to be added to the logits
|
|
:return: S1, L1: (world_size, num_beam, vocab_size).
|
|
"""
|
|
|
|
|
|
assert (top_p is None) or (
|
|
top_k is None
|
|
), "top_p and top_k cannot be used together"
|
|
assert ((top_p is not None) and (do_sample)) or (
|
|
top_p is None
|
|
), "top_p can only be used with sampling"
|
|
assert ((top_k is not None) and (do_sample)) or (
|
|
top_k is None
|
|
), "top_k can only be used with sampling"
|
|
|
|
|
|
(
|
|
input_ids,
|
|
attention_mask,
|
|
decoder_input_ids,
|
|
decoder_attention_mask,
|
|
) = self.duplicate_and_align_input_ids(
|
|
input_ids,
|
|
attention_mask,
|
|
decoder_input_ids,
|
|
decoder_attention_mask,
|
|
)
|
|
|
|
logits = (
|
|
self.fwd_pass(
|
|
input_ids, decoder_input_ids, attention_mask, decoder_attention_mask
|
|
)
|
|
/ temperature
|
|
)
|
|
|
|
logits = torch.nn.functional.log_softmax(logits, dim=-1)
|
|
|
|
world_size = input_ids.shape[0]
|
|
num_beams = decoder_input_ids.shape[1]
|
|
|
|
logits = logits.view(world_size * num_beams, -1)
|
|
|
|
if do_sample and process_logits_before_rsa:
|
|
if top_p is not None:
|
|
logits = TopPLogitsWarper(top_p=top_p)(input_ids=None, scores=logits)
|
|
if top_k is not None:
|
|
logits = TopKLogitsWarper(top_k=top_k)(input_ids=None, scores=logits)
|
|
|
|
logits = logits.view(world_size, num_beams, -1)
|
|
|
|
if beam_scores is not None:
|
|
logits = logits + beam_scores[None, ..., None]
|
|
|
|
|
|
S1, L1 = compute_rsa_probas(logits, self.prior, rationality=rationality)
|
|
logits = S1
|
|
|
|
if do_sample and not process_logits_before_rsa:
|
|
logits = logits.view(world_size * num_beams, -1)
|
|
if top_p is not None:
|
|
logits = TopPLogitsWarper(top_p=top_p)(input_ids=None, scores=logits)
|
|
if top_k is not None:
|
|
logits = TopKLogitsWarper(top_k=top_k)(input_ids=None, scores=logits)
|
|
|
|
logits = logits.view(world_size, num_beams, -1)
|
|
|
|
return logits, L1
|
|
|
|
def generate(
|
|
self,
|
|
target_id: int,
|
|
source_texts_ids: torch.Tensor,
|
|
source_text_attention_mask: torch.Tensor,
|
|
max_length: int = 100,
|
|
num_beams: int = 8,
|
|
do_sample=True,
|
|
top_p: Optional[float] = None,
|
|
top_k: Optional[int] = None,
|
|
temperature: float = 1.0,
|
|
rationality: float = 1.0,
|
|
process_logits_before_rsa=True,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
|
|
:param target_id: the id of the target object
|
|
:param source_texts_ids: (world_size, input_length) the tokenized source texts
|
|
:param source_text_attention_mask: (world_size, input_length) the attention mask for the source texts
|
|
:param max_length: the maximum length to generate
|
|
:param do_sample: are we sampling or using argmax
|
|
:param top_p: parameters for the logits processor top p
|
|
:param top_k: parameters for the logits processor top k
|
|
:param temperature: sampling temperature
|
|
:param rationality: how rational is the speaker (higher means more rational)
|
|
:param process_logits_before_rsa: should we apply the logits processor before or after the RSA computation
|
|
:return: decoder_input_ids : (num_beams, max_length) decoded sequences, beam_scores: (num_beams) the scores
|
|
of the beams
|
|
"""
|
|
|
|
self.num_beam = num_beams
|
|
self.world_size = source_texts_ids.shape[0]
|
|
|
|
self.prior = torch.ones((self.world_size, self.num_beam)).to(self.device) / self.world_size
|
|
beam_scores = torch.zeros(self.num_beam).to(self.device)
|
|
|
|
|
|
decoder_input_ids = torch.full(
|
|
(self.num_beam, 2),
|
|
0,
|
|
dtype=torch.long,
|
|
device=self.device,
|
|
)
|
|
|
|
|
|
decoder_attention_mask = torch.ones_like(decoder_input_ids).to(self.device)
|
|
|
|
new_beams = []
|
|
finished_beams = []
|
|
|
|
|
|
for t in range(max_length):
|
|
|
|
num_beams = decoder_input_ids.shape[0]
|
|
|
|
S1, L1 = self.compute_rsa_probas(
|
|
source_texts_ids,
|
|
source_text_attention_mask,
|
|
decoder_input_ids,
|
|
decoder_attention_mask,
|
|
do_sample=do_sample,
|
|
top_p=top_p,
|
|
top_k=top_k,
|
|
temperature=temperature,
|
|
rationality=rationality,
|
|
beam_scores=beam_scores,
|
|
process_logits_before_rsa=process_logits_before_rsa,
|
|
)
|
|
|
|
|
|
idx_beam, idx_token, tokens_scores = sample_from_probs(
|
|
S1[target_id].squeeze(), num_beams, do_sample
|
|
)
|
|
|
|
|
|
|
|
new_beams = []
|
|
|
|
for idx_t, idx_b, token_score in zip(idx_token, idx_beam, tokens_scores):
|
|
new_beams.append(
|
|
(
|
|
decoder_input_ids[idx_b].tolist() + [idx_t.item()],
|
|
beam_scores[idx_b] + token_score.item(),
|
|
L1[:, idx_b, idx_t.item()],
|
|
)
|
|
)
|
|
|
|
|
|
new_beams = sorted(new_beams, key=lambda x: x[1], reverse=True)
|
|
|
|
|
|
new_beams = new_beams[: self.num_beam]
|
|
|
|
|
|
_new_beams = []
|
|
for beam in new_beams:
|
|
if beam[0][-1] == self.tokenizer.eos_token_id:
|
|
finished_beams.append(beam)
|
|
|
|
else:
|
|
_new_beams.append(beam)
|
|
|
|
new_beams = _new_beams
|
|
|
|
if len(new_beams) == 0:
|
|
break
|
|
|
|
|
|
max_beam_len = max(len(x[0]) for x in new_beams)
|
|
new_beams = [
|
|
(
|
|
x[0] + [self.tokenizer.pad_token_id] * (max_beam_len - len(x[0])),
|
|
x[1],
|
|
x[2],
|
|
)
|
|
for x in new_beams
|
|
]
|
|
|
|
|
|
beam_scores = torch.tensor([x[1] for x in new_beams]).to(self.device)
|
|
|
|
|
|
decoder_input_ids: torch.Tensor = torch.tensor(
|
|
[x[0] for x in new_beams], device=self.device
|
|
)
|
|
|
|
|
|
decoder_attention_mask = (
|
|
decoder_input_ids != self.tokenizer.pad_token_id
|
|
).long()
|
|
|
|
self.prior = torch.stack([x[2] for x in new_beams], dim=1).to(self.device)
|
|
|
|
|
|
|
|
results = []
|
|
|
|
|
|
max_beam_len = max(len(x[0]) for x in finished_beams + new_beams)
|
|
for x in finished_beams + new_beams:
|
|
results.append(
|
|
(
|
|
x[0] + [self.tokenizer.pad_token_id] * (max_beam_len - len(x[0])),
|
|
x[1],
|
|
x[2],
|
|
)
|
|
)
|
|
|
|
decoder_input_ids = torch.tensor([x[0] for x in results], device=self.device)
|
|
|
|
beam_scores = torch.tensor([x[1] for x in results]).to(self.device)
|
|
|
|
return decoder_input_ids, beam_scores
|
|
|