DavidGF commited on
Commit
2279d36
1 Parent(s): f69f387

Upload 3 files

Browse files
kraken_model/configuration_kraken.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class KrakenConfig(PretrainedConfig):
4
+ model_type = "kraken"
5
+
6
+ def __init__(self, config_dict=None, **kwargs):
7
+ super().__init__(**kwargs)
8
+ self.config_dict = config_dict or {}
kraken_model/modeling_kraken.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import PreTrainedModel, AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification, TextClassificationPipeline
3
+ from configuration_kraken import KrakenConfig
4
+ import tokenizer_template_switch
5
+
6
+ class KrakenForCausalLM(PreTrainedModel):
7
+ config_class = KrakenConfig
8
+
9
+ def __init__(self, config):
10
+ super().__init__(config)
11
+ self.tokenizers = {key: AutoTokenizer.from_pretrained(name, device_map="auto") for key, name in config.config_dict['tokenizers'].items()}
12
+ self.models = self.load_expert_models(config.config_dict['models'], config.config_dict['quantization'])
13
+ self.router_model = AutoModelForSequenceClassification.from_pretrained(config.config_dict['router'], trust_remote_code=True,device_map="auto")
14
+ self.tokenizer = AutoTokenizer.from_pretrained(config.config_dict['router'], trust_remote_code=True,device_map="auto")
15
+ self.router = TextClassificationPipeline(model=self.router_model, tokenizer=self.tokenizer)
16
+ self.models_indices = config.config_dict['class_indices']
17
+
18
+ def load_expert_models(self, models_dict, quantization_dict):
19
+ models = {}
20
+ for key, name in models_dict.items():
21
+ quantization = quantization_dict.get(key)
22
+ if quantization == "8bit":
23
+ models[key] = AutoModelForCausalLM.from_pretrained(name, trust_remote_code=True, device_map="auto", load_in_8bit=True, torch_dtype="auto")
24
+ elif quantization == "4bit":
25
+ models[key] = AutoModelForCausalLM.from_pretrained(name, trust_remote_code=True, device_map="auto", load_in_4bit=True, torch_dtype="auto")
26
+ elif quantization == "awq":
27
+ models[key] = self.load_awq_model(name)
28
+ else:
29
+ models[key] = AutoModelForCausalLM.from_pretrained(name, trust_remote_code=True, device_map="auto", torch_dtype="auto")
30
+ return models
31
+
32
+ def load_awq_model(self, name):
33
+ return AutoModelForCausalLM.from_pretrained(name, trust_remote_code=True, device_map="auto")
34
+
35
+ def tokenize_inputs(self, text, model_key):
36
+ return self.tokenizers[model_key](text, return_tensors="pt")
37
+
38
+ def determine_model(self, text):
39
+ prediction = self.router(text)[0]["label"]
40
+ model_decision_index = self.models_indices[prediction]
41
+ model_keys = ['expert1', 'expert2', 'expert3', 'expert4','expert5']
42
+ return model_keys[model_decision_index]
43
+
44
+ def expert_tokenizer(self, text):
45
+ model_key = self.determine_model(text)
46
+ return self.tokenizers[model_key]
47
+
48
+
49
+ def generate(self, input_ids, **generate_kwargs):
50
+ # Tokenize the input_ids
51
+ text = self.tokenizer.batch_decode(input_ids, skip_special_tokens=False)[0]
52
+
53
+ msgs = tokenizer_template_switch.recover_chat_messages(text, self.tokenizer)
54
+ if msgs and msgs[0]['role'] == 'system' and msgs[0]['content']=='<|im_start|>system':
55
+ # Delete the first element
56
+ msgs.pop(0)
57
+ # Check if the last element has the role 'assistant'
58
+ if msgs and msgs[-1]['role'] == 'assistant':
59
+ # Delete the last element
60
+ msgs.pop()
61
+
62
+ # Determine the model key using the existing routing logic
63
+ model_key = self.determine_model(text)
64
+ # Show the routing result
65
+ print(f"Choosing {model_key} ..")
66
+ # Retrieve the model from the dictionary
67
+ model = self.models[model_key]
68
+
69
+ mod_txt = self.tokenizers[model_key].apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
70
+ current_device = input_ids.device if isinstance(input_ids, torch.Tensor) else 'cpu'
71
+
72
+ # Tokenize accordingly to the best model
73
+
74
+ tok = self.tokenizers[model_key](mod_txt, return_tensors="pt")
75
+ tok_input_ids = tok.input_ids.to(current_device)
76
+ tok_attention_mask = tok.attention_mask.to(current_device)
77
+
78
+ # Generate text using the retrieved model
79
+ return model.generate(tok_input_ids, attention_mask=tok_attention_mask, **generate_kwargs)
80
+
81
+
82
+
kraken_model/tokenizer_template_switch.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from transformers import AutoTokenizer
3
+
4
+ def extract_separators(template):
5
+ """
6
+ Extracts separators used in the tokenization template.
7
+ """
8
+ # Adjust the regex to correctly match the specific pattern between '{{' and '+ message["content"] +'
9
+ pattern = r"\{\{\s*([^{}]+?)\s*\+ message\['content'\]"
10
+ matches = re.findall(pattern, template)
11
+ # Clean up any extra spaces and return the matches
12
+ separators = [match.strip() for match in matches]
13
+
14
+ if any("message['role']" in element for element in separators):
15
+ roles = ["system", "user", "assistant"]
16
+ separators_ = []
17
+ for role in roles:
18
+ separators_.append(separators[0].replace(" + message['role'] + ", role).replace("'",""))
19
+ return separators_
20
+
21
+ return separators
22
+
23
+ def detect_eos_token(jinja_template, tokenizer):
24
+ if "<|im_end|>" in jinja_template:
25
+ return "<|im_end|>"
26
+ if "</s>" in jinja_template:
27
+ return "</s>"
28
+ if "eos_token" in jinja_template:
29
+ return tokenizer.eos_token
30
+ else:
31
+ return "<|endoftext|>"
32
+
33
+ def recover_messages(formatted_message, separators, eos_token):
34
+ """
35
+ Recovers the original messages from the formatted message string.
36
+ """
37
+ # Split the formatted message using the end-of-string token
38
+ split_messages = formatted_message.split(eos_token)
39
+
40
+ # Remove the last empty string if it exists due to a trailing separator
41
+ if split_messages and split_messages[-1].strip() == '':
42
+ split_messages.pop()
43
+
44
+ # Prepare the list to hold the recovered messages
45
+ recovered_messages = []
46
+
47
+ # Define roles after the first message, alternating between "user" and "assistant"
48
+ alternate_roles = ["user", "assistant"]
49
+
50
+ # Iterate over the split messages
51
+ for index, message_content in enumerate(split_messages):
52
+ # Determine the role, starting with "system" for the first message
53
+ # then alternating between "user" and "assistant" for subsequent messages
54
+ if index == 0:
55
+ role = "system"
56
+ else:
57
+ role = alternate_roles[(index - 1) % 2]
58
+
59
+ # Clean the message content by removing leading/trailing whitespace and separators
60
+ clean_content = message_content.strip()
61
+ for separator in separators:
62
+ clean_content = clean_content.replace(separator.strip("'"), '', 1).strip()
63
+
64
+ # Append the cleaned message with its role to the list
65
+ recovered_messages.append({"role": role, "content": clean_content})
66
+
67
+ return recovered_messages
68
+
69
+ def recover_chat_messages(tokenized_chat, tokenizer):
70
+ """
71
+ Given a tokenized_chat string and a tokenizer, returns the list of message dictionaries.
72
+ """
73
+ jinja_template = tokenizer.chat_template
74
+ separators = extract_separators(jinja_template)
75
+ eos_token = eos_token = detect_eos_token(jinja_template, tokenizer)
76
+ recovered_messages = recover_messages(tokenized_chat, separators, eos_token)
77
+ return recovered_messages
78
+
79
+ # Example usage
80
+ if __name__ == "__main__":
81
+ checkpoint = "Qwen/Qwen1.5-0.5B"
82
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
83
+
84
+ messages = [
85
+ {
86
+ "role": "system",
87
+ "content": "You are a friendly chatbot who always responds in the style of a pirate",
88
+ },
89
+ {"role": "user", "content": "How many helicopters can a human eat in one sitting?"},
90
+ ]
91
+ tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=False)
92
+ print(tokenized_chat)
93
+
94
+ recovered_messages = recover_chat_messages(tokenized_chat, tokenizer)
95
+ print(recovered_messages)