{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "3e1b0206-2912-4385-97b9-5948ed70dfc8", "metadata": {}, "outputs": [], "source": [ "import cv2\n", "import mediapipe as mp #face detector\n", "import math\n", "import numpy as np\n", "import time\n", "\n", "import warnings\n", "warnings.simplefilter(\"ignore\", UserWarning)\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from PIL import Image\n", "from torchvision import transforms" ] }, { "cell_type": "markdown", "id": "a0907155", "metadata": {}, "source": [ "#### Model architectures" ] }, { "cell_type": "code", "execution_count": null, "id": "f67038e3", "metadata": {}, "outputs": [], "source": [ "class Bottleneck(nn.Module):\n", " expansion = 4\n", " def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):\n", " super(Bottleneck, self).__init__()\n", " \n", " self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=False)\n", " self.batch_norm1 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.99)\n", " \n", " self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same', bias=False)\n", " self.batch_norm2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.99)\n", " \n", " self.conv3 = nn.Conv2d(out_channels, out_channels*self.expansion, kernel_size=1, stride=1, padding=0, bias=False)\n", " self.batch_norm3 = nn.BatchNorm2d(out_channels*self.expansion, eps=0.001, momentum=0.99)\n", " \n", " self.i_downsample = i_downsample\n", " self.stride = stride\n", " self.relu = nn.ReLU()\n", " \n", " def forward(self, x):\n", " identity = x.clone()\n", " x = self.relu(self.batch_norm1(self.conv1(x)))\n", " \n", " x = self.relu(self.batch_norm2(self.conv2(x)))\n", " \n", " x = self.conv3(x)\n", " x = self.batch_norm3(x)\n", " \n", " #downsample if needed\n", " if self.i_downsample is not None:\n", " identity = self.i_downsample(identity)\n", " #add identity\n", " x+=identity\n", " x=self.relu(x)\n", " \n", " return x\n", "\n", "class Conv2dSame(torch.nn.Conv2d):\n", "\n", " def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int:\n", " return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)\n", "\n", " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " ih, iw = x.size()[-2:]\n", "\n", " pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0])\n", " pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1])\n", "\n", " if pad_h > 0 or pad_w > 0:\n", " x = F.pad(\n", " x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]\n", " )\n", " return F.conv2d(\n", " x,\n", " self.weight,\n", " self.bias,\n", " self.stride,\n", " self.padding,\n", " self.dilation,\n", " self.groups,\n", " )\n", "\n", "class ResNet(nn.Module):\n", " def __init__(self, ResBlock, layer_list, num_classes, num_channels=3):\n", " super(ResNet, self).__init__()\n", " self.in_channels = 64\n", "\n", " self.conv_layer_s2_same = Conv2dSame(num_channels, 64, 7, stride=2, groups=1, bias=False)\n", " self.batch_norm1 = nn.BatchNorm2d(64, eps=0.001, momentum=0.99)\n", " self.relu = nn.ReLU()\n", " self.max_pool = nn.MaxPool2d(kernel_size = 3, stride=2)\n", " \n", " self.layer1 = self._make_layer(ResBlock, layer_list[0], planes=64, stride=1)\n", " self.layer2 = self._make_layer(ResBlock, layer_list[1], planes=128, stride=2)\n", " self.layer3 = self._make_layer(ResBlock, layer_list[2], planes=256, stride=2)\n", " self.layer4 = self._make_layer(ResBlock, layer_list[3], planes=512, stride=2)\n", " \n", " self.avgpool = nn.AdaptiveAvgPool2d((1,1))\n", " self.fc1 = nn.Linear(512*ResBlock.expansion, 512)\n", " self.relu1 = nn.ReLU()\n", " self.fc2 = nn.Linear(512, num_classes)\n", "\n", " def extract_features(self, x):\n", " x = self.relu(self.batch_norm1(self.conv_layer_s2_same(x)))\n", " x = self.max_pool(x)\n", " # print(x.shape)\n", " x = self.layer1(x)\n", " x = self.layer2(x)\n", " x = self.layer3(x)\n", " x = self.layer4(x)\n", " \n", " x = self.avgpool(x)\n", " x = x.reshape(x.shape[0], -1)\n", " x = self.fc1(x)\n", " return x\n", " \n", " def forward(self, x):\n", " x = self.extract_features(x)\n", " x = self.relu1(x)\n", " x = self.fc2(x)\n", " return x\n", " \n", " def _make_layer(self, ResBlock, blocks, planes, stride=1):\n", " ii_downsample = None\n", " layers = []\n", " \n", " if stride != 1 or self.in_channels != planes*ResBlock.expansion:\n", " ii_downsample = nn.Sequential(\n", " nn.Conv2d(self.in_channels, planes*ResBlock.expansion, kernel_size=1, stride=stride, bias=False, padding=0),\n", " nn.BatchNorm2d(planes*ResBlock.expansion, eps=0.001, momentum=0.99)\n", " )\n", " \n", " layers.append(ResBlock(self.in_channels, planes, i_downsample=ii_downsample, stride=stride))\n", " self.in_channels = planes*ResBlock.expansion\n", " \n", " for i in range(blocks-1):\n", " layers.append(ResBlock(self.in_channels, planes))\n", " \n", " return nn.Sequential(*layers)\n", " \n", "def ResNet50(num_classes, channels=3):\n", " return ResNet(Bottleneck, [3,4,6,3], num_classes, channels)\n", "\n", "\n", "class LSTMPyTorch(nn.Module):\n", " def __init__(self):\n", " super(LSTMPyTorch, self).__init__()\n", " \n", " self.lstm1 = nn.LSTM(input_size=512, hidden_size=512, batch_first=True, bidirectional=False)\n", " self.lstm2 = nn.LSTM(input_size=512, hidden_size=256, batch_first=True, bidirectional=False)\n", " self.fc = nn.Linear(256, 7)\n", " self.softmax = nn.Softmax(dim=1)\n", "\n", " def forward(self, x):\n", " x, _ = self.lstm1(x)\n", " x, _ = self.lstm2(x) \n", " x = self.fc(x[:, -1, :])\n", " x = self.softmax(x)\n", " return x" ] }, { "cell_type": "markdown", "id": "fcbcf9fa-a7cc-4d4c-b723-6d7efd49b94b", "metadata": {}, "source": [ "#### Sub functions" ] }, { "cell_type": "code", "execution_count": 2, "id": "6d0fc324-98a8-4efc-bb11-4bec8a015790", "metadata": {}, "outputs": [], "source": [ "def pth_processing(fp):\n", " class PreprocessInput(torch.nn.Module):\n", " def init(self):\n", " super(PreprocessInput, self).init()\n", "\n", " def forward(self, x):\n", " x = x.to(torch.float32)\n", " x = torch.flip(x, dims=(0,))\n", " x[0, :, :] -= 91.4953\n", " x[1, :, :] -= 103.8827\n", " x[2, :, :] -= 131.0912\n", " return x\n", "\n", " def get_img_torch(img):\n", " \n", " ttransform = transforms.Compose([\n", " transforms.PILToTensor(),\n", " PreprocessInput()\n", " ])\n", " img = img.resize((224, 224), Image.Resampling.NEAREST)\n", " img = ttransform(img)\n", " img = torch.unsqueeze(img, 0)\n", " return img\n", " return get_img_torch(fp)\n", "\n", "def tf_processing(fp):\n", " def preprocess_input(x):\n", " x_temp = np.copy(x)\n", " x_temp = x_temp[..., ::-1]\n", " x_temp[..., 0] -= 91.4953\n", " x_temp[..., 1] -= 103.8827\n", " x_temp[..., 2] -= 131.0912\n", " return x_temp\n", "\n", " def get_img_tf(img):\n", " img = cv2.resize(img, (224,224), interpolation=cv2.INTER_NEAREST)\n", " img = tf.keras.utils.img_to_array(img)\n", " img = preprocess_input(img)\n", " img = np.array([img])\n", " return img\n", "\n", " return get_img_tf(fp)\n", "\n", "def norm_coordinates(normalized_x, normalized_y, image_width, image_height):\n", " \n", " x_px = min(math.floor(normalized_x * image_width), image_width - 1)\n", " y_px = min(math.floor(normalized_y * image_height), image_height - 1)\n", " \n", " return x_px, y_px\n", "\n", "def get_box(fl, w, h):\n", " idx_to_coors = {}\n", " for idx, landmark in enumerate(fl.landmark):\n", " landmark_px = norm_coordinates(landmark.x, landmark.y, w, h)\n", "\n", " if landmark_px:\n", " idx_to_coors[idx] = landmark_px\n", "\n", " x_min = np.min(np.asarray(list(idx_to_coors.values()))[:,0])\n", " y_min = np.min(np.asarray(list(idx_to_coors.values()))[:,1])\n", " endX = np.max(np.asarray(list(idx_to_coors.values()))[:,0])\n", " endY = np.max(np.asarray(list(idx_to_coors.values()))[:,1])\n", "\n", " (startX, startY) = (max(0, x_min), max(0, y_min))\n", " (endX, endY) = (min(w - 1, endX), min(h - 1, endY))\n", " \n", " return startX, startY, endX, endY\n", "\n", "def display_EMO_PRED(img, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255), line_width=2, ):\n", " lw = line_width or max(round(sum(img.shape) / 2 * 0.003), 2)\n", " text2_color = (255, 0, 255)\n", " p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))\n", " cv2.rectangle(img, p1, p2, text2_color, thickness=lw, lineType=cv2.LINE_AA)\n", " font = cv2.FONT_HERSHEY_SIMPLEX\n", "\n", " tf = max(lw - 1, 1)\n", " text_fond = (0, 0, 0)\n", " text_width_2, text_height_2 = cv2.getTextSize(label, font, lw / 3, tf)\n", " text_width_2 = text_width_2[0] + round(((p2[0] - p1[0]) * 10) / 360)\n", " center_face = p1[0] + round((p2[0] - p1[0]) / 2)\n", "\n", " cv2.putText(img, label,\n", " (center_face - round(text_width_2 / 2), p1[1] - round(((p2[0] - p1[0]) * 20) / 360)), font,\n", " lw / 3, text_fond, thickness=tf, lineType=cv2.LINE_AA)\n", " cv2.putText(img, label,\n", " (center_face - round(text_width_2 / 2), p1[1] - round(((p2[0] - p1[0]) * 20) / 360)), font,\n", " lw / 3, text2_color, thickness=tf, lineType=cv2.LINE_AA)\n", " return img\n", "\n", "def display_FPS(img, text, margin=1.0, box_scale=1.0):\n", " img_h, img_w, _ = img.shape\n", " line_width = int(min(img_h, img_w) * 0.001) # line width\n", " thickness = max(int(line_width / 3), 1) # font thickness\n", "\n", " font_face = cv2.FONT_HERSHEY_SIMPLEX\n", " font_color = (0, 0, 0)\n", " font_scale = thickness / 1.5\n", "\n", " t_w, t_h = cv2.getTextSize(text, font_face, font_scale, None)[0]\n", "\n", " margin_n = int(t_h * margin)\n", " sub_img = img[0 + margin_n: 0 + margin_n + t_h + int(2 * t_h * box_scale),\n", " img_w - t_w - margin_n - int(2 * t_h * box_scale): img_w - margin_n]\n", "\n", " white_rect = np.ones(sub_img.shape, dtype=np.uint8) * 255\n", "\n", " img[0 + margin_n: 0 + margin_n + t_h + int(2 * t_h * box_scale),\n", " img_w - t_w - margin_n - int(2 * t_h * box_scale):img_w - margin_n] = cv2.addWeighted(sub_img, 0.5, white_rect, .5,\n", " 1.0)\n", "\n", " cv2.putText(img=img,\n", " text=text,\n", " org=(img_w - t_w - margin_n - int(2 * t_h * box_scale) // 2,\n", " 0 + margin_n + t_h + int(2 * t_h * box_scale) // 2),\n", " fontFace=font_face,\n", " fontScale=font_scale,\n", " color=font_color,\n", " thickness=thickness,\n", " lineType=cv2.LINE_AA,\n", " bottomLeftOrigin=False)\n", "\n", " return img" ] }, { "cell_type": "markdown", "id": "bae915fd-cc3d-4dc1-83fc-c9c32e1b12a8", "metadata": {}, "source": [ "#### Testing models by webcam" ] }, { "cell_type": "code", "execution_count": 8, "id": "c05ed967-a30e-47f5-96ed-b32bab0c6879", "metadata": {}, "outputs": [], "source": [ "mp_face_mesh = mp.solutions.face_mesh\n", "\n", "name_backbone_model = 'FER_static_ResNet50_AffectNet.pt'\n", "# name_LSTM_model = 'IEMOCAP'\n", "# name_LSTM_model = 'CREMA-D'\n", "# name_LSTM_model = 'RAMAS'\n", "# name_LSTM_model = 'RAVDESS'\n", "# name_LSTM_model = 'SAVEE'\n", "name_LSTM_model = 'Aff-Wild2'\n", "\n", "# torch\n", "\n", "pth_backbone_model = ResNet50(7, channels=3)\n", "pth_backbone_model.load_state_dict(torch.load(name_backbone_model))\n", "pth_backbone_model.eval()\n", "\n", "pth_LSTM_model = LSTMPyTorch()\n", "pth_LSTM_model.load_state_dict(torch.load('FER_dinamic_LSTM_{0}.pt'.format(name_LSTM_model)))\n", "pth_LSTM_model.eval()\n", "\n", "\n", "DICT_EMO = {0: 'Neutral', 1: 'Happiness', 2: 'Sadness', 3: 'Surprise', 4: 'Fear', 5: 'Disgust', 6: 'Anger'}\n", "\n", "cap = cv2.VideoCapture(0)\n", "w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))\n", "h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))\n", "fps = np.round(cap.get(cv2.CAP_PROP_FPS))\n", "\n", "path_save_video = 'result.mp4'\n", "vid_writer = cv2.VideoWriter(path_save_video, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))\n", "\n", "lstm_features = []\n", " \n", "with mp_face_mesh.FaceMesh(\n", "max_num_faces=1,\n", "refine_landmarks=False,\n", "min_detection_confidence=0.5,\n", "min_tracking_confidence=0.5) as face_mesh:\n", "\n", " while cap.isOpened():\n", " t1 = time.time()\n", " success, frame = cap.read()\n", " if frame is None: break\n", "\n", " frame_copy = frame.copy()\n", " frame_copy.flags.writeable = False\n", " frame_copy = cv2.cvtColor(frame_copy, cv2.COLOR_BGR2RGB)\n", " results = face_mesh.process(frame_copy)\n", " frame_copy.flags.writeable = True\n", "\n", " if results.multi_face_landmarks:\n", " for fl in results.multi_face_landmarks:\n", " startX, startY, endX, endY = get_box(fl, w, h)\n", " cur_face = frame_copy[startY:endY, startX: endX]\n", " \n", " cur_face = pth_processing(Image.fromarray(cur_face))\n", " features = torch.nn.functional.relu(pth_backbone_model.extract_features(cur_face)).detach().numpy()\n", "\n", " if len(lstm_features) == 0:\n", " lstm_features = [features]*10\n", " else:\n", " lstm_features = lstm_features[1:] + [features]\n", "\n", " lstm_f = torch.from_numpy(np.vstack(lstm_features))\n", " lstm_f = torch.unsqueeze(lstm_f, 0)\n", " output = pth_LSTM_model(lstm_f).detach().numpy()\n", " \n", " cl = np.argmax(output)\n", " label = DICT_EMO[cl]\n", " frame = display_EMO_PRED(frame, (startX, startY, endX, endY), label+' {0:.1%}'.format(output[0][cl]), line_width=3)\n", "\n", " t2 = time.time()\n", "\n", " frame = display_FPS(frame, 'FPS: {0:.1f}'.format(1 / (t2 - t1)), box_scale=.5)\n", "\n", " vid_writer.write(frame)\n", " \n", " cv2.imshow('Webcam', frame)\n", " if cv2.waitKey(1) & 0xFF == ord('q'):\n", " break\n", "\n", " vid_writer.release()\n", " cap.release()\n", " cv2.destroyAllWindows()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 5 }