real-jiakai commited on
Commit
cc4e08c
1 Parent(s): cf30567

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +38 -21
README.md CHANGED
@@ -71,6 +71,30 @@ model_path = "real-jiakai/roberta-base-uncased-finetuned-swag"
71
  tokenizer = AutoTokenizer.from_pretrained(model_path)
72
  model = AutoModelForMultipleChoice.from_pretrained(model_path)
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  # Example scenarios
75
  test_examples = [
76
  {
@@ -93,29 +117,22 @@ test_examples = [
93
  }
94
  ]
95
 
96
- def predict_swag(context, endings, model, tokenizer):
97
- encoding = tokenizer(
98
- [context] * 4,
99
- endings,
100
- truncation=True,
101
- max_length=128,
102
- padding="max_length",
103
- return_tensors="pt"
104
  )
105
 
106
- input_ids = encoding['input_ids'].unsqueeze(0)
107
- attention_mask = encoding['attention_mask'].unsqueeze(0)
108
-
109
- outputs = model(input_ids=input_ids, attention_mask=attention_mask)
110
- logits = outputs.logits
111
-
112
- predicted_idx = torch.argmax(logits).item()
113
-
114
- return {
115
- 'context': context,
116
- 'predicted_ending': endings[predicted_idx],
117
- 'probabilities': torch.softmax(logits, dim=1)[0].tolist()
118
- }
119
  ```
120
 
121
  ## Limitations and Biases
 
71
  tokenizer = AutoTokenizer.from_pretrained(model_path)
72
  model = AutoModelForMultipleChoice.from_pretrained(model_path)
73
 
74
+ def predict_swag(context, endings, model, tokenizer):
75
+ encoding = tokenizer(
76
+ [context] * 4,
77
+ endings,
78
+ truncation=True,
79
+ max_length=128,
80
+ padding="max_length",
81
+ return_tensors="pt"
82
+ )
83
+
84
+ input_ids = encoding['input_ids'].unsqueeze(0)
85
+ attention_mask = encoding['attention_mask'].unsqueeze(0)
86
+
87
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
88
+ logits = outputs.logits
89
+
90
+ predicted_idx = torch.argmax(logits).item()
91
+
92
+ return {
93
+ 'context': context,
94
+ 'predicted_ending': endings[predicted_idx],
95
+ 'probabilities': torch.softmax(logits, dim=1)[0].tolist()
96
+ }
97
+
98
  # Example scenarios
99
  test_examples = [
100
  {
 
117
  }
118
  ]
119
 
120
+ for i, example in enumerate(test_examples, 1):
121
+ result = predict_swag(
122
+ example['context'],
123
+ example['endings'],
124
+ model,
125
+ tokenizer
 
 
126
  )
127
 
128
+ print(f"\n=== Test Scenario {i} ===")
129
+ print(f"Initial Context: {result['context']}")
130
+ print(f"\nPredicted Most Likely Ending: {result['predicted_ending']}")
131
+ print("\nProbabilities for All Options:")
132
+ for idx, (ending, prob) in enumerate(zip(result['all_endings'], result['probabilities'])):
133
+ print(f"Option {idx}: {ending}")
134
+ print(f"Probability: {prob:.3f}")
135
+ print("\n" + "="*50)
 
 
 
 
 
136
  ```
137
 
138
  ## Limitations and Biases