Paul Rock commited on
Commit
f30df32
1 Parent(s): ce808c8

Example added

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. README.md +4 -1
  3. test_lora.py +135 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ /.idea/
README.md CHANGED
@@ -19,4 +19,7 @@ tags:
19
 
20
  # Saiga/Yarn-Mistral 7B 128k, Russian Mistral-based chatbot
21
 
22
- Welcome to the adapter-only version of Saiga 7B LoRA. This model is built upon the foundation of [Nous-Yarn-Mistral-7b-128k](https://huggingface.co/NousResearch/Yarn-Mistral-7b-128k).
 
 
 
 
19
 
20
  # Saiga/Yarn-Mistral 7B 128k, Russian Mistral-based chatbot
21
 
22
+ Welcome to the adapter-only version of Saiga 7B LoRA.
23
+ This model is built upon the foundation of [Nous-Yarn-Mistral-7b-128k](https://huggingface.co/NousResearch/Yarn-Mistral-7b-128k).
24
+
25
+
test_lora.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import logging
3
+ from peft import PeftModel, PeftConfig
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
5
+
6
+ # MODEL_NAME = "IlyaGusev/gigasaiga_lora"
7
+ # MODEL_NAME = "evilfreelancer/ruGPT-3.5-13B-lora"
8
+ # MODEL_NAME = "./output"
9
+ MODEL_NAME = "evilfreelancer/saiga_mistral_7b_128k_lora"
10
+ DEFAULT_MESSAGE_TEMPLATE = "<s>{role}\n{content}</s>\n"
11
+ DEFAULT_SYSTEM_PROMPT = """
12
+ Ты — Saiga 2, русскоязычный автоматический ассистент. Ты разговариваешь с людьми и помогаешь им.
13
+ """
14
+
15
+
16
+ class Conversation:
17
+ def __init__(
18
+ self,
19
+ message_template=DEFAULT_MESSAGE_TEMPLATE,
20
+ system_prompt=DEFAULT_SYSTEM_PROMPT,
21
+ start_token_id=2,
22
+ # Bot token may be a list or single int
23
+ bot_token_id=10093,
24
+ # int (amount of questions and answers) or None (unlimited)
25
+ history_limit=None,
26
+ ):
27
+ self.logger = logging.getLogger('Conversation')
28
+ self.message_template = message_template
29
+ self.start_token_id = start_token_id
30
+ self.bot_token_id = bot_token_id
31
+ self.history_limit = history_limit
32
+ self.messages = [{
33
+ "role": "system",
34
+ "content": system_prompt
35
+ }]
36
+
37
+ def get_start_token_id(self):
38
+ return self.start_token_id
39
+
40
+ def get_bot_token_id(self):
41
+ return self.bot_token_id
42
+
43
+ def add_message(self, role, message):
44
+ self.messages.append({
45
+ "role": role,
46
+ "content": message
47
+ })
48
+ self.trim_history()
49
+
50
+ def add_user_message(self, message):
51
+ self.add_message("user", message)
52
+
53
+ def add_bot_message(self, message):
54
+ self.add_message("assistant", message)
55
+
56
+ def trim_history(self):
57
+ if self.history_limit is not None and len(self.messages) > self.history_limit + 1:
58
+ overflow = len(self.messages) - (self.history_limit + 1)
59
+ self.messages = [self.messages[0]] + self.messages[overflow + 1:] # remove old messages except system
60
+
61
+ def get_prompt(self, tokenizer):
62
+ final_text = ""
63
+ # print(self.messages)
64
+ for message in self.messages:
65
+ message_text = self.message_template.format(**message)
66
+ final_text += message_text
67
+
68
+ # Bot token id may be an array
69
+ if isinstance(self.bot_token_id, (list, tuple)):
70
+ final_text += tokenizer.decode([self.start_token_id] + self.bot_token_id)
71
+ else:
72
+ final_text += tokenizer.decode([self.start_token_id, self.bot_token_id])
73
+
74
+ return final_text.strip()
75
+
76
+
77
+ def generate(model, tokenizer, prompt, generation_config):
78
+ data = tokenizer(prompt, return_tensors="pt")
79
+ data = {k: v.to(model.device) for k, v in data.items()}
80
+ output_ids = model.generate(
81
+ **data,
82
+ generation_config=generation_config
83
+ )[0]
84
+ output_ids = output_ids[len(data["input_ids"][0]):]
85
+ output = tokenizer.decode(output_ids, skip_special_tokens=True)
86
+ return output.strip()
87
+
88
+
89
+ config = PeftConfig.from_pretrained(MODEL_NAME)
90
+ model = AutoModelForCausalLM.from_pretrained(
91
+ config.base_model_name_or_path,
92
+ load_in_8bit=True,
93
+ torch_dtype=torch.float16,
94
+ device_map="auto",
95
+ use_flash_attention_2=True,
96
+ )
97
+ model = PeftModel.from_pretrained(
98
+ model,
99
+ MODEL_NAME,
100
+ torch_dtype=torch.float16
101
+ )
102
+ model.eval()
103
+
104
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
105
+ generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
106
+ print(generation_config)
107
+
108
+ template_path = 'internal_prompts/rugpt35.json'
109
+ conversation = Conversation()
110
+ while True:
111
+ user_message = input("User: ")
112
+
113
+ # Reset chat command
114
+ if user_message.strip() == "/reset":
115
+ conversation = Conversation()
116
+ print("History reset completed!")
117
+ continue
118
+
119
+ # Skip empty messages from user
120
+ if user_message.strip() == "":
121
+ continue
122
+
123
+ conversation.add_user_message(user_message)
124
+ prompt = conversation.get_prompt(tokenizer)
125
+ output = generate(
126
+ model=model,
127
+ tokenizer=tokenizer,
128
+ prompt=prompt,
129
+ generation_config=generation_config
130
+ )
131
+ conversation.add_bot_message(output)
132
+ print("Bot:", output)
133
+ print()
134
+ print("==============================")
135
+ print()