This model is finetuned using AG news dataset for 2 epochs using 120000 train samples and evaluated on the test set with below metrics.

Test Loss: 0.1629

Accuracy: 0.9521

F1 Score: 0.9521

Precision: 0.9522

Recall: 0.9522

# Import necessary libraries
import torch
import torch.nn as nn
from transformers import T5Tokenizer, T5ForConditionalGeneration

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the model class (same structure as used during training)
class CustomT5Model(nn.Module):
    def __init__(self):
        super(CustomT5Model, self).__init__()
        self.t5 = T5ForConditionalGeneration.from_pretrained("t5-large")
        self.classifier = nn.Linear(1024, 4)  # 4 classes for AG News

    def forward(self, input_ids, attention_mask=None):
        encoder_outputs = self.t5.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        hidden_states = encoder_outputs.last_hidden_state  # (batch_size, seq_len, hidden_dim)
        logits = self.classifier(hidden_states[:, 0, :])  # Use [CLS] token representation
        return logits

# Initialize the model
model = CustomT5Model().to(device)

# Load the saved model weights from Hugging Face
model_path = "https://huggingface.co/Vijayendra/T5-large-docClassification/resolve/main/best_model.pth"
model.load_state_dict(torch.hub.load_state_dict_from_url(model_path, map_location=device))
model.eval()

# Load the tokenizer
tokenizer = T5Tokenizer.from_pretrained("t5-large")

# Inference function
def infer(model, tokenizer, text):
    model.eval()
    with torch.no_grad():
        # Preprocess the input text
        inputs = tokenizer(
            [f"classify: {text}"],
            max_length=99,
            truncation=True,
            padding="max_length",
            return_tensors="pt"
        )
        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs["attention_mask"].to(device)

        # Get model predictions
        logits = model(input_ids=input_ids, attention_mask=attention_mask)
        preds = torch.argmax(logits, dim=-1)

        # Map class index to label
        label_map = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
        return label_map[preds.item()]

# Example usage
text = "NASA announces new mission to study asteroids"
result = infer(model, tokenizer, text)
print(f"Predicted category: {result}")
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.

Model tree for Vijayendra/T5-large-docClassification

Base model

google-t5/t5-large
Finetuned
(75)
this model

Dataset used to train Vijayendra/T5-large-docClassification