wangleiofficial commited on
Commit
82dc5f9
·
verified ·
1 Parent(s): cf48cb9

fix param err

Browse files
Files changed (1) hide show
  1. dnaflash.py +1 -1
dnaflash.py CHANGED
@@ -450,7 +450,7 @@ class FLASHTransformerForSequenceClassification(FLASHTransformerForPretrained):
450
 
451
  # 获取基模型输出
452
  outputs = super().forward(
453
- input_ids=input_ids
454
  )
455
  hidden_states = outputs["hidden_states"]
456
  input_mask_expanded = input_ids["attention_mask"].unsqueeze(-1).expand(hidden_states.size()) # 维度匹配
 
450
 
451
  # 获取基模型输出
452
  outputs = super().forward(
453
+ input_ids
454
  )
455
  hidden_states = outputs["hidden_states"]
456
  input_mask_expanded = input_ids["attention_mask"].unsqueeze(-1).expand(hidden_states.size()) # 维度匹配