acho0057's picture
add files
a4c9d33
import torch
def predict_fn(data, model_and_tokenizer):
# destruct model and tokenizer
model, tokenizer = model_and_tokenizer
# Tokenize sentences
sentences = data.pop("inputs", data)
encoded_input = tokenizer(sentences, add_special_tokens=False,return_tensors='pt')
input_id_chunks = list(encoded_input['input_ids'][0].split(510))
mask_chunks = list(encoded_input['attention_mask'][0].split(510))
for i in range(len(input_id_chunks)):
input_id_chunks[i]=torch.cat([torch.Tensor([101]),input_id_chunks[i],torch.Tensor([102])])
mask_chunks[i] = torch.cat([
torch.Tensor([1]), mask_chunks[i], torch.Tensor([1])
])
pad_len = 512 - input_id_chunks[i].shape[0]
if pad_len > 0:
input_id_chunks[i] = torch.cat([input_id_chunks[i],torch.Tensor([0]*pad_len)])
mask_chunks[i] = torch.cat([mask_chunks[i],torch.Tensor([0]*pad_len)])
input_ids = torch.stack(input_id_chunks)
attention_masks = torch.stack(mask_chunks)
input_dict = {
'input_ids': input_ids.long(),
'attention_mask': attention_masks.int()
}
output = model(**input_dict)
print("inference.py")
return output