pvanand commited on
Commit
7e87316
·
verified ·
1 Parent(s): 35fe877

Create aib4.py

Browse files
Files changed (1) hide show
  1. aib4.py +450 -0
aib4.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+ import base64
4
+
5
+ class BhashiniClient:
6
+ """
7
+ A client for interacting with Bhashini's ASR, NMT, and TTS services.
8
+
9
+ Methods:
10
+ list_available_languages(task_type): Lists available languages for a given task.
11
+ get_supported_voices(source_language): Gets supported genders for TTS in a language.
12
+ asr(audio_content, source_language, audio_format='wav', sampling_rate=16000): Performs ASR.
13
+ translate(text, source_language, target_language): Translates text from source to target language.
14
+ tts(text, source_language, gender='female', sampling_rate=8000): Performs TTS.
15
+ """
16
+
17
+ PIPELINE_CONFIG_ENDPOINT = "https://meity-auth.ulcacontrib.org/ulca/apis/v0/model/getModelsPipeline"
18
+ INFERENCE_ENDPOINT = "https://dhruva-api.bhashini.gov.in/services/inference/pipeline"
19
+ PIPELINE_ID = "64392f96daac500b55c543cd"
20
+
21
+ def __init__(self, user_id, api_key, pipeline_id = PIPELINE_ID):
22
+ """
23
+ Initializes the BhashiniClient with user credentials and pipeline ID.
24
+
25
+ Args:
26
+ user_id (str): Your user ID.
27
+ api_key (str): Your ULCA API key.
28
+ pipeline_id (str): The pipeline ID.
29
+
30
+ Raises:
31
+ Exception: If the pipeline configuration retrieval fails.
32
+ """
33
+ self.user_id = user_id
34
+ self.api_key = api_key
35
+ self.pipeline_id = pipeline_id
36
+ self.headers = {
37
+ "Content-Type": "application/json",
38
+ "userID": self.user_id,
39
+ "ulcaApiKey": self.api_key
40
+ }
41
+ self.config = self._get_pipeline_config()
42
+ self.pipeline_data = self._parse_pipeline_config(self.config)
43
+ self.inference_api_key = self.pipeline_data['inferenceApiKey']
44
+
45
+ def _get_pipeline_config(self):
46
+ """
47
+ Retrieves the pipeline configuration.
48
+
49
+ Returns:
50
+ dict: The pipeline configuration.
51
+
52
+ Raises:
53
+ Exception: If the request fails.
54
+ """
55
+ payload = {
56
+ "pipelineTasks": [
57
+ {"taskType": "asr"},
58
+ {"taskType": "translation"},
59
+ {"taskType": "tts"}
60
+ ],
61
+ "pipelineRequestConfig": {
62
+ "pipelineId": self.pipeline_id
63
+ }
64
+ }
65
+ response = requests.post(
66
+ self.PIPELINE_CONFIG_ENDPOINT,
67
+ headers=self.headers,
68
+ data=json.dumps(payload)
69
+ )
70
+ response.raise_for_status()
71
+ return response.json()
72
+
73
+ def _parse_pipeline_config(self, config):
74
+ """
75
+ Parses the pipeline configuration and extracts necessary information.
76
+
77
+ Args:
78
+ config (dict): The pipeline configuration.
79
+
80
+ Returns:
81
+ dict: Parsed pipeline data.
82
+ """
83
+ inference_api_key = config['pipelineInferenceAPIEndPoint']['inferenceApiKey']['value']
84
+ callback_url = config['pipelineInferenceAPIEndPoint']['callbackUrl']
85
+ pipeline_data = {
86
+ 'asr': {},
87
+ 'tts': {},
88
+ 'translation': {},
89
+ 'inferenceApiKey': inference_api_key,
90
+ 'callbackUrl': callback_url
91
+ }
92
+
93
+ for pipeline in config['pipelineResponseConfig']:
94
+ task_type = pipeline['taskType']
95
+ if task_type in ['asr', 'translation', 'tts']:
96
+ for language_config in pipeline['config']:
97
+ source_language = language_config['language']['sourceLanguage']
98
+
99
+ if task_type != 'translation':
100
+ if source_language not in pipeline_data[task_type]:
101
+ pipeline_data[task_type][source_language] = []
102
+
103
+ language_info = {
104
+ 'serviceId': language_config['serviceId'],
105
+ 'sourceScriptCode': language_config['language'].get('sourceScriptCode')
106
+ }
107
+
108
+ if task_type == 'tts':
109
+ language_info['supportedVoices'] = language_config.get('supportedVoices', [])
110
+
111
+ pipeline_data[task_type][source_language].append(language_info)
112
+ else:
113
+ target_language = language_config['language']['targetLanguage']
114
+ if source_language not in pipeline_data[task_type]:
115
+ pipeline_data[task_type][source_language] = {}
116
+
117
+ if target_language not in pipeline_data[task_type][source_language]:
118
+ pipeline_data[task_type][source_language][target_language] = []
119
+
120
+ language_info = {
121
+ 'serviceId': language_config['serviceId'],
122
+ 'sourceScriptCode': language_config['language'].get('sourceScriptCode'),
123
+ 'targetScriptCode': language_config['language'].get('targetScriptCode')
124
+ }
125
+
126
+ pipeline_data[task_type][source_language][target_language].append(language_info)
127
+
128
+ return pipeline_data
129
+
130
+ def list_available_languages(self, task_type):
131
+ """
132
+ Lists the available languages for the specified task.
133
+
134
+ Args:
135
+ task_type (str): The task type ('asr', 'translation', or 'tts').
136
+
137
+ Returns:
138
+ list or dict: A list of available languages, or a dictionary for translation.
139
+
140
+ Raises:
141
+ ValueError: If an invalid task type is provided.
142
+
143
+ Usage Example:
144
+ client = BhashiniClient(user_id, api_key, pipeline_id)
145
+ asr_languages = client.list_available_languages('asr')
146
+ print("Available ASR Languages:", asr_languages)
147
+
148
+ translation_languages = client.list_available_languages('translation')
149
+ print("Available Translation Languages:", translation_languages)
150
+ """
151
+ if task_type not in ['asr', 'translation', 'tts']:
152
+ raise ValueError("Invalid task type. Choose from 'asr', 'translation', or 'tts'.")
153
+
154
+ if task_type == 'translation':
155
+ languages = {}
156
+ for src_lang in self.pipeline_data['translation']:
157
+ languages[src_lang] = list(self.pipeline_data['translation'][src_lang].keys())
158
+ return languages
159
+ else:
160
+ return list(self.pipeline_data[task_type].keys())
161
+
162
+ def get_supported_voices(self, source_language):
163
+ """
164
+ Returns the supported genders for TTS in the specified language.
165
+
166
+ Args:
167
+ source_language (str): The language code (e.g., 'hi' for Hindi).
168
+
169
+ Returns:
170
+ list: A list of supported genders (e.g., ['male', 'female']).
171
+
172
+ Raises:
173
+ ValueError: If TTS is not supported for the language.
174
+
175
+ Usage Example:
176
+ client = BhashiniClient(user_id, api_key, pipeline_id)
177
+ voices = client.get_supported_voices('hi')
178
+ print("Supported voices for Hindi TTS:", voices)
179
+ """
180
+ if source_language not in self.pipeline_data['tts']:
181
+ available_languages = ', '.join(self.list_available_languages('tts'))
182
+ raise ValueError(
183
+ f"TTS not supported for language '{source_language}'. "
184
+ f"Available languages: {available_languages}"
185
+ )
186
+
187
+ service_info = self.pipeline_data['tts'][source_language][0]
188
+ supported_voices = service_info.get('supportedVoices', [])
189
+ return supported_voices
190
+
191
+
192
+ def asr(self, audio_content, source_language, audio_format='wav', sampling_rate=16000):
193
+ """
194
+ Performs Automatic Speech Recognition on the provided audio content.
195
+
196
+ Args:
197
+ audio_content (bytes): The audio content in bytes.
198
+ source_language (str): The language code of the audio (e.g., 'hi' for Hindi).
199
+ audio_format (str): supported formats of audio content: ('wav', 'mp3', 'flac', 'ogg'.)
200
+ sampling_rate (int): The sampling rate of the audio in Hz.
201
+
202
+ Returns:
203
+ dict: The ASR response from the API.
204
+
205
+ Raises:
206
+ ValueError: If the language is not supported.
207
+ Exception: If the API request fails.
208
+
209
+ Usage Example:
210
+ client = BhashiniClient(user_id, api_key, pipeline_id)
211
+ with open('audio.wav', 'rb') as f:
212
+ audio_content = f.read()
213
+ asr_result = client.asr(audio_content, source_language='hi', audio_format='wav')
214
+ print("ASR Result:", asr_result)
215
+ """
216
+ if source_language not in self.pipeline_data['asr']:
217
+ available_languages = ', '.join(self.list_available_languages('asr'))
218
+ raise ValueError(
219
+ f"ASR not supported for language '{source_language}'. "
220
+ f"Available languages: {available_languages}"
221
+ )
222
+
223
+ service_info = self.pipeline_data['asr'][source_language][0]
224
+ service_id = service_info['serviceId']
225
+
226
+ payload = {
227
+ "pipelineTasks": [
228
+ {
229
+ "taskType": "asr",
230
+ "config": {
231
+ "language": {
232
+ "sourceLanguage": source_language
233
+ },
234
+ "serviceId": service_id,
235
+ "audioFormat": audio_format,
236
+ "samplingRate": sampling_rate
237
+ }
238
+ }
239
+ ],
240
+ "inputData": {
241
+ "audio": [
242
+ {
243
+ "audioContent": base64.b64encode(audio_content).decode('utf-8')
244
+ }
245
+ ]
246
+ }
247
+ }
248
+
249
+ headers = {
250
+ 'Accept': '*/*',
251
+ 'Authorization': self.inference_api_key,
252
+ 'Content-Type': 'application/json'
253
+ }
254
+
255
+ response = requests.post(
256
+ self.INFERENCE_ENDPOINT,
257
+ headers=headers,
258
+ data=json.dumps(payload)
259
+ )
260
+
261
+ self._handle_response_errors(response)
262
+ return response.json()
263
+
264
+ def translate(self, text, source_language, target_language):
265
+ """
266
+ Translates the provided text from the source language to the target language.
267
+
268
+ Args:
269
+ text (str): The text to translate.
270
+ source_language (str): The source language code.
271
+ target_language (str): The target language code.
272
+
273
+ Returns:
274
+ dict: The translation response from the API.
275
+
276
+ Raises:
277
+ ValueError: If the language pair is not supported.
278
+ Exception: If the API request fails.
279
+
280
+ Usage Example:
281
+ client = BhashiniClient(user_id, api_key, pipeline_id)
282
+ translation_result = client.translate(
283
+ 'मेरा नाम विहिर है।',
284
+ source_language='hi',
285
+ target_language='gu'
286
+ )
287
+ print("Translation Result:", translation_result)
288
+ """
289
+ if source_language not in self.pipeline_data['translation']:
290
+ available_languages = ', '.join(self.list_available_languages('translation').keys())
291
+ raise ValueError(
292
+ f"Translation not supported from language '{source_language}'. "
293
+ f"Available source languages: {available_languages}"
294
+ )
295
+
296
+ if target_language not in self.pipeline_data['translation'][source_language]:
297
+ available_targets = ', '.join(self.pipeline_data['translation'][source_language].keys())
298
+ raise ValueError(
299
+ f"Translation from '{source_language}' to '{target_language}' not supported. "
300
+ f"Available target languages for '{source_language}': {available_targets}"
301
+ )
302
+
303
+ service_info = self.pipeline_data['translation'][source_language][target_language][0]
304
+ service_id = service_info['serviceId']
305
+
306
+ payload = {
307
+ "pipelineTasks": [
308
+ {
309
+ "taskType": "translation",
310
+ "config": {
311
+ "language": {
312
+ "sourceLanguage": source_language,
313
+ "targetLanguage": target_language
314
+ },
315
+ "serviceId": service_id
316
+ }
317
+ }
318
+ ],
319
+ "inputData": {
320
+ "input": [
321
+ {
322
+ "source": text
323
+ }
324
+ ]
325
+ }
326
+ }
327
+
328
+ headers = {
329
+ 'Accept': '*/*',
330
+ 'Authorization': self.inference_api_key,
331
+ 'Content-Type': 'application/json'
332
+ }
333
+
334
+ response = requests.post(
335
+ self.INFERENCE_ENDPOINT,
336
+ headers=headers,
337
+ data=json.dumps(payload)
338
+ )
339
+
340
+ self._handle_response_errors(response)
341
+ return response.json()
342
+
343
+ def tts(self, text, source_language, gender='female', sampling_rate=8000):
344
+ """
345
+ Converts the provided text to speech in the specified language.
346
+
347
+ Args:
348
+ text (str): The text to convert to speech.
349
+ source_language (str): The language code of the text.
350
+ gender (str): The desired voice gender ('male' or 'female').
351
+ sampling_rate (int): The sampling rate in Hz.
352
+
353
+ Returns:
354
+ dict: The TTS response from the API.
355
+
356
+ Raises:
357
+ ValueError: If the language or gender is not supported.
358
+ Exception: If the API request fails.
359
+
360
+ Usage Example:
361
+ client = BhashiniClient(user_id, api_key, pipeline_id)
362
+ tts_result = client.tts(
363
+ 'હેલો વર્લ્ડ',
364
+ source_language='gu',
365
+ gender='female'
366
+ )
367
+ # Save the audio output
368
+ audio_base64 = tts_result['pipelineResponse'][0]['audio'][0]['audioContent']
369
+ audio_data = base64.b64decode(audio_base64)
370
+ with open('output_audio.wav', 'wb') as f:
371
+ f.write(audio_data)
372
+ """
373
+ if source_language not in self.pipeline_data['tts']:
374
+ available_languages = ', '.join(self.list_available_languages('tts'))
375
+ raise ValueError(
376
+ f"TTS not supported for language '{source_language}'. "
377
+ f"Available languages: {available_languages}"
378
+ )
379
+
380
+ service_info = self.pipeline_data['tts'][source_language][0]
381
+ service_id = service_info['serviceId']
382
+ supported_voices = service_info.get('supportedVoices', [])
383
+
384
+ if gender not in ['male', 'female']:
385
+ raise ValueError("Gender must be 'male' or 'female'.")
386
+
387
+ if supported_voices and gender not in supported_voices:
388
+ available_genders = ', '.join(supported_voices)
389
+ raise ValueError(
390
+ f"Gender '{gender}' not supported for language '{source_language}'. "
391
+ f"Available genders: {available_genders}"
392
+ )
393
+
394
+ payload = {
395
+ "pipelineTasks": [
396
+ {
397
+ "taskType": "tts",
398
+ "config": {
399
+ "language": {
400
+ "sourceLanguage": source_language
401
+ },
402
+ "serviceId": service_id,
403
+ "gender": gender,
404
+ "samplingRate": sampling_rate
405
+ }
406
+ }
407
+ ],
408
+ "inputData": {
409
+ "input": [
410
+ {
411
+ "source": text
412
+ }
413
+ ]
414
+ }
415
+ }
416
+
417
+ headers = {
418
+ 'Accept': '*/*',
419
+ 'Authorization': self.inference_api_key,
420
+ 'Content-Type': 'application/json'
421
+ }
422
+
423
+ response = requests.post(
424
+ self.INFERENCE_ENDPOINT,
425
+ headers=headers,
426
+ data=json.dumps(payload)
427
+ )
428
+
429
+ self._handle_response_errors(response)
430
+ return response.json()
431
+
432
+ def _handle_response_errors(self, response):
433
+ """
434
+ Handles errors in the response.
435
+
436
+ Args:
437
+ response (requests.Response): The response object.
438
+
439
+ Raises:
440
+ Exception: If an HTTP error occurs.
441
+ """
442
+ try:
443
+ response.raise_for_status()
444
+ except requests.HTTPError as http_err:
445
+ try:
446
+ error_info = response.json()
447
+ error_message = error_info.get('message', 'An error occurred.')
448
+ except json.JSONDecodeError:
449
+ error_message = response.text
450
+ raise Exception(f"HTTP error occurred: {error_message}") from http_err