HaryaniAnjali commited on
Commit
53b77bd
·
verified ·
1 Parent(s): 6069c51

Delete emotion_model.py

Browse files
Files changed (1) hide show
  1. emotion_model.py +0 -52
emotion_model.py DELETED
@@ -1,52 +0,0 @@
1
- import tensorflow as tf
2
- import torch
3
- import numpy as np
4
- import os
5
-
6
- # Load the Keras model
7
- keras_model = tf.keras.models.load_model('wav2vec_model.h5')
8
-
9
- # Create a PyTorch model with the same architecture
10
- class EmotionClassifier(torch.nn.Module):
11
- def __init__(self, input_shape, num_classes):
12
- super().__init__()
13
- # Adjust this architecture to match your Keras model
14
- self.flatten = torch.nn.Flatten()
15
- self.layers = torch.nn.Sequential(
16
- torch.nn.Linear(input_shape, 128),
17
- torch.nn.ReLU(),
18
- torch.nn.Dropout(0.3),
19
- torch.nn.Linear(128, 64),
20
- torch.nn.ReLU(),
21
- torch.nn.Dropout(0.3),
22
- torch.nn.Linear(64, num_classes)
23
- )
24
-
25
- def forward(self, x):
26
- x = self.flatten(x)
27
- return self.layers(x)
28
-
29
- # Create PyTorch model
30
- # Adjust these parameters based on your Keras model
31
- input_shape = 13 * 128 # n_mfcc * max_length
32
- num_classes = 7 # Number of emotions
33
- pytorch_model = EmotionClassifier(input_shape, num_classes)
34
-
35
- # Copy weights from Keras to PyTorch
36
- # This would need to be adjusted based on your exact architecture
37
- for i, layer in enumerate(keras_model.layers):
38
- if isinstance(layer, tf.keras.layers.Dense):
39
- # Get Keras weights and bias
40
- keras_weights = layer.get_weights()[0]
41
- keras_bias = layer.get_weights()[1]
42
-
43
- # Find the corresponding PyTorch layer
44
- # This is simplified; you'd need to match layers properly
45
- pytorch_layer = pytorch_model.layers[i * 2]
46
-
47
- # Copy weights and bias
48
- pytorch_layer.weight.data = torch.tensor(keras_weights.T, dtype=torch.float32)
49
- pytorch_layer.bias.data = torch.tensor(keras_bias, dtype=torch.float32)
50
-
51
- # Save the PyTorch model
52
- torch.save(pytorch_model.state_dict(), 'emotion_model.pt')