googlefan commited on
Commit
1589c32
1 Parent(s): 8426191

Create ultravox_pipeline.py

Browse files
Files changed (1) hide show
  1. ultravox_pipeline.py +127 -0
ultravox_pipeline.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ import numpy as np
5
+ import transformers
6
+
7
+ # We must use relative import in this directory to allow uploading to HF Hub
8
+ # Even "from . import X" pattern doesn't work (undocumented and unclear why)
9
+ from .ultravox_model import UltravoxModel
10
+ from .ultravox_processing import UltravoxProcessor
11
+
12
+
13
+ class UltravoxPipeline(transformers.Pipeline):
14
+ def __init__(
15
+ self,
16
+ model: UltravoxModel,
17
+ tokenizer: Optional[transformers.PreTrainedTokenizerBase] = None,
18
+ audio_processor: Optional[transformers.ProcessorMixin] = None,
19
+ **kwargs
20
+ ):
21
+ if tokenizer is None:
22
+ try:
23
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
24
+ model.config._name_or_path
25
+ )
26
+ except:
27
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
28
+ model.config.text_model_id or model.config.text_config._name_or_path
29
+ )
30
+
31
+ if audio_processor is None:
32
+ audio_processor = transformers.AutoProcessor.from_pretrained(
33
+ model.config.audio_model_id or model.config.audio_config._name_or_path
34
+ )
35
+
36
+ super().__init__(model=model, tokenizer=tokenizer, **kwargs)
37
+
38
+ self.processor = UltravoxProcessor(
39
+ audio_processor=audio_processor,
40
+ tokenizer=tokenizer,
41
+ stack_factor=model.config.stack_factor,
42
+ )
43
+
44
+ def _sanitize_parameters(self, **kwargs):
45
+ generation_keys = ["temperature", "max_new_tokens", "repetition_penalty"]
46
+ generation_kwargs = {k: kwargs[k] for k in kwargs if k in generation_keys}
47
+ return {}, generation_kwargs, {}
48
+
49
+ def preprocess(self, inputs: Dict[str, Any]):
50
+ turns: list = inputs.get("turns", [])
51
+
52
+ audio = inputs.get("audio", None)
53
+ # Convert to float32 if needed.
54
+ if isinstance(audio, np.ndarray):
55
+ if audio.dtype == np.float64:
56
+ audio = audio.astype(np.float32)
57
+ elif audio.dtype == np.int16:
58
+ audio = audio.astype(np.float32) / np.float32(32768.0)
59
+ elif audio.dtype == np.int32:
60
+ audio = audio.astype(np.float32) / np.float32(2147483648.0)
61
+
62
+ if audio is not None and (len(turns) == 0 or turns[-1]["role"] != "user"):
63
+ prompt = inputs.get("prompt", "<|audio|>")
64
+ if "<|audio|>" not in prompt:
65
+ logging.warning(
66
+ "Prompt does not contain '<|audio|>', appending '<|audio|>' to the end of the prompt."
67
+ )
68
+
69
+ prompt += " <|audio|>"
70
+ turns.append({"role": "user", "content": prompt})
71
+
72
+ text = self.processor.tokenizer.apply_chat_template(
73
+ turns, add_generation_prompt=True, tokenize=False
74
+ )
75
+
76
+ if "sampling_rate" not in inputs and audio is not None:
77
+ logging.warning(
78
+ "No sampling rate provided, using default of 16kHz. We highly recommend providing the correct sampling rate."
79
+ )
80
+
81
+ output = self.processor(
82
+ text=text,
83
+ audio=audio,
84
+ sampling_rate=inputs.get("sampling_rate", 16000),
85
+ )
86
+ if "audio_values" in output:
87
+ output["audio_values"] = output["audio_values"].to(self.model.dtype)
88
+
89
+ return output
90
+
91
+ def _forward(
92
+ self,
93
+ model_inputs: Dict[str, Any],
94
+ temperature: Optional[float] = None,
95
+ max_new_tokens: Optional[int] = None,
96
+ repetition_penalty: float = 1.1,
97
+ ) -> List[int]:
98
+ temperature = temperature or None
99
+ do_sample = temperature is not None
100
+
101
+ terminators = [self.tokenizer.eos_token_id]
102
+ if "<|eot_id|>" in self.tokenizer.added_tokens_encoder:
103
+ terminators.append(self.tokenizer.convert_tokens_to_ids("<|eot_id|>"))
104
+
105
+ input_len = model_inputs["input_ids"].shape[1]
106
+
107
+ outputs = self.model.generate(
108
+ **model_inputs,
109
+ do_sample=do_sample,
110
+ temperature=temperature,
111
+ max_new_tokens=max_new_tokens,
112
+ repetition_penalty=repetition_penalty,
113
+ eos_token_id=terminators
114
+ )
115
+ return outputs[0][input_len:]
116
+
117
+ def postprocess(self, model_outputs) -> str:
118
+ output_text = self.tokenizer.decode(model_outputs, skip_special_tokens=True)
119
+ return output_text
120
+
121
+
122
+ transformers.pipelines.PIPELINE_REGISTRY.register_pipeline(
123
+ "ultravox-pipeline",
124
+ pipeline_class=UltravoxPipeline,
125
+ pt_model=transformers.AutoModel,
126
+ type="multimodal",
127
+ )