Safetensors
English
llama
sound language model
jan-hq commited on
Commit
a6537fa
1 Parent(s): 249da5f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +76 -2
README.md CHANGED
@@ -32,16 +32,90 @@ We expand the Semantic tokens experiment with WhisperVQ as a tokenizer for audio
32
 
33
  ## How to Get Started with the Model
34
 
 
 
35
  First, we need to convert the audio file to sound tokens
36
 
37
  ```python
38
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  ```
40
 
41
  Then, we can inference the model the same as any other LLM.
42
 
43
  ```python
44
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  ```
46
 
47
  ## Training process
 
32
 
33
  ## How to Get Started with the Model
34
 
35
+ Try this model using [Google Colab Notebook](https://colab.research.google.com/drive/18IiwN0AzBZaox5o0iidXqWD1xKq11XbZ?usp=sharing).
36
+
37
  First, we need to convert the audio file to sound tokens
38
 
39
  ```python
40
+ device = "cuda" if torch.cuda.is_available() else "cpu"
41
+ if not os.path.exists("whisper-vq-stoks-medium-en+pl-fixed.model"):
42
+ hf_hub_download(
43
+ repo_id="jan-hq/WhisperVQ",
44
+ filename="whisper-vq-stoks-medium-en+pl-fixed.model",
45
+ local_dir=".",
46
+ )
47
+ vq_model = RQBottleneckTransformer.load_model(
48
+ "whisper-vq-stoks-medium-en+pl-fixed.model"
49
+ ).to(device)
50
+ def audio_to_sound_tokens(audio_path, target_bandwidth=1.5, device=device):
51
+ vq_model.ensure_whisper(device)
52
+
53
+ wav, sr = torchaudio.load(audio_path)
54
+ if sr != 16000:
55
+ wav = torchaudio.functional.resample(wav, sr, 16000)
56
+ with torch.no_grad():
57
+ codes = vq_model.encode_audio(wav.to(device))
58
+ codes = codes[0].cpu().tolist()
59
+
60
+ result = ''.join(f'<|sound_{num:04d}|>' for num in codes)
61
+ return f'<|sound_start|>{result}<|sound_end|>'
62
+
63
+ def audio_to_sound_tokens_transcript(audio_path, target_bandwidth=1.5, device=device):
64
+ vq_model.ensure_whisper(device)
65
+
66
+ wav, sr = torchaudio.load(audio_path)
67
+ if sr != 16000:
68
+ wav = torchaudio.functional.resample(wav, sr, 16000)
69
+ with torch.no_grad():
70
+ codes = vq_model.encode_audio(wav.to(device))
71
+ codes = codes[0].cpu().tolist()
72
+
73
+ result = ''.join(f'<|sound_{num:04d}|>' for num in codes)
74
+ return f'<|reserved_special_token_69|><|sound_start|>{result}<|sound_end|>'
75
  ```
76
 
77
  Then, we can inference the model the same as any other LLM.
78
 
79
  ```python
80
+ def setup_pipeline(model_path, use_4bit=False, use_8bit=False):
81
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
82
+
83
+ model_kwargs = {"device_map": "auto"}
84
+
85
+ if use_4bit:
86
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(
87
+ load_in_4bit=True,
88
+ bnb_4bit_compute_dtype=torch.bfloat16,
89
+ bnb_4bit_use_double_quant=True,
90
+ bnb_4bit_quant_type="nf4",
91
+ )
92
+ elif use_8bit:
93
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(
94
+ load_in_8bit=True,
95
+ bnb_8bit_compute_dtype=torch.bfloat16,
96
+ bnb_8bit_use_double_quant=True,
97
+ )
98
+ else:
99
+ model_kwargs["torch_dtype"] = torch.bfloat16
100
+
101
+ model = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs)
102
+
103
+ return pipeline("text-generation", model=model, tokenizer=tokenizer)
104
+
105
+ def generate_text(pipe, messages, max_new_tokens=64, temperature=0.0, do_sample=False):
106
+ generation_args = {
107
+ "max_new_tokens": max_new_tokens,
108
+ "return_full_text": False,
109
+ "temperature": temperature,
110
+ "do_sample": do_sample,
111
+ }
112
+
113
+ output = pipe(messages, **generation_args)
114
+ return output[0]['generated_text']
115
+
116
+ # Usage
117
+ llm_path = "homebrewltd/llama3.1-s-instruct-v0.2"
118
+ pipe = setup_pipeline(llm_path, use_8bit=True)
119
  ```
120
 
121
  ## Training process