Ahmedasd commited on
Commit
83fcd99
·
1 Parent(s): e491cdb

first commit

Browse files
Files changed (3) hide show
  1. app.py +143 -0
  2. gradio intro.mp3 +0 -0
  3. requirements.txt +91 -0
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
3
+ from datasets import load_dataset
4
+ from duckduckgo_search import DDGS
5
+ from newspaper import Article
6
+ import scipy
7
+ from transformers import (
8
+ MT5Tokenizer,
9
+ AdamW,
10
+ MT5ForConditionalGeneration,
11
+ pipeline
12
+ )
13
+ from transformers import VitsModel, AutoTokenizer
14
+ import IPython.display as ipd
15
+ import torch
16
+ import numpy as np
17
+ import gradio as gr
18
+ import os
19
+
20
+ class Webapp:
21
+ def __init__(self):
22
+ self.DEVICE = 0 if torch.cuda.is_available() else "cpu"
23
+ self.REF_MODEL = 'google/mt5-small'
24
+ self.MODEL_NAME = 'Ahmedasd/arabic-summarization-hhh-100-batches'
25
+ self.model_id = "openai/whisper-base"
26
+ self.tts_model_id = "SeyedAli/Arabic-Speech-synthesis"
27
+ self.tts_model = VitsModel.from_pretrained(self.tts_model_id).to(self.DEVICE)
28
+ self.tts_tokenizer = AutoTokenizer.from_pretrained(self.tts_model_id)
29
+
30
+ self.summ_tokenizer = MT5Tokenizer.from_pretrained(self.REF_MODEL)
31
+ self.summ_model = MT5ForConditionalGeneration.from_pretrained(self.MODEL_NAME).to(self.DEVICE)
32
+
33
+ self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
34
+
35
+
36
+
37
+ self.stt_model = WhisperForConditionalGeneration.from_pretrained(self.model_id)
38
+ self.stt_model.to(self.DEVICE)
39
+
40
+ self.processor = WhisperProcessor.from_pretrained(self.model_id)
41
+ self.forced_decoder_ids = self.processor.get_decoder_prompt_ids(language="arabic", task="transcribe")
42
+ def speech_to_text(self, input):
43
+ print('gradio audio type: ', type(input))
44
+ print('gradio audio: ', input)
45
+ new_sample_rate = 16000
46
+ new_length = int(len(input[1]) * new_sample_rate / 48000)
47
+ audio_sr_16000 = scipy.signal.resample(input[1], new_length)
48
+ print('input audio16000: ', audio_sr_16000)
49
+ input_features = self.processor(audio_sr_16000, sampling_rate=new_sample_rate, return_tensors="pt").input_features.to(self.DEVICE)
50
+ predicted_ids = self.stt_model.generate(input_features, forced_decoder_ids=self.forced_decoder_ids)
51
+ transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)
52
+ return transcription
53
+ def get_articles(self, query, num):
54
+ with DDGS(timeout=20) as ddgs:
55
+ try:
56
+ results = ddgs.news(query, max_results=num)
57
+ urls = [r['url'] for r in results]
58
+ print('successful connection!')
59
+ except Exception as error:
60
+ urls = ['https://www.bbc.com/arabic/media-65576589']
61
+
62
+ articles = []
63
+ for url in urls:
64
+ article = Article(url)
65
+ article.download()
66
+ article.parse()
67
+ articles.append(article.text.replace('\n',''))
68
+ return articles
69
+ def summarize(self, text, model):
70
+ text_encoding = self.summ_tokenizer(
71
+ text,
72
+ max_length=512,
73
+ padding='max_length',
74
+ truncation=True,
75
+ return_attention_mask=True,
76
+ add_special_tokens=True,
77
+ return_tensors='pt'
78
+ )
79
+ generated_ids = self.summ_model.generate(
80
+ input_ids=text_encoding['input_ids'].to(self.DEVICE),
81
+ attention_mask = text_encoding['attention_mask'].to(self.DEVICE),
82
+ max_length=128,
83
+ # num_beams=2,
84
+ repetition_penalty=2.5,
85
+ # length_penalty=1.0,
86
+ # early_stopping=True
87
+ )
88
+
89
+ preds = [self.summ_tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
90
+ for gen_id in generated_ids
91
+ ]
92
+ return "".join(preds)
93
+ def summarize_articles(self, articles: int, model):
94
+ summaries = []
95
+ for article in articles:
96
+ summaries.append(self.summarize(article, model))
97
+ return summaries
98
+ def text_to_speech(self, text):
99
+ inputs = self.tts_tokenizer(text, return_tensors="pt").to(self.DEVICE)
100
+ print('text_to_speech text: ', text)
101
+ with torch.no_grad():
102
+ wav = self.tts_model(**inputs).waveform
103
+ print('text_to_speech wav: ', wav)
104
+ return {'wav':wav, 'rate':self.tts_model.config.sampling_rate}
105
+ def topic_voice_to_summary_voices(self, topic_voice, number_articles):
106
+ topic = self.speech_to_text(topic_voice)
107
+ print('topic: ', topic)
108
+ articles = self.get_articles(topic, number_articles)
109
+ print('articles: ', articles)
110
+ summaries = self.summarize_articles(articles, self.summ_model)
111
+ print('summaries: ', summaries)
112
+ voices_wav_rate = [self.text_to_speech(summary) for summary in summaries]
113
+
114
+ return voices_wav_rate
115
+ def run(self):
116
+ with gr.Blocks(title = 'أخبار مسموعة', analytics_enabled=True, theme = gr.themes.Glass, css = 'dir: rtl;') as demo:
117
+ gr.Markdown(
118
+ """
119
+ # أخبار مسموعة
120
+ اذكر الموضوع الذي تريد البحث عنه وسوف نخبرك بملخصات الأخبار بشأنه.
121
+ """, rtl = True)
122
+ intro_voice = gr.Audio(type='filepath', value = os.getcwd() + '/gradio intro.mp3', visible = False, autoplay = True)
123
+ topic_voice = gr.Audio(type="numpy", sources = 'microphone', label ='سجل موضوع للبحث')
124
+ num_articles = gr.Slider(minimum=1, maximum=10, value=1, step = 1, label = "عدد المقالات")
125
+ output_audio = gr.Audio(streaming = True, autoplay = True, label = 'الملخصات')
126
+
127
+ # Events
128
+ # generate summaries
129
+ @topic_voice.stop_recording(inputs = [topic_voice, num_articles], outputs = output_audio)
130
+ def get_summ_audio(topic_voice, num_articles):
131
+ summ_voices = self.topic_voice_to_summary_voices(topic_voice, num_articles)
132
+ m =15000
133
+ print('summ voices: ', summ_voices)
134
+ print('wav: ')
135
+ print('max: ', (np.array(summ_voices[0]['wav'][0].cpu()*m, dtype = np.int16)).max())
136
+ print('min: ', (np.array(summ_voices[0]['wav'][0].cpu()*m, dtype = np.int16)).min())
137
+ print('len: ', len(np.array(summ_voices[0]['wav'][0].cpu(), dtype = np.int16)))
138
+ summ_audio = [(voice['rate'], np.squeeze(np.array(voice['wav'].cpu()*m, dtype = np.int16))) for voice in summ_voices]
139
+ return summ_audio[0] #only first
140
+ return demo
141
+
142
+ app = Webapp()
143
+ app.run().launch()
gradio intro.mp3 ADDED
Binary file (69.6 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ traitlets==5.7.1
2
+ pygments==2.16.1
3
+ ptyprocess==0.7.0
4
+ pexpect==4.9.0
5
+ pathlib==1.0.1
6
+ decorator==4.4.2
7
+ pickleshare==0.7.5
8
+ backcall==0.2.0
9
+ wcwidth==0.2.12
10
+ platformdirs==4.1.0
11
+ ipykernel==5.5.6
12
+ tornado==6.3.2
13
+ six==1.16.0
14
+ setuptools==67.7.2
15
+ psutil==5.9.5
16
+ pyparsing==3.1.1
17
+ certifi==2023.11.17
18
+ httplib2==0.22.0
19
+ numpy==1.23.5
20
+ packaging==23.2
21
+ defusedxml==0.7.1
22
+ cffi==1.16.0
23
+ cycler==0.12.1
24
+ kiwisolver==1.4.5
25
+ debugpy==1.6.6
26
+ portpicker==1.5.2
27
+ astunparse==1.6.3
28
+ tqdm==4.66.1
29
+ mpmath==1.3.0
30
+ sympy==1.12
31
+ pydot==1.4.2
32
+ torch==2.1.0+cu121
33
+ urllib3==2.0.7
34
+ chardet==5.2.0
35
+ idna==3.6
36
+ requests==2.31.0
37
+ ipywidgets==7.7.1
38
+ pydantic==1.10.13
39
+ filelock==3.13.1
40
+ cloudpickle==2.2.1
41
+ etils==1.6.0
42
+ rich==13.7.0
43
+ transformers==4.35.2
44
+ tokenizers==0.15.0
45
+ safetensors==0.4.1
46
+ regex==2023.6.3
47
+ fsspec==2023.6.0
48
+ pytz==2023.3.post1
49
+ pyarrow==10.0.1
50
+ numexpr==2.8.8
51
+ pandas==1.5.3
52
+ soundfile==0.12.1
53
+ multidict==6.0.4
54
+ yarl==1.9.4
55
+ frozenlist==1.4.0
56
+ aiosignal==1.3.1
57
+ aiohttp==3.9.1
58
+ xxhash==3.4.1
59
+ lxml==4.9.3
60
+ soupsieve==2.5
61
+ webencodings==0.5.1
62
+ html5lib==1.1
63
+ scipy==1.11.4
64
+ wrapt==1.14.1
65
+ gast==0.5.4
66
+ termcolor==2.4.0
67
+ cryptography==41.0.7
68
+ cachetools==5.3.2
69
+ uritemplate==4.1.1
70
+ oauth2client==4.1.3
71
+ pyasn1==0.5.1
72
+ rsa==4.9
73
+ tblib==3.0.0
74
+ h5py==3.9.0
75
+ flatbuffers==23.5.26
76
+ joblib==1.3.2
77
+ threadpoolctl==3.2.0
78
+ sniffio==1.3.0
79
+ anyio==3.7.1
80
+ click==8.1.7
81
+ markupsafe==2.1.3
82
+ jinja2==3.1.2
83
+ attrs==23.1.0
84
+ referencing==0.32.0
85
+ webcolors==1.13
86
+ jsonschema==4.19.2
87
+ entrypoints==0.4
88
+ toolz==0.12.0
89
+ altair==4.2.2
90
+ mdurl==0.1.2
91
+ typer==0.9.0