Martijn van Beers commited on
Commit
315cc6b
1 Parent(s): c59b0ef

Update to use the debiased model

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -94,8 +94,8 @@ else:
94
 
95
  model_id = "gpt2"
96
  model_gpt = GPT2LMHeadModel.from_pretrained(model_id).to(device)
97
- #model_custom = torch.load("./gpt2_attn_heads_dm_top10_seed_1.pt")
98
- model_custom = GPT2LMHeadModel.from_pretrained("gpt2-large").to(device)
99
  tokenizer = GPT2TokenizerFast.from_pretrained(model_id)
100
  dataset = CrowSPairsDataset()
101
 
 
94
 
95
  model_id = "gpt2"
96
  model_gpt = GPT2LMHeadModel.from_pretrained(model_id).to(device)
97
+ model_custom = GPT2LMHeadModel.from_pretrained("iabhijith/GPT2-small-debiased").to(device)
98
+
99
  tokenizer = GPT2TokenizerFast.from_pretrained(model_id)
100
  dataset = CrowSPairsDataset()
101