VMORnD commited on
Commit
360c1a4
1 Parent(s): 7740f1a

Upload core.py

Browse files
Files changed (1) hide show
  1. core.py +78 -0
core.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ from transformers import pipeline
3
+
4
+ import whisper
5
+
6
+ import datetime
7
+
8
+ transformers.utils.move_cache()
9
+
10
+ # ====================================
11
+ # Load speech recognition model
12
+ # speech_recognition_pipeline = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h")
13
+ speech_recognition_model = whisper.load_model("base")
14
+
15
+ # ====================================
16
+ # Load text summarization model English
17
+ # text_summarization_pipeline_En = pipeline("summarization", model="facebook/bart-large-cnn")
18
+ tokenizer_En = transformers.AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
19
+ text_summarization_model_En = transformers.AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
20
+
21
+ # ====================================
22
+ # Load text summarization model Vietnamese
23
+ tokenizer_Vi = transformers.AutoTokenizer.from_pretrained("VietAI/vit5-large-vietnews-summarization")
24
+ text_summarization_model_Vi = transformers.AutoModelForSeq2SeqLM.from_pretrained("VietAI/vit5-large-vietnews-summarization")
25
+
26
+ def asr_transcript(input_file):
27
+ audio = whisper.load_audio(input_file)
28
+ output = speech_recognition_model.transcribe(audio)
29
+ text = output['text']
30
+ lang = "English"
31
+ if output["language"] == 'en':
32
+ lang = "English"
33
+ elif output["language"] == 'vi':
34
+ lang = "Vietnamese"
35
+
36
+ detail = ""
37
+ for segment in output['segments']:
38
+ start = str(datetime.timedelta(seconds=round(segment['start'])))
39
+ end = str(datetime.timedelta(seconds=round(segment['end'])))
40
+ small_text = segment['text']
41
+ detail = detail + start + "-" + end + " " + small_text + "\n"
42
+ return text, lang, detail
43
+
44
+ def text_summarize_en(text_input):
45
+ encoding = tokenizer_En(text_input, truncation=True, return_tensors="pt")
46
+ input_ids, attention_masks = encoding["input_ids"], encoding["attention_mask"]
47
+ outputs = text_summarization_model_En.generate(
48
+ input_ids=input_ids, attention_mask=attention_masks,
49
+ max_length=256,
50
+ early_stopping=True
51
+ )
52
+ text = ""
53
+ for output in outputs:
54
+ line = tokenizer_En.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
55
+ text = text + line
56
+ return text
57
+
58
+ def text_summarize_vi(text_input):
59
+ encoding = tokenizer_Vi(text_input, truncation=True, return_tensors="pt")
60
+ input_ids, attention_masks = encoding["input_ids"], encoding["attention_mask"]
61
+ outputs = text_summarization_model_Vi.generate(
62
+ input_ids=input_ids, attention_mask=attention_masks,
63
+ max_length=256,
64
+ early_stopping=True
65
+ )
66
+ text = ""
67
+ for output in outputs:
68
+ line = tokenizer_Vi.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
69
+ text = text + line
70
+ return text
71
+
72
+ def text_summarize(text_input, lang):
73
+ if lang == 'English':
74
+ return text_summarize_en(text_input)
75
+ elif lang == 'Vietnamese':
76
+ return text_summarize_vi(text_input)
77
+ else:
78
+ return ""