nxphi47 commited on
Commit
6544b41
1 Parent(s): f4b3d1c

Update multipurpose_chatbot/engines/transformers_engine.py

Browse files
multipurpose_chatbot/engines/transformers_engine.py CHANGED
@@ -429,6 +429,7 @@ class TransformersEngine(BaseEngine):
429
 
430
  # ! MUST PUT INSIDE torch.no_grad() otherwise it will overflow OOM
431
  import sys
 
432
  with torch.no_grad():
433
  inputs = self.tokenizer(prompt, return_tensors='pt')
434
  num_tokens = inputs.input_ids.size(1)
@@ -448,12 +449,16 @@ class TransformersEngine(BaseEngine):
448
  for token in generator:
449
  out_tokens.extend(token.tolist())
450
  response = self.tokenizer.decode(out_tokens)
 
 
451
  num_tokens += 1
452
  print(f"{response}", end='\r')
453
  sys.stdout.flush()
454
  yield response, num_tokens
455
 
456
  if response is not None:
 
 
457
  full_text = prompt + response
458
  num_tokens = len(self.tokenizer.encode(full_text))
459
  yield response, num_tokens
 
429
 
430
  # ! MUST PUT INSIDE torch.no_grad() otherwise it will overflow OOM
431
  import sys
432
+ self._model._sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)
433
  with torch.no_grad():
434
  inputs = self.tokenizer(prompt, return_tensors='pt')
435
  num_tokens = inputs.input_ids.size(1)
 
449
  for token in generator:
450
  out_tokens.extend(token.tolist())
451
  response = self.tokenizer.decode(out_tokens)
452
+ if "<|im_start|>assistant\n" in response:
453
+ response = response.split("<|im_start|>assistant\n")
454
  num_tokens += 1
455
  print(f"{response}", end='\r')
456
  sys.stdout.flush()
457
  yield response, num_tokens
458
 
459
  if response is not None:
460
+ if "<|im_start|>assistant\n" in response:
461
+ response = response.split("<|im_start|>assistant\n")
462
  full_text = prompt + response
463
  num_tokens = len(self.tokenizer.encode(full_text))
464
  yield response, num_tokens