{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "f8bdd950-1b95-4088-890a-94417292f6e1", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[nltk_data] Downloading package punkt to /home/gorkem/nltk_data...\n", "[nltk_data] Package punkt is already up-to-date!\n", "2023-10-13 00:33:39.399490: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading if not downloaded Coqui XTTS V1\n", " > tts_models/multilingual/multi-dataset/xtts_v1 is already downloaded.\n", " > Using model: xtts\n", "XTTS downloaded\n", "Loading XTTS\n", "[2023-10-13 00:34:12,573] [INFO] [logging.py:93:log_dist] [Rank -1] DeepSpeed info: version=0.8.3+f1e4fb0b, git-hash=f1e4fb0b, git-branch=HEAD\n", "[2023-10-13 00:34:12,587] [WARNING] [config_utils.py:75:_process_deprecated_field] Config parameter replace_method is deprecated. This parameter is no longer needed, please remove from your call to DeepSpeed-inference\n", "[2023-10-13 00:34:12,589] [WARNING] [config_utils.py:75:_process_deprecated_field] Config parameter mp_size is deprecated use tensor_parallel.tp_size instead\n", "[2023-10-13 00:34:12,590] [INFO] [logging.py:93:log_dist] [Rank -1] quantize_bits = 8 mlp_extra_grouping = False, quantize_groups = 1\n", "[2023-10-13 00:34:12,854] [INFO] [logging.py:93:log_dist] [Rank -1] DeepSpeed-Inference config: {'layer_id': 0, 'hidden_size': 1024, 'intermediate_size': 4096, 'heads': 16, 'num_hidden_layers': -1, 'fp16': False, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-05, 'mp_size': 1, 'q_int8': False, 'scale_attention': True, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'rotary_dim': -1, 'rotate_half': False, 'rotate_every_two': True, 'return_tuple': True, 'mlp_after_attn': True, 'mlp_act_func_type': , 'specialized_mode': False, 'training_mp_size': 1, 'bigscience_bloom': False, 'max_out_tokens': 1024, 'scale_attn_by_inverse_layer_idx': False, 'enable_qkv_quantization': False, 'use_mup': False, 'return_single_tuple': False}\n", "Done loading TTS\n", "Loaded as API: https://sanchit-gandhi-whisper-jax.hf.space/ ✔\n" ] } ], "source": [ "from __future__ import annotations\n", "\n", "import os\n", "# By using XTTS you agree to CPML license https://coqui.ai/cpml\n", "os.environ[\"COQUI_TOS_AGREED\"] = \"1\"\n", "\n", "import gradio as gr\n", "import numpy as np\n", "import torch\n", "import nltk # we'll use this to split into sentences\n", "nltk.download('punkt')\n", "import uuid\n", "\n", "import librosa\n", "import torchaudio\n", "from TTS.api import TTS\n", "from TTS.tts.configs.xtts_config import XttsConfig\n", "from TTS.tts.models.xtts import Xtts\n", "from TTS.utils.generic_utils import get_user_data_dir\n", "\n", "# This will trigger downloading model\n", "print(\"Downloading if not downloaded Coqui XTTS V1\")\n", "tts = TTS(\"tts_models/multilingual/multi-dataset/xtts_v1\")\n", "del tts\n", "print(\"XTTS downloaded\")\n", "\n", "print(\"Loading XTTS\")\n", "#Below will use model directly for inference\n", "model_path = os.path.join(get_user_data_dir(\"tts\"), \"tts_models--multilingual--multi-dataset--xtts_v1\")\n", "config = XttsConfig()\n", "config.load_json(os.path.join(model_path, \"config.json\"))\n", "model = Xtts.init_from_config(config)\n", "model.load_checkpoint(\n", " config,\n", " checkpoint_path=os.path.join(model_path, \"model.pth\"),\n", " vocab_path=os.path.join(model_path, \"vocab.json\"),\n", " eval=True,\n", " use_deepspeed=True\n", ")\n", "model.cuda()\n", "print(\"Done loading TTS\")\n", "\n", "\n", "title = \"Voice chat with Mistral 7B Instruct\"\n", "\n", "DESCRIPTION = \"\"\"# Voice chat with Mistral 7B Instruct\"\"\"\n", "css = \"\"\".toast-wrap { display: none !important } \"\"\"\n", "\n", "from huggingface_hub import HfApi\n", "HF_TOKEN = os.environ.get(\"HF_TOKEN\")\n", "# will use api to restart space on a unrecoverable error\n", "api = HfApi(token=HF_TOKEN)\n", "\n", "repo_id = \"ylacombe/voice-chat-with-lama\"\n", "\n", "system_message = \"\\nYou are a helpful, respectful and honest assistant. Your answers are short, ideally a few words long, if it is possible. Always answer as helpfully as possible, while being safe.\\n\\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\"\n", "temperature = 0.9\n", "top_p = 0.6\n", "repetition_penalty = 1.2\n", "\n", "\n", "import gradio as gr\n", "import os\n", "import time\n", "\n", "import gradio as gr\n", "from transformers import pipeline\n", "import numpy as np\n", "\n", "from gradio_client import Client\n", "from huggingface_hub import InferenceClient\n", "\n", "\n", "# This client is down\n", "#whisper_client = Client(\"https://sanchit-gandhi-whisper-large-v2.hf.space/\")\n", "# Replacement whisper client, it may be time limited\n", "whisper_client = Client(\"https://sanchit-gandhi-whisper-jax.hf.space\")\n", "text_client = InferenceClient(\n", " \"mistralai/Mistral-7B-Instruct-v0.1\"\n", ")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "d8687cd2-e989-4db9-b16a-04ad9460e6f1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7861\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "ERROR: Too many requests on mistral client\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Traceback (most recent call last):\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/gradio/queueing.py\", line 388, in call_prediction\n", " output = await route_utils.call_process_api(\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/gradio/route_utils.py\", line 219, in call_process_api\n", " output = await app.get_blocks().process_api(\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/gradio/blocks.py\", line 1437, in process_api\n", " result = await self.call_function(\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/gradio/blocks.py\", line 1123, in call_function\n", " prediction = await utils.async_iteration(iterator)\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/gradio/utils.py\", line 503, in async_iteration\n", " return await iterator.__anext__()\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/gradio/utils.py\", line 496, in __anext__\n", " return await anyio.to_thread.run_sync(\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/anyio/to_thread.py\", line 31, in run_sync\n", " return await get_asynclib().run_sync_in_worker_thread(\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/anyio/_backends/_asyncio.py\", line 937, in run_sync_in_worker_thread\n", " return await future\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/anyio/_backends/_asyncio.py\", line 867, in run\n", " result = context.run(func, *args)\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/gradio/utils.py\", line 479, in run_sync_iterator_async\n", " return next(iterator)\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/gradio/utils.py\", line 629, in gen_wrapper\n", " yield from f(*args, **kwargs)\n", " File \"/tmp/ipykernel_8679/550220560.py\", line 134, in generate_speech\n", " text_to_generate = history[-1][1]\n", "TypeError: 'NoneType' object is not subscriptable\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "ERROR: Too many requests on mistral client\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Traceback (most recent call last):\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/gradio/queueing.py\", line 388, in call_prediction\n", " output = await route_utils.call_process_api(\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/gradio/route_utils.py\", line 219, in call_process_api\n", " output = await app.get_blocks().process_api(\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/gradio/blocks.py\", line 1437, in process_api\n", " result = await self.call_function(\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/gradio/blocks.py\", line 1123, in call_function\n", " prediction = await utils.async_iteration(iterator)\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/gradio/utils.py\", line 503, in async_iteration\n", " return await iterator.__anext__()\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/gradio/utils.py\", line 496, in __anext__\n", " return await anyio.to_thread.run_sync(\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/anyio/to_thread.py\", line 31, in run_sync\n", " return await get_asynclib().run_sync_in_worker_thread(\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/anyio/_backends/_asyncio.py\", line 937, in run_sync_in_worker_thread\n", " return await future\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/anyio/_backends/_asyncio.py\", line 867, in run\n", " result = context.run(func, *args)\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/gradio/utils.py\", line 479, in run_sync_iterator_async\n", " return next(iterator)\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/gradio/utils.py\", line 629, in gen_wrapper\n", " yield from f(*args, **kwargs)\n", " File \"/tmp/ipykernel_8679/550220560.py\", line 134, in generate_speech\n", " text_to_generate = history[-1][1]\n", "TypeError: 'NoneType' object is not subscriptable\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "ERROR: Too many requests on mistral client\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Traceback (most recent call last):\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/gradio/queueing.py\", line 388, in call_prediction\n", " output = await route_utils.call_process_api(\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/gradio/route_utils.py\", line 219, in call_process_api\n", " output = await app.get_blocks().process_api(\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/gradio/blocks.py\", line 1437, in process_api\n", " result = await self.call_function(\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/gradio/blocks.py\", line 1123, in call_function\n", " prediction = await utils.async_iteration(iterator)\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/gradio/utils.py\", line 503, in async_iteration\n", " return await iterator.__anext__()\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/gradio/utils.py\", line 496, in __anext__\n", " return await anyio.to_thread.run_sync(\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/anyio/to_thread.py\", line 31, in run_sync\n", " return await get_asynclib().run_sync_in_worker_thread(\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/anyio/_backends/_asyncio.py\", line 937, in run_sync_in_worker_thread\n", " return await future\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/anyio/_backends/_asyncio.py\", line 867, in run\n", " result = context.run(func, *args)\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/gradio/utils.py\", line 479, in run_sync_iterator_async\n", " return next(iterator)\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/gradio/utils.py\", line 629, in gen_wrapper\n", " yield from f(*args, **kwargs)\n", " File \"/tmp/ipykernel_8679/550220560.py\", line 134, in generate_speech\n", " text_to_generate = history[-1][1]\n", "TypeError: 'NoneType' object is not subscriptable\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "ERROR: Too many requests on mistral client\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Traceback (most recent call last):\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/gradio/queueing.py\", line 388, in call_prediction\n", " output = await route_utils.call_process_api(\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/gradio/route_utils.py\", line 219, in call_process_api\n", " output = await app.get_blocks().process_api(\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/gradio/blocks.py\", line 1437, in process_api\n", " result = await self.call_function(\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/gradio/blocks.py\", line 1123, in call_function\n", " prediction = await utils.async_iteration(iterator)\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/gradio/utils.py\", line 503, in async_iteration\n", " return await iterator.__anext__()\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/gradio/utils.py\", line 496, in __anext__\n", " return await anyio.to_thread.run_sync(\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/anyio/to_thread.py\", line 31, in run_sync\n", " return await get_asynclib().run_sync_in_worker_thread(\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/anyio/_backends/_asyncio.py\", line 937, in run_sync_in_worker_thread\n", " return await future\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/anyio/_backends/_asyncio.py\", line 867, in run\n", " result = context.run(func, *args)\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/gradio/utils.py\", line 479, in run_sync_iterator_async\n", " return next(iterator)\n", " File \"/home/gorkem/.local/lib/python3.10/site-packages/gradio/utils.py\", line 629, in gen_wrapper\n", " yield from f(*args, **kwargs)\n", " File \"/tmp/ipykernel_8679/550220560.py\", line 134, in generate_speech\n", " text_to_generate = history[-1][1]\n", "TypeError: 'NoneType' object is not subscriptable\n" ] } ], "source": [ "\n", "###### COQUI TTS FUNCTIONS ######\n", "def get_latents(speaker_wav):\n", " # create as function as we can populate here with voice cleanup/filtering\n", " gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav)\n", " return gpt_cond_latent, diffusion_conditioning, speaker_embedding\n", "\n", "\n", "def format_prompt(message, history):\n", " prompt = \"\"\n", " for user_prompt, bot_response in history:\n", " prompt += f\"[INST] {user_prompt} [/INST]\"\n", " prompt += f\" {bot_response} \"\n", " prompt += f\"[INST] {message} [/INST]\"\n", " return prompt\n", "\n", "def generate(\n", " prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,\n", "):\n", " temperature = float(temperature)\n", " if temperature < 1e-2:\n", " temperature = 1e-2\n", " top_p = float(top_p)\n", "\n", " generate_kwargs = dict(\n", " temperature=temperature,\n", " max_new_tokens=max_new_tokens,\n", " top_p=top_p,\n", " repetition_penalty=repetition_penalty,\n", " do_sample=True,\n", " seed=42,\n", " )\n", "\n", " formatted_prompt = format_prompt(prompt, history)\n", "\n", " try:\n", " stream = text_client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)\n", " output = \"\"\n", " for response in stream:\n", " output += response.token.text\n", " yield output\n", "\n", " except Exception as e:\n", " if \"Too Many Requests\" in str(e):\n", " print(\"ERROR: Too many requests on mistral client\")\n", " gr.Warning(\"Unfortunately Mistral is unable to process\")\n", " output = \"Unfortuanately I am not able to process your request now !\"\n", " else:\n", " print(\"Unhandled Exception: \", str(e))\n", " gr.Warning(\"Unfortunately Mistral is unable to process\")\n", " output = \"I do not know what happened but I could not understand you .\"\n", " \n", " return output\n", "\n", "\n", "def transcribe(wav_path):\n", " \n", " # get first element from whisper_jax and strip it to delete begin and end space\n", " return whisper_client.predict(\n", "\t\t\t\twav_path,\t# str (filepath or URL to file) in 'inputs' Audio component\n", "\t\t\t\t\"transcribe\",\t# str in 'Task' Radio component\n", " False, # return_timestamps=False for whisper-jax https://gist.github.com/sanchit-gandhi/781dd7003c5b201bfe16d28634c8d4cf#file-whisper_jax_endpoint-py\n", "\t\t\t\tapi_name=\"/predict\"\n", " )[0].strip()\n", " \n", "\n", "# Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text.\n", "\n", "\n", "def add_text(history, text):\n", " history = [] if history is None else history\n", " history = history + [(text, None)]\n", " return history, gr.update(value=\"\", interactive=False)\n", "\n", "\n", "def add_file(history, file):\n", " history = [] if history is None else history\n", " \n", " try:\n", " text = transcribe(\n", " file\n", " )\n", " print(\"Transcribed text:\",text)\n", " except Exception as e:\n", " print(str(e))\n", " gr.Warning(\"There was an issue with transcription, please try writing for now\")\n", " # Apply a null text on error\n", " text = \"Transcription seems failed, please tell me a joke about chickens\"\n", " \n", " history = history + [(text, None)]\n", " return history\n", "\n", "\n", "\n", "def bot(history, system_prompt=\"\"): \n", " history = [] if history is None else history\n", "\n", " if system_prompt == \"\":\n", " system_prompt = system_message\n", " \n", " history[-1][1] = \"\"\n", " for character in generate(history[-1][0], history[:-1]):\n", " history[-1][1] = character\n", " yield history \n", "\n", "\n", "def get_latents(speaker_wav):\n", " # Generate speaker embedding and latents for TTS\n", " gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav)\n", " return gpt_cond_latent, diffusion_conditioning, speaker_embedding\n", "\n", "latent_map={}\n", "latent_map[\"Female_Voice\"] = get_latents(\"examples/female.wav\")\n", "\n", "def get_voice(prompt,language, latent_tuple,suffix=\"0\"):\n", " gpt_cond_latent,diffusion_conditioning, speaker_embedding = latent_tuple\n", " # Direct version\n", " t0 = time.time()\n", " out = model.inference(\n", " prompt,\n", " language,\n", " gpt_cond_latent,\n", " speaker_embedding,\n", " diffusion_conditioning\n", " )\n", " inference_time = time.time() - t0\n", " print(f\"I: Time to generate audio: {round(inference_time*1000)} milliseconds\")\n", " real_time_factor= (time.time() - t0) / out['wav'].shape[-1] * 24000\n", " print(f\"Real-time factor (RTF): {real_time_factor}\")\n", " wav_filename=f\"output_{suffix}.wav\"\n", " torchaudio.save(wav_filename, torch.tensor(out[\"wav\"]).unsqueeze(0), 24000)\n", " return wav_filename\n", "\n", "def generate_speech(history):\n", " text_to_generate = history[-1][1]\n", " text_to_generate = text_to_generate.replace(\"\\n\", \" \").strip()\n", " text_to_generate = nltk.sent_tokenize(text_to_generate)\n", "\n", " language = \"en\"\n", "\n", " wav_list = []\n", " for i,sentence in enumerate(text_to_generate):\n", " # Sometimes prompt coming on output remove it \n", " sentence= sentence.replace(\"\",\"\")\n", " # A fast fix for last chacter, may produce weird sounds if it is with text\n", " if sentence[-1] in [\"!\",\"?\",\".\",\",\"]:\n", " #just add a space\n", " sentence = sentence[:-1] + \" \" + sentence[-1]\n", " \n", " print(\"Sentence:\", sentence)\n", " \n", " try: \n", " # generate speech using precomputed latents\n", " # This is not streaming but it will be fast\n", " \n", " # giving sentence suffix so we can merge all to single audio at end\n", " # On mobile there is no autoplay support due to mobile security!\n", " wav = get_voice(sentence,language, latent_map[\"Female_Voice\"], suffix=i)\n", " wav_list.append(wav)\n", " \n", " yield wav\n", " wait_time= librosa.get_duration(path=wav)\n", " print(\"Sleeping till audio end\")\n", " time.sleep(wait_time)\n", "\n", " except RuntimeError as e :\n", " if \"device-side assert\" in str(e):\n", " # cannot do anything on cuda device side error, need tor estart\n", " print(f\"Exit due to: Unrecoverable exception caused by prompt:{sentence}\", flush=True)\n", " gr.Warning(\"Unhandled Exception encounter, please retry in a minute\")\n", " print(\"Cuda device-assert Runtime encountered need restart\")\n", "\n", " \n", " # HF Space specific.. This error is unrecoverable need to restart space \n", " api.restart_space(repo_id=repo_id)\n", " else:\n", " print(\"RuntimeError: non device-side assert error:\", str(e))\n", " raise e\n", " #Spoken on autoplay everysencen now produce a concataned one at the one\n", " #requires pip install ffmpeg-python\n", " files_to_concat= [ffmpeg.input(w) for w in wav_list]\n", " combined_file_name=\"combined.wav\"\n", " ffmpeg.concat(*files_to_concat,v=0, a=1).output(combined_file_name).run(overwrite_output=True)\n", "\n", " return gr.Audio.update(value=combined_file_name, autoplay=False)\n", " \n", "\n", "with gr.Blocks(title=title) as demo:\n", " gr.Markdown(DESCRIPTION)\n", " \n", " \n", " chatbot = gr.Chatbot(\n", " [],\n", " elem_id=\"chatbot\",\n", " avatar_images=('examples/lama.jpeg', 'examples/lama2.jpeg'),\n", " bubble_full_width=False,\n", " )\n", "\n", " with gr.Row():\n", " txt = gr.Textbox(\n", " scale=3,\n", " show_label=False,\n", " placeholder=\"Enter text and press enter, or speak to your microphone\",\n", " container=False,\n", " )\n", " txt_btn = gr.Button(value=\"Submit text\",scale=1)\n", " btn = gr.Audio(source=\"microphone\", type=\"filepath\", scale=4)\n", " \n", " with gr.Row():\n", " audio = gr.Audio(type=\"numpy\", streaming=False, autoplay=True, label=\"Generated audio response\", show_label=True)\n", "\n", " clear_btn = gr.ClearButton([chatbot, audio])\n", " \n", " txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(\n", " bot, chatbot, chatbot\n", " ).then(generate_speech, chatbot, audio)\n", "\n", " txt_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False)\n", "\n", " txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(\n", " bot, chatbot, chatbot\n", " ).then(generate_speech, chatbot, audio)\n", " \n", " txt_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False)\n", " \n", " file_msg = btn.stop_recording(add_file, [chatbot, btn], [chatbot], queue=False).then(\n", " bot, chatbot, chatbot\n", " ).then(generate_speech, chatbot, audio)\n", " \n", "\n", " gr.Markdown(\"\"\"\n", "This Space demonstrates how to speak to a chatbot, based solely on open-source models.\n", "It relies on 3 models:\n", "1. [Whisper-large-v2](https://huggingface.co/spaces/sanchit-gandhi/whisper-jax) as an ASR model, to transcribe recorded audio to text. It is called through a [gradio client](https://www.gradio.app/docs/client).\n", "2. [Mistral-7b-instruct](https://huggingface.co/spaces/osanseviero/mistral-super-fast) as the chat model, the actual chat model. It is called from [huggingface_hub](https://huggingface.co/docs/huggingface_hub/guides/inference).\n", "3. [Coqui's XTTS](https://huggingface.co/spaces/coqui/xtts) as a TTS model, to generate the chatbot answers. This time, the model is hosted locally.\n", "\n", "Note:\n", "- By using this demo you agree to the terms of the Coqui Public Model License at https://coqui.ai/cpml\"\"\")\n", "demo.queue()\n", "demo.launch(debug=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "652d675a-8912-44cb-830d-29fc5d6679d4", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }