|
from transformers import BertTokenizer, BertForSequenceClassification |
|
import torch |
|
|
|
|
|
|
|
|
|
model = BertForSequenceClassification.from_pretrained('./test_model') |
|
|
|
tokenizer = BertForSequenceClassification.from_pretrained('./test_tokenizer') |
|
|
|
|
|
def predict_relevance(question, answer): |
|
|
|
if not answer.strip(): |
|
return "Irrelevant" |
|
|
|
|
|
inputs = tokenizer(question, answer, return_tensors="pt", padding=True, truncation=True) |
|
model.eval() |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
probabilities = torch.softmax(logits, dim=-1) |
|
|
|
threshold = 0.5 |
|
prediction = torch.argmax(probabilities, dim=-1) |
|
relevant_prob = probabilities[0, 1] |
|
|
|
|
|
if relevant_prob > threshold: |
|
return "Relevant" |
|
else: |
|
return "Irrelevant" |
|
|
|
|
|
question = "What is your experience with Python?" |
|
answer = "I have minimal experience with java, mostly for small automation tasks." |
|
result = predict_relevance(question, answer) |
|
print(f"Relevance: {result}") |
|
|