yezdata commited on
Commit
38a732c
·
verified ·
1 Parent(s): 7913eef

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +3 -2
README.md CHANGED
@@ -94,9 +94,10 @@ To obtain probabilistic outputs and uncertainty metrics, use the `mc_forward` me
94
  ```python
95
  # Perform 50 stochastic passes
96
  N_SAMPLES = 50
97
- model.eval()
98
 
99
  inputs = tokenizer("I am so happy you are here!", return_tensors="pt")
 
 
100
  with torch.no_grad():
101
  logits_mc = model.mc_forward(inputs['input_ids'], inputs['attention_mask'], n_samples=N_SAMPLES) # Automatically keeps Dropout active, even when in model.eval
102
 
@@ -120,7 +121,7 @@ for idx in sorted_indices:
120
  prob, unc = m_probs[idx].item(), u_vals[idx].item()
121
  label = model.config.id2label[idx.item()]
122
 
123
- if prob > 0.05: # Print only emotions with prob > 5% (optional for clarity)
124
  print(f"{label:<15} | {prob:>8.2%} | ±{unc:>8.4f}")
125
  ```
126
 
 
94
  ```python
95
  # Perform 50 stochastic passes
96
  N_SAMPLES = 50
 
97
 
98
  inputs = tokenizer("I am so happy you are here!", return_tensors="pt")
99
+
100
+ model.eval()
101
  with torch.no_grad():
102
  logits_mc = model.mc_forward(inputs['input_ids'], inputs['attention_mask'], n_samples=N_SAMPLES) # Automatically keeps Dropout active, even when in model.eval
103
 
 
121
  prob, unc = m_probs[idx].item(), u_vals[idx].item()
122
  label = model.config.id2label[idx.item()]
123
 
124
+ if prob > 0.05: # Print only emotions with prob > 5%
125
  print(f"{label:<15} | {prob:>8.2%} | ±{unc:>8.4f}")
126
  ```
127