|
import fire |
|
from typing import List, Dict |
|
import torch |
|
from peft import PeftModel |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, BitsAndBytesConfig |
|
|
|
MODEL_BASE = "t-tech/T-lite-it-1.0" |
|
MODEL_ADAPTER = "ZeroAgency/o1_t-lite-it-1.0_lora" |
|
|
|
SYSTEM_PROMPT = "Вы — ИИ-помощник. Отформатируйте свои ответы следующим образом: <Thought> Ваши мысли (понимание, рассуждения) </Thought> <output> Ваш ответ </output>" |
|
|
|
|
|
class ChatHistory: |
|
def __init__(self, history_limit: int = None, system_prompt: str = None): |
|
self.history_limit: int | None = history_limit |
|
self.system_prompt: str | None = system_prompt |
|
self.messages: List[Dict] = [] |
|
if self.system_prompt is not None: |
|
self.messages.append({"role": "system", "content": self.system_prompt}) |
|
|
|
def add_message(self, role: str, message: str): |
|
self.messages.append({"role": role, "content": message}) |
|
self.trim_history() |
|
|
|
def add_user_message(self, message: str): |
|
self.add_message("user", message) |
|
|
|
def add_assistant_message(self, message: str): |
|
self.add_message("assistant", message) |
|
|
|
def add_function_call(self, message: str): |
|
self.add_message("function_call", message) |
|
|
|
def add_function_response(self, message: str): |
|
self.add_message("function_response", message) |
|
|
|
def trim_history(self): |
|
appendix = 0 |
|
if self.system_prompt is not None: |
|
appendix = 1 |
|
if self.history_limit is not None and len(self.messages) > self.history_limit + appendix: |
|
overflow = len(self.messages) - (self.history_limit + appendix) |
|
self.messages = [self.messages[0]] + self.messages[overflow + appendix:] |
|
|
|
def get_messages(self) -> list: |
|
return self.messages |
|
|
|
|
|
def generate(model, tokenizer, prompt, generation_config): |
|
data = tokenizer(prompt, return_tensors="pt") |
|
data = {k: v.to(model.device) for k, v in data.items()} |
|
output_ids = model.generate(**data, generation_config=generation_config)[0] |
|
output_ids = output_ids[len(data["input_ids"][0]):] |
|
output = tokenizer.decode(output_ids, skip_special_tokens=True) |
|
return output.strip() |
|
|
|
|
|
def get_prompt(tokenizer, messages: List[Dict], add_generation_prompt: bool = False): |
|
return tokenizer.apply_chat_template( |
|
messages, |
|
add_special_tokens=False, |
|
tokenize=False, |
|
add_generation_prompt=add_generation_prompt, |
|
) |
|
|
|
|
|
def chat( |
|
history_limit: int = 1, |
|
system_prompt: str | None = SYSTEM_PROMPT, |
|
max_new_tokens: int = 2048, |
|
repetition_penalty: float = 1.2, |
|
do_sample: bool = True, |
|
temperature: float = 0.5, |
|
top_p: float = 0.6, |
|
top_k: int = 40, |
|
): |
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_BASE) |
|
|
|
|
|
|
|
|
|
|
|
|
|
quantization_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_use_double_quant=True |
|
) |
|
|
|
|
|
generation_config = GenerationConfig.from_pretrained(MODEL_ADAPTER) |
|
generation_config.max_new_tokens = max_new_tokens |
|
generation_config.repetition_penalty = repetition_penalty |
|
generation_config.do_sample = do_sample |
|
generation_config.temperature = temperature |
|
generation_config.top_p = top_p |
|
generation_config.top_k = top_k |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_BASE, |
|
generation_config=generation_config, |
|
quantization_config=quantization_config, |
|
torch_dtype=torch.bfloat16, |
|
attn_implementation=None |
|
) |
|
|
|
|
|
model = PeftModel.from_pretrained( |
|
model=model, |
|
model_id=MODEL_ADAPTER, |
|
torch_dtype=torch.bfloat16, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
chat_history = ChatHistory(history_limit, system_prompt) |
|
while True: |
|
user_message = input("User: ") |
|
|
|
|
|
if user_message.strip() == "/reset": |
|
chat_history = ChatHistory(history_limit, system_prompt) |
|
print("History reset completed!") |
|
continue |
|
|
|
|
|
if user_message.strip() == "": |
|
continue |
|
|
|
|
|
chat_history.add_user_message(user_message) |
|
|
|
|
|
prompt = get_prompt(tokenizer, chat_history.get_messages(), True) |
|
|
|
|
|
output = generate(model, tokenizer, prompt, generation_config) |
|
|
|
|
|
chat_history.add_assistant_message(output) |
|
print("Assistant:", output) |
|
|
|
|
|
if __name__ == "__main__": |
|
fire.Fire(chat) |
|
|