nicolinho commited on
Commit
df3bc49
·
verified ·
1 Parent(s): 6b1068b

Update modeling_custom.py

Browse files
Files changed (1) hide show
  1. modeling_custom.py +4 -3
modeling_custom.py CHANGED
@@ -85,10 +85,11 @@ class CustomOutput(ModelOutput):
85
 
86
  class LlamaForRewardModelWithGating(LlamaPreTrainedModel):
87
  def __init__(self, config):
 
88
  super().__init__(config)
89
- self.model = AutoModelForSequenceClassification.from_pretrained(
90
- "sfairXC/FsfairX-LLaMA3-RM-v0.1", num_labels=1, torch_dtype=torch.bfloat16, use_flash_attention_2=True, trust_remote_code=True,)
91
- #self.model = LlamaModel(config).to(torch.bfloat16)
92
  self.num_labels = config.num_labels
93
  config_dict = config.to_dict()
94
  self.num_objectives = config_dict.get("num_objectives", 19)
 
85
 
86
  class LlamaForRewardModelWithGating(LlamaPreTrainedModel):
87
  def __init__(self, config):
88
+ config.torch_dtype = torch.bfloat16
89
  super().__init__(config)
90
+ #self.model = AutoModelForSequenceClassification.from_pretrained(
91
+ # "sfairXC/FsfairX-LLaMA3-RM-v0.1", num_labels=1, torch_dtype=torch.bfloat16, use_flash_attention_2=True, trust_remote_code=True,)
92
+ self.model = LlamaModel(config)
93
  self.num_labels = config.num_labels
94
  config_dict = config.to_dict()
95
  self.num_objectives = config_dict.get("num_objectives", 19)