saitejad commited on
Commit
76b7899
β€’
1 Parent(s): c1cab94

app.py created

Browse files
Files changed (1) hide show
  1. app.py +103 -0
app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import librosa
3
+ import torch
4
+ from transformers import AutoTokenizer, pipeline, logging
5
+ from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
6
+
7
+ from transformers import SpeechT5Processor, SpeechT5ForSpeechToText
8
+
9
+
10
+
11
+ # Audio
12
+
13
+ checkpoint = "microsoft/speecht5_asr"
14
+ audio_processor = SpeechT5Processor.from_pretrained(checkpoint)
15
+ audio_model = SpeechT5ForSpeechToText.from_pretrained(checkpoint)
16
+
17
+ def process_audio(sampling_rate, waveform):
18
+ # convert from int16 to floating point
19
+ waveform = waveform / 32678.0
20
+
21
+ # convert to mono if the stereo
22
+ if len(waveform.shape) > 1:
23
+ waveform = librosa.to_mono(waveform.T)
24
+
25
+ # resample to 16 kHz if necessary
26
+ if sampling_rate != 16000:
27
+ waveform = librosa.resample(waveform, orig_sr=sampling_rate, target_sr=16000)
28
+
29
+ # limit to 30 seconds
30
+ waveform = waveform[:16000*30]
31
+
32
+ # make PyTorch tensor
33
+ waveform = torch.tensor(waveform)
34
+ return waveform
35
+
36
+
37
+ def audio_to_text(audio, mic_audio=None):
38
+ # audio = tuple (sample_rate, frames) or (sample_rate, (frames, channels))
39
+ if mic_audio is not None:
40
+ sampling_rate, waveform = mic_audio
41
+ elif audio is not None:
42
+ sampling_rate, waveform = audio
43
+ else:
44
+ return "(please provide audio)"
45
+
46
+ waveform = process_audio(sampling_rate, waveform)
47
+ inputs = audio_processor(audio=waveform, sampling_rate=16000, return_tensors="pt")
48
+ predicted_ids = audio_model.generate(**inputs, max_length=400)
49
+ transcription = audio_processor.batch_decode(predicted_ids, skip_special_tokens=True)
50
+ return transcription[0]
51
+
52
+
53
+ # Text Generation
54
+
55
+ model_name_or_path = "TheBloke/Llama-2-7b-Chat-GPTQ"
56
+ model_basename = "gptq_model-4bit-128g"
57
+
58
+ use_triton = False
59
+
60
+ llama_tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
61
+
62
+ llama_model = AutoGPTQForCausalLM.from_quantized(model_name_or_path,
63
+ model_basename=model_basename,
64
+ use_safetensors=True,
65
+ trust_remote_code=True,
66
+ device="cuda:0",
67
+ use_triton=use_triton,
68
+ quantize_config=None)
69
+
70
+ def generate(text):
71
+ prompt = text
72
+ system_message = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. Give short, simple and direct answers"
73
+ prompt_template=f'''[INST] <<SYS>>
74
+ {system_message}
75
+ <</SYS>>
76
+
77
+ {prompt} [/INST]'''
78
+
79
+ pipe = pipeline(
80
+ "text-generation",
81
+ model=model,
82
+ tokenizer=tokenizer,
83
+ max_new_tokens=512,
84
+ temperature=0.7,
85
+ top_p=0.95,
86
+ repetition_penalty=1.15
87
+ )
88
+
89
+ return pipe(prompt_template)[0]['generated_text']
90
+
91
+ def audio_text_generate(audio):
92
+ audio_text = audio_to_text(audio)
93
+ generated_text = generate(audio_text)
94
+ response = generated_text[generated_text.index("[/INST]")+7:].strip()
95
+ return audio_text, response
96
+
97
+
98
+ demo = gr.Interface(fn=audio_text_generate,
99
+ inputs=gr.Audio(source="microphone"),
100
+ outputs=[gr.Text(label="Audio Text"), gr.Text(label="Generated Text")])
101
+ # examples=["https://samplelib.com/lib/preview/mp3/sample-3s.mp3"], cache_examples=True)
102
+
103
+ demo.launch()