jessicayjm commited on
Commit
713cf03
1 Parent(s): f798a31

update example

Browse files
Files changed (1) hide show
  1. README.md +3 -3
README.md CHANGED
@@ -63,8 +63,8 @@ del state_dict['model.embeddings.position_ids']
63
  model.load_state_dict(state_dict)
64
 
65
  # use the model
66
- target = ['My cat died yesterday.']
67
- observer = ['I am sorry for your loss.']
68
 
69
  target_encodings = tokenizer(target, padding=True, truncation=True)
70
  target_input_ids = torch.LongTensor(target_encodings['input_ids']).to('cuda')
@@ -75,5 +75,5 @@ observer_attention_mask = torch.LongTensor(observer_encodings['attention_mask'])
75
 
76
  model.eval()
77
  output = model(target_input_ids, target_attention_mask, observer_input_ids, observer_attention_mask)
78
- print(output) # [0.3304]
79
  ```
 
63
  model.load_state_dict(state_dict)
64
 
65
  # use the model
66
+ target = ["I'm so sad that my cat died yesterday."]
67
+ observer = ["It's ok to feel sad."]
68
 
69
  target_encodings = tokenizer(target, padding=True, truncation=True)
70
  target_input_ids = torch.LongTensor(target_encodings['input_ids']).to('cuda')
 
75
 
76
  model.eval()
77
  output = model(target_input_ids, target_attention_mask, observer_input_ids, observer_attention_mask)
78
+ print(output) # [0.5755]
79
  ```