Spaces:
Running
Running
saeedbenadeeb
commited on
Commit
·
5fc7eb1
1
Parent(s):
6ceb71f
Lora Model Uploaded
Browse files- app.py +10 -2
- encoders/transformer.py +27 -1
- lora_only_model.pth +3 -0
- models/__init__.py +2 -1
- models/lora.py +24 -0
app.py
CHANGED
@@ -12,7 +12,7 @@ emotions = ["happy", "sad", "angry", "neutral", "fear", "disgust", "surprise"]
|
|
12 |
label_mapping = {str(idx): emotion for idx, emotion in enumerate(emotions)}
|
13 |
|
14 |
# Load the trained model
|
15 |
-
model_path = "
|
16 |
cfg = {
|
17 |
"model": {
|
18 |
"encoder": "Wav2Vec2Classifier",
|
@@ -25,9 +25,17 @@ cfg = {
|
|
25 |
}
|
26 |
}
|
27 |
model = Wav2Vec2EmotionClassifier(num_classes=len(emotions), optimizer_cfg=cfg["model"]["optimizer"])
|
28 |
-
|
|
|
|
|
|
|
29 |
model.eval()
|
30 |
|
|
|
|
|
|
|
|
|
|
|
31 |
# Optional: we define a minimum number of samples to avoid Wav2Vec2 conv errors
|
32 |
MIN_SAMPLES = 10 # or 16000 if you want at least 1 second
|
33 |
|
|
|
12 |
label_mapping = {str(idx): emotion for idx, emotion in enumerate(emotions)}
|
13 |
|
14 |
# Load the trained model
|
15 |
+
model_path = "lora_only_model.pth"
|
16 |
cfg = {
|
17 |
"model": {
|
18 |
"encoder": "Wav2Vec2Classifier",
|
|
|
25 |
}
|
26 |
}
|
27 |
model = Wav2Vec2EmotionClassifier(num_classes=len(emotions), optimizer_cfg=cfg["model"]["optimizer"])
|
28 |
+
state_dict = torch.load(model_path, map_location=torch.device("cpu"))
|
29 |
+
model.load_state_dict(state_dict, strict=False)
|
30 |
+
|
31 |
+
|
32 |
model.eval()
|
33 |
|
34 |
+
|
35 |
+
for name, param in model.named_parameters():
|
36 |
+
if param.requires_grad:
|
37 |
+
print(f"{name}: {param.data}")
|
38 |
+
|
39 |
# Optional: we define a minimum number of samples to avoid Wav2Vec2 conv errors
|
40 |
MIN_SAMPLES = 10 # or 16000 if you want at least 1 second
|
41 |
|
encoders/transformer.py
CHANGED
@@ -3,7 +3,7 @@ import torch
|
|
3 |
from torchmetrics import Accuracy, Precision, Recall, F1Score
|
4 |
from transformers import Wav2Vec2Model, Wav2Vec2ForSequenceClassification
|
5 |
import torch.nn.functional as F
|
6 |
-
|
7 |
|
8 |
class Wav2Vec2Classifier(pl.LightningModule):
|
9 |
def __init__(self, num_classes, optimizer_cfg = "Adam", l1_lambda=0.0):
|
@@ -166,6 +166,32 @@ class Wav2Vec2EmotionClassifier(pl.LightningModule):
|
|
166 |
else:
|
167 |
self.optimizer = None
|
168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
def forward(self, x, attention_mask=None):
|
170 |
return self.model(x, attention_mask=attention_mask).logits
|
171 |
|
|
|
3 |
from torchmetrics import Accuracy, Precision, Recall, F1Score
|
4 |
from transformers import Wav2Vec2Model, Wav2Vec2ForSequenceClassification
|
5 |
import torch.nn.functional as F
|
6 |
+
from models.lora import LinearWithLoRA, LoRALayer
|
7 |
|
8 |
class Wav2Vec2Classifier(pl.LightningModule):
|
9 |
def __init__(self, num_classes, optimizer_cfg = "Adam", l1_lambda=0.0):
|
|
|
166 |
else:
|
167 |
self.optimizer = None
|
168 |
|
169 |
+
# Apply LoRA
|
170 |
+
low_rank = 8
|
171 |
+
lora_alpha = 16
|
172 |
+
self.apply_lora(low_rank, lora_alpha)
|
173 |
+
|
174 |
+
def apply_lora(self, rank, alpha):
|
175 |
+
# Replace specific linear layers with LinearWithLoRA
|
176 |
+
for layer in self.model.wav2vec2.encoder.layers:
|
177 |
+
layer.attention.q_proj = LinearWithLoRA(layer.attention.q_proj, rank, alpha)
|
178 |
+
layer.attention.k_proj = LinearWithLoRA(layer.attention.k_proj, rank, alpha)
|
179 |
+
layer.attention.v_proj = LinearWithLoRA(layer.attention.v_proj, rank, alpha)
|
180 |
+
layer.attention.out_proj = LinearWithLoRA(layer.attention.out_proj, rank, alpha)
|
181 |
+
|
182 |
+
layer.feed_forward.intermediate_dense = LinearWithLoRA(layer.feed_forward.intermediate_dense, rank, alpha)
|
183 |
+
layer.feed_forward.output_dense = LinearWithLoRA(layer.feed_forward.output_dense, rank, alpha)
|
184 |
+
|
185 |
+
def state_dict(self, *args, **kwargs):
|
186 |
+
# Save only LoRA and classifier/projector parameters
|
187 |
+
state = super().state_dict(*args, **kwargs)
|
188 |
+
return {k: v for k, v in state.items() if "lora" in k or "classifier" in k or "projector" in k}
|
189 |
+
|
190 |
+
def load_state_dict(self, state_dict, strict=True):
|
191 |
+
missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False)
|
192 |
+
if missing_keys or unexpected_keys:
|
193 |
+
print(f"Missing keys: {missing_keys}")
|
194 |
+
print(f"Unexpected keys: {unexpected_keys}")
|
195 |
def forward(self, x, attention_mask=None):
|
196 |
return self.model(x, attention_mask=attention_mask).logits
|
197 |
|
lora_only_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fc2029a0dcf22d2b626533192bda3fa6098653df84be452b88c4db830a7c9216
|
3 |
+
size 8185738
|
models/__init__.py
CHANGED
@@ -1 +1,2 @@
|
|
1 |
-
from . import CTCencoder
|
|
|
|
1 |
+
from . import CTCencoder
|
2 |
+
from . import lora
|
models/lora.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class LoRALayer(nn.Module):
|
5 |
+
def __init__(self, input_dim, output_dim, rank, alpha):
|
6 |
+
super().__init__()
|
7 |
+
std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
|
8 |
+
self.A = nn.Parameter(torch.randn(input_dim, rank) * std_dev) # Low-rank matrix A
|
9 |
+
self.B = nn.Parameter(torch.zeros(rank, output_dim)) # Low-rank matrix B
|
10 |
+
self.alpha = alpha # Scaling factor
|
11 |
+
def forward(self, x):
|
12 |
+
# Apply low-rank adaptation: x + alpha * (x @ A @ B)
|
13 |
+
return self.alpha * (x @ self.A @ self.B)
|
14 |
+
|
15 |
+
|
16 |
+
class LinearWithLoRA(nn.Module):
|
17 |
+
def __init__(self, linear_layer, rank, alpha):
|
18 |
+
super().__init__()
|
19 |
+
self.linear = linear_layer # Original linear layer
|
20 |
+
self.lora = LoRALayer(linear_layer.in_features, linear_layer.out_features, rank, alpha) # LoRA layer
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
# Combine original linear layer output with LoRA adaptation
|
24 |
+
return self.linear(x) + self.lora(x)
|