DavidGF commited on
Commit
2d97b03
1 Parent(s): c83613c

Update kraken_model/modeling_kraken.py

Browse files
Files changed (1) hide show
  1. kraken_model/modeling_kraken.py +11 -6
kraken_model/modeling_kraken.py CHANGED
@@ -41,10 +41,6 @@ class KrakenForCausalLM(PreTrainedModel):
41
  model_keys = ['expert1', 'expert2', 'expert3', 'expert4','expert5']
42
  return model_keys[model_decision_index]
43
 
44
- def expert_tokenizer(self, text):
45
- model_key = self.determine_model(text)
46
- return self.tokenizers[model_key]
47
-
48
 
49
  def generate(self, input_ids, **generate_kwargs):
50
  # Tokenize the input_ids
@@ -75,8 +71,17 @@ class KrakenForCausalLM(PreTrainedModel):
75
  tok_input_ids = tok.input_ids.to(current_device)
76
  tok_attention_mask = tok.attention_mask.to(current_device)
77
 
78
- # Generate text using the retrieved model
79
- return model.generate(tok_input_ids, attention_mask=tok_attention_mask, **generate_kwargs)
 
 
 
 
 
 
 
 
 
80
 
81
 
82
 
 
41
  model_keys = ['expert1', 'expert2', 'expert3', 'expert4','expert5']
42
  return model_keys[model_decision_index]
43
 
 
 
 
 
44
 
45
  def generate(self, input_ids, **generate_kwargs):
46
  # Tokenize the input_ids
 
71
  tok_input_ids = tok.input_ids.to(current_device)
72
  tok_attention_mask = tok.attention_mask.to(current_device)
73
 
74
+
75
+ # Generate text using the modified model
76
+ output_ids = model.generate(tok_input_ids, attention_mask=tok_attention_mask, **generate_kwargs)
77
+
78
+ # Decode the output using the expert tokenizer
79
+ decoded_text = self.tokenizers[model_key].decode(output_ids[0], skip_special_tokens=True)
80
+
81
+ # Retokenize the decoded text using the base tokenizer for external compatibility
82
+ retokenized_ids = self.tokenizer(decoded_text, return_tensors="pt").input_ids.to(current_device)
83
+
84
+ return retokenized_ids
85
 
86
 
87