RickMartel commited on
Commit
83162a1
1 Parent(s): cdfd163

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -19,10 +19,14 @@ class StoppingCriteriaSub(StoppingCriteria):
19
  self.encounters = encounters
20
 
21
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
 
 
22
  for stop in self.stops:
23
- if sum( input_ids[0] == stop ) >= self.encounters: return True
24
- return False
 
25
 
 
26
  stop_words = ['.']
27
  stop_words_ids = [tokenizer(stop_word, return_tensors='pt', add_special_tokens=False)['input_ids'].squeeze() for stop_word in stop_words]
28
  stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids,
 
19
  self.encounters = encounters
20
 
21
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
22
+ last_tkn = input_ids[0][-1]
23
+ stop_word_found = False
24
  for stop in self.stops:
25
+ if sum( input_ids[0] == stop ) >= self.encounters:
26
+ stop_word_found = True
27
+ return stop_word_found and self.stops[0] == last_tkn
28
 
29
+ # The StoppingCriteriaSub assumes period is the first token id.
30
  stop_words = ['.']
31
  stop_words_ids = [tokenizer(stop_word, return_tensors='pt', add_special_tokens=False)['input_ids'].squeeze() for stop_word in stop_words]
32
  stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids,