Vijayendra commited on
Commit
4ec4ab6
·
verified ·
1 Parent(s): 37dcd36

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +66 -3
README.md CHANGED
@@ -1,3 +1,66 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ ---
4
+ # Import necessary libraries
5
+ import torch
6
+ import torch.nn as nn
7
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
8
+
9
+ # Set device
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+ # Define the model class (same structure as used during training)
13
+ class CustomT5Model(nn.Module):
14
+ def __init__(self):
15
+ super(CustomT5Model, self).__init__()
16
+ self.t5 = T5ForConditionalGeneration.from_pretrained("t5-large")
17
+ self.classifier = nn.Linear(1024, 4) # 4 classes for AG News
18
+
19
+ def forward(self, input_ids, attention_mask=None):
20
+ encoder_outputs = self.t5.encoder(
21
+ input_ids=input_ids,
22
+ attention_mask=attention_mask,
23
+ return_dict=True
24
+ )
25
+ hidden_states = encoder_outputs.last_hidden_state # (batch_size, seq_len, hidden_dim)
26
+ logits = self.classifier(hidden_states[:, 0, :]) # Use [CLS] token representation
27
+ return logits
28
+
29
+ # Initialize the model
30
+ model = CustomT5Model().to(device)
31
+
32
+ # Load the saved model weights from Hugging Face
33
+ model_path = "https://huggingface.co/Vijayendra/T5-large-docClassification/resolve/main/best_model.pth"
34
+ model.load_state_dict(torch.hub.load_state_dict_from_url(model_path, map_location=device))
35
+ model.eval()
36
+
37
+ # Load the tokenizer
38
+ tokenizer = T5Tokenizer.from_pretrained("t5-large")
39
+
40
+ # Inference function
41
+ def infer(model, tokenizer, text):
42
+ model.eval()
43
+ with torch.no_grad():
44
+ # Preprocess the input text
45
+ inputs = tokenizer(
46
+ [f"classify: {text}"],
47
+ max_length=99,
48
+ truncation=True,
49
+ padding="max_length",
50
+ return_tensors="pt"
51
+ )
52
+ input_ids = inputs["input_ids"].to(device)
53
+ attention_mask = inputs["attention_mask"].to(device)
54
+
55
+ # Get model predictions
56
+ logits = model(input_ids=input_ids, attention_mask=attention_mask)
57
+ preds = torch.argmax(logits, dim=-1)
58
+
59
+ # Map class index to label
60
+ label_map = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
61
+ return label_map[preds.item()]
62
+
63
+ # Example usage
64
+ text = "NASA announces new mission to study asteroids"
65
+ result = infer(model, tokenizer, text)
66
+ print(f"Predicted category: {result}")