chaenykim commited on
Commit
67181a9
1 Parent(s): 15df9b5

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +21 -0
README.md CHANGED
@@ -1,3 +1,24 @@
1
  ---
2
  license: mit
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
  ---
4
+ ## Usage
5
+
6
+ ```python
7
+ import torch
8
+ from informer_models import InformerConfig, InformerForSequenceClassification
9
+
10
+ # Loading the model
11
+ model = InformerForSequenceClassification.from_pretrained("BrachioLab/supernova-classification")
12
+ model.to(device)
13
+ model.eval()
14
+ y_true = []
15
+ y_pred = []
16
+ for i, batch in enumerate(test_dataloader):
17
+ print(f"processing batch {i}")
18
+ batch = {k: v.to(device) for k, v in batch.items() if k != "objid"}
19
+ with torch.no_grad():
20
+ outputs = model(**batch)
21
+ y_true.extend(batch['labels'].cpu().numpy())
22
+ y_pred.extend(torch.argmax(outputs.logits, dim=2).squeeze().cpu().numpy())
23
+ print(f"accuracy: {sum([1 for i, j in zip(y_true, y_pred) if i == j]) / len(y_true)}")
24
+ ```