yuhuili commited on
Commit
57917e5
·
verified ·
1 Parent(s): 5324549

Update model/ea_model.py

Browse files
Files changed (1) hide show
  1. model/ea_model.py +1 -0
model/ea_model.py CHANGED
@@ -99,6 +99,7 @@ class EaModel(nn.Module):
99
  base_model = KVMixtralForCausalLM.from_pretrained(
100
  base_model_path, **kwargs
101
  )
 
102
 
103
  configpath=os.path.join(ea_model_path,"config.json")
104
  if not os.path.exists(configpath):
 
99
  base_model = KVMixtralForCausalLM.from_pretrained(
100
  base_model_path, **kwargs
101
  )
102
+ base_model.cuda()
103
 
104
  configpath=os.path.join(ea_model_path,"config.json")
105
  if not os.path.exists(configpath):