Spaces:
Running
Running
RickMartel
commited on
Commit
•
83162a1
1
Parent(s):
cdfd163
Update app.py
Browse files
app.py
CHANGED
@@ -19,10 +19,14 @@ class StoppingCriteriaSub(StoppingCriteria):
|
|
19 |
self.encounters = encounters
|
20 |
|
21 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
|
|
|
|
22 |
for stop in self.stops:
|
23 |
-
if sum( input_ids[0] == stop ) >= self.encounters:
|
24 |
-
|
|
|
25 |
|
|
|
26 |
stop_words = ['.']
|
27 |
stop_words_ids = [tokenizer(stop_word, return_tensors='pt', add_special_tokens=False)['input_ids'].squeeze() for stop_word in stop_words]
|
28 |
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids,
|
|
|
19 |
self.encounters = encounters
|
20 |
|
21 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
22 |
+
last_tkn = input_ids[0][-1]
|
23 |
+
stop_word_found = False
|
24 |
for stop in self.stops:
|
25 |
+
if sum( input_ids[0] == stop ) >= self.encounters:
|
26 |
+
stop_word_found = True
|
27 |
+
return stop_word_found and self.stops[0] == last_tkn
|
28 |
|
29 |
+
# The StoppingCriteriaSub assumes period is the first token id.
|
30 |
stop_words = ['.']
|
31 |
stop_words_ids = [tokenizer(stop_word, return_tensors='pt', add_special_tokens=False)['input_ids'].squeeze() for stop_word in stop_words]
|
32 |
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids,
|