kevinwang676 commited on
Commit
79cb6e1
·
1 Parent(s): 3e9f9d7

Upload 4 files

Browse files
training/__init__.py ADDED
File without changes
training/data.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import requests
3
+ import os, glob
4
+
5
+ # english literature
6
+ books = [
7
+ 'https://www.gutenberg.org/cache/epub/1513/pg1513.txt',
8
+ 'https://www.gutenberg.org/files/2701/2701-0.txt',
9
+ 'https://www.gutenberg.org/cache/epub/84/pg84.txt',
10
+ 'https://www.gutenberg.org/cache/epub/2641/pg2641.txt',
11
+ 'https://www.gutenberg.org/cache/epub/1342/pg1342.txt',
12
+ 'https://www.gutenberg.org/cache/epub/100/pg100.txt'
13
+ ]
14
+
15
+ #default english
16
+ # allowed_chars = ' abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*()-_+=\"\':;[]{}/<>,.`~\n\\'
17
+
18
+ #german
19
+ allowed_chars = ' aäbcdefghijklmnoöpqrsßtuüvwxyzABCDEFGHIJKLMNOÖPQRSTUÜVWXYZ0123456789!@#$%^&*()-_+=\"\':;[]{}/<>,.`~\n\\'
20
+
21
+
22
+ def download_book(book):
23
+ return requests.get(book).content.decode('utf-8')
24
+
25
+
26
+ def filter_data(data):
27
+ print('Filtering data')
28
+ return ''.join([char for char in data if char in allowed_chars])
29
+
30
+
31
+ def load_books(fromfolder=False):
32
+ text_data = []
33
+ if fromfolder:
34
+ current_working_directory = os.getcwd()
35
+ print(current_working_directory)
36
+ path = 'text'
37
+ for filename in glob.glob(os.path.join(path, '*.txt')):
38
+ with open(os.path.join(os.getcwd(), filename), 'r') as f: # open in readonly mode
39
+ print(f'Loading {filename}')
40
+ text_data.append(filter_data(str(f.read())))
41
+ else:
42
+ print(f'Loading {len(books)} books into ram')
43
+ for book in books:
44
+ text_data.append(filter_data(str(download_book(book))))
45
+ print('Loaded books')
46
+ return ' '.join(text_data)
47
+
48
+
49
+ def random_split_chunk(data, size=14):
50
+ data = data.split(' ')
51
+ index = random.randrange(0, len(data))
52
+ return ' '.join(data[index:index+size])
training/train.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import fnmatch
3
+ import shutil
4
+
5
+ import numpy
6
+ import torchaudio
7
+ import gradio
8
+
9
+ from bark.hubert.pre_kmeans_hubert import CustomHubert
10
+ from bark.hubert.customtokenizer import auto_train
11
+ from tqdm.auto import tqdm
12
+
13
+
14
+ def training_prepare_files(path, model,progress=gradio.Progress(track_tqdm=True)):
15
+
16
+ semanticsfolder = "./training/data/output"
17
+ wavfolder = "./training/data/output_wav"
18
+ ready = os.path.join(path, 'ready')
19
+
20
+ testfiles = fnmatch.filter(os.listdir(ready), '*.npy')
21
+ if(len(testfiles) < 1):
22
+ # prepare and copy for training
23
+ hubert_model = CustomHubert(checkpoint_path=model)
24
+
25
+ wavfiles = fnmatch.filter(os.listdir(wavfolder), '*.wav')
26
+ for i, f in tqdm(enumerate(wavfiles), total=len(wavfiles)):
27
+ semaname = '.'.join(f.split('.')[:-1]) # Cut off the extension
28
+ semaname = f'{semaname}.npy'
29
+ semafilename = os.path.join(semanticsfolder, semaname)
30
+ if not os.path.isfile(semafilename):
31
+ print(f'Skipping {f} no semantics pair found!')
32
+ continue
33
+
34
+ print('Processing', f)
35
+ wav, sr = torchaudio.load(os.path.join(wavfolder, f))
36
+ if wav.shape[0] == 2: # Stereo to mono if needed
37
+ wav = wav.mean(0, keepdim=True)
38
+ output = hubert_model.forward(wav, input_sample_hz=sr)
39
+ out_array = output.cpu().numpy()
40
+ fname = f'{i}_semantic_features.npy'
41
+ numpy.save(os.path.join(ready, fname), out_array)
42
+ fname = f'{i}_semantic.npy'
43
+ shutil.copy(semafilename, os.path.join(ready, fname))
44
+
45
+ def train(path, save_every, max_epochs):
46
+ auto_train(path, save_epochs=save_every)
47
+
training/training_prepare.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import uuid
3
+ import numpy
4
+ import os
5
+ import random
6
+ import fnmatch
7
+
8
+ from tqdm.auto import tqdm
9
+ from scipy.io import wavfile
10
+
11
+ from bark.generation import load_model, SAMPLE_RATE
12
+ from bark.api import semantic_to_waveform
13
+
14
+ from bark import text_to_semantic
15
+ from bark.generation import load_model
16
+
17
+ from training.data import load_books, random_split_chunk
18
+
19
+ output = 'training/data/output'
20
+ output_wav = 'training/data/output_wav'
21
+
22
+
23
+ def prepare_semantics_from_text(num_generations):
24
+ loaded_data = load_books(True)
25
+
26
+ print('Loading semantics model')
27
+ load_model(use_gpu=True, use_small=False, force_reload=False, model_type='text')
28
+
29
+ if not os.path.isdir(output):
30
+ os.mkdir(output)
31
+
32
+ loop = 1
33
+ while 1:
34
+ filename = uuid.uuid4().hex + '.npy'
35
+ file_name = os.path.join(output, filename)
36
+ text = ''
37
+ while not len(text) > 0:
38
+ text = random_split_chunk(loaded_data) # Obtain a short chunk of text
39
+ text = text.strip()
40
+ print(f'{loop} Generating semantics for text:', text)
41
+ loop+=1
42
+ semantics = text_to_semantic(text, temp=round(random.uniform(0.6, 0.8), ndigits=2))
43
+ numpy.save(file_name, semantics)
44
+
45
+
46
+ def prepare_wavs_from_semantics():
47
+ if not os.path.isdir(output):
48
+ raise Exception('No \'output\' folder, make sure you run create_data.py first!')
49
+ if not os.path.isdir(output_wav):
50
+ os.mkdir(output_wav)
51
+
52
+ print('Loading coarse model')
53
+ load_model(use_gpu=True, use_small=False, force_reload=False, model_type='coarse')
54
+ print('Loading fine model')
55
+ load_model(use_gpu=True, use_small=False, force_reload=False, model_type='fine')
56
+
57
+ files = fnmatch.filter(os.listdir(output), '*.npy')
58
+ current = 1
59
+ total = len(files)
60
+
61
+ for i, f in tqdm(enumerate(files), total=len(files)):
62
+ real_name = '.'.join(f.split('.')[:-1]) # Cut off the extension
63
+ file_name = os.path.join(output, f)
64
+ out_file = os.path.join(output_wav, f'{real_name}.wav')
65
+ if not os.path.isfile(out_file) and os.path.isfile(file_name): # Don't process files that have already been processed, to be able to continue previous generations
66
+ print(f'Processing ({i+1}/{total}) -> {f}')
67
+ wav = semantic_to_waveform(numpy.load(file_name), temp=round(random.uniform(0.6, 0.8), ndigits=2))
68
+ # Change to PCM16
69
+ # wav = (wav * 32767).astype(np.int16)
70
+ wavfile.write(out_file, SAMPLE_RATE, wav)
71
+
72
+ print('Done!')
73
+