shangeth commited on
Commit
ec9a712
1 Parent(s): 101c04c

quantization added

Browse files
Files changed (3) hide show
  1. app.py +1 -1
  2. model.py +1 -1
  3. trainer.py +10 -1
app.py CHANGED
@@ -34,7 +34,7 @@ def plot_mel_spectrogram(mel_spec):
34
  def get_or_load_model():
35
  if 'model' not in st.session_state or 'tokenizer' not in st.session_state or 'processor' not in st.session_state:
36
  ckpt_path = "checkpoints/pretrained_checkpoint.ckpt"
37
- model = SpeechLLMLightning.load_from_checkpoint(ckpt_path)
38
  tokenizer = model.llm_tokenizer
39
  model.eval()
40
  model.freeze()
 
34
  def get_or_load_model():
35
  if 'model' not in st.session_state or 'tokenizer' not in st.session_state or 'processor' not in st.session_state:
36
  ckpt_path = "checkpoints/pretrained_checkpoint.ckpt"
37
+ model = SpeechLLMLightning.load_from_checkpoint(ckpt_path, quantize=True)
38
  tokenizer = model.llm_tokenizer
39
  model.eval()
40
  model.freeze()
model.py CHANGED
@@ -13,7 +13,7 @@ else:
13
  class HubertXCNNEnoder(nn.Module):
14
  def __init__(self, audio_enc_dim, llm_dim, finetune=False):
15
  super().__init__()
16
- self.encoder = HubertModel.from_pretrained('facebook/hubert-xlarge-ll60k', device_map = device)
17
  for param in self.encoder.parameters():
18
  param.requires_grad = False
19
 
 
13
  class HubertXCNNEnoder(nn.Module):
14
  def __init__(self, audio_enc_dim, llm_dim, finetune=False):
15
  super().__init__()
16
+ self.encoder = HubertModel.from_pretrained('facebook/hubert-xlarge-ll60k').to(device)
17
  for param in self.encoder.parameters():
18
  param.requires_grad = False
19
 
trainer.py CHANGED
@@ -6,6 +6,9 @@ from peft import LoraConfig, get_peft_model, PeftModel
6
  import pytorch_lightning as pl
7
  from model import HubertXCNNEnoder
8
 
 
 
 
9
 
10
  if torch.cuda.is_available():
11
  # Set the device to CUDA
@@ -15,7 +18,7 @@ else:
15
  device = "cpu"
16
 
17
  class SpeechLLMLightning(pl.LightningModule):
18
- def __init__(self, audio_enc_dim=512, llm_dim=2048, llm_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0"):
19
  super().__init__()
20
  self.save_hyperparameters()
21
 
@@ -48,6 +51,12 @@ class SpeechLLMLightning(pl.LightningModule):
48
  self.audio_encoder.eval()
49
  self.llm_model.eval()
50
 
 
 
 
 
 
 
51
 
52
  def encode(self, mel, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids):
53
  batch_size = mel.shape[0]
 
6
  import pytorch_lightning as pl
7
  from model import HubertXCNNEnoder
8
 
9
+ from torch.quantization import quantize_dynamic
10
+ import torch.jit as jit
11
+
12
 
13
  if torch.cuda.is_available():
14
  # Set the device to CUDA
 
18
  device = "cpu"
19
 
20
  class SpeechLLMLightning(pl.LightningModule):
21
+ def __init__(self, audio_enc_dim=512, llm_dim=2048, llm_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0", quantize=True):
22
  super().__init__()
23
  self.save_hyperparameters()
24
 
 
51
  self.audio_encoder.eval()
52
  self.llm_model.eval()
53
 
54
+ if quantize:
55
+ self.llm_model = jit.script(self.llm_model)
56
+ self.llm_model = quantize_dynamic(
57
+ self.llm_model, {nn.Linear}, dtype=torch.qint8
58
+ )
59
+
60
 
61
  def encode(self, mel, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids):
62
  batch_size = mel.shape[0]