File size: 2,670 Bytes
4ec4ab6 d6e81b1 4ec4ab6 1bd862c 01508d7 d6e81b1 915173a d6e81b1 915173a d6e81b1 915173a d6e81b1 915173a 96a411a d6e81b1 e0c17c6 4ec4ab6 d6e81b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
---
license: mit
datasets:
- fancyzhx/ag_news
language:
- en
metrics:
- accuracy
base_model:
- google-t5/t5-large
pipeline_tag: text-classification
tags:
- ag
- news
- document
- classification
---
This model is finetuned using AG news dataset for 2 epochs using 120000 train samples and evaluated on the test set with below metrics.
Test Loss: 0.1629
Accuracy: 0.9521
F1 Score: 0.9521
Precision: 0.9522
Recall: 0.9522
```python
# Import necessary libraries
import torch
import torch.nn as nn
from transformers import T5Tokenizer, T5ForConditionalGeneration
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define the model class (same structure as used during training)
class CustomT5Model(nn.Module):
def __init__(self):
super(CustomT5Model, self).__init__()
self.t5 = T5ForConditionalGeneration.from_pretrained("t5-large")
self.classifier = nn.Linear(1024, 4) # 4 classes for AG News
def forward(self, input_ids, attention_mask=None):
encoder_outputs = self.t5.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
hidden_states = encoder_outputs.last_hidden_state # (batch_size, seq_len, hidden_dim)
logits = self.classifier(hidden_states[:, 0, :]) # Use [CLS] token representation
return logits
# Initialize the model
model = CustomT5Model().to(device)
# Load the saved model weights from Hugging Face
model_path = "https://huggingface.co/Vijayendra/T5-large-docClassification/resolve/main/best_model.pth"
model.load_state_dict(torch.hub.load_state_dict_from_url(model_path, map_location=device))
model.eval()
# Load the tokenizer
tokenizer = T5Tokenizer.from_pretrained("t5-large")
# Inference function
def infer(model, tokenizer, text):
model.eval()
with torch.no_grad():
# Preprocess the input text
inputs = tokenizer(
[f"classify: {text}"],
max_length=99,
truncation=True,
padding="max_length",
return_tensors="pt"
)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
# Get model predictions
logits = model(input_ids=input_ids, attention_mask=attention_mask)
preds = torch.argmax(logits, dim=-1)
# Map class index to label
label_map = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
return label_map[preds.item()]
# Example usage
text = "NASA announces new mission to study asteroids"
result = infer(model, tokenizer, text)
print(f"Predicted category: {result}") |