DavidGF commited on
Commit
cf015ac
1 Parent(s): 1bdec50

Upload modeling_kraken_lora.py

Browse files
kraken_model/modeling_kraken_lora.py CHANGED
@@ -43,10 +43,6 @@ class KrakenForCausalLM(PreTrainedModel):
43
  adapter_keys = ['lora_expert1', 'lora_expert2', 'lora_expert3', 'lora_expert4', 'lora_expert5']
44
  return adapter_keys[model_decision_index]
45
 
46
- def expert_tokenizer(self, text):
47
- adapter_key = self.determine_adapter(text)
48
- return self.tokenizers[adapter_key]
49
-
50
 
51
  def generate(self, input_ids, **generate_kwargs):
52
  # Tokenize the input_ids
@@ -74,13 +70,19 @@ class KrakenForCausalLM(PreTrainedModel):
74
  current_device = input_ids.device if isinstance(input_ids, torch.Tensor) else 'cpu'
75
 
76
  # Tokenize accordingly to the best model
77
-
78
  tok = self.tokenizers[adapter_key](mod_txt, return_tensors="pt")
79
  tok_input_ids = tok.input_ids.to(current_device)
80
  tok_attention_mask = tok.attention_mask.to(current_device)
81
 
82
  # Generate text using the modified model
83
- return model_with_lora.generate(tok_input_ids, attention_mask=tok_attention_mask, **generate_kwargs)
 
 
 
 
 
 
 
84
 
85
 
86
 
 
43
  adapter_keys = ['lora_expert1', 'lora_expert2', 'lora_expert3', 'lora_expert4', 'lora_expert5']
44
  return adapter_keys[model_decision_index]
45
 
 
 
 
 
46
 
47
  def generate(self, input_ids, **generate_kwargs):
48
  # Tokenize the input_ids
 
70
  current_device = input_ids.device if isinstance(input_ids, torch.Tensor) else 'cpu'
71
 
72
  # Tokenize accordingly to the best model
 
73
  tok = self.tokenizers[adapter_key](mod_txt, return_tensors="pt")
74
  tok_input_ids = tok.input_ids.to(current_device)
75
  tok_attention_mask = tok.attention_mask.to(current_device)
76
 
77
  # Generate text using the modified model
78
+ output_ids = model_with_lora.generate(tok_input_ids, attention_mask=tok_attention_mask, **generate_kwargs)
79
+
80
+ # Decode the output using the expert tokenizer
81
+ decoded_text = self.tokenizers[adapter_key].decode(output_ids[0], skip_special_tokens=True)
82
+
83
+ # Retokenize the decoded text using the base tokenizer for external compatibility
84
+ retokenized_ids = self.tokenizer(decoded_text, return_tensors="pt").input_ids.to(current_device)
85
+ return retokenized_ids
86
 
87
 
88