rwcuffney's picture
Create app.py
c51d635
raw
history blame
No virus
1.1 kB
from datasets import load_dataset
dataset = load_dataset('rwcuffney/pick_a_card_test', batch_size=32, shuffle=True)
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained('rwcuffney/autotrain-pick_a_card-3726099224')
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('rwcuffney/autotrain-pick_a_card-3726099224')
def preprocess_text(text):
encoded = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
return encoded
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
model.eval()
for batch in dataset:
# Preprocess the text
text = batch['text']
inputs = preprocess_text(text)
inputs = inputs.to(device)
# Make predictions
with torch.no_grad():
outputs = model(**inputs)
predicted_classes = torch.argmax(outputs.logits, dim=-1)
# Print the predicted class labels
predicted_labels = [dataset.features['label'].names[i] for i in predicted_classes]
print(predicted_labels)