enpaiva commited on
Commit
381bf7c
1 Parent(s): 008b9db

Update model_keeper.py

Browse files
Files changed (1) hide show
  1. model_keeper.py +25 -21
model_keeper.py CHANGED
@@ -30,32 +30,32 @@ class KeeperModelForCausalLM(PreTrainedModel):
30
  self.bert = None
31
  self.llm = None
32
 
33
- if cfg:
34
- print("Initializing KeeperModelForCausalLM from cfg")
35
- # Inicialización con configuración
36
 
37
- self.bert = AutoModel.from_pretrained(cfg.retriever_config['_name_or_path'])
38
 
39
- bnb_config = BitsAndBytesConfig(
40
- load_in_4bit=True,
41
- bnb_4bit_quant_type="nf4",
42
- bnb_4bit_compute_dtype=torch.bfloat16
43
- )
44
 
45
- self.llm = AutoModelForCausalLM.from_pretrained(
46
- cfg.model_config['_name_or_path'],
47
- device_map=cfg.device_map,
48
- torch_dtype=torch.bfloat16,
49
- quantization_config=bnb_config
50
- )
51
 
52
- # Almacena kwargs para la serialización y carga futura
53
- # self.init_kwargs = {'cfg': cfg}
54
 
55
- print("Initialization complete")
56
- else:
57
- # Si cfg no se proporciona, esto se manejará en el método from_pretrained
58
- print("Initializing KeeperTokenizer without cfg")
59
 
60
  self.n_cands = n_cands
61
  self.update_both = update_both
@@ -81,6 +81,10 @@ class KeeperModelForCausalLM(PreTrainedModel):
81
  self.prompt_right = state_dict["prompt_right"].to(device)
82
  if "respuesta" in state_dict:
83
  self.respuesta = state_dict["respuesta"].to(device)
 
 
 
 
84
  else:
85
  # Optionally handle the case where CUDA is not available
86
  print("CUDA is not available. Tensors will remain on CPU.")
 
30
  self.bert = None
31
  self.llm = None
32
 
33
+ # if cfg:
34
+ # print("Initializing KeeperModelForCausalLM from cfg")
35
+ # # Inicialización con configuración
36
 
37
+ # self.bert = AutoModel.from_pretrained(cfg.retriever_config['_name_or_path'])
38
 
39
+ # bnb_config = BitsAndBytesConfig(
40
+ # load_in_4bit=True,
41
+ # bnb_4bit_quant_type="nf4",
42
+ # bnb_4bit_compute_dtype=torch.bfloat16
43
+ # )
44
 
45
+ # self.llm = AutoModelForCausalLM.from_pretrained(
46
+ # cfg.model_config['_name_or_path'],
47
+ # device_map=cfg.device_map,
48
+ # torch_dtype=torch.bfloat16,
49
+ # quantization_config=bnb_config
50
+ # )
51
 
52
+ # # Almacena kwargs para la serialización y carga futura
53
+ # # self.init_kwargs = {'cfg': cfg}
54
 
55
+ # print("Initialization complete")
56
+ # else:
57
+ # # Si cfg no se proporciona, esto se manejará en el método from_pretrained
58
+ # print("Initializing KeeperTokenizer without cfg")
59
 
60
  self.n_cands = n_cands
61
  self.update_both = update_both
 
81
  self.prompt_right = state_dict["prompt_right"].to(device)
82
  if "respuesta" in state_dict:
83
  self.respuesta = state_dict["respuesta"].to(device)
84
+ if "bert" in state_dict:
85
+ self.bert = state_dict["bert"].to(device)
86
+ if "llm" in state_dict:
87
+ self.llm = state_dict["llm"].to(device)
88
  else:
89
  # Optionally handle the case where CUDA is not available
90
  print("CUDA is not available. Tensors will remain on CPU.")