LipNet / lipnet /model.py
thienphuc12339's picture
Upload 10 files
6d5d850 verified
import logging
from typing import Iterable, Optional
import tensorflow as tf
from . import config
from .preprocessing import VideoPreprocessor
logger = logging.getLogger(__name__)
def _configure_tensorflow() -> None:
"""
Apply lightweight TensorFlow runtime tweaks to avoid noisy logs and GPU OOMs.
"""
try:
tf.get_logger().setLevel(logging.ERROR)
gpus = tf.config.list_physical_devices("GPU")
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
except Exception as exc:
logger.debug("TensorFlow runtime configuration skipped: %s", exc)
class LipReadingModel:
def __init__(self, model_path: str = str(config.MODEL_PATH)):
# Initialize character mappings before loading the model
vocab_chars = (
"aa\u0192bcdde\u02c6ghiklmno\u201copqrstuuvxy\u00a0\u2026?"
"a???????????\u201a\u0160????????\u00a1\u008d?i?\u00a2\u2022?o???????????\u00a3\u2014?u??????y????'?!123456789 "
)
vocab = []
seen = set()
for ch in vocab_chars:
if ch not in seen:
seen.add(ch)
vocab.append(ch)
self.char_to_num = tf.keras.layers.StringLookup(vocabulary=vocab, oov_token="")
self.num_to_char = tf.keras.layers.StringLookup(
vocabulary=self.char_to_num.get_vocabulary(), oov_token="", invert=True
)
_configure_tensorflow()
try:
self.model = tf.keras.models.load_model(
model_path,
custom_objects={"CTCLoss": self.CTCLoss},
)
logger.info("Model loaded successfully from %s", model_path)
except Exception as exc:
logger.error("Error loading model from %s: %s", model_path, exc)
self.model = self.build_model()
@staticmethod
def CTCLoss(y_true, y_pred):
batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")
input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")
return tf.keras.backend.ctc_batch_cost(y_true, y_pred, input_length, label_length)
def build_model(self):
model = tf.keras.Sequential()
model.add(tf.keras.layers.Conv3D(64, (3, 3, 3), strides=(1, 2, 2), input_shape=(None, config.TARGET_SIZE, config.TARGET_SIZE, 1), padding="same"))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Activation("relu"))
model.add(tf.keras.layers.MaxPool3D((1, 2, 2), padding="same"))
model.add(tf.keras.layers.Conv3D(128, (3, 3, 3), strides=(1, 2, 2), padding="same"))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Activation("relu"))
model.add(tf.keras.layers.MaxPool3D((1, 2, 2), padding="same"))
model.add(tf.keras.layers.Conv3D(256, (3, 3, 3), strides=(1, 2, 2), padding="same"))
model.add(tf.keras.layers.LayerNormalization())
model.add(tf.keras.layers.Activation("relu"))
model.add(tf.keras.layers.MaxPool3D((1, 2, 2), padding="same"))
model.add(tf.keras.layers.Conv3D(256, (3, 3, 3), padding="same"))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Activation("relu"))
model.add(tf.keras.layers.MaxPool3D((1, 2, 2), padding="same"))
model.add(tf.keras.layers.TimeDistributed(tf.keras.layers.Flatten()))
model.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(512, kernel_initializer="Orthogonal", return_sequences=True)))
model.add(tf.keras.layers.Dropout(0.4))
model.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(256, kernel_initializer="Orthogonal", return_sequences=True)))
model.add(tf.keras.layers.Dropout(0.4))
model.add(tf.keras.layers.Dense(self.char_to_num.vocabulary_size() + 1, kernel_initializer="he_normal", activation="softmax"))
logger.info("Built the fallback model architecture.")
return model
def predict(self, normalized_frames: Optional[tf.Tensor]):
if self.model is None:
return "? Model not loaded. Please check the model path and ensure the model file is accessible."
if normalized_frames is None:
return "? No frames extracted from the video. Please ensure the video contains a clear view of the face and lips."
if int(tf.size(normalized_frames)) == 0:
return "? No frames extracted from the video. Please ensure the video contains a clear view of the face and lips."
try:
frames = tf.expand_dims(normalized_frames, axis=0)
yhat = self.model.predict(frames, verbose=0)
input_length = [yhat.shape[1]]
decoded_tf = tf.keras.backend.ctc_decode(yhat, input_length=input_length, greedy=True)[0][0]
decoded = decoded_tf.numpy().flatten()
prediction = "".join(
[
self.num_to_char(int(num)).numpy().decode("utf-8")
for num in decoded
if int(num) != -1
]
)
return prediction.strip()
except Exception as exc:
logger.error("Error during prediction: %s", exc)
return f"? An error occurred during prediction: {exc}"
def predict_from_video(
video_path: Optional[str] = None,
frames: Optional[Iterable] = None,
model: Optional[LipReadingModel] = None,
preprocessor: Optional[VideoPreprocessor] = None,
):
"""
Predicts the text from a video file or webcam frames using the provided model.
"""
if model is None:
model = LipReadingModel()
if preprocessor is None:
preprocessor = VideoPreprocessor()
if video_path:
normalized_frames = preprocessor.preprocess_video(video_path)
elif frames is not None:
normalized_frames = preprocessor.preprocess_frames(frames)
else:
return "? No video or frames provided for prediction."
if normalized_frames is None:
return "? Unable to extract frames from the provided video."
return model.predict(normalized_frames)