krum-utsav commited on
Commit
f850893
1 Parent(s): 8236b58

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +53 -0
README.md CHANGED
@@ -38,6 +38,59 @@ paraphraser.paraphrase("Hey, can yuo hepl me cancel my last order?", tone="witty
38
  # "Hey, I need your help with my last order. Can you wave your magic wand and make it disappear?"
39
  ```
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  ## Sample training data
42
 
43
  ```json
 
38
  # "Hey, I need your help with my last order. Can you wave your magic wand and make it disappear?"
39
  ```
40
 
41
+ OR use directly with transformers
42
+
43
+ ```
44
+ from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
45
+
46
+
47
+ DEVICE = "cuda"
48
+ EOC_FORMAT = "\n\n### END"
49
+
50
+
51
+ class StoppingCriteriaSub(StoppingCriteria):
52
+ """Helps in stopping the generation when a certain sequence of tokens is generated."""
53
+
54
+ def __init__(self, stops: list = []):
55
+ super().__init__()
56
+ self.stops = stops
57
+
58
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> bool:
59
+ return input_ids[0][-len(self.stops) :].tolist() == self.stops
60
+
61
+
62
+ stopping_criteria = StoppingCriteriaList(
63
+ [StoppingCriteriaSub(stops=tokenizer(EOC_FORMAT)["input_ids"])]
64
+ )
65
+
66
+
67
+ def predict(input_text: str) -> str:
68
+ tokenized = tokenizer(
69
+ input_text,
70
+ max_length=self.max_length,
71
+ padding=True,
72
+ truncation=True,
73
+ return_tensors="pt",
74
+ )
75
+
76
+ with torch.no_grad():
77
+ out = model.generate(
78
+ input_ids=tokenized["input_ids"].to(DEVICE),
79
+ attention_mask=tokenized["attention_mask"].to(DEVICE),
80
+ pad_token_id=self.tokenizer.eos_token_id,
81
+ max_new_tokens=max_new_tokens,
82
+ num_return_sequences=num_return_sequences,
83
+ do_sample=True,
84
+ temperature=temperature,
85
+ top_p=top_p,
86
+ stopping_criteria=self.stopping_criteria,
87
+ )
88
+
89
+ out_texts = [self.tokenizer.decode(o, skip_special_tokens=True) for o in out]
90
+ for o in out_texts:
91
+ print(o)
92
+ ```
93
+
94
  ## Sample training data
95
 
96
  ```json