File size: 3,885 Bytes
6917624
 
 
 
 
 
62cc942
d66aec1
 
6917624
 
 
 
 
d66aec1
6917624
 
 
 
 
 
 
 
 
 
 
 
 
 
d66aec1
6917624
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62cc942
6917624
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import logging
from typing import Any, Dict, List, Optional

import transformers

# We must use relative import in this directory to allow uploading to HF Hub
# Even "from . import X" pattern doesn't work (undocumented and unclear why)
from .ultravox_model import UltravoxModel
from .ultravox_processing import UltravoxProcessor


class UltravoxPipeline(transformers.Pipeline):
    def __init__(
        self,
        model: UltravoxModel,
        tokenizer: Optional[transformers.PreTrainedTokenizerBase] = None,
        audio_processor: Optional[transformers.ProcessorMixin] = None,
        **kwargs
    ):
        if tokenizer is None:
            tokenizer = transformers.AutoTokenizer.from_pretrained(
                model.config._name_or_path
            )

        if audio_processor is None:
            audio_processor = transformers.Wav2Vec2Processor.from_pretrained(
                model.config.audio_model_id
            )

        self.processor = UltravoxProcessor(
            audio_processor, tokenizer=tokenizer, stack_factor=model.config.stack_factor
        )

        super().__init__(model=model, tokenizer=tokenizer, **kwargs)

    def _sanitize_parameters(self, **kwargs):
        generation_kwargs = {}
        if "temperature" in kwargs:
            generation_kwargs["temperature"] = kwargs["temperature"]
        if "max_new_tokens" in kwargs:
            generation_kwargs["max_new_tokens"] = kwargs["max_new_tokens"]
        if "repetition_penalty" in kwargs:
            generation_kwargs["repetition_penalty"] = kwargs["repetition_penalty"]
        return {}, generation_kwargs, {}

    def preprocess(self, inputs: Dict[str, Any]):
        if "turns" in inputs:
            turns = inputs["turns"]
        else:
            prompt = inputs.get("prompt", "<|audio|>")
            if "<|audio|>" not in prompt:
                logging.warning(
                    "Prompt does not contain '<|audio|>', appending '<|audio|>' to the end of the prompt."
                )
                prompt += " <|audio|>"
            turns = [{"role": "user", "content": prompt}]

        text = self.processor.tokenizer.apply_chat_template(turns, tokenize=False)

        # TODO: allow text-only mode?
        assert "audio" in inputs, "Audio input is required"

        if "sampling_rate" not in inputs:
            logging.warning(
                "No sampling rate provided, using default of 16kHz. We highly recommend providing the correct sampling rate."
            )

        return self.processor(
            text=text,
            audio=inputs["audio"],
            sampling_rate=inputs.get("sampling_rate", 16000),
        )

    def _forward(
        self,
        model_inputs: Dict[str, Any],
        temperature: Optional[float] = None,
        max_new_tokens: Optional[int] = None,
        repetition_penalty: float = 1.1,
    ) -> List[int]:
        temperature = temperature or None
        do_sample = temperature is not None

        terminators = [self.tokenizer.eos_token_id]
        if "<|eot_id|>" in self.tokenizer.added_tokens_encoder:
            terminators.append(self.tokenizer.convert_tokens_to_ids("<|eot_id|>"))

        input_len = model_inputs["input_ids"].shape[1]

        outputs = self.model.generate(
            **model_inputs,
            do_sample=do_sample,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            repetition_penalty=repetition_penalty,
            eos_token_id=terminators
        )
        return outputs[0][input_len:]

    def postprocess(self, model_outputs) -> str:
        output_text = self.tokenizer.decode(model_outputs, skip_special_tokens=True)
        return output_text


transformers.pipelines.PIPELINE_REGISTRY.register_pipeline(
    "ultravox-pipeline",
    pipeline_class=UltravoxPipeline,
    pt_model=transformers.AutoModel,
    type="multimodal",
)