samyak152002 commited on
Commit
a7fc3b1
1 Parent(s): 4946849

Update script

Browse files
Files changed (1) hide show
  1. script +33 -17
script CHANGED
@@ -1,26 +1,42 @@
1
  import torch
2
- from transformers import DistilBertTokenizer, DistilBertModel
3
 
4
  # Load the tokenizer and model
5
- tokenizer = DistilBertTokenizer.from_pretrained("tokenizer_config.json")
6
- model = DistilBertModel.from_pretrained("model.pt")
7
 
8
- # Define the inference function
9
- def predict(text):
10
- # Tokenize the input
11
- inputs = tokenizer(text, padding="max_length", truncation=True, return_tensors="pt")
12
 
13
- # Perform the inference
14
- with torch.no_grad():
15
- outputs = model(**inputs)
16
- logits = outputs.logits
17
 
18
- # Convert logits to probabilities
19
- probabilities = torch.softmax(logits, dim=1).squeeze().tolist()
20
 
21
- return probabilities
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  # Example usage
24
- text = "This is a sample input."
25
- probabilities = predict(text)
26
- print(probabilities)
 
1
  import torch
2
+ from transformers import DistilBertModel, DistilBertTokenizer
3
 
4
  # Load the tokenizer and model
5
+ tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
6
+ model = DistilBertCNN(num_labels=3) # Assuming you have defined the custom classification layers
7
 
8
+ # Move the model to CPU
9
+ device = torch.device("cpu")
10
+ model.to(device)
 
11
 
12
+ # Load the saved model state dictionary
13
+ model.load_state_dict(torch.load("path/to/save/directory/model.pt", map_location=device))
 
 
14
 
15
+ # Set the model to evaluation mode
16
+ model.eval()
17
 
18
+ # Define a function to predict the class of a given tweet
19
+ def classify_tweet(tweet):
20
+ inputs = tokenizer.encode_plus(
21
+ tweet,
22
+ add_special_tokens=True,
23
+ max_length=128,
24
+ padding="max_length",
25
+ truncation=True,
26
+ return_tensors="pt"
27
+ )
28
+ input_ids = inputs["input_ids"].to(device)
29
+ attention_mask = inputs["attention_mask"].to(device)
30
+
31
+ with torch.no_grad():
32
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
33
+
34
+ logits = outputs[0]
35
+ predicted_class = torch.argmax(logits).item()
36
+
37
+ return predicted_class
38
 
39
  # Example usage
40
+ tweet = "This is a sample tweet."
41
+ predicted_class = classify_tweet(tweet)
42
+ print(f"Predicted Class: {predicted_class}")