Chi Honolulu commited on
Commit
ce082fb
·
1 Parent(s): dbe41d3

Upload README.md

Browse files
Files changed (1) hide show
  1. README.md +5 -0
README.md CHANGED
@@ -33,14 +33,18 @@ Here is how to use this model to classify a context-window of a dialogue:
33
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
34
  import torch
35
 
 
36
  test_texts = ['Utterance2']
 
37
  test_text_pairs = ['Utterance1;Utterance2;Utterance3']
38
 
 
39
  checkpoint_path = "chi2024/mt5-base-binary-cs-iiia"
40
  model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint_path)\
41
  .to("cuda" if torch.cuda.is_available() else "cpu")
42
  tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
43
 
 
44
  def verbalize_input(text: str, text_pair: str) -> str:
45
  return "Utterance: %s\nContext: %s" % (text, text_pair)
46
 
@@ -53,6 +57,7 @@ def predict_one(text, pair):
53
  tokenizer.batch_decode(outputs, skip_special_tokens=True)]
54
  return decoded
55
 
 
56
  preds_txt = [predict_one(t,p) for t,p in zip(test_texts, test_text_pairs)]
57
  preds_lbl = [1 if x == 'positive' else 0 for x in preds_txt]
58
  print(preds_lbl)
 
33
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
34
  import torch
35
 
36
+ # Target utterance
37
  test_texts = ['Utterance2']
38
+ # Bi-directional context of the target utterance
39
  test_text_pairs = ['Utterance1;Utterance2;Utterance3']
40
 
41
+ # Load the model and tokenizer
42
  checkpoint_path = "chi2024/mt5-base-binary-cs-iiia"
43
  model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint_path)\
44
  .to("cuda" if torch.cuda.is_available() else "cpu")
45
  tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
46
 
47
+ # Define helper functions
48
  def verbalize_input(text: str, text_pair: str) -> str:
49
  return "Utterance: %s\nContext: %s" % (text, text_pair)
50
 
 
57
  tokenizer.batch_decode(outputs, skip_special_tokens=True)]
58
  return decoded
59
 
60
+ # Run the prediction
61
  preds_txt = [predict_one(t,p) for t,p in zip(test_texts, test_text_pairs)]
62
  preds_lbl = [1 if x == 'positive' else 0 for x in preds_txt]
63
  print(preds_lbl)