DavidGF commited on
Commit
d81f6d7
1 Parent(s): 914d5ce

Upload folder using huggingface_hub

Browse files
configuration_kraken.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class KrakenConfig(PretrainedConfig):
4
+ model_type = "kraken"
5
+
6
+ def __init__(self, config_dict=None, **kwargs):
7
+ super().__init__(**kwargs)
8
+ self.config_dict = config_dict or {}
kraken_model/config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "KrakenForCausalLM"
4
+ ],
5
+ "config_dict": {
6
+ "class_indices": {
7
+ "LABEL_0": 0,
8
+ "LABEL_1": 1,
9
+ "LABEL_2": 2,
10
+ "LABEL_3": 3,
11
+ "LABEL_4": 4
12
+ },
13
+ "model_type": "kraken",
14
+ "models": {
15
+ "expert1": "microsoft/Phi-3-medium-128k-instruct",
16
+ "expert2": "gorilla-llm/gorilla-openfunctions-v2",
17
+ "expert3": "ise-uiuc/Magicoder-S-DS-6.7B",
18
+ "expert4": "defog/llama-3-sqlcoder-8b",
19
+ "expert5": "VAGOsolutions/Llama-3-SauerkrautLM-8b-Instruct"
20
+ },
21
+ "quantization": {
22
+ "expert1": null,
23
+ "expert2": null,
24
+ "expert3": null,
25
+ "expert4": null,
26
+ "expert5": null
27
+ },
28
+ "router": "./kraken_router",
29
+ "tokenizers": {
30
+ "expert1": "microsoft/Phi-3-medium-128k-instruct",
31
+ "expert2": "gorilla-llm/gorilla-openfunctions-v2",
32
+ "expert3": "ise-uiuc/Magicoder-S-DS-6.7B",
33
+ "expert4": "defog/llama-3-sqlcoder-8b",
34
+ "expert5": "VAGOsolutions/Llama-3-SauerkrautLM-8b-Instruct"
35
+ }
36
+ },
37
+ "model_type": "kraken",
38
+ "torch_dtype": "float32",
39
+ "transformers_version": "4.41.0"
40
+ }
kraken_model/generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.41.0"
4
+ }
kraken_model/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0acccf86abbb885412274d405d4555248636fe829dedbd7655bceb29a023535
3
+ size 1856007992
kraken_router/added_tokens.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "<|endoftext|>": 151643,
3
+ "<|im_end|>": 151645,
4
+ "<|im_start|>": 151644
5
+ }
kraken_router/config.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "Qwen/Qwen1.5-0.5B",
3
+ "architectures": [
4
+ "Qwen2ForSequenceClassification"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "eos_token_id": 151643,
9
+ "hidden_act": "silu",
10
+ "hidden_size": 1024,
11
+ "id2label": {
12
+ "0": "LABEL_0",
13
+ "1": "LABEL_1",
14
+ "2": "LABEL_2",
15
+ "3": "LABEL_3",
16
+ "4": "LABEL_4"
17
+ },
18
+ "initializer_range": 0.02,
19
+ "intermediate_size": 2816,
20
+ "label2id": {
21
+ "LABEL_0": 0,
22
+ "LABEL_1": 1,
23
+ "LABEL_2": 2,
24
+ "LABEL_3": 3,
25
+ "LABEL_4": 4
26
+ },
27
+ "max_position_embeddings": 32768,
28
+ "max_window_layers": 21,
29
+ "model_type": "qwen2",
30
+ "num_attention_heads": 16,
31
+ "num_hidden_layers": 24,
32
+ "num_key_value_heads": 16,
33
+ "pad_token_id": 151643,
34
+ "problem_type": "single_label_classification",
35
+ "rms_norm_eps": 1e-06,
36
+ "rope_theta": 1000000.0,
37
+ "sliding_window": 32768,
38
+ "tie_word_embeddings": true,
39
+ "torch_dtype": "float32",
40
+ "transformers_version": "4.41.0",
41
+ "use_cache": true,
42
+ "use_sliding_window": false,
43
+ "vocab_size": 151936
44
+ }
kraken_router/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
kraken_router/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0462ee8274d9aab1297d68d7363e40940d90fd274cb2679c4790fdf4371f2808
3
+ size 1856004208
kraken_router/special_tokens_map.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>"
5
+ ],
6
+ "eos_token": {
7
+ "content": "<|endoftext|>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false
12
+ },
13
+ "pad_token": "<|endoftext|>"
14
+ }
kraken_router/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
kraken_router/tokenizer_config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "151643": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "151644": {
13
+ "content": "<|im_start|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "151645": {
21
+ "content": "<|im_end|>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ }
28
+ },
29
+ "additional_special_tokens": [
30
+ "<|im_start|>",
31
+ "<|im_end|>"
32
+ ],
33
+ "bos_token": null,
34
+ "chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
35
+ "clean_up_tokenization_spaces": false,
36
+ "eos_token": "<|endoftext|>",
37
+ "errors": "replace",
38
+ "model_max_length": 32768,
39
+ "pad_token": "<|endoftext|>",
40
+ "split_special_tokens": false,
41
+ "tokenizer_class": "Qwen2Tokenizer",
42
+ "unk_token": null
43
+ }
kraken_router/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_kraken.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import PreTrainedModel, AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification, TextClassificationPipeline
3
+ from configuration_kraken import KrakenConfig
4
+ import tokenizer_template_switch
5
+
6
+ class KrakenForCausalLM(PreTrainedModel):
7
+ config_class = KrakenConfig
8
+
9
+ def __init__(self, config):
10
+ super().__init__(config)
11
+ self.tokenizers = {key: AutoTokenizer.from_pretrained(name, device_map="auto") for key, name in config.config_dict['tokenizers'].items()}
12
+ self.models = self.load_expert_models(config.config_dict['models'], config.config_dict['quantization'])
13
+ self.router_model = AutoModelForSequenceClassification.from_pretrained(config.config_dict['router'], trust_remote_code=True,device_map="auto")
14
+ self.tokenizer = AutoTokenizer.from_pretrained(config.config_dict['router'], trust_remote_code=True,device_map="auto")
15
+ self.router = TextClassificationPipeline(model=self.router_model, tokenizer=self.tokenizer)
16
+ self.models_indices = config.config_dict['class_indices']
17
+
18
+ def load_expert_models(self, models_dict, quantization_dict):
19
+ models = {}
20
+ for key, name in models_dict.items():
21
+ quantization = quantization_dict.get(key)
22
+ if quantization == "8bit":
23
+ models[key] = AutoModelForCausalLM.from_pretrained(name, trust_remote_code=True, device_map="auto", load_in_8bit=True, torch_dtype="auto")
24
+ elif quantization == "4bit":
25
+ models[key] = AutoModelForCausalLM.from_pretrained(name, trust_remote_code=True, device_map="auto", load_in_4bit=True, torch_dtype="auto")
26
+ elif quantization == "awq":
27
+ models[key] = self.load_awq_model(name)
28
+ else:
29
+ models[key] = AutoModelForCausalLM.from_pretrained(name, trust_remote_code=True, device_map="auto", torch_dtype="auto")
30
+ return models
31
+
32
+ def load_awq_model(self, name):
33
+ return AutoModelForCausalLM.from_pretrained(name, trust_remote_code=True, device_map="auto")
34
+
35
+ def tokenize_inputs(self, text, model_key):
36
+ return self.tokenizers[model_key](text, return_tensors="pt")
37
+
38
+ def determine_model(self, text):
39
+ prediction = self.router(text)[0]["label"]
40
+ model_decision_index = self.models_indices[prediction]
41
+ model_keys = ['expert1', 'expert2', 'expert3', 'expert4','expert5']
42
+ return model_keys[model_decision_index]
43
+
44
+ def expert_tokenizer(self, text):
45
+ model_key = self.determine_model(text)
46
+ return self.tokenizers[model_key]
47
+
48
+
49
+ def generate(self, input_ids, **generate_kwargs):
50
+ # Tokenize the input_ids
51
+ text = self.tokenizer.batch_decode(input_ids, skip_special_tokens=False)[0]
52
+
53
+ msgs = tokenizer_template_switch.recover_chat_messages(text, self.tokenizer)
54
+ if msgs and msgs[0]['role'] == 'system' and msgs[0]['content']=='<|im_start|>system':
55
+ # Delete the first element
56
+ msgs.pop(0)
57
+ # Check if the last element has the role 'assistant'
58
+ if msgs and msgs[-1]['role'] == 'assistant':
59
+ # Delete the last element
60
+ msgs.pop()
61
+
62
+ # Determine the model key using the existing routing logic
63
+ model_key = self.determine_model(text)
64
+ # Show the routing result
65
+ print(f"Choosing {model_key} ..")
66
+ # Retrieve the model from the dictionary
67
+ model = self.models[model_key]
68
+
69
+ mod_txt = self.tokenizers[model_key].apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
70
+ current_device = input_ids.device if isinstance(input_ids, torch.Tensor) else 'cpu'
71
+
72
+ # Tokenize accordingly to the best model
73
+
74
+ tok = self.tokenizers[model_key](mod_txt, return_tensors="pt")
75
+ tok_input_ids = tok.input_ids.to(current_device)
76
+ tok_attention_mask = tok.attention_mask.to(current_device)
77
+
78
+ # Generate text using the retrieved model
79
+ return model.generate(tok_input_ids, attention_mask=tok_attention_mask, **generate_kwargs)
80
+
81
+
82
+
tokenizer_template_switch.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from transformers import AutoTokenizer
3
+
4
+ def extract_separators(template):
5
+ """
6
+ Extracts separators used in the tokenization template.
7
+ """
8
+ # Adjust the regex to correctly match the specific pattern between '{{' and '+ message["content"] +'
9
+ pattern = r"\{\{\s*([^{}]+?)\s*\+ message\['content'\]"
10
+ matches = re.findall(pattern, template)
11
+ # Clean up any extra spaces and return the matches
12
+ separators = [match.strip() for match in matches]
13
+
14
+ if any("message['role']" in element for element in separators):
15
+ roles = ["system", "user", "assistant"]
16
+ separators_ = []
17
+ for role in roles:
18
+ separators_.append(separators[0].replace(" + message['role'] + ", role).replace("'",""))
19
+ return separators_
20
+
21
+ return separators
22
+
23
+ def detect_eos_token(jinja_template, tokenizer):
24
+ if "<|im_end|>" in jinja_template:
25
+ return "<|im_end|>"
26
+ if "</s>" in jinja_template:
27
+ return "</s>"
28
+ if "eos_token" in jinja_template:
29
+ return tokenizer.eos_token
30
+ else:
31
+ return "<|endoftext|>"
32
+
33
+ def recover_messages(formatted_message, separators, eos_token):
34
+ """
35
+ Recovers the original messages from the formatted message string.
36
+ """
37
+ # Split the formatted message using the end-of-string token
38
+ split_messages = formatted_message.split(eos_token)
39
+
40
+ # Remove the last empty string if it exists due to a trailing separator
41
+ if split_messages and split_messages[-1].strip() == '':
42
+ split_messages.pop()
43
+
44
+ # Prepare the list to hold the recovered messages
45
+ recovered_messages = []
46
+
47
+ # Define roles after the first message, alternating between "user" and "assistant"
48
+ alternate_roles = ["user", "assistant"]
49
+
50
+ # Iterate over the split messages
51
+ for index, message_content in enumerate(split_messages):
52
+ # Determine the role, starting with "system" for the first message
53
+ # then alternating between "user" and "assistant" for subsequent messages
54
+ if index == 0:
55
+ role = "system"
56
+ else:
57
+ role = alternate_roles[(index - 1) % 2]
58
+
59
+ # Clean the message content by removing leading/trailing whitespace and separators
60
+ clean_content = message_content.strip()
61
+ for separator in separators:
62
+ clean_content = clean_content.replace(separator.strip("'"), '', 1).strip()
63
+
64
+ # Append the cleaned message with its role to the list
65
+ recovered_messages.append({"role": role, "content": clean_content})
66
+
67
+ return recovered_messages
68
+
69
+ def recover_chat_messages(tokenized_chat, tokenizer):
70
+ """
71
+ Given a tokenized_chat string and a tokenizer, returns the list of message dictionaries.
72
+ """
73
+ jinja_template = tokenizer.chat_template
74
+ separators = extract_separators(jinja_template)
75
+ eos_token = eos_token = detect_eos_token(jinja_template, tokenizer)
76
+ recovered_messages = recover_messages(tokenized_chat, separators, eos_token)
77
+ return recovered_messages
78
+
79
+ # Example usage
80
+ if __name__ == "__main__":
81
+ checkpoint = "Qwen/Qwen1.5-0.5B"
82
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
83
+
84
+ messages = [
85
+ {
86
+ "role": "system",
87
+ "content": "You are a friendly chatbot who always responds in the style of a pirate",
88
+ },
89
+ {"role": "user", "content": "How many helicopters can a human eat in one sitting?"},
90
+ ]
91
+ tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=False)
92
+ print(tokenized_chat)
93
+
94
+ recovered_messages = recover_chat_messages(tokenized_chat, tokenizer)
95
+ print(recovered_messages)