Confused about bidirectional attention when implementing custom sampling loop

#25
by ericanthonymitchell - opened

I'm trying to implement a custom sampling loop for GPT-JT, because I need some features not supported by model.generate. However, I'm a bit confused about how the bidirectional attention mask is tracked. Can someone point me to the code when GPT-JT bidirectional vs causal masking is controlled?

In this answer, @juewang mentions that the causal attention mask for GPT-JT is set to 1 by default. However, loading GPT-JT with transformers.AutoModelForCausalLM.from_pretrained just loads a normal GPT-J model, and the attention bias for GPT-J defaults to causal attention, as far as I can tell from here.

Could someone explain what I'm missing? I'm confused about how GPT-JT can implement custom attention masking, when there doesn't seem to be any GPT-JT-specific code in HuggingFace (just relying on GPT-J).

Thanks!

I was confused because I didn't realize that the attention_mask is actually a PyTorch registered buffer, i.e., part of the weights checkpoint; it's not controlled in code. The mask is in model.transformer.h[i].attn.bias.data[:].

My simple sampling loop looks like this, for reference:

def gptjt_sample(model, tokenizer, prompt_text, max_length=100, eos_token_id=None, do_sample=False):
    dev = list(model.parameters())[0].device
    input_ids = tokenizer(prompt_text, return_tensors='pt').input_ids.to(dev)
    past_key_values = None
    output_ids = input_ids
    for i in range(max_length):
        possibly_only_last_token = output_ids[:, -1:] if past_key_values is not None else output_ids
        outputs = model(possibly_only_last_token, use_cache=True, past_key_values=past_key_values, output_hidden_states=True)
        past_key_values = outputs.past_key_values
    
        next_token_logits = outputs.logits[:, -1, :]
        if do_sample:
            next_token = torch.multinomial(torch.softmax(next_token_logits, dim=-1), num_samples=1)
        else:
            next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
        output_ids = torch.cat([output_ids, next_token], dim=-1)
        if eos_token_id is not None and next_token == eos_token_id:
            break
    return tokenizer.decode(output_ids[0], skip_special_tokens=True)
ericanthonymitchell changed discussion status to closed
Together org

Yeah, you are right, attention_mask is a registered buffer and will be overwritten after loading the ckpt.

Sign up or log in to comment