File size: 4,320 Bytes
f30df32 568b8fe f30df32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import torch
import logging
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
# MODEL_NAME = "IlyaGusev/gigasaiga_lora"
# MODEL_NAME = "evilfreelancer/ruGPT-3.5-13B-lora"
# MODEL_NAME = "./output"
MODEL_NAME = "evilfreelancer/saiga_mistral_7b_128k_lora"
DEFAULT_MESSAGE_TEMPLATE = "<s>{role}\n{content}</s>\n"
DEFAULT_SYSTEM_PROMPT = """
Ты — Saiga 2, русскоязычный автоматический ассистент. Ты разговариваешь с людьми и помогаешь им.
"""
class Conversation:
def __init__(
self,
message_template=DEFAULT_MESSAGE_TEMPLATE,
system_prompt=DEFAULT_SYSTEM_PROMPT,
start_token_id=2,
# Bot token may be a list or single int
bot_token_id=10093,
# int (amount of questions and answers) or None (unlimited)
history_limit=None,
):
self.logger = logging.getLogger('Conversation')
self.message_template = message_template
self.start_token_id = start_token_id
self.bot_token_id = bot_token_id
self.history_limit = history_limit
self.messages = [{
"role": "system",
"content": system_prompt
}]
def get_start_token_id(self):
return self.start_token_id
def get_bot_token_id(self):
return self.bot_token_id
def add_message(self, role, message):
self.messages.append({
"role": role,
"content": message
})
self.trim_history()
def add_user_message(self, message):
self.add_message("user", message)
def add_bot_message(self, message):
self.add_message("assistant", message)
def trim_history(self):
if self.history_limit is not None and len(self.messages) > self.history_limit + 1:
overflow = len(self.messages) - (self.history_limit + 1)
self.messages = [self.messages[0]] + self.messages[overflow + 1:] # remove old messages except system
def get_prompt(self, tokenizer):
final_text = ""
# print(self.messages)
for message in self.messages:
message_text = self.message_template.format(**message)
final_text += message_text
# Bot token id may be an array
if isinstance(self.bot_token_id, (list, tuple)):
final_text += tokenizer.decode([self.start_token_id] + self.bot_token_id)
else:
final_text += tokenizer.decode([self.start_token_id, self.bot_token_id])
return final_text.strip()
def generate(model, tokenizer, prompt, generation_config):
data = tokenizer(prompt, return_tensors="pt")
data = {k: v.to(model.device) for k, v in data.items()}
output_ids = model.generate(
**data,
generation_config=generation_config
)[0]
output_ids = output_ids[len(data["input_ids"][0]):]
output = tokenizer.decode(output_ids, skip_special_tokens=True)
return output.strip()
config = PeftConfig.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
config.base_model_name_or_path,
load_in_8bit=True,
torch_dtype=torch.float16,
device_map="auto",
use_flash_attention_2=True,
)
model = PeftModel.from_pretrained(
model,
MODEL_NAME,
torch_dtype=torch.float16
)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
print(generation_config)
template_path = 'saiga_v2.json'
conversation = Conversation()
while True:
user_message = input("User: ")
# Reset chat command
if user_message.strip() == "/reset":
conversation = Conversation()
print("History reset completed!")
continue
# Skip empty messages from user
if user_message.strip() == "":
continue
conversation.add_user_message(user_message)
prompt = conversation.get_prompt(tokenizer)
output = generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
generation_config=generation_config
)
conversation.add_bot_message(output)
print("Bot:", output)
print()
print("==============================")
print()
|