StephaneBah commited on
Commit
88ef79d
1 Parent(s): 947af02

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -17
app.py CHANGED
@@ -1,8 +1,7 @@
1
- import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
  from diffusers import DiffusionPipeline
4
  import torch
5
- import accelerate
6
 
7
  # Load the models and tokenizers
8
  translation_model_name = "google/madlad400-3b-mt"
@@ -15,7 +14,7 @@ diffusion_model_name = "stabilityai/stable-diffusion-xl-base-1.0"
15
  diffusion_pipeline = DiffusionPipeline.from_pretrained(diffusion_model_name, torch_dtype=torch.float16)
16
  diffusion_pipeline = diffusion_pipeline.to("cuda")
17
 
18
- # Define the translation and transcription pipeline with accelerate
19
  translation_pipeline = pipeline("translation", model=translation_model, tokenizer=translation_tokenizer, device_map="auto")
20
  transcription_pipeline = pipeline("automatic-speech-recognition", model=transcription_model, device_map="auto")
21
 
@@ -28,22 +27,28 @@ def transcribe_and_translate_audio_fon(audio_path, num_images=1):
28
  translation_result = translation_pipeline(transcription_fon, source_lang="fon", target_lang="fr")
29
  translation_fr = translation_result[0]["translation_text"]
30
 
 
31
  images = diffusion_pipeline(translation_fr, num_images_per_prompt=num_images)["images"]
32
 
33
  return images
34
 
35
- # Create a Streamlit app
36
- st.title("Fon Audio to Image Translation")
37
-
38
- # Upload audio file
39
- audio_file = st.file_uploader("Upload an audio file", type=["wav"])
40
-
41
- # Transcribe, translate and generate images
42
- if audio_file:
43
- images = transcribe_and_translate_audio_fon(audio_file)
44
- st.image(images[0])
45
-
46
-
47
- # Use Accelerate to distribute the computation across available GPUs
48
- #images = accelerate.launch(transcribe_and_translate_and_generate, audio_file="Fongbe_Speech_Dataset/Fongbe_Speech_Dataset/fongbe_speech_audio_files/wav/64_fongbe_6b36d45b77344caeb1c8d773303c9dcb_for_validation_2022-03-11-23-50-13.wav", num_images=2)
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
  from diffusers import DiffusionPipeline
4
  import torch
 
5
 
6
  # Load the models and tokenizers
7
  translation_model_name = "google/madlad400-3b-mt"
 
14
  diffusion_pipeline = DiffusionPipeline.from_pretrained(diffusion_model_name, torch_dtype=torch.float16)
15
  diffusion_pipeline = diffusion_pipeline.to("cuda")
16
 
17
+ # Define the translation and transcription pipeline
18
  translation_pipeline = pipeline("translation", model=translation_model, tokenizer=translation_tokenizer, device_map="auto")
19
  transcription_pipeline = pipeline("automatic-speech-recognition", model=transcription_model, device_map="auto")
20
 
 
27
  translation_result = translation_pipeline(transcription_fon, source_lang="fon", target_lang="fr")
28
  translation_fr = translation_result[0]["translation_text"]
29
 
30
+ # Generate images based on the French translation using the diffusion model
31
  images = diffusion_pipeline(translation_fr, num_images_per_prompt=num_images)["images"]
32
 
33
  return images
34
 
35
+ # Create a Gradio interface
36
+ def process_audio(audio, num_images):
37
+ images = transcribe_and_translate_audio_fon(audio, num_images)
38
+ return images
 
 
 
 
 
 
 
 
 
 
39
 
40
+ # Define Gradio interface components
41
+ audio_input = gr.Audio(source="upload", type="filepath", label="Upload an audio file")
42
+ image_output = gr.Gallery(label="Generated Images").style(grid=2)
43
+ num_images_input = gr.Slider(minimum=1, maximum=5, step=1, value=1, label="Number of Images")
44
+
45
+ # Launch Gradio interface
46
+ interface = gr.Interface(
47
+ fn=process_audio,
48
+ inputs=[audio_input, num_images_input],
49
+ outputs=image_output,
50
+ title="Fon Audio to Image Translation",
51
+ description="Upload an audio file in Fon, and the app will transcribe, translate to French, and generate related images."
52
+ )
53
+
54
+ interface.launch()