awacke1 commited on
Commit
44f49c7
1 Parent(s): 628cba2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -60,10 +60,10 @@ model = BlenderbotForConditionalGeneration.from_pretrained(mname)
60
  tokenizer = BlenderbotTokenizer.from_pretrained(mname)
61
 
62
  def take_last_tokens(inputs, note_history, history):
63
- """Filter the last 128 tokens"""
64
- if inputs['input_ids'].shape[1] > 128:
65
- inputs['input_ids'] = torch.tensor([inputs['input_ids'][0][-128:].tolist()])
66
- inputs['attention_mask'] = torch.tensor([inputs['attention_mask'][0][-128:].tolist()])
67
  note_history = ['</s> <s>'.join(note_history[0].split('</s> <s>')[2:])]
68
  history = history[1:]
69
  return inputs, note_history, history
 
60
  tokenizer = BlenderbotTokenizer.from_pretrained(mname)
61
 
62
  def take_last_tokens(inputs, note_history, history):
63
+ filterTokenCount = 128 # filter last 128 tokens
64
+ if inputs['input_ids'].shape[1] > filterTokenCount:
65
+ inputs['input_ids'] = torch.tensor([inputs['input_ids'][0][-filterTokenCount:].tolist()])
66
+ inputs['attention_mask'] = torch.tensor([inputs['attention_mask'][0][-filterTokenCount:].tolist()])
67
  note_history = ['</s> <s>'.join(note_history[0].split('</s> <s>')[2:])]
68
  history = history[1:]
69
  return inputs, note_history, history