JefferyJapheth commited on
Commit
6ca127a
1 Parent(s): bd8ea4e

new prediction logic

Browse files
Files changed (1) hide show
  1. app.py +358 -0
app.py CHANGED
@@ -370,3 +370,361 @@ with mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=
370
  cap.release()
371
  cv2.destroyAllWindows()
372
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  cap.release()
371
  cv2.destroyAllWindows()
372
 
373
+ video_path = './Test/HAPPY.mp4'
374
+
375
+ cap = cv2.VideoCapture(video_path)
376
+
377
+ mp_drawing = mp.solutions.drawing_utils
378
+ mp_face_mesh = mp.solutions.face_mesh
379
+ mp_hands = mp.solutions.hands
380
+ mp_pose = mp.solutions.pose
381
+
382
+ data_list = []
383
+ ROWS_PER_FRAME = 543 # Constant number of landmarks per frame
384
+
385
+ with mp_face_mesh.FaceMesh(static_image_mode=False, max_num_faces=1) as face_mesh, \
386
+ mp_hands.Hands(static_image_mode=False, max_num_hands=2) as hands, \
387
+ mp_pose.Pose(static_image_mode=False) as pose:
388
+
389
+ frame_number = 0
390
+ while cap.isOpened():
391
+ ret, image = cap.read()
392
+ if not ret:
393
+ break
394
+
395
+ # Convert the BGR image to RGB for Mediapipe
396
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
397
+
398
+ # Process face landmarks
399
+ results_face = face_mesh.process(image_rgb)
400
+ if results_face.multi_face_landmarks:
401
+ face_landmarks = results_face.multi_face_landmarks[0]
402
+ for idx, landmark in enumerate(face_landmarks.landmark):
403
+ data_list.append([frame_number, f"{frame_number}-face-{idx}", "face", idx, landmark.x, landmark.y, landmark.z])
404
+
405
+ # Process hand landmarks
406
+ results_hands = hands.process(image_rgb)
407
+ if results_hands.multi_hand_landmarks:
408
+ for hand_landmarks in results_hands.multi_hand_landmarks:
409
+ for idx, landmark in enumerate(hand_landmarks.landmark):
410
+ data_list.append([frame_number, f"{frame_number}-right_hand-{idx}", "right-hand", idx, landmark.x, landmark.y, landmark.z])
411
+ mp_drawing.draw_landmarks(image, hand_landmarks, mp_hands.HAND_CONNECTIONS)
412
+
413
+ # Process pose landmarks
414
+ results_pose = pose.process(image_rgb)
415
+ if results_pose.pose_landmarks:
416
+ pose_landmarks = results_pose.pose_landmarks.landmark
417
+ for idx, landmark in enumerate(pose_landmarks):
418
+ data_list.append([frame_number, f"{frame_number}-pose-{idx}", "pose", idx, landmark.x, landmark.y, landmark.z])
419
+
420
+ # Pad the landmarks with NaN values if the number of landmarks is less than ROWS_PER_FRAME
421
+ while len(data_list) < (frame_number + 1) * ROWS_PER_FRAME:
422
+ data_list.append([frame_number, f"{frame_number}-right_hand-{len(data_list) % ROWS_PER_FRAME}", "right-hand", len(data_list) % ROWS_PER_FRAME, np.nan, np.nan, np.nan])
423
+
424
+ # Draw the landmarks on the frame (optional)
425
+ mp_drawing.draw_landmarks(image, face_landmarks, mp_face_mesh.FACEMESH_CONTOURS)
426
+ mp_drawing.draw_landmarks(image, results_pose.pose_landmarks, mp_pose.POSE_CONNECTIONS)
427
+
428
+ # Display the frame (optional)
429
+ cv2.imshow('MediaPipe', image)
430
+ frame_number += 1
431
+
432
+ # Press 'q' to quit
433
+ if cv2.waitKey(1) & 0xFF == ord('q'):
434
+ break
435
+
436
+ cap.release()
437
+ cv2.destroyAllWindows()
438
+
439
+ df = pd.DataFrame(data_list, columns=["frame", "row_id", "type", "landmark_index", "x", "y", "z"])
440
+ df.to_parquet("extracted_features.parquet", index=False)
441
+
442
+ # test_data = pd.read_parquet('./1006440534.parquet')
443
+ # test_data_kaggle = pd.read_parquet('1001373962.parquet')
444
+ # test_data_kaggle2 = pd.read_parquet('./100015657.parquet')
445
+ # test_data_kaggle3 = pd.read_parquet('./1003700302.parquet')
446
+ # test_data_kaggle4 = pd.read_parquet('./1007127288.parquet')
447
+ test_data_my_own = pd.read_parquet('extracted_features.parquet')
448
+ test_data_my_own['frame'] = test_data_my_own['frame'].astype('int16')
449
+ test_data_my_own['landmark_index'] = test_data_my_own['landmark_index'].astype('int16')
450
+
451
+
452
+
453
+ def load_relevant_data_subset(pq_path, ROWS_PER_FRAME = 543):
454
+ data_columns = ['x', 'y', 'z']
455
+ data = pd.read_parquet(pq_path, columns=data_columns)
456
+ n_frames = int( len(data) / ROWS_PER_FRAME)
457
+ print(f"Data: {len(data)} Number of Frames: {n_frames}")
458
+ data = data.values.reshape(n_frames, ROWS_PER_FRAME, len(data_columns))
459
+ return data.astype(np.float32)
460
+
461
+
462
+
463
+ # demo_raw_data = load_relevant_data_subset('./1006440534.parquet')
464
+ demo_raw_data = load_relevant_data_subset('./extracted_features.parquet')
465
+ # demo_raw_data = load_relevant_data_subset('./1003700302.parquet', test_data_kaggle3['frame'].nunique())
466
+ # demo_raw_data = load_relevant_data_subset('./extracted_features.parquet')
467
+
468
+ ORD2SIGN = {206: 'sticky',
469
+ 20: 'before',
470
+ 178: 'pretty',
471
+ 114: 'hen',
472
+ 221: 'tomorrow',
473
+ 230: 'up',
474
+ 25: 'blow',
475
+ 236: 'weus',
476
+ 184: 'read',
477
+ 191: 'say',
478
+ 248: 'zebra',
479
+ 189: 'sad',
480
+ 62: 'drawer',
481
+ 5: 'animal',
482
+ 167: 'pen',
483
+ 60: 'donkey',
484
+ 41: 'cheek',
485
+ 51: 'cowboy',
486
+ 192: 'scissors',
487
+ 181: 'quiet',
488
+ 63: 'drink',
489
+ 94: 'girl',
490
+ 200: 'sleepy',
491
+ 249: 'zipper',
492
+ 171: 'pig',
493
+ 13: 'bad',
494
+ 9: 'arm',
495
+ 61: 'down',
496
+ 123: 'if',
497
+ 240: 'why',
498
+ 166: 'pajamas',
499
+ 203: 'snow',
500
+ 137: 'loud',
501
+ 195: 'shirt',
502
+ 31: 'brown',
503
+ 146: 'moon',
504
+ 23: 'bird',
505
+ 210: 'sun',
506
+ 76: 'fast',
507
+ 1: 'after',
508
+ 54: 'cute',
509
+ 77: 'feet',
510
+ 4: 'alligator',
511
+ 87: 'food',
512
+ 113: 'hello',
513
+ 93: 'giraffe',
514
+ 180: 'puzzle',
515
+ 211: 'table',
516
+ 132: 'like',
517
+ 153: 'no',
518
+ 122: 'icecream',
519
+ 67: 'duck',
520
+ 69: 'elephant',
521
+ 141: 'many',
522
+ 18: 'bedroom',
523
+ 205: 'stay',
524
+ 74: 'fall',
525
+ 246: 'yourself',
526
+ 183: 'rain',
527
+ 135: 'listen',
528
+ 44: 'chocolate',
529
+ 124: 'into',
530
+ 11: 'awake',
531
+ 40: 'chair',
532
+ 7: 'any',
533
+ 155: 'nose',
534
+ 118: 'home',
535
+ 161: 'open',
536
+ 58: 'dog',
537
+ 50: 'cow',
538
+ 241: 'will',
539
+ 149: 'mouth',
540
+ 177: 'pretend',
541
+ 172: 'pizza',
542
+ 75: 'farm',
543
+ 163: 'outside',
544
+ 234: 'water',
545
+ 81: 'finish',
546
+ 159: 'old',
547
+ 121: 'hungry',
548
+ 112: 'helicopter',
549
+ 130: 'lamp',
550
+ 222: 'tongue',
551
+ 194: 'shhh',
552
+ 6: 'another',
553
+ 103: 'gum',
554
+ 214: 'thankyou',
555
+ 128: 'kiss',
556
+ 101: 'grass',
557
+ 64: 'drop',
558
+ 157: 'now',
559
+ 233: 'wake',
560
+ 116: 'hide',
561
+ 201: 'smile',
562
+ 226: 'toy',
563
+ 216: 'there',
564
+ 147: 'morning',
565
+ 10: 'aunt',
566
+ 102: 'green',
567
+ 36: 'car',
568
+ 213: 'taste',
569
+ 39: 'cereal',
570
+ 207: 'store',
571
+ 66: 'dryer',
572
+ 162: 'orange',
573
+ 218: 'thirsty',
574
+ 83: 'first',
575
+ 45: 'clean',
576
+ 3: 'all',
577
+ 198: 'sick',
578
+ 129: 'kitty',
579
+ 96: 'glasswindow',
580
+ 202: 'snack',
581
+ 150: 'nap',
582
+ 53: 'cut',
583
+ 73: 'face',
584
+ 99: 'grandma',
585
+ 209: 'stuck',
586
+ 91: 'garbage',
587
+ 115: 'hesheit',
588
+ 95: 'give',
589
+ 104: 'hair',
590
+ 125: 'jacket',
591
+ 165: 'owl',
592
+ 82: 'fireman',
593
+ 227: 'tree',
594
+ 16: 'because',
595
+ 17: 'bed',
596
+ 30: 'brother',
597
+ 143: 'minemy',
598
+ 127: 'jump',
599
+ 245: 'yesterday',
600
+ 145: 'mom',
601
+ 111: 'hear',
602
+ 174: 'police',
603
+ 223: 'tooth',
604
+ 212: 'talk',
605
+ 224: 'toothbrush',
606
+ 164: 'owie',
607
+ 47: 'closet',
608
+ 169: 'penny',
609
+ 24: 'black',
610
+ 85: 'flag',
611
+ 238: 'white',
612
+ 134: 'lips',
613
+ 231: 'vacuum',
614
+ 8: 'apple',
615
+ 105: 'happy',
616
+ 151: 'napkin',
617
+ 92: 'gift',
618
+ 70: 'empty',
619
+ 46: 'close',
620
+ 52: 'cry',
621
+ 138: 'mad',
622
+ 49: 'clown',
623
+ 204: 'stairs',
624
+ 42: 'child',
625
+ 173: 'please',
626
+ 65: 'dry',
627
+ 72: 'eye',
628
+ 235: 'wet',
629
+ 32: 'bug',
630
+ 109: 'haveto',
631
+ 228: 'uncle',
632
+ 199: 'sleep',
633
+ 176: 'potty',
634
+ 29: 'boy',
635
+ 136: 'look',
636
+ 107: 'hate',
637
+ 71: 'every',
638
+ 12: 'backyard',
639
+ 22: 'better',
640
+ 84: 'fish',
641
+ 56: 'dance',
642
+ 139: 'make',
643
+ 98: 'goose',
644
+ 38: 'cat',
645
+ 232: 'wait',
646
+ 14: 'balloon',
647
+ 247: 'yucky',
648
+ 2: 'airplane',
649
+ 88: 'for',
650
+ 126: 'jeans',
651
+ 154: 'noisy',
652
+ 142: 'milk',
653
+ 239: 'who',
654
+ 90: 'frog',
655
+ 35: 'can',
656
+ 215: 'that',
657
+ 117: 'high',
658
+ 244: 'yes',
659
+ 196: 'shoe',
660
+ 108: 'have',
661
+ 48: 'cloud',
662
+ 170: 'person',
663
+ 187: 'ride',
664
+ 34: 'callonphone',
665
+ 37: 'carrot',
666
+ 100: 'grandpa',
667
+ 120: 'hot',
668
+ 131: 'later',
669
+ 229: 'underwear',
670
+ 0: 'TV',
671
+ 140: 'man',
672
+ 217: 'think',
673
+ 220: 'time',
674
+ 80: 'finger',
675
+ 86: 'flower',
676
+ 15: 'bath',
677
+ 28: 'book',
678
+ 193: 'see',
679
+ 208: 'story',
680
+ 26: 'blue',
681
+ 78: 'find',
682
+ 148: 'mouse',
683
+ 79: 'fine',
684
+ 179: 'puppy',
685
+ 55: 'dad',
686
+ 21: 'beside',
687
+ 225: 'touch',
688
+ 89: 'frenchfries',
689
+ 188: 'room',
690
+ 19: 'bee',
691
+ 27: 'boat',
692
+ 156: 'not',
693
+ 59: 'doll',
694
+ 97: 'go',
695
+ 190: 'same',
696
+ 144: 'mitten',
697
+ 160: 'on',
698
+ 57: 'dirty',
699
+ 182: 'radio',
700
+ 197: 'shower',
701
+ 186: 'refrigerator',
702
+ 158: 'nuts',
703
+ 175: 'pool',
704
+ 242: 'wolf',
705
+ 243: 'yellow',
706
+ 110: 'head',
707
+ 237: 'where',
708
+ 33: 'bye',
709
+ 133: 'lion',
710
+ 152: 'night',
711
+ 106: 'hat',
712
+ 43: 'chin',
713
+ 68: 'ear',
714
+ 168: 'pencil',
715
+ 119: 'horse',
716
+ 219: 'tiger',
717
+ 185: 'red'}
718
+
719
+ import tflite_runtime.interpreter as tflite
720
+
721
+ interpreter = tflite.Interpreter("./model.tflite")
722
+ found_signatures = list(interpreter.get_signature_list().keys())
723
+ prediction_fn = interpreter.get_signature_runner("serving_default")
724
+
725
+ prediction_fn(inputs=demo_raw_data)
726
+
727
+
728
+ output = prediction_fn(inputs=demo_raw_data)
729
+ sign = output['outputs'].argmax()
730
+ print("PRED : ", ORD2SIGN.get(sign), f'[{sign}]')