Dupaja commited on
Commit
9277202
·
1 Parent(s): 0824f3c

Attempt to reduce latency by moving more to init

Browse files
Files changed (1) hide show
  1. handler.py +6 -6
handler.py CHANGED
@@ -22,19 +22,19 @@ class EndpointHandler:
22
  self.model= SpeechT5ForTextToSpeech.from_pretrained(checkpoint)
23
  self.processor = SpeechT5Processor.from_pretrained(checkpoint)
24
  self.vocoder = SpeechT5HifiGan.from_pretrained(vocoder_id)
25
- self.embeddings_dataset = load_dataset(dataset_id, split="validation")
 
 
 
26
 
27
 
28
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
29
 
30
  given_text = data.get("inputs", "")
31
-
32
-
33
- speaker_embeddings = torch.tensor(self.embeddings_dataset[7306]["xvector"]).unsqueeze(0)
34
-
35
  inputs = self.processor(text=given_text, return_tensors="pt")
36
 
37
- speech = self.model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=self.vocoder)
38
 
39
 
40
 
 
22
  self.model= SpeechT5ForTextToSpeech.from_pretrained(checkpoint)
23
  self.processor = SpeechT5Processor.from_pretrained(checkpoint)
24
  self.vocoder = SpeechT5HifiGan.from_pretrained(vocoder_id)
25
+ embeddings_dataset = load_dataset(dataset_id, split="validation")
26
+ self.embeddings_dataset = embeddings_dataset
27
+ self.speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
28
+
29
 
30
 
31
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
32
 
33
  given_text = data.get("inputs", "")
34
+
 
 
 
35
  inputs = self.processor(text=given_text, return_tensors="pt")
36
 
37
+ speech = self.model.generate_speech(inputs["input_ids"], self.speaker_embeddings, vocoder=self.vocoder)
38
 
39
 
40