mbesinci commited on
Commit
02bde05
·
verified ·
1 Parent(s): ade7875

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -25
app.py CHANGED
@@ -1,46 +1,55 @@
1
  import gradio as gr
2
  import cv2
3
  import torch
4
- from transformers import AutoModelForImageClassification, AutoFeatureExtractor
5
 
6
- # Modeli yükleyin (bu örnek bir modeldir, uygun bir model bulun)
7
- model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224") # Modeli değiştirin
8
- feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
 
9
 
10
- # Görüntüyü modele uygun hale getirmek için yardımcı fonksiyon
11
- def preprocess_image(image):
12
- inputs = feature_extractor(images=image, return_tensors="pt")
13
  return inputs['pixel_values']
14
 
15
- # Video işlemi fonksiyonu
16
- def correct_gaze_in_video(video_path):
17
  # Video dosyasını aç
18
  cap = cv2.VideoCapture(video_path)
19
  fps = cap.get(cv2.CAP_PROP_FPS)
20
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
21
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
22
-
23
  # Çıkış videosu için ayarlar
24
- output_path = "corrected_gaze_output.mp4"
25
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
26
  out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
27
-
 
28
  while cap.isOpened():
29
  ret, frame = cap.read()
30
  if not ret:
31
  break
32
 
33
- # Göz temasını düzeltmek için her kareyi işleme
34
- inputs = preprocess_image(frame)
35
- with torch.no_grad():
36
- outputs = model(pixel_values=inputs)
37
 
38
- # Modelin çıktısını işleyin (örneğin, doğrudan göz temasına göre değişiklik yapma)
39
- # corrected_frame = işlemler burada yapılabilir
40
-
41
- # Çıkışı video dosyasına ekleyin
42
- out.write(frame) # corrected_frame olarak değiştirin
43
-
 
 
 
 
 
 
 
 
 
 
 
44
  # Kaynakları serbest bırak
45
  cap.release()
46
  out.release()
@@ -49,11 +58,11 @@ def correct_gaze_in_video(video_path):
49
 
50
  # Gradio arayüzü
51
  iface = gr.Interface(
52
- fn=correct_gaze_in_video,
53
  inputs="file",
54
  outputs="file",
55
- title="Gaze Correction in Video",
56
- description="Upload a video to correct gaze direction."
57
  )
58
 
59
  iface.launch()
 
1
  import gradio as gr
2
  import cv2
3
  import torch
4
+ from transformers import VideoMAEForVideoClassification, AutoFeatureExtractor
5
 
6
+ # Göz teması algılama modeli ve özellik çıkarıcıyı yükleyin
7
+ model_name = "kanlo/videomae-base-ASD_Eye_Contact_v1"
8
+ model = VideoMAEForVideoClassification.from_pretrained(model_name)
9
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
10
 
11
+ def preprocess_frames(frames):
12
+ # Her kareyi modele uygun şekilde işleyin
13
+ inputs = feature_extractor(frames, return_tensors="pt")
14
  return inputs['pixel_values']
15
 
16
+ def detect_eye_contact(video_path):
 
17
  # Video dosyasını aç
18
  cap = cv2.VideoCapture(video_path)
19
  fps = cap.get(cv2.CAP_PROP_FPS)
20
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
21
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
22
+
23
  # Çıkış videosu için ayarlar
24
+ output_path = "eye_contact_output.mp4"
25
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
26
  out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
27
+
28
+ frames = []
29
  while cap.isOpened():
30
  ret, frame = cap.read()
31
  if not ret:
32
  break
33
 
34
+ frames.append(frame)
 
 
 
35
 
36
+ # Modeli belirli bir aralıkta çalıştırarak göz temasını algılayın
37
+ if len(frames) >= 16: # 16 karede bir işlem yapıyoruz
38
+ inputs = preprocess_frames(frames)
39
+ with torch.no_grad():
40
+ outputs = model(pixel_values=inputs)
41
+ prediction = outputs.logits.argmax(-1).item()
42
+
43
+ # Göz teması varsa çerçeveye ek açıklama ekleyin
44
+ for frame in frames:
45
+ if prediction == 1: # 1 göz teması var anlamına gelir (modelde böyle olduğunu varsayıyoruz)
46
+ cv2.putText(frame, "Eye Contact", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
47
+ else:
48
+ cv2.putText(frame, "No Eye Contact", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
49
+ out.write(frame)
50
+
51
+ frames = [] # 16 karelik grubu işledikten sonra sıfırlayın
52
+
53
  # Kaynakları serbest bırak
54
  cap.release()
55
  out.release()
 
58
 
59
  # Gradio arayüzü
60
  iface = gr.Interface(
61
+ fn=detect_eye_contact,
62
  inputs="file",
63
  outputs="file",
64
+ title="Eye Contact Detection in Video",
65
+ description="Upload a video to detect eye contact in each frame."
66
  )
67
 
68
  iface.launch()