File size: 2,706 Bytes
285573d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import torch
import torch.nn.functional as F
from transformers import BertTokenizer, BertForTokenClassification
import re
import string
def preprocess_input_text(text):
"""
This function adds a [MASK] token after each word, inserts a space before every punctuation mark,
and converts all words to lowercase.
It returns the original words from the input text along with the preprocessed version of the input text.
"""
text = re.sub(r'([' + string.punctuation + '])', r' \1', text)
text = re.sub(' +', ' ', text)
words = text.split(" ")
text = text.lower()
output = []
for word in text.split(" "):
output.append(word)
output.append("[MASK]")
return words, " ".join(output)
def predict_using_trained_model_old(input_text, model_dir, device):
"""
This function loads a model and predicts whether each word in the input text is correct or incorrect.
The output is the input text, where each word is followed by a label indicating whether the word is correct (0) or incorrect (1).
"""
words, input_text = preprocess_input_text(input_text)
tokenizer = BertTokenizer.from_pretrained(model_dir)
model = BertForTokenClassification.from_pretrained(model_dir, num_labels=2)
model.to(device)
tokenized_inputs = tokenizer(input_text, max_length=128, padding='max_length', truncation=True, return_tensors="pt")
input_ids = tokenized_inputs["input_ids"].to(device)
attention_mask = tokenized_inputs["attention_mask"].to(device)
model.eval()
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits
predictions = torch.argmax(logits, dim=-1).squeeze().cpu().numpy()
tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze().cpu().numpy())
model_output = []
mask_index = 0
for token, prediction in zip(tokens, predictions):
if token == "[MASK]":
model_output.append(str(prediction))
mask_index += 1
elif token != "[CLS]" and token != "[SEP]" and token != "[PAD]":
model_output.append(words[mask_index])
return " ".join(model_output)
if __name__ == '__main__':
input_text = "Model u tekstu prepoznije riječi u kojima se nalazaju pogreške."
model_dir = "."
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
print(f"Using device: {device}")
model_output_text = predict_using_trained_model_old(input_text, model_dir, device)
print(f"Model output: {model_output_text}")
|