fffiloni commited on
Commit
6935ada
1 Parent(s): 7edde5e

Create app_zero.py

Browse files
Files changed (1) hide show
  1. app_zero.py +152 -0
app_zero.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2023) Tsinghua University, Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import gradio as gr
16
+ import spaces
17
+
18
+ import argparse
19
+ from model_zero import SALMONN
20
+
21
+ class ff:
22
+ def generate(self, wav_path, prompt, prompt_pattern, num_beams, temperature, top_p):
23
+ print(f'wav_path: {wav_path}, prompt: {prompt}, temperature: {temperature}, num_beams: {num_beams}, top_p: {top_p}')
24
+ return "I'm sorry, but I cannot answer that question as it is not clear what you are asking. Can you please provide more context or clarify your question?"
25
+
26
+ parser = argparse.ArgumentParser()
27
+ parser.add_argument("--device", type=str, default="cuda:0")
28
+ parser.add_argument("--ckpt_path", type=str, default="./salmonn_7b_v0.pth")
29
+ parser.add_argument("--whisper_path", type=str, default="./whisper_large_v2")
30
+ parser.add_argument("--beats_path", type=str, default="./beats/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt")
31
+ parser.add_argument("--vicuna_path", type=str, default="./vicuna-7b-v1.5")
32
+ parser.add_argument("--low_resource", action='store_true', default=False)
33
+ parser.add_argument("--port", default=9527)
34
+
35
+ args = parser.parse_args()
36
+ args.low_resource = True # for huggingface A10 7b demo
37
+ # model = ff()
38
+ model = SALMONN(
39
+ ckpt=args.ckpt_path,
40
+ whisper_path=args.whisper_path,
41
+ beats_path=args.beats_path,
42
+ vicuna_path=args.vicuna_path,
43
+ low_resource=args.low_resource,
44
+ lora_alpha=28,
45
+ device='cpu'
46
+ )
47
+ model.to(args.device)
48
+ model.eval()
49
+
50
+ @spaces.GPU(enable_queue=True)
51
+ def gradio_answer(speech, text_input, num_beams, temperature, top_p):
52
+
53
+ llm_message = model.generate(
54
+ wav_path=speech,
55
+ prompt=text_input,
56
+ num_beams=num_beams,
57
+ temperature=temperature,
58
+ top_p=top_p,
59
+ )
60
+
61
+ return llm_message[0]
62
+
63
+ title = """<h1 style="text-align: center;">SALMONN: Speech Audio Language Music Open Neural Network</h1>"""
64
+ image_src = """<h1 align="center"><a href="https://github.com/bytedance/SALMONN"><img src="https://raw.githubusercontent.com/bytedance/SALMONN/main/resource/salmon.png", alt="SALMONN" border="0" style="margin: 0 auto; height: 200px;" /></a> </h1>"""
65
+ description = """<h3 style="text-align: center;">This is a simplified gradio demo for <a href="https://huggingface.co/tsinghua-ee/SALMONN-7B" target="_blank">SALMONN-7B</a>. <br />To experience SALMONN-13B, you can go to <a href="https://bytedance.github.io/SALMONN">https://bytedance.github.io/SALMONN</a>.<br /> Upload your audio and ask a question!</h3>"""
66
+
67
+ css = """
68
+ div#col-container {
69
+ margin: 0 auto;
70
+ max-width: 840px;
71
+ }
72
+ """
73
+
74
+ with gr.Blocks(css=css) as demo:
75
+ with gr.Column(elem_id="col-container"):
76
+ gr.HTML(title)
77
+ #gr.Markdown(image_src)
78
+ gr.HTML(description)
79
+
80
+ with gr.Row():
81
+ with gr.Column():
82
+ speech = gr.Audio(label="Audio", type='filepath')
83
+
84
+ with gr.Row():
85
+ text_input = gr.Textbox(label='User question', placeholder='Please upload your audio first', interactive=True)
86
+ submit_btn = gr.Button("Submit")
87
+ answer = gr.Textbox(label="Salmonn answer")
88
+
89
+ with gr.Accordion("Advanced Settings", open=False):
90
+ num_beams = gr.Slider(
91
+ minimum=1,
92
+ maximum=10,
93
+ value=4,
94
+ step=1,
95
+ interactive=True,
96
+ label="beam search numbers",
97
+ )
98
+
99
+ top_p = gr.Slider(
100
+ minimum=0.1,
101
+ maximum=1.0,
102
+ value=0.9,
103
+ step=0.1,
104
+ interactive=True,
105
+ label="top p",
106
+ )
107
+
108
+ temperature = gr.Slider(
109
+ minimum=0.8,
110
+ maximum=2.0,
111
+ value=1.0,
112
+ step=0.1,
113
+ interactive=False,
114
+ label="temperature",
115
+ )
116
+
117
+
118
+ with gr.Row():
119
+ examples = gr.Examples(
120
+ examples = [
121
+ ["resource/audio_demo/gunshots.wav", "Recognize the speech and give me the transcription."],
122
+ ["resource/audio_demo/gunshots.wav", "Listen to the speech and translate it into German."],
123
+ ["resource/audio_demo/gunshots.wav", "Provide the phonetic transcription for the speech."],
124
+ ["resource/audio_demo/gunshots.wav", "Please describe the audio."],
125
+ ["resource/audio_demo/gunshots.wav", "Recognize what the speaker says and describe the background audio at the same time."],
126
+ ["resource/audio_demo/gunshots.wav", "Use your strong reasoning skills to answer the speaker's question in detail based on the background sound."],
127
+ ["resource/audio_demo/duck.wav", "Please list each event in the audio in order."],
128
+ ["resource/audio_demo/duck.wav", "Based on the audio, write a story in detail. Your story should be highly related to the audio."],
129
+ ["resource/audio_demo/duck.wav", "How many speakers did you hear in this audio? Who are they?"],
130
+ ["resource/audio_demo/excitement.wav", "Describe the emotion of the speaker."],
131
+ ["resource/audio_demo/mountain.wav", "Please answer the question in detail."],
132
+ ["resource/audio_demo/jobs.wav", "Give me only three keywords of the text. Explain your reason."],
133
+ ["resource/audio_demo/2_30.wav", "What is the time mentioned in the speech?"],
134
+ ["resource/audio_demo/music.wav", "Please describe the music in detail."],
135
+ ["resource/audio_demo/music.wav", "What is the emotion of the music? Explain the reason in detail."],
136
+ ["resource/audio_demo/music.wav", "Can you write some lyrics of the song?"],
137
+ ["resource/audio_demo/music.wav", "Give me a title of the music based on its rhythm and emotion."]
138
+ ],
139
+ inputs=[speech, text_input]
140
+ )
141
+
142
+
143
+ text_input.submit(
144
+ gradio_answer, [speech, text_input, num_beams, temperature, top_p], [answer]
145
+ )
146
+ submit_btn.click(
147
+ gradio_answer, [speech, text_input, num_beams, temperature, top_p], [answer]
148
+ )
149
+
150
+
151
+ # demo.launch(share=True, enable_queue=True, server_port=int(args.port))
152
+ demo.queue(max_size=20).launch(share=False)