File size: 4,148 Bytes
222e3bd
e935ff6
ea6a933
83ebf46
ef93563
c79df46
ef93563
 
6f94cd7
74e078c
 
8e17b76
f2cf91b
ef93563
 
 
8e17b76
294478a
f2cf91b
 
044fea4
f2cf91b
9366d4b
 
2e101cc
5c4653b
 
 
 
 
 
f2cf91b
5c4653b
 
 
 
a65b632
044fea4
 
22b65d9
a65b632
477ec86
9366d4b
044fea4
 
 
 
a65b632
22b65d9
9366d4b
044fea4
5c4653b
1cd6967
74e078c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ae721a
dcbacca
76e9cbc
74e078c
 
 
 
 
 
 
 
 
 
 
 
 
76e9cbc
74e078c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76e9cbc
294478a
4ea5474
294478a
74e078c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import os
import sys
import requests
import json
from huggingface_hub import HfApi

# start xVASynth service (no HTTP)
import resources.app.no_server as xvaserver

from gr_client import BlocksDemo

# model
hf_model_name = "Pendrokar/xvapitch_nvidia"
model_repo = HfApi()
commits = model_repo.list_repo_commits(repo_id=hf_model_name)
latest_commit_sha = commits[0].commit_id
hf_cache_models_path = f'/home/user/.cache/huggingface/hub/models--Pendrokar--xvapitch_nvidia/snapshots/{latest_commit_sha}/'
models_path = hf_cache_models_path

current_voice_model = None
base_speaker_emb = ''

def load_model(voice_model_name):
	model_path =  models_path + voice_model_name

	model_type = 'xVAPitch'
	language = 'en'

	data = {
		'outputs': None,
		'version': '3.0',
		'model': model_path,
		'modelType': model_type,
		'base_lang': language,
		'pluginsContext': '{}',
	}

	embs = base_speaker_emb

	print('Loading voice model...')
	try:
		json_data = xvaserver.loadModel(data)
		current_voice_model = voice_model_name

		with open(model_path + '.json', 'r', encoding='utf-8') as f:
		    voice_model_json = json.load(f)
		embs = voice_model_json['games'][0]['base_speaker_emb']
	except requests.exceptions.RequestException as err:
		print(f'FAILED to load voice model: {err}')

	return embs


class LocalBlocksDemo(BlocksDemo):
	def predict(
		self,
		input_text,
		voice,
		lang,
		pacing,
		pitch,
		energy,
		anger,
		happy,
		sad,
		surprise,
		use_deepmoji
	):
		# grab only the first 1000 characters
		input_text = input_text[:1000]

		# load voice model if not the current model
		if (current_voice_model != voice):
			base_speaker_emb = load_model(voice)

		model_type = 'xVAPitch'
		pace = pacing if pacing else 1.0
		save_path = '/tmp/xvapitch_audio_sample.wav'
		language = lang
		use_sr = 0
		use_cleanup = 0

		pluginsContext = {}
		pluginsContext["mantella_settings"] = {
			"emAngry": (anger if anger > 0 else 0),
			"emHappy": (happy if happy > 0 else 0),
			"emSad": (sad if sad > 0 else 0),
			"emSurprise": (surprise if surprise > 0 else 0),
			"run_model": use_deepmoji
		}


		data = {
			'pluginsContext': json.dumps(pluginsContext),
			'modelType': model_type,
			# pad with whitespaces as a workaround to avoid cutoffs
			'sequence': input_text.center(len(input_text) + 2, ' '),
			'pace': pace,
			'outfile': save_path,
			'vocoder': 'n/a',
			'base_lang': language,
			'base_emb': base_speaker_emb,
			'useSR': use_sr,
			'useCleanup': use_cleanup,
		}

		print('Synthesizing...')
		try:
			json_data = xvaserver.synthesize(data)
			# response = requests.post('http://0.0.0.0:8008/synthesize', json=data, timeout=60)
			# response.raise_for_status()  # If the response contains an HTTP error status code, raise an exception
			# json_data = json.loads(response.text)
		except requests.exceptions.RequestException as err:
			print('FAILED to synthesize: {err}')
			save_path = ''
			response = {'text': '{"message": "Failed"}'}
			json_data = {
				'arpabet': ['Failed'],
				'durations': [0],
				'em_anger': anger,
				'em_happy': happy,
				'em_sad': sad,
				'em_surprise': surprise,
			}

		# print('server.log contents:')
		# with open('resources/app/server.log', 'r') as f:
		# 	print(f.read())

		arpabet_html = '<h6>ARPAbet & Phoneme lengths</h6>'
		arpabet_symbols = json_data['arpabet'].split('|')
		utter_time = 0
		for symb_i in range(len(json_data['durations'])):
			# skip PAD symbol
			if (arpabet_symbols[symb_i] == '<PAD>'):
				continue

			length = float(json_data['durations'][symb_i])
			arpa_length = str(round(length/2, 1))
			arpabet_html += '<strong\
				class="arpabet"\
				style="padding: 0 '\
				+ str(arpa_length)\
				+'em"'\
				+f" title=\"{utter_time} + {length}\""\
				+'>'\
				+ arpabet_symbols[symb_i]\
				+ '</strong> '
			utter_time += round(length, 1)

		return [
			save_path,
			arpabet_html,
			round(json_data['em_angry'][0], 2),
			round(json_data['em_happy'][0], 2),
			round(json_data['em_sad'][0], 2),
			round(json_data['em_surprise'][0], 2),
			json_data
		]

if __name__ == "__main__":
	print('running custom Gradio interface')
	demo = LocalBlocksDemo()
	demo.block.launch()