fastperson / model.py
kwmr's picture
upload
1f5c279
raw
history blame
No virus
6.8 kB
import copy
import subprocess
import numpy as np
import pytsmod as tsm
from moviepy.audio.AudioClip import AudioArrayClip
from moviepy.editor import *
from moviepy.video.fx.speedx import speedx
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline, BertTokenizer, BertForNextSentencePrediction
import torch
import whisper
from utils import two_chnnel_to_one_channel, convert_sample_rate
subprocess.run(['apt-get', '-y', 'install', 'imagemagick'])
transcriber = whisper.load_model("medium")
sentence_transformer = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
next_sentence_predict = BertForNextSentencePrediction.from_pretrained("bert-base-cased").eval()
summarizer = pipeline("summarization", model="philschmid/bart-large-cnn-samsum")
def summarize_video(video_path, ratio_sum, playback_speed):
print("Start summarize video")
output_path = os.path.join(os.path.dirname(video_path), 'output.mp4')
movie_clip = VideoFileClip(video_path)
audio_sampling_rate = movie_clip.audio.fps
clip_audio = np.array(movie_clip.audio.to_soundarray())
# ๆ–‡ๅญ—ใฎๆ›ธใ่ตทใ“ใ—
print("Start transcribing text")
audio_fp32 = convert_sample_rate(clip_audio, audio_sampling_rate, 16000)
audio_fp32 = two_chnnel_to_one_channel(audio_fp32).astype(np.float32)
transcription_results = transcriber.transcribe(audio_fp32)
# ๆ–‡ใฎๅฅๅˆ‡ใ‚Œใ”ใจใซใƒ†ใ‚ญใ‚นใƒˆ/็™บ่ฉฑๆ™‚้–“ใ‚’ใพใจใ‚ใ‚‹
print("Start summarizing text/speech time")
periods = ('.', '!', '?')
clip_sentences = []
head_sentence = True
for r in transcription_results['segments']:
if head_sentence:
start_time = r['start']
clip_sentences.append({'sentence':'', 'sentences':[], 'duration':[r['start'], None], 'durations':[]})
head_sentence = False
clip_sentences[-1]['sentence'] += r['text']
clip_sentences[-1]['sentences'].append(r['text'])
clip_sentences[-1]['durations'].append([r['start'], r['end']])
if r['text'].endswith(periods):
clip_sentences[-1]['duration'][1] = r['end']
head_sentence = True
# ๆ–‡ๅญ—ใฎ่ฆ็ด„
print("Start summarizing sentences")
transcription = transcription_results['text']
summary_text = summarizer(transcription, max_length=int(len(transcription)*0.1), min_length=int(len(transcription)*0.05), do_sample=False)[0]['summary_text']
print(summary_text)
# ่ฆ็ด„ๆ–‡ใจไธ€่‡ดใ™ใ‚‹ๆ–‡ใ‚’ๅˆคๅˆฅ
print("Start deleting sentences that match the summary sentence")
summary_embedings = [sentence_transformer.encode(s, convert_to_tensor=True) for s in summary_text.split('.')]
important_sentence_idxs = [False]*len(clip_sentences)
for s, clip_sentence in enumerate(clip_sentences):
embedding = sentence_transformer.encode(clip_sentence['sentence'], convert_to_tensor=True)
for s_e in summary_embedings:
if util.pytorch_cos_sim(embedding, s_e) > ratio_sum:
important_sentence_idxs[s] = True
# ใจใชใ‚Šใฎๆ–‡ใจๆŽฅ็ถšใ™ใ‚‹ๆ–‡ใ‚’ๅˆคๅˆฅ
print("Start identifying sentences that are connected to the sentence next to it")
def next_prob(prompt, next_sentence, b=1.2):
encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
logits = next_sentence_predict(**encoding, labels=torch.LongTensor([1])).logits
pos = b ** logits[0, 0]
neg = b ** logits[0, 1]
return float(pos / (pos + neg))
connection_idxs = [False]*(len(clip_sentences)-1)
for s in range(len(clip_sentences)-1):
if next_prob(clip_sentences[s]['sentence'], clip_sentences[s+1]['sentence']) > 0.88:
connection_idxs[s] = True
# ่ฆ็ด„ๅพŒใฎๆ–‡็ซ ใฎใฟๆฎ‹ใ™
def combine_arrays(A, B):
C = copy.deepcopy(A)
for i in range(len(A)):
if A[i]:
j = i
while j < len(B) and B[j]:
C[j+1] = True
j += 1
j = i
while j > 0 and B[j-1]:
C[j] = True
j -= 1
return C
important_idxs = combine_arrays(important_sentence_idxs, connection_idxs)
# ่ฆ็ด„ๅพŒใฎๆ–‡็ซ ใŒใฉใ“ใ‹ใ‚’ๅฏ่ฆ–ๅŒ–
html_text = "<h1 class='title'>Full Transcription</h1>"
for idx in range(len(important_sentence_idxs)):
seconds = clip_sentences[idx]['duration'][0] * (1/playback_speed)
minutes, seconds = divmod(seconds, 60)
if important_idxs[idx]:
html_text += '<p> <b>' + f"{int(minutes)}:{int(seconds):02} | {clip_sentences[idx]['sentence']} </b> </p>"
else:
html_text += f"{int(minutes)}:{int(seconds):02} | {clip_sentences[idx]['sentence']}</p>"
print(html_text)
# ๅ‹•็”ปใ‚’็ตๅˆ
print("Start combine movies")
clips = []
for i in range(len(important_idxs)):
if important_idxs[i]:
tmp_clips = []
for j in range(len(clip_sentences[i]['sentences'])):
start_time, end_time = clip_sentences[i]['durations'][j][0], clip_sentences[i]['durations'][j][1]
if end_time > movie_clip.duration:
end_time = movie_clip.duration
if start_time > movie_clip.duration:
continue
clip = movie_clip.subclip(start_time, end_time)
clip = clip.set_pos("center").set_duration(end_time-start_time)
tmp_clips.append(clip)
clips.append(concatenate_videoclips(tmp_clips))
# ใ‚ฏใƒชใƒƒใƒ—ใ‚’ใ‚ฏใƒญใ‚นใƒ‡ใ‚ฃใ‚พใƒซใƒ–ใง็ตๅˆ
# for c in range(len(clips)-1):
# fade_duration = 2
# clips[c] = clips[c].crossfadeout(fade_duration).audio_fadeout(fade_duration)
# clips[c+1] = clips[c+1].crossfadein(fade_duration).audio_fadein(fade_duration)
# ๅ‹•็”ปใ‚’็ตๅˆใ—ๅ†็”Ÿ้€Ÿๅบฆใ‚’ๅค‰ๅŒ–ใ•ใ›ใ‚‹
final_video = concatenate_videoclips(clips, method="chain")
final_video_audio = np.array(final_video.audio.to_soundarray(fps=audio_sampling_rate))
if playback_speed != 1:
final_video_audio_fixed = tsm.wsola(final_video_audio, 1/playback_speed).T
else:
final_video_audio_fixed = final_video_audio
final_video = speedx(final_video, factor=playback_speed)
final_video = final_video.set_audio(AudioArrayClip(final_video_audio_fixed, fps=audio_sampling_rate))
# if final_video.duration > 30:
# final_video = final_video.subclip(0, 30)
final_video.write_videofile(output_path)
print(output_path)
print("Success summarize video")
return output_path, summary_text, html_text