How to get esm2_t33_650M_UR50D Fixed embedding using in the downstram task?

#4
by xigua666 - opened

Here is my usage process. I'm not sure if it's correct. Can someone help me

def forward(self,encoded_inputs):
# print(seqs)
# encoded_inputs = self.tokenizer(seqs, max_length=65, padding=True, truncation=True, return_tensors='pt')
embedded_data = self.model(**encoded_inputs).last_hidden_state.mean(0) # this code
print("embedded_data shape == >",embedded_data.shape)
print(embedded_data)

    # .mean(0) [:,0, :]
   
    output = torch.squeeze(self.main(embedded_data))
    # print(output.shape)
    return output

Sign up or log in to comment