Ar4ikov commited on
Commit
c967f9a
1 Parent(s): ec17185

Update wav2vec2speechclassification.py

Browse files
Files changed (1) hide show
  1. wav2vec2speechclassification.py +3 -5
wav2vec2speechclassification.py CHANGED
@@ -2,7 +2,7 @@ from dataclasses import dataclass
2
  from typing import Optional, Tuple
3
  import torch
4
  from transformers.file_utils import ModelOutput
5
- from transformers import Wav2Vec2Config
6
 
7
 
8
  @dataclass
@@ -25,8 +25,7 @@ from transformers.models.wav2vec2.modeling_wav2vec2 import (
25
 
26
  class Wav2Vec2ClassificationHead(nn.Module):
27
  """Head for wav2vec classification task."""
28
- config_class = Wav2Vec2Config
29
- model_type = "wav2vec2"
30
 
31
  def __init__(self, config):
32
  super().__init__()
@@ -45,8 +44,7 @@ class Wav2Vec2ClassificationHead(nn.Module):
45
 
46
 
47
  class Wav2Vec2ForSpeechClassification(Wav2Vec2PreTrainedModel):
48
- config_class = Wav2Vec2Config
49
- model_type = "wav2vec2"
50
 
51
  def __init__(self, config):
52
  super().__init__(config)
 
2
  from typing import Optional, Tuple
3
  import torch
4
  from transformers.file_utils import ModelOutput
5
+ from .wav2vec2fsr_config import W2V2FSRConfig
6
 
7
 
8
  @dataclass
 
25
 
26
  class Wav2Vec2ClassificationHead(nn.Module):
27
  """Head for wav2vec classification task."""
28
+ config_class = W2V2FSRConfig
 
29
 
30
  def __init__(self, config):
31
  super().__init__()
 
44
 
45
 
46
  class Wav2Vec2ForSpeechClassification(Wav2Vec2PreTrainedModel):
47
+ config_class = W2V2FSRConfig
 
48
 
49
  def __init__(self, config):
50
  super().__init__(config)