Indikacr commited on
Commit
5866842
1 Parent(s): 4e71b12

Create app_gradio.py

Browse files
Files changed (1) hide show
  1. app_gradio.py +0 -91
app_gradio.py CHANGED
@@ -1,91 +0,0 @@
1
- import pickle
2
- import cv2
3
- import mediapipe as mp
4
- import numpy as np
5
- import matplotlib.pyplot as plt
6
- import gradio as gr
7
- import time
8
- from PIL import Image
9
- import random
10
-
11
- # Load model once globally
12
- model_dict = pickle.load(open('./model.p', 'rb'))
13
- model = model_dict['model']
14
-
15
- mp_hands = mp.solutions.hands
16
- mp_drawing = mp.solutions.drawing_utils
17
- mp_drawing_styles = mp.solutions.drawing_styles
18
-
19
- # Function to process a single frame
20
- def predict(input):
21
- hands = mp_hands.Hands(static_image_mode=True, min_detection_confidence=0.3)
22
- labels_dict = {0: 'L', 1: 'A', 2: 'B', 3: 'C', 4: 'V', 5: 'W', 6: 'Y'}
23
- data_aux = []
24
- frame_rgb = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
25
- results = hands.process(frame_rgb)
26
-
27
- if results.multi_hand_landmarks:
28
- for hand_landmarks in results.multi_hand_landmarks:
29
- mp_drawing.draw_landmarks(
30
- input,
31
- hand_landmarks,
32
- mp_hands.HAND_CONNECTIONS,
33
- mp_drawing_styles.get_default_hand_landmarks_style(),
34
- mp_drawing_styles.get_default_hand_connections_style())
35
-
36
- for i in range(len(hand_landmarks.landmark)):
37
- x = hand_landmarks.landmark[i].x
38
- y = hand_landmarks.landmark[i].y
39
- data_aux.append(x)
40
- data_aux.append(y)
41
-
42
- if data_aux:
43
- prediction = model.predict([np.asarray(data_aux)])
44
- predicted_character = labels_dict[int(prediction[0])]
45
- output = cv2.putText(input, f'Sign: {predicted_character}', (5, 30), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 255, 255), 3, cv2.LINE_AA)
46
- output_image = Image.fromarray(cv2.cvtColor(output, cv2.COLOR_BGR2RGB))
47
- return output_image, predicted_character
48
- else:
49
- return Image.fromarray(cv2.cvtColor(input, cv2.COLOR_BGR2RGB)), "No hand landmarks detected in the current frame."
50
-
51
- # Function to process video
52
- def vid_inf(vid, progress=gr.Progress()):
53
- cap = cv2.VideoCapture(vid)
54
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
55
- processed_frames = 0
56
- writer = None
57
- character_output = []
58
- tmpname = f"output_{random.randint(111111111, 999999999)}.mp4"
59
-
60
- while cap.isOpened():
61
- ret, frame = cap.read()
62
- if not ret:
63
- break
64
-
65
- if writer is None:
66
- height, width, _ = frame.shape
67
- writer = cv2.VideoWriter(tmpname, cv2.VideoWriter_fourcc(*'mp4v'), cap.get(cv2.CAP_PROP_FPS), (width, height))
68
-
69
- processed_frames += 1
70
- progress(processed_frames / total_frames, desc=f"Processing frame {processed_frames}/{total_frames}")
71
- out, character = predict(frame)
72
- writer.write(cv2.cvtColor(np.array(out), cv2.COLOR_RGB2BGR))
73
- character_output.append(character)
74
-
75
- cap.release()
76
- writer.release()
77
- return tmpname, character_output
78
-
79
- input_video = gr.Video(sources=["webcam", "upload"], label="Input Video")
80
- output_video = gr.Video(label="Processed Video")
81
- output_character = gr.Textbox(label="Sign Sequence")
82
-
83
- # Create Gradio Interface for Video Inference
84
- interface_video = gr.Interface(
85
- fn=vid_inf,
86
- inputs=[input_video],
87
- outputs=[output_video, output_character],
88
- title="Video Inference"
89
- )
90
-
91
- interface_video.launch(share=True)