wenge-research
commited on
Commit
•
f1a9e8d
1
Parent(s):
18a4ed3
Update README.md
Browse files
README.md
CHANGED
@@ -30,5 +30,53 @@ tags:
|
|
30 |
|
31 |
## 运行方式
|
32 |
|
33 |
-
|
|
|
|
|
|
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
## 运行方式
|
32 |
|
33 |
+
```python
|
34 |
+
import torch
|
35 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer, GenerationConfig
|
36 |
+
from transformers import StoppingCriteria, StoppingCriteriaList
|
37 |
|
38 |
+
pretrained_model_name_or_path = "wenge-research/yayi-7b-llama2"
|
39 |
+
tokenizer = LlamaTokenizer.from_pretrained(pretrained_model_name_or_path)
|
40 |
+
model = LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=False)
|
41 |
+
|
42 |
+
# Define the stopping criteria
|
43 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
44 |
+
def __init__(self, keywords_ids:list):
|
45 |
+
self.keywords = keywords_ids
|
46 |
+
|
47 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
48 |
+
if input_ids[0][-1] in self.keywords:
|
49 |
+
return True
|
50 |
+
return False
|
51 |
+
|
52 |
+
stop_words = ["<|End|>", "<|YaYi|>", "<|Human|>", "</s>"]
|
53 |
+
stop_ids = [tokenizer.encode(w)[-1] for w in stop_words]
|
54 |
+
stop_criteria = KeywordsStoppingCriteria(stop_ids)
|
55 |
+
|
56 |
+
# inference
|
57 |
+
prompt = "你是谁?"
|
58 |
+
formatted_prompt = f"""<|System|>:
|
59 |
+
You are a helpful, respectful and honest assistant named YaYi developed by Beijing Wenge Technology Co.,Ltd. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
|
60 |
+
|
61 |
+
<|Human|>:
|
62 |
+
{prompt}
|
63 |
+
|
64 |
+
<|YaYi|>:
|
65 |
+
"""
|
66 |
+
|
67 |
+
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
|
68 |
+
eos_token_id = tokenizer("<|End|>").input_ids[0]
|
69 |
+
generation_config = GenerationConfig(
|
70 |
+
eos_token_id=eos_token_id,
|
71 |
+
pad_token_id=eos_token_id,
|
72 |
+
do_sample=True,
|
73 |
+
max_new_tokens=256,
|
74 |
+
temperature=0.3,
|
75 |
+
repetition_penalty=1.1,
|
76 |
+
no_repeat_ngram_size=0
|
77 |
+
)
|
78 |
+
response = model.generate(**inputs, generation_config=generation_config, stopping_criteria=StoppingCriteriaList([stop_criteria]))
|
79 |
+
response = [response[0][len(inputs.input_ids[0]):]]
|
80 |
+
response_str = tokenizer.batch_decode(response, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
|
81 |
+
print(response_str)
|
82 |
+
```
|