nxphi47 commited on
Commit
3eeebb2
1 Parent(s): fb10dcf

Update multipurpose_chatbot/engines/transformers_engine.py

Browse files
multipurpose_chatbot/engines/transformers_engine.py CHANGED
@@ -397,6 +397,109 @@ class NewGenerationMixin(GenerationMixin):
397
 
398
 
399
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  class TransformersEngine(BaseEngine):
401
  @property
402
  def max_position_embeddings(self) -> int:
@@ -424,6 +527,18 @@ class TransformersEngine(BaseEngine):
424
  print(self._model)
425
  print(f"{self.max_position_embeddings=}")
426
 
 
 
 
 
 
 
 
 
 
 
 
 
427
  @spaces.GPU
428
  def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
429
 
@@ -431,6 +546,9 @@ class TransformersEngine(BaseEngine):
431
  import sys
432
  # self._model._sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)
433
  self._model.sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)
 
 
 
434
  with torch.no_grad():
435
  inputs = self.tokenizer(prompt, return_tensors='pt')
436
  num_tokens = inputs.input_ids.size(1)
@@ -447,7 +565,7 @@ class TransformersEngine(BaseEngine):
447
 
448
  out_tokens = []
449
  response = None
450
- for token in generator:
451
  out_tokens.extend(token.tolist())
452
  response = self.tokenizer.decode(out_tokens)
453
  if "<|im_start|>assistant\n" in response:
@@ -455,11 +573,15 @@ class TransformersEngine(BaseEngine):
455
  num_tokens += 1
456
  # print(f"{response}", end='\r')
457
  # sys.stdout.flush()
 
458
  yield response, num_tokens
459
-
 
460
  if response is not None:
461
  if "<|im_start|>assistant\n" in response:
462
  response = response.split("<|im_start|>assistant\n")[-1]
 
 
463
  full_text = prompt + response
464
  num_tokens = len(self.tokenizer.encode(full_text))
465
  yield response, num_tokens
 
397
 
398
 
399
 
400
+
401
+ from ..configs import (
402
+ STREAM_CHECK_MULTIPLE,
403
+ STREAM_YIELD_MULTIPLE,
404
+ )
405
+
406
+
407
+ BLOCK_LANGS = str(os.environ.get("BLOCK_LANGS", ""))
408
+ BLOCK_LANGS = [x.strip() for x in BLOCK_LANGS.strip().split(";")] if len(BLOCK_LANGS.strip()) > 0 else []
409
+ LANG_BLOCK_HISTORY = bool(int(os.environ.get("LANG_BLOCK_HISTORY", "0")))
410
+ KEYWORDS = os.environ.get("KEYWORDS", "").strip()
411
+ KEYWORDS = KEYWORDS.split(";") if len(KEYWORDS) > 0 else []
412
+ KEYWORDS = [x.lower() for x in KEYWORDS]
413
+
414
+ LANG_BLOCK_MESSAGE = """Unsupported language."""
415
+
416
+ KEYWORD_BLOCK_MESSAGE = "Invalid request."
417
+
418
+
419
+ def _detect_lang(text):
420
+ # Disable language that may have safety risk
421
+ from langdetect import detect as detect_lang
422
+ dlang = None
423
+ try:
424
+ dlang = detect_lang(text)
425
+ except Exception as e:
426
+ if "No features in text." in str(e):
427
+ return "en"
428
+ else:
429
+ return "zh"
430
+ return dlang
431
+
432
+
433
+ def block_lang(
434
+ message: str,
435
+ history: List[Tuple[str, str]] = None,
436
+ ) -> str:
437
+ # relieve history base block
438
+ if len(BLOCK_LANGS) == 0:
439
+ return False
440
+
441
+ if LANG_BLOCK_HISTORY and history is not None and any((LANG_BLOCK_MESSAGE in x[1].strip()) for x in history):
442
+ return True
443
+ else:
444
+ _lang = _detect_lang(message)
445
+ if _lang in BLOCK_LANGS:
446
+ # print(f'Detect blocked {_lang}: {message}')
447
+ return True
448
+ else:
449
+ return False
450
+
451
+ def safety_check(text, history=None, ) -> Optional[str]:
452
+ """
453
+ Despite our effort in safety tuning and red teaming, our models may still generate harmful or illegal content.
454
+ This provides an additional security measure to enhance safety and compliance with local regulations.
455
+ """
456
+ if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS):
457
+ return KEYWORD_BLOCK_MESSAGE
458
+
459
+ if len(BLOCK_LANGS) > 0:
460
+ if block_lang(text, history):
461
+ return LANG_BLOCK_MESSAGE
462
+
463
+ return None
464
+
465
+
466
+ def safety_check_conversation_string(text, delimiter=None) -> Optional[str]:
467
+ if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS):
468
+ return KEYWORD_BLOCK_MESSAGE
469
+ if len(BLOCK_LANGS) > 0:
470
+ import re
471
+ delimiter = delimiter or (r"</s><\|im_start\|>user\n", r"</s><\|im_start\|>assistant\n", r"<\|im_start\|>system\n")
472
+ turns = re.split(r"|".join(delimiter), text)
473
+ turns = [t for t in turns if t.strip() != '']
474
+ for t in turns:
475
+ if block_lang(t):
476
+ return LANG_BLOCK_MESSAGE
477
+ return None
478
+
479
+
480
+ def is_check_safety():
481
+ return len(KEYWORDS) > 0 or len(BLOCK_LANGS) > 0
482
+
483
+
484
+ def safety_check_conversation(conversation) -> Optional[str]:
485
+ """
486
+ Despite our effort in safety tuning and red teaming, our models may still generate harmful or illegal content.
487
+ This provides an additional security measure to enhance safety and compliance with local regulations.
488
+ """
489
+ texts = [c['content'] for c in conversation]
490
+ for text in texts:
491
+ if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS):
492
+ return KEYWORD_BLOCK_MESSAGE
493
+
494
+ if len(BLOCK_LANGS) > 0:
495
+ if block_lang(text):
496
+ return LANG_BLOCK_MESSAGE
497
+ return None
498
+
499
+
500
+
501
+
502
+
503
  class TransformersEngine(BaseEngine):
504
  @property
505
  def max_position_embeddings(self) -> int:
 
527
  print(self._model)
528
  print(f"{self.max_position_embeddings=}")
529
 
530
+ def maybe_raise_safety(self, message, gen_index=-1):
531
+ if is_check_safety():
532
+ if gen_index < 0:
533
+ message_safety = safety_check_conversation_string(message)
534
+ if message_safety is not None:
535
+ raise gr.Error(message_safety)
536
+ else:
537
+ if STREAM_CHECK_MULTIPLE > 0 and gen_index % STREAM_CHECK_MULTIPLE == 0:
538
+ message_safety = safety_check_conversation_string(message)
539
+ if message_safety is not None:
540
+ raise gr.Error(message_safety)
541
+
542
  @spaces.GPU
543
  def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
544
 
 
546
  import sys
547
  # self._model._sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)
548
  self._model.sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)
549
+
550
+ self.maybe_raise_safety(prompt)
551
+
552
  with torch.no_grad():
553
  inputs = self.tokenizer(prompt, return_tensors='pt')
554
  num_tokens = inputs.input_ids.size(1)
 
565
 
566
  out_tokens = []
567
  response = None
568
+ for index, token in enumerate(generator):
569
  out_tokens.extend(token.tolist())
570
  response = self.tokenizer.decode(out_tokens)
571
  if "<|im_start|>assistant\n" in response:
 
573
  num_tokens += 1
574
  # print(f"{response}", end='\r')
575
  # sys.stdout.flush()
576
+ self.maybe_raise_safety(response, gen_index=index)
577
  yield response, num_tokens
578
+
579
+ del generator
580
  if response is not None:
581
  if "<|im_start|>assistant\n" in response:
582
  response = response.split("<|im_start|>assistant\n")[-1]
583
+
584
+ self.maybe_raise_safety(response)
585
  full_text = prompt + response
586
  num_tokens = len(self.tokenizer.encode(full_text))
587
  yield response, num_tokens