Vijayendra's picture
Update README.md
4ec4ab6 verified
|
raw
history blame
2.26 kB
metadata
license: mit

Import necessary libraries

import torch import torch.nn as nn from transformers import T5Tokenizer, T5ForConditionalGeneration

Set device

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Define the model class (same structure as used during training)

class CustomT5Model(nn.Module): def init(self): super(CustomT5Model, self).init() self.t5 = T5ForConditionalGeneration.from_pretrained("t5-large") self.classifier = nn.Linear(1024, 4) # 4 classes for AG News

def forward(self, input_ids, attention_mask=None):
    encoder_outputs = self.t5.encoder(
        input_ids=input_ids,
        attention_mask=attention_mask,
        return_dict=True
    )
    hidden_states = encoder_outputs.last_hidden_state  # (batch_size, seq_len, hidden_dim)
    logits = self.classifier(hidden_states[:, 0, :])  # Use [CLS] token representation
    return logits

Initialize the model

model = CustomT5Model().to(device)

Load the saved model weights from Hugging Face

model_path = "https://huggingface.co/Vijayendra/T5-large-docClassification/resolve/main/best_model.pth" model.load_state_dict(torch.hub.load_state_dict_from_url(model_path, map_location=device)) model.eval()

Load the tokenizer

tokenizer = T5Tokenizer.from_pretrained("t5-large")

Inference function

def infer(model, tokenizer, text): model.eval() with torch.no_grad(): # Preprocess the input text inputs = tokenizer( [f"classify: {text}"], max_length=99, truncation=True, padding="max_length", return_tensors="pt" ) input_ids = inputs["input_ids"].to(device) attention_mask = inputs["attention_mask"].to(device)

    # Get model predictions
    logits = model(input_ids=input_ids, attention_mask=attention_mask)
    preds = torch.argmax(logits, dim=-1)

    # Map class index to label
    label_map = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
    return label_map[preds.item()]

Example usage

text = "NASA announces new mission to study asteroids" result = infer(model, tokenizer, text) print(f"Predicted category: {result}")