googlefan commited on
Commit
aaf7d72
·
verified ·
1 Parent(s): 41accf6

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +82 -0
README.md ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - ja
4
+ base_model:
5
+ - google/gemma-2-2b-jpn-it
6
+ pipeline_tag: any-to-any
7
+ ---
8
+ ```py
9
+ import transformers
10
+ import librosa
11
+ import torch
12
+ import numpy as np
13
+ from typing import Dict, Any
14
+
15
+ model = transformers.AutoModel.from_pretrained(
16
+ "neody/ultravox-gemma-2-2b-jpn-it", trust_remote_code=True
17
+ )
18
+ model.to("cuda", dtype=torch.bfloat16)
19
+ processor = transformers.AutoProcessor.from_pretrained(
20
+ "neody/ultravox-gemma-2-2b-jpn-it", trust_remote_code=True
21
+ )
22
+ path = "record.wav"
23
+ audio, sr = librosa.load(path, sr=16000)
24
+
25
+
26
+ def preprocess(inputs: Dict[str, Any], device, dtype):
27
+ turns: list = inputs.get("turns", [])
28
+
29
+ audio = inputs.get("audio", None)
30
+ # Convert to float32 if needed.
31
+ if isinstance(audio, np.ndarray):
32
+ if audio.dtype == np.float64:
33
+ audio = audio.astype(np.float32)
34
+ elif audio.dtype == np.int16:
35
+ audio = audio.astype(np.float32) / np.float32(32768.0)
36
+ elif audio.dtype == np.int32:
37
+ audio = audio.astype(np.float32) / np.float32(2147483648.0)
38
+
39
+ if audio is not None and (len(turns) == 0 or turns[-1]["role"] != "user"):
40
+ prompt = inputs.get("prompt", "<|audio|>")
41
+ if "<|audio|>" not in prompt:
42
+ print(
43
+ "Prompt does not contain '<|audio|>', appending '<|audio|>' to the end of the prompt."
44
+ )
45
+
46
+ prompt += " <|audio|>"
47
+ turns.append({"role": "user", "content": prompt})
48
+
49
+ text = processor.tokenizer.apply_chat_template(
50
+ turns, add_generation_prompt=True, tokenize=False
51
+ )
52
+
53
+ if "sampling_rate" not in inputs and audio is not None:
54
+ print(
55
+ "No sampling rate provided, using default of 16kHz. We highly recommend providing the correct sampling rate."
56
+ )
57
+
58
+ output = processor(
59
+ text=text,
60
+ audio=audio,
61
+ sampling_rate=inputs.get("sampling_rate", 16000),
62
+ )
63
+ if "audio_values" in output:
64
+ output["audio_values"] = output["audio_values"].to(device, dtype)
65
+ return output.to(device, dtype)
66
+
67
+
68
+ turns = []
69
+ print(
70
+ processor.tokenizer.decode(
71
+ model.generate(
72
+ **preprocess(
73
+ {"audio": audio, "turns": turns, "sampling_rate": sr},
74
+ "cuda",
75
+ torch.bfloat16,
76
+ ),
77
+ max_new_tokens=300,
78
+ ).squeeze(),
79
+ skip_special_tokens=True,
80
+ )
81
+ )
82
+ ```