fastperson / model.py
kwmr's picture
upload
7d76a49
import subprocess
import numpy as np
import pytsmod as tsm
from firebase_admin import firestore
from moviepy.audio.AudioClip import AudioArrayClip
from moviepy.editor import *
from moviepy.video.fx.speedx import speedx
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline, BertTokenizer, BertForNextSentencePrediction
import torch
import whisper
from utils import two_chnnel_to_one_channel, convert_sample_rate, log_firestore
subprocess.run(['apt-get', '-y', 'install', 'imagemagick'])
# 音声認識モデル
transcriber = whisper.load_model("medium")
# 文章の埋め込みを生成する文章の埋め込みをモデル
sentence_transformer = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
# BERTのTokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
# 2つの文が連続しているかどうかを判定するモデル
next_sentence_predict = BertForNextSentencePrediction.from_pretrained("bert-base-cased").eval()
# 文章の要約モデル
summarizer = pipeline("summarization", model="philschmid/bart-large-cnn-samsum")
def log_firestore(db, local_id="000000", action="test", sim_thr="", playback_speed=""):
data = {
"local_id": local_id,
"timestamp": firestore.SERVER_TIMESTAMP,
"action": action,
"sim_thr": sim_thr,
"playback_speed": playback_speed
}
doc_ref = db.collection("exp.1").document()
doc_ref.set(data)
print(f"Download Video: {doc_ref.id}")
def summarize_video(video_path, sim_thr, playback_speed, local_id, db):
"""
動画要約
Parameters:
video_path (str): 動画のファイルパス
sim_thr (float): 要約文との一致度合いの閾値
playback_speed (float): 再生速度
Returns:
output_path (str): 出力動画のファイルパス
summary_text (str): 要約された文章
full_textt (str): 元の文章(要約で抽出されたところを強調)
"""
print("Start summarize video")
## 動画の保存パスを設定
output_path = os.path.join(os.path.dirname(video_path), 'output.mp4')
## 動画クリップの作成
movie_clip = VideoFileClip(video_path)
## オーディオのサンプリングレートを取得
audio_sampling_rate = movie_clip.audio.fps
## オーディオをnumpy配列に変換
clip_audio = np.array(movie_clip.audio.to_soundarray())
# 文字の書き起こし
print("Start transcribing text")
## サンプリングレートを変更
audio_fp32 = convert_sample_rate(clip_audio, audio_sampling_rate, 16000)
audio_fp32 = two_chnnel_to_one_channel(audio_fp32).astype(np.float32)
## 文字起こしの結果を取得
transcription_results = transcriber.transcribe(audio_fp32)
# 文の句切れごとにテキスト/発話時間をまとめる
print("Start summarizing text/speech time")
## 句読点を指定
periods = ('.', '!', '?')
## センテンスごとのテキストと時間を格納するリストを初期化
clip_sentences = []
## 先頭の文かどうかのフラグを初期化
head_sentence = True
## センテンスごとのテキストと時間を格納
for r in transcription_results['segments']:
if head_sentence:
start_time = r['start']
clip_sentences.append({'sentence':'', 'sentences':[], 'duration':[r['start'], None], 'durations':[]})
head_sentence = False
clip_sentences[-1]['sentence'] += r['text']
clip_sentences[-1]['sentences'].append(r['text'])
clip_sentences[-1]['durations'].append([r['start'], r['end']])
if r['text'].endswith(periods):
clip_sentences[-1]['duration'][1] = r['end']
head_sentence = True
# 文章の要約
print("Start summarizing sentences")
## 文字起こしの結果を取得
transcription = transcription_results['text']
## 文字の要約を生成
summary_text = summarizer(transcription, max_length=int(len(transcription)*0.1), min_length=int(len(transcription)*0.05), do_sample=False)[0]['summary_text']
## 要約された文章を出力
print(summary_text)
# 要約文と一致する文を判別
print("Start deleting sentences that match the summary sentence")
## 要約文の各文の埋め込みを生成
summary_embedings = [sentence_transformer.encode(s, convert_to_tensor=True) for s in summary_text.split('.')]
## 重要な文のインデックスを格納するリストを初期化
important_sentence_idxs = [False]*len(clip_sentences)
## 文の埋め込みを生成して、要約文との一致が閾値以上であれば重要文としてマークする
for s, clip_sentence in enumerate(clip_sentences):
embedding = sentence_transformer.encode(clip_sentence['sentence'], convert_to_tensor=True)
for s_e in summary_embedings:
if util.pytorch_cos_sim(embedding, s_e) > sim_thr:
important_sentence_idxs[s] = True
# となりの文と接続する文を判別
print("Start identifying sentences that are connected to the sentence next to it")
def next_prob(prompt, next_sentence, b=1.2):
encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
logits = next_sentence_predict(**encoding, labels=torch.LongTensor([1])).logits
pos = b ** logits[0, 0]
neg = b ** logits[0, 1]
return float(pos / (pos + neg))
## 文が接続しているかどうかのフラグを格納するリストを初期化
connection_idxs = [False]*(len(clip_sentences)-1)
## 2つの文が連続しているかどうかを判定して、接続している場合はフラグをTrueにする
for s in range(len(clip_sentences)-1):
if next_prob(clip_sentences[s]['sentence'], clip_sentences[s+1]['sentence']) > 0.88:
connection_idxs[s] = True
# 要約後の文章のみ残す
def get_important_sentences(important_sentence_idxs, connection_idxs):
"""
重要な文のインデックスリストを返す
Parameters:
important_sentence_idxs (List[bool]): 要約文と一致する文のリスト
connection_idxs (List[bool]): となりの文と接続する文かどうかの判定のリスト
Returns:
important_idxs (List[bool]): 重要な文のリスト
"""
for i, val in enumerate(important_sentence_idxs):
if val:
# 右側の要素を確認して更新する
j = i
while j < len(connection_idxs) and connection_idxs[j]:
important_sentence_idxs[j + 1] = True
j += 1
# 左側の要素を確認して更新する
j = i - 1
while j >= 0 and connection_idxs[j]:
important_sentence_idxs[j] = True
j -= 1
important_idxs = important_sentence_idxs
return important_idxs
important_idxs = get_important_sentences(important_sentence_idxs, connection_idxs)
# 要約後の文章が元の文章のどこを抽出したのかを可視化
full_textt = "<h1 class='title'>Full Transcription</h1>"
## 重要な文であれば太字に、そうでなければ通常のフォントでHTML表現のテキストを生成
for idx in range(len(important_sentence_idxs)):
seconds = clip_sentences[idx]['duration'][0] * (1/playback_speed)
minutes, seconds = divmod(seconds, 60)
if important_idxs[idx]:
full_textt += '<p> <b>' + f"{int(minutes)}:{int(seconds):02} | {clip_sentences[idx]['sentence']} </b> </p>"
else:
full_textt += f"{int(minutes)}:{int(seconds):02} | {clip_sentences[idx]['sentence']}</p>"
print(full_textt)
# 動画を結合
print("Start combine movies")
clips = []
## 重要文であれば、その文の開始時間と終了時間からクリップを生成してリストに格納
for i in range(len(important_idxs)):
if important_idxs[i]:
tmp_clips = []
for j in range(len(clip_sentences[i]['sentences'])):
start_time, end_time = clip_sentences[i]['durations'][j][0], clip_sentences[i]['durations'][j][1]
if end_time > movie_clip.duration:
end_time = movie_clip.duration
if start_time > movie_clip.duration:
continue
clip = movie_clip.subclip(start_time, end_time)
clip = clip.set_pos("center").set_duration(end_time-start_time)
tmp_clips.append(clip)
clips.append(concatenate_videoclips(tmp_clips))
# クリップをクロスディゾルブで結合
# for c in range(len(clips)-1):
# fade_duration = 2
# clips[c] = clips[c].crossfadeout(fade_duration).audio_fadeout(fade_duration)
# clips[c+1] = clips[c+1].crossfadein(fade_duration).audio_fadein(fade_duration)
# 動画を結合し再生速度を変化させる
## クリップを連結する
final_video = concatenate_videoclips(clips, method="chain")
## オーディオをnumpy配列に変換
final_video_audio = np.array(final_video.audio.to_soundarray(fps=audio_sampling_rate))
## 再生速度を変更する
if playback_speed != 1:
final_video_audio_fixed = tsm.wsola(final_video_audio, 1/playback_speed).T
else:
final_video_audio_fixed = final_video_audio
## 動画の再生速度を変更し、オーディオを設定する
final_video = speedx(final_video, factor=playback_speed)
final_video = final_video.set_audio(AudioArrayClip(final_video_audio_fixed, fps=audio_sampling_rate))
# if final_video.duration > 30:
# final_video = final_video.subclip(0, 30)
## 動画をファイルに書き込む
final_video.write_videofile(output_path)
print(output_path)
print("Success summarize video")
log_firestore(db, local_id=str(local_id), action='SV', sim_thr=str(sim_thr), playback_speed=str(playback_speed))
return output_path, summary_text, full_textt