Text Generation
Transformers
PyTorch
English
gpt_neox
causal-lm
Inference Endpoints
text-generation-inference
dmayhem93 commited on
Commit
c96b952
1 Parent(s): 3123a2b

Add StoppingCriteria to example to stop on one assistant response

Browse files
Files changed (1) hide show
  1. README.md +11 -2
README.md CHANGED
@@ -24,12 +24,20 @@ datasets:
24
  Get started chatting with `StableLM-Tuned-Alpha` by using the following code snippet:
25
 
26
  ```python
27
- from transformers import AutoModelForCausalLM, AutoTokenizer
28
 
29
  tokenizer = AutoTokenizer.from_pretrained("StabilityAI/stablelm-tuned-alpha-7b")
30
  model = AutoModelForCausalLM.from_pretrained("StabilityAI/stablelm-tuned-alpha-7b")
31
  model.half().cuda()
32
 
 
 
 
 
 
 
 
 
33
  system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
34
  - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
35
  - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
@@ -37,12 +45,13 @@ system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
37
  - StableLM will refuse to participate in anything that could harm a human.
38
  """
39
 
40
- prompt = f"{system_prompt}<|USER|>What's your mood today?"
41
 
42
  inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
43
  tokens = model.generate(
44
  **inputs,
45
  max_new_tokens=64,
 
46
  print(tokenizer.decode(tokens[0], skip_special_tokens=True))
47
  ```
48
 
 
24
  Get started chatting with `StableLM-Tuned-Alpha` by using the following code snippet:
25
 
26
  ```python
27
+ from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
28
 
29
  tokenizer = AutoTokenizer.from_pretrained("StabilityAI/stablelm-tuned-alpha-7b")
30
  model = AutoModelForCausalLM.from_pretrained("StabilityAI/stablelm-tuned-alpha-7b")
31
  model.half().cuda()
32
 
33
+ class StopOnTokens(StoppingCriteria):
34
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
35
+ stop_ids = [50278, 50279, 50277, 1, 0]
36
+ for stop_id in stop_ids:
37
+ if input_ids[0][-1] == stop_id:
38
+ return True
39
+ return False
40
+
41
  system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
42
  - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
43
  - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
 
45
  - StableLM will refuse to participate in anything that could harm a human.
46
  """
47
 
48
+ prompt = f"{system_prompt}<|USER|>What's your mood today<|ASSISTANT|>?"
49
 
50
  inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
51
  tokens = model.generate(
52
  **inputs,
53
  max_new_tokens=64,
54
+ StoppingCriteriaList([StopOnTokens()]))
55
  print(tokenizer.decode(tokens[0], skip_special_tokens=True))
56
  ```
57