Ar4ikov commited on
Commit
71c53d9
1 Parent(s): 0b7e1d7

Update wav2vec2speechclassification.py

Browse files
Files changed (1) hide show
  1. wav2vec2speechclassification.py +4 -0
wav2vec2speechclassification.py CHANGED
@@ -2,6 +2,7 @@ from dataclasses import dataclass
2
  from typing import Optional, Tuple
3
  import torch
4
  from transformers.file_utils import ModelOutput
 
5
 
6
 
7
  @dataclass
@@ -24,6 +25,7 @@ from transformers.models.wav2vec2.modeling_wav2vec2 import (
24
 
25
  class Wav2Vec2ClassificationHead(nn.Module):
26
  """Head for wav2vec classification task."""
 
27
 
28
  def __init__(self, config):
29
  super().__init__()
@@ -42,6 +44,8 @@ class Wav2Vec2ClassificationHead(nn.Module):
42
 
43
 
44
  class Wav2Vec2ForSpeechClassification(Wav2Vec2PreTrainedModel):
 
 
45
  def __init__(self, config):
46
  super().__init__(config)
47
  self.num_labels = config.num_labels
 
2
  from typing import Optional, Tuple
3
  import torch
4
  from transformers.file_utils import ModelOutput
5
+ from transformers import AutoConfig
6
 
7
 
8
  @dataclass
 
25
 
26
  class Wav2Vec2ClassificationHead(nn.Module):
27
  """Head for wav2vec classification task."""
28
+ config_class = AutoConfig
29
 
30
  def __init__(self, config):
31
  super().__init__()
 
44
 
45
 
46
  class Wav2Vec2ForSpeechClassification(Wav2Vec2PreTrainedModel):
47
+ config_class = AutoConfig
48
+
49
  def __init__(self, config):
50
  super().__init__(config)
51
  self.num_labels = config.num_labels