taka-yamakoshi commited on
Commit
2092dd1
1 Parent(s): ddf537a
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -54,13 +54,13 @@ if __name__=='__main__':
54
  tokenizer,model = load_model()
55
  mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
56
 
57
- sent_1 = st.sidebar.text_input('Sentence 1',on_change=clear_data)
58
- sent_2 = st.sidebar.text_input('Sentence 2',on_change=clear_data)
59
  input_ids_1 = tokenizer(sent_1).input_ids
60
  input_ids_2 = tokenizer(sent_2).input_ids
61
  input_ids = np.array([input_ids_1,input_ids_2])
62
 
63
- outputs = model(input_ids, interv_type='swap', interv_dict = {0,{'lay':[(8,1,[0,1])]}})
64
  logprobs = jax.nn.log_softmax(outputs.logits, axis = -1)
65
  preds = [np.random.choice(np.arange(len(probs)),p=np.exp(probs)/np.sum(np.exp(probs))) for probs in logprobs[0]]
66
  st.write([tokenizer.decode([token]) for token in preds])
 
54
  tokenizer,model = load_model()
55
  mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
56
 
57
+ sent_1 = st.sidebar.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.',on_change=clear_data)
58
+ sent_2 = st.sidebar.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.',on_change=clear_data)
59
  input_ids_1 = tokenizer(sent_1).input_ids
60
  input_ids_2 = tokenizer(sent_2).input_ids
61
  input_ids = np.array([input_ids_1,input_ids_2])
62
 
63
+ outputs = model(input_ids, interv_type='swap', interv_dict = {0:{'lay':[(8,1,[0,1])]}})
64
  logprobs = jax.nn.log_softmax(outputs.logits, axis = -1)
65
  preds = [np.random.choice(np.arange(len(probs)),p=np.exp(probs)/np.sum(np.exp(probs))) for probs in logprobs[0]]
66
  st.write([tokenizer.decode([token]) for token in preds])