general_chat / ai4b.py
pvanand's picture
Rename aib4.py to ai4b.py
00d821e verified
raw
history blame
16.7 kB
import requests
import json
import base64
class BhashiniClient:
"""
A client for interacting with Bhashini's ASR, NMT, and TTS services.
Methods:
list_available_languages(task_type): Lists available languages for a given task.
get_supported_voices(source_language): Gets supported genders for TTS in a language.
asr(audio_content, source_language, audio_format='wav', sampling_rate=16000): Performs ASR.
translate(text, source_language, target_language): Translates text from source to target language.
tts(text, source_language, gender='female', sampling_rate=8000): Performs TTS.
"""
PIPELINE_CONFIG_ENDPOINT = "https://meity-auth.ulcacontrib.org/ulca/apis/v0/model/getModelsPipeline"
INFERENCE_ENDPOINT = "https://dhruva-api.bhashini.gov.in/services/inference/pipeline"
PIPELINE_ID = "64392f96daac500b55c543cd"
def __init__(self, user_id, api_key, pipeline_id = PIPELINE_ID):
"""
Initializes the BhashiniClient with user credentials and pipeline ID.
Args:
user_id (str): Your user ID.
api_key (str): Your ULCA API key.
pipeline_id (str): The pipeline ID.
Raises:
Exception: If the pipeline configuration retrieval fails.
"""
self.user_id = user_id
self.api_key = api_key
self.pipeline_id = pipeline_id
self.headers = {
"Content-Type": "application/json",
"userID": self.user_id,
"ulcaApiKey": self.api_key
}
self.config = self._get_pipeline_config()
self.pipeline_data = self._parse_pipeline_config(self.config)
self.inference_api_key = self.pipeline_data['inferenceApiKey']
def _get_pipeline_config(self):
"""
Retrieves the pipeline configuration.
Returns:
dict: The pipeline configuration.
Raises:
Exception: If the request fails.
"""
payload = {
"pipelineTasks": [
{"taskType": "asr"},
{"taskType": "translation"},
{"taskType": "tts"}
],
"pipelineRequestConfig": {
"pipelineId": self.pipeline_id
}
}
response = requests.post(
self.PIPELINE_CONFIG_ENDPOINT,
headers=self.headers,
data=json.dumps(payload)
)
response.raise_for_status()
return response.json()
def _parse_pipeline_config(self, config):
"""
Parses the pipeline configuration and extracts necessary information.
Args:
config (dict): The pipeline configuration.
Returns:
dict: Parsed pipeline data.
"""
inference_api_key = config['pipelineInferenceAPIEndPoint']['inferenceApiKey']['value']
callback_url = config['pipelineInferenceAPIEndPoint']['callbackUrl']
pipeline_data = {
'asr': {},
'tts': {},
'translation': {},
'inferenceApiKey': inference_api_key,
'callbackUrl': callback_url
}
for pipeline in config['pipelineResponseConfig']:
task_type = pipeline['taskType']
if task_type in ['asr', 'translation', 'tts']:
for language_config in pipeline['config']:
source_language = language_config['language']['sourceLanguage']
if task_type != 'translation':
if source_language not in pipeline_data[task_type]:
pipeline_data[task_type][source_language] = []
language_info = {
'serviceId': language_config['serviceId'],
'sourceScriptCode': language_config['language'].get('sourceScriptCode')
}
if task_type == 'tts':
language_info['supportedVoices'] = language_config.get('supportedVoices', [])
pipeline_data[task_type][source_language].append(language_info)
else:
target_language = language_config['language']['targetLanguage']
if source_language not in pipeline_data[task_type]:
pipeline_data[task_type][source_language] = {}
if target_language not in pipeline_data[task_type][source_language]:
pipeline_data[task_type][source_language][target_language] = []
language_info = {
'serviceId': language_config['serviceId'],
'sourceScriptCode': language_config['language'].get('sourceScriptCode'),
'targetScriptCode': language_config['language'].get('targetScriptCode')
}
pipeline_data[task_type][source_language][target_language].append(language_info)
return pipeline_data
def list_available_languages(self, task_type):
"""
Lists the available languages for the specified task.
Args:
task_type (str): The task type ('asr', 'translation', or 'tts').
Returns:
list or dict: A list of available languages, or a dictionary for translation.
Raises:
ValueError: If an invalid task type is provided.
Usage Example:
client = BhashiniClient(user_id, api_key, pipeline_id)
asr_languages = client.list_available_languages('asr')
print("Available ASR Languages:", asr_languages)
translation_languages = client.list_available_languages('translation')
print("Available Translation Languages:", translation_languages)
"""
if task_type not in ['asr', 'translation', 'tts']:
raise ValueError("Invalid task type. Choose from 'asr', 'translation', or 'tts'.")
if task_type == 'translation':
languages = {}
for src_lang in self.pipeline_data['translation']:
languages[src_lang] = list(self.pipeline_data['translation'][src_lang].keys())
return languages
else:
return list(self.pipeline_data[task_type].keys())
def get_supported_voices(self, source_language):
"""
Returns the supported genders for TTS in the specified language.
Args:
source_language (str): The language code (e.g., 'hi' for Hindi).
Returns:
list: A list of supported genders (e.g., ['male', 'female']).
Raises:
ValueError: If TTS is not supported for the language.
Usage Example:
client = BhashiniClient(user_id, api_key, pipeline_id)
voices = client.get_supported_voices('hi')
print("Supported voices for Hindi TTS:", voices)
"""
if source_language not in self.pipeline_data['tts']:
available_languages = ', '.join(self.list_available_languages('tts'))
raise ValueError(
f"TTS not supported for language '{source_language}'. "
f"Available languages: {available_languages}"
)
service_info = self.pipeline_data['tts'][source_language][0]
supported_voices = service_info.get('supportedVoices', [])
return supported_voices
def asr(self, audio_content, source_language, audio_format='wav', sampling_rate=16000):
"""
Performs Automatic Speech Recognition on the provided audio content.
Args:
audio_content (bytes): The audio content in bytes.
source_language (str): The language code of the audio (e.g., 'hi' for Hindi).
audio_format (str): supported formats of audio content: ('wav', 'mp3', 'flac', 'ogg'.)
sampling_rate (int): The sampling rate of the audio in Hz.
Returns:
dict: The ASR response from the API.
Raises:
ValueError: If the language is not supported.
Exception: If the API request fails.
Usage Example:
client = BhashiniClient(user_id, api_key, pipeline_id)
with open('audio.wav', 'rb') as f:
audio_content = f.read()
asr_result = client.asr(audio_content, source_language='hi', audio_format='wav')
print("ASR Result:", asr_result)
"""
if source_language not in self.pipeline_data['asr']:
available_languages = ', '.join(self.list_available_languages('asr'))
raise ValueError(
f"ASR not supported for language '{source_language}'. "
f"Available languages: {available_languages}"
)
service_info = self.pipeline_data['asr'][source_language][0]
service_id = service_info['serviceId']
payload = {
"pipelineTasks": [
{
"taskType": "asr",
"config": {
"language": {
"sourceLanguage": source_language
},
"serviceId": service_id,
"audioFormat": audio_format,
"samplingRate": sampling_rate
}
}
],
"inputData": {
"audio": [
{
"audioContent": base64.b64encode(audio_content).decode('utf-8')
}
]
}
}
headers = {
'Accept': '*/*',
'Authorization': self.inference_api_key,
'Content-Type': 'application/json'
}
response = requests.post(
self.INFERENCE_ENDPOINT,
headers=headers,
data=json.dumps(payload)
)
self._handle_response_errors(response)
return response.json()
def translate(self, text, source_language, target_language):
"""
Translates the provided text from the source language to the target language.
Args:
text (str): The text to translate.
source_language (str): The source language code.
target_language (str): The target language code.
Returns:
dict: The translation response from the API.
Raises:
ValueError: If the language pair is not supported.
Exception: If the API request fails.
Usage Example:
client = BhashiniClient(user_id, api_key, pipeline_id)
translation_result = client.translate(
'मेरा नाम विहिर है।',
source_language='hi',
target_language='gu'
)
print("Translation Result:", translation_result)
"""
if source_language not in self.pipeline_data['translation']:
available_languages = ', '.join(self.list_available_languages('translation').keys())
raise ValueError(
f"Translation not supported from language '{source_language}'. "
f"Available source languages: {available_languages}"
)
if target_language not in self.pipeline_data['translation'][source_language]:
available_targets = ', '.join(self.pipeline_data['translation'][source_language].keys())
raise ValueError(
f"Translation from '{source_language}' to '{target_language}' not supported. "
f"Available target languages for '{source_language}': {available_targets}"
)
service_info = self.pipeline_data['translation'][source_language][target_language][0]
service_id = service_info['serviceId']
payload = {
"pipelineTasks": [
{
"taskType": "translation",
"config": {
"language": {
"sourceLanguage": source_language,
"targetLanguage": target_language
},
"serviceId": service_id
}
}
],
"inputData": {
"input": [
{
"source": text
}
]
}
}
headers = {
'Accept': '*/*',
'Authorization': self.inference_api_key,
'Content-Type': 'application/json'
}
response = requests.post(
self.INFERENCE_ENDPOINT,
headers=headers,
data=json.dumps(payload)
)
self._handle_response_errors(response)
return response.json()
def tts(self, text, source_language, gender='female', sampling_rate=8000):
"""
Converts the provided text to speech in the specified language.
Args:
text (str): The text to convert to speech.
source_language (str): The language code of the text.
gender (str): The desired voice gender ('male' or 'female').
sampling_rate (int): The sampling rate in Hz.
Returns:
dict: The TTS response from the API.
Raises:
ValueError: If the language or gender is not supported.
Exception: If the API request fails.
Usage Example:
client = BhashiniClient(user_id, api_key, pipeline_id)
tts_result = client.tts(
'હેલો વર્લ્ડ',
source_language='gu',
gender='female'
)
# Save the audio output
audio_base64 = tts_result['pipelineResponse'][0]['audio'][0]['audioContent']
audio_data = base64.b64decode(audio_base64)
with open('output_audio.wav', 'wb') as f:
f.write(audio_data)
"""
if source_language not in self.pipeline_data['tts']:
available_languages = ', '.join(self.list_available_languages('tts'))
raise ValueError(
f"TTS not supported for language '{source_language}'. "
f"Available languages: {available_languages}"
)
service_info = self.pipeline_data['tts'][source_language][0]
service_id = service_info['serviceId']
supported_voices = service_info.get('supportedVoices', [])
if gender not in ['male', 'female']:
raise ValueError("Gender must be 'male' or 'female'.")
if supported_voices and gender not in supported_voices:
available_genders = ', '.join(supported_voices)
raise ValueError(
f"Gender '{gender}' not supported for language '{source_language}'. "
f"Available genders: {available_genders}"
)
payload = {
"pipelineTasks": [
{
"taskType": "tts",
"config": {
"language": {
"sourceLanguage": source_language
},
"serviceId": service_id,
"gender": gender,
"samplingRate": sampling_rate
}
}
],
"inputData": {
"input": [
{
"source": text
}
]
}
}
headers = {
'Accept': '*/*',
'Authorization': self.inference_api_key,
'Content-Type': 'application/json'
}
response = requests.post(
self.INFERENCE_ENDPOINT,
headers=headers,
data=json.dumps(payload)
)
self._handle_response_errors(response)
return response.json()
def _handle_response_errors(self, response):
"""
Handles errors in the response.
Args:
response (requests.Response): The response object.
Raises:
Exception: If an HTTP error occurs.
"""
try:
response.raise_for_status()
except requests.HTTPError as http_err:
try:
error_info = response.json()
error_message = error_info.get('message', 'An error occurred.')
except json.JSONDecodeError:
error_message = response.text
raise Exception(f"HTTP error occurred: {error_message}") from http_err