cmeraki commited on
Commit
c149c23
·
1 Parent(s): 4730a41

Added pipeline

Browse files
Files changed (1) hide show
  1. pipeline.py +113 -0
pipeline.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ import numpy as np
4
+ from transformers import MimiModel, GenerationConfig
5
+ from transformers import Pipeline
6
+
7
+ class IndriPipeline(Pipeline):
8
+ def __init__(self, *args, **kwargs):
9
+ super().__init__(*args, **kwargs)
10
+
11
+ self.audio_tokenizer = MimiModel.from_pretrained('kyutai/mimi').to(device=self.device)
12
+
13
+ # TODO: Ideally all of this should come from model config
14
+ self.convert_token = self.tokenizer.encode('[convert]')
15
+ self.stop_token = self.tokenizer.encode('[stop]')
16
+ self.text_modality_token = self.tokenizer.encode('[text]')
17
+ self.acoustic_modality_token = self.tokenizer.encode('[mimi]')
18
+ self.num_codebooks = 8
19
+ self.audio_offset = 50257
20
+
21
+ self.model.generation_config = GenerationConfig(
22
+ eos_token_id=self.stop_token,
23
+ max_length=kwargs.get('max_length', 1024),
24
+ temperature=kwargs.get('temperature', 0.5),
25
+ top_k=kwargs.get('top_k', 15),
26
+ do_sample=kwargs.get('do_sample', True)
27
+ )
28
+
29
+ def _sanitize_parameters(self, **kwargs):
30
+ task = kwargs.get('task', 'tts')
31
+ assert task in ['tts', 'asr'], f'Task must be one of tts, asr. You provided: {task}'
32
+
33
+ speaker = kwargs.get('speaker', '[spkr_unk]')
34
+
35
+ preprocess_kwargs = {
36
+ 'task': task,
37
+ 'speaker': speaker
38
+ }
39
+
40
+ return preprocess_kwargs, {}, {}
41
+
42
+ def _prepare_tts_tokens(self, text_tokens, speaker):
43
+ input_tokens = np.hstack([
44
+ self.text_modality_token,
45
+ text_tokens,
46
+ self.convert_token,
47
+ self.acoustic_modality_token,
48
+ self.tokenizer.encode(speaker)
49
+ ])
50
+
51
+ return input_tokens.tolist()
52
+
53
+ def _prepare_asr_tokens(self, audio_tokens):
54
+ pass
55
+
56
+ def _sanitize_text(self, text):
57
+ text = text.lower()
58
+ text = re.sub(r'\n+', ' ', text)
59
+ text = re.sub(r'[ \t]+', ' ', text)
60
+
61
+ text = re.sub(r'([,\.?])+', r'\1', text)
62
+
63
+ return text.strip()
64
+
65
+ def _deserialize_tokens(self, tokens, num_codebooks):
66
+ cb = [tokens[i::num_codebooks] for i in range(num_codebooks)]
67
+ min_shape = min([c.shape for c in cb])[0]
68
+ acoustic_tokens = torch.vstack([c[:min_shape] - 2048*i for i, c in enumerate(cb)])
69
+
70
+ return acoustic_tokens
71
+
72
+ def preprocess(self, inputs, speaker, task):
73
+ # TODO: Check for batching
74
+ if task == 'tts':
75
+ input_text = self._sanitize_text(inputs)
76
+ input_tokens = self.tokenizer.encode(input_text)
77
+ task_tokens = self._prepare_tts_tokens(input_tokens, speaker)
78
+ task_tokens = torch.tensor(task_tokens).unsqueeze(0)
79
+
80
+ elif task == 'asr':
81
+ raise ValueError('ASR task is not yet supported')
82
+
83
+ return {'task_tokens': task_tokens}
84
+
85
+ def _forward(self, model_inputs, **forward_args):
86
+
87
+ outputs = self.model.generate(model_inputs['task_tokens'])
88
+ audio_tokens = []
89
+
90
+ for idx, inputs in enumerate(model_inputs['task_tokens']):
91
+ truncated = outputs[idx, inputs.shape[-1]:]
92
+ end = torch.where(truncated == self.stop_token[0])[-1]
93
+
94
+ if end.shape[-1] > 0:
95
+ end = end[0]
96
+ else:
97
+ end = truncated.shape[-1]
98
+
99
+ truncated = truncated[:end]
100
+ truncated -= self.audio_offset
101
+ truncated = self._deserialize_tokens(torch.tensor(truncated), self.num_codebooks)
102
+ audio_tokens.append(truncated)
103
+
104
+ audio_tokens = torch.vstack(audio_tokens).unsqueeze(0)
105
+ audio = self.audio_tokenizer.decode(audio_tokens).audio_values
106
+
107
+ return {
108
+ 'audio_tokens': audio_tokens, # (B, num_codebooks, num_samples)
109
+ 'audio': audio # (B, 1, num_audio_samples)
110
+ }
111
+
112
+ def postprocess(self, model_outputs):
113
+ return model_outputs