Spaces:
Running
Running
import os | |
import subprocess | |
import random | |
import numpy as np | |
import json | |
from datetime import timedelta | |
import tempfile | |
import gradio as gr | |
from groq import Groq | |
client = Groq(api_key=os.environ.get("Groq_Api_Key")) | |
# llms | |
MAX_SEED = np.iinfo(np.int32).max | |
def update_max_tokens(model): | |
if model in ["llama3-70b-8192", "llama3-8b-8192", "gemma-7b-it", "gemma2-9b-it"]: | |
return gr.update(maximum=8192) | |
elif model == "mixtral-8x7b-32768": | |
return gr.update(maximum=32768) | |
def create_history_messages(history): | |
history_messages = [{"role": "user", "content": m[0]} for m in history] | |
history_messages.extend([{"role": "assistant", "content": m[1]} for m in history]) | |
return history_messages | |
def generate_response(prompt, history, model, temperature, max_tokens, top_p, seed): | |
messages = create_history_messages(history) | |
messages.append({"role": "user", "content": prompt}) | |
print(messages) | |
if seed == 0: | |
seed = random.randint(1, MAX_SEED) | |
stream = client.chat.completions.create( | |
messages=messages, | |
model=model, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
top_p=top_p, | |
seed=seed, | |
stop=None, | |
stream=True, | |
) | |
response = "" | |
for chunk in stream: | |
delta_content = chunk.choices[0].delta.content | |
if delta_content is not None: | |
response += delta_content | |
yield response | |
return response | |
# speech to text | |
ALLOWED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"] | |
MAX_FILE_SIZE_MB = 25 | |
LANGUAGE_CODES = { | |
"English": "en", | |
"Chinese": "zh", | |
"German": "de", | |
"Spanish": "es", | |
"Russian": "ru", | |
"Korean": "ko", | |
"French": "fr", | |
"Japanese": "ja", | |
"Portuguese": "pt", | |
"Turkish": "tr", | |
"Polish": "pl", | |
"Catalan": "ca", | |
"Dutch": "nl", | |
"Arabic": "ar", | |
"Swedish": "sv", | |
"Italian": "it", | |
"Indonesian": "id", | |
"Hindi": "hi", | |
"Finnish": "fi", | |
"Vietnamese": "vi", | |
"Hebrew": "he", | |
"Ukrainian": "uk", | |
"Greek": "el", | |
"Malay": "ms", | |
"Czech": "cs", | |
"Romanian": "ro", | |
"Danish": "da", | |
"Hungarian": "hu", | |
"Tamil": "ta", | |
"Norwegian": "no", | |
"Thai": "th", | |
"Urdu": "ur", | |
"Croatian": "hr", | |
"Bulgarian": "bg", | |
"Lithuanian": "lt", | |
"Latin": "la", | |
"Māori": "mi", | |
"Malayalam": "ml", | |
"Welsh": "cy", | |
"Slovak": "sk", | |
"Telugu": "te", | |
"Persian": "fa", | |
"Latvian": "lv", | |
"Bengali": "bn", | |
"Serbian": "sr", | |
"Azerbaijani": "az", | |
"Slovenian": "sl", | |
"Kannada": "kn", | |
"Estonian": "et", | |
"Macedonian": "mk", | |
"Breton": "br", | |
"Basque": "eu", | |
"Icelandic": "is", | |
"Armenian": "hy", | |
"Nepali": "ne", | |
"Mongolian": "mn", | |
"Bosnian": "bs", | |
"Kazakh": "kk", | |
"Albanian": "sq", | |
"Swahili": "sw", | |
"Galician": "gl", | |
"Marathi": "mr", | |
"Panjabi": "pa", | |
"Sinhala": "si", | |
"Khmer": "km", | |
"Shona": "sn", | |
"Yoruba": "yo", | |
"Somali": "so", | |
"Afrikaans": "af", | |
"Occitan": "oc", | |
"Georgian": "ka", | |
"Belarusian": "be", | |
"Tajik": "tg", | |
"Sindhi": "sd", | |
"Gujarati": "gu", | |
"Amharic": "am", | |
"Yiddish": "yi", | |
"Lao": "lo", | |
"Uzbek": "uz", | |
"Faroese": "fo", | |
"Haitian": "ht", | |
"Pashto": "ps", | |
"Turkmen": "tk", | |
"Norwegian Nynorsk": "nn", | |
"Maltese": "mt", | |
"Sanskrit": "sa", | |
"Luxembourgish": "lb", | |
"Burmese": "my", | |
"Tibetan": "bo", | |
"Tagalog": "tl", | |
"Malagasy": "mg", | |
"Assamese": "as", | |
"Tatar": "tt", | |
"Hawaiian": "haw", | |
"Lingala": "ln", | |
"Hausa": "ha", | |
"Bashkir": "ba", | |
"jw": "jw", | |
"Sundanese": "su", | |
} | |
# Checks file extension, size, and downsamples if needed. | |
def check_file(audio_file_path): | |
if not audio_file_path: | |
return None, gr.Error("Please upload an audio file.") | |
file_size_mb = os.path.getsize(audio_file_path) / (1024 * 1024) | |
file_extension = audio_file_path.split(".")[-1].lower() | |
if file_extension not in ALLOWED_FILE_EXTENSIONS: | |
return ( | |
None, | |
gr.Error( | |
f"Invalid file type (.{file_extension}). Allowed types: {', '.join(ALLOWED_FILE_EXTENSIONS)}" | |
), | |
) | |
if file_size_mb > MAX_FILE_SIZE_MB: | |
gr.Warning( | |
f"File size too large ({file_size_mb:.2f} MB). Attempting to downsample to 16kHz. Maximum allowed: {MAX_FILE_SIZE_MB} MB" | |
) | |
output_file_path = os.path.splitext(audio_file_path)[0] + "_downsampled.wav" | |
try: | |
subprocess.run( | |
[ | |
"ffmpeg", | |
"-i", | |
audio_file_path, | |
"-ar", | |
"16000", | |
"-ac", | |
"1", | |
"-map", | |
"0:a:", | |
output_file_path, | |
], | |
check=True, | |
) | |
# Check size after downsampling | |
downsampled_size_mb = os.path.getsize(output_file_path) / (1024 * 1024) | |
if downsampled_size_mb > MAX_FILE_SIZE_MB: | |
return ( | |
None, | |
gr.Error( | |
f"File size still too large after downsampling ({downsampled_size_mb:.2f} MB). Maximum allowed: {MAX_FILE_SIZE_MB} MB" | |
), | |
) | |
return output_file_path, None | |
except subprocess.CalledProcessError as e: | |
return None, gr.Error(f"Error during downsampling: {e}") | |
return audio_file_path, None | |
def transcribe_audio(audio_file_path, prompt, language, auto_detect_language, model): | |
# Check and process the file first | |
processed_path, error_message = check_file(audio_file_path) | |
# If there's an error during file check | |
if error_message: | |
return error_message | |
with open(processed_path, "rb") as file: | |
transcription = client.audio.transcriptions.create( | |
file=(os.path.basename(processed_path), file.read()), | |
model=model, | |
prompt=prompt, | |
response_format="text", | |
language=None if auto_detect_language else language, | |
temperature=0.0, | |
) | |
return transcription.text | |
def translate_audio(audio_file_path, prompt, model): | |
# Check and process the file first | |
processed_path, error_message = check_file(audio_file_path) | |
# If there's an error during file check | |
if error_message: | |
return error_message | |
with open(processed_path, "rb") as file: | |
translation = client.audio.translations.create( | |
file=(os.path.basename(processed_path), file.read()), | |
model=model, | |
prompt=prompt, | |
response_format="text", | |
temperature=0.0, | |
) | |
return translation.text | |
# subtitles maker | |
# helper function convert json transcription to srt | |
from datetime import timedelta | |
def create_srt_from_text(transcription_text): | |
srt_lines = [] | |
start_time = timedelta(seconds=0) | |
# Define a function to calculate the duration based on text length | |
def calculate_duration(text): | |
words_per_minute = 110 | |
words = len(text.split()) | |
duration_seconds = (words / words_per_minute) * 60 | |
return timedelta(seconds=duration_seconds) | |
text_parts = transcription_text.split(".") | |
for i, text_part in enumerate(text_parts): | |
text_part = text_part.strip() | |
if text_part: | |
duration = calculate_duration(text_part) | |
end_time = start_time + duration | |
start_timestamp = str(start_time).split('.')[0] + ',' + str(start_time.microseconds // 1000).zfill(3) | |
end_timestamp = str(end_time).split('.')[0] + ',' + str(end_time.microseconds // 1000).zfill(3) | |
srt_lines.append(f"{i + 1}\n{start_timestamp} --> {end_timestamp}\n{text_part.strip()}\n\n") | |
start_time = end_time # Move to the next time slot | |
return "".join(srt_lines) | |
# getting transcription + using helper function + adding subs to video if input is video | |
def generate_subtitles(audio_file_path, prompt, language, auto_detect_language, model): | |
# Check and process the file first | |
processed_path, error_message = check_file(audio_file_path) | |
# If there's an error during file check | |
if error_message: | |
return error_message, None, None | |
with open(processed_path, "rb") as file: | |
transcription_json = client.audio.transcriptions.create( | |
file=(os.path.basename(processed_path), file.read()), | |
model=model, | |
prompt=prompt, | |
response_format="json", | |
language=None if auto_detect_language else language, # Conditional language parameter | |
temperature=0.0, | |
) | |
# Convert the Transcription object to a dictionary | |
transcription_json = json.loads(transcription_json.to_json()) | |
transcription_text = transcription_json['text'] | |
srt_content = create_srt_from_text(transcription_text) | |
# Create a temporary file for SRT content | |
with tempfile.NamedTemporaryFile(mode="w", suffix=".srt", delete=False) as temp_srt_file: | |
temp_srt_path = temp_srt_file.name | |
temp_srt_file.write(srt_content) | |
# Generate subtitles and add to video if input is video | |
if audio_file_path.lower().endswith((".amp4", ".awebm")): | |
try: | |
# Use ffmpeg to burn subtitles into the video | |
output_file_path = audio_file_path.replace(os.path.splitext(audio_file_path)[1], "_with_subs" + os.path.splitext(audio_file_path)[1]) | |
subprocess.run( | |
[ | |
"ffmpeg", | |
"-i", | |
audio_file_path, | |
"-vf", | |
f"subtitles={temp_srt_path}", | |
output_file_path, | |
], | |
check=True, | |
) | |
return temp_srt_path, output_file_path, None | |
except subprocess.CalledProcessError as e: | |
return None, None, gr.Error(f"Error during subtitle addition: {e}") | |
return temp_srt_path, None, None | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# Groq API UI | |
Inference by Groq | |
Hugging Face Space by [Nick088](https://linktr.ee/Nick088) | |
""" | |
) | |
with gr.Tabs(): | |
with gr.TabItem("select option here:"): | |
with gr.Tabs(): | |
with gr.TabItem("Speech To Text"): | |
gr.Markdown("Speech to Text coming soon!") | |
with gr.TabItem("LLMs"): | |
with gr.Column(): | |
model = gr.Dropdown( | |
choices=[ | |
"llama3-70b-8192", | |
"llama3-8b-8192", | |
"mixtral-8x7b-32768", | |
"gemma-7b-it", | |
"gemma2-9b-it", | |
], | |
value="llama3-70b-8192", | |
label="Model", | |
) | |
temperature = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
step=0.01, | |
value=0.5, | |
label="Temperature", | |
info="Controls diversity of the generated text. Lower is more deterministic, higher is more creative.", | |
) | |
max_tokens = gr.Slider( | |
minimum=1, | |
maximum=8192, | |
step=1, | |
value=4096, | |
label="Max Tokens", | |
info="The maximum number of tokens that the model can process in a single response.<br>Maximums: 8k for gemma 7b it, gemma2 9b it, llama 7b & 70b, 32k for mixtral 8x7b.", | |
) | |
top_p = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
step=0.01, | |
value=0.5, | |
label="Top P", | |
info="A method of text generation where a model will only consider the most probable next tokens that make up the probability p.", | |
) | |
seed = gr.Number( | |
precision=0, value=42, label="Seed", info="A starting point to initiate generation, use 0 for random" | |
) | |
model.change(update_max_tokens, inputs=[model], outputs=max_tokens) | |
chatbot = gr.ChatInterface( | |
fn=generate_response, | |
chatbot=None, | |
additional_inputs=[ | |
model, | |
temperature, | |
max_tokens, | |
top_p, | |
seed, | |
], | |
) | |
model.change(update_max_tokens, inputs=[model], outputs=max_tokens) | |
with gr.TabItem("Transcription"): | |
gr.Markdown("Transcript audio from files to text!") | |
with gr.Column(): | |
audio_input = gr.File( | |
type="filepath", label="Upload File containing Audio", file_types=[f".{ext}" for ext in ALLOWED_FILE_EXTENSIONS] | |
) | |
model_choice_transcribe = gr.Dropdown( | |
choices=["whisper-large-v3"], # Only include 'whisper-large-v3' | |
value="whisper-large-v3", | |
label="Model", | |
) | |
transcribe_prompt = gr.Textbox( | |
label="Prompt (Optional)", | |
info="Specify any context or spelling corrections.", | |
) | |
language = gr.Dropdown( | |
choices=[(lang, code) for lang, code in LANGUAGE_CODES.items()], | |
value="en", | |
label="Language", | |
) | |
auto_detect_language = gr.Checkbox(label="Auto Detect Language") | |
transcribe_button = gr.Button("Transcribe") | |
transcription_output = gr.Textbox(label="Transcription") | |
transcribe_button.click( | |
transcribe_audio, | |
inputs=[audio_input, transcribe_prompt, language, auto_detect_language, model_choice_transcribe], | |
outputs=transcription_output, | |
) | |
with gr.TabItem("Translation"): | |
gr.Markdown("Transcript audio from files and translate them to English text!") | |
with gr.Column(): | |
audio_input_translate = gr.File( | |
type="filepath", label="Upload File containing Audio", file_types=[f".{ext}" for ext in ALLOWED_FILE_EXTENSIONS] | |
) | |
model_choice_translate = gr.Dropdown( | |
choices=["whisper-large-v3"], # Only include 'whisper-large-v3' | |
value="whisper-large-v3", | |
label="Model", | |
) | |
translate_prompt = gr.Textbox( | |
label="Prompt (Optional)", | |
info="Specify any context or spelling corrections.", | |
) | |
translate_button = gr.Button("Translate") | |
translation_output = gr.Textbox(label="Translation") | |
translate_button.click( | |
translate_audio, | |
inputs=[audio_input_translate, translate_prompt, model_choice_translate], | |
outputs=translation_output, | |
) | |
with gr.TabItem("Subtitle Maker"): | |
with gr.Column(): | |
audio_input_subtitles = gr.File( | |
label="Upload Audio/Video", | |
file_types=[f".{ext}" for ext in ALLOWED_FILE_EXTENSIONS], | |
) | |
model_choice_subtitles = gr.Dropdown( | |
choices=["whisper-large-v3"], # Only include 'whisper-large-v3' | |
value="whisper-large-v3", | |
label="Model", | |
) | |
transcribe_prompt_subtitles = gr.Textbox( | |
label="Prompt (Optional)", | |
info="Specify any context or spelling corrections.", | |
) | |
language_subtitles = gr.Dropdown( | |
choices=[(lang, code) for lang, code in LANGUAGE_CODES.items()], | |
value="en", | |
label="Language", | |
) | |
auto_detect_language_subtitles = gr.Checkbox( | |
label="Auto Detect Language" | |
) | |
transcribe_button_subtitles = gr.Button("Generate Subtitles") | |
srt_output = gr.File(label="SRT Output File") | |
video_output = gr.File(label="Output Video with Subtitles") | |
transcribe_button_subtitles.click( | |
generate_subtitles, | |
inputs=[ | |
audio_input_subtitles, | |
transcribe_prompt_subtitles, | |
language_subtitles, | |
auto_detect_language_subtitles, | |
model_choice_subtitles, | |
], | |
outputs=[srt_output, video_output, gr.Textbox(label="Error")] | |
) | |
demo.launch() | |