|
|
import torch |
|
|
|
|
|
def predict_fn(data, model_and_tokenizer): |
|
|
|
|
|
model, tokenizer = model_and_tokenizer |
|
|
|
|
|
|
|
|
sentences = data.pop("inputs", data) |
|
|
encoded_input = tokenizer(sentences, add_special_tokens=False,return_tensors='pt') |
|
|
input_id_chunks = list(encoded_input['input_ids'][0].split(510)) |
|
|
mask_chunks = list(encoded_input['attention_mask'][0].split(510)) |
|
|
for i in range(len(input_id_chunks)): |
|
|
input_id_chunks[i]=torch.cat([torch.Tensor([101]),input_id_chunks[i],torch.Tensor([102])]) |
|
|
mask_chunks[i] = torch.cat([ |
|
|
torch.Tensor([1]), mask_chunks[i], torch.Tensor([1]) |
|
|
]) |
|
|
pad_len = 512 - input_id_chunks[i].shape[0] |
|
|
if pad_len > 0: |
|
|
input_id_chunks[i] = torch.cat([input_id_chunks[i],torch.Tensor([0]*pad_len)]) |
|
|
mask_chunks[i] = torch.cat([mask_chunks[i],torch.Tensor([0]*pad_len)]) |
|
|
|
|
|
input_ids = torch.stack(input_id_chunks) |
|
|
attention_masks = torch.stack(mask_chunks) |
|
|
|
|
|
input_dict = { |
|
|
'input_ids': input_ids.long(), |
|
|
'attention_mask': attention_masks.int() |
|
|
} |
|
|
output = model(**input_dict) |
|
|
print("inference.py") |
|
|
return output |
|
|
|
|
|
|