vumichien commited on
Commit
40da08b
1 Parent(s): baf282d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -13
app.py CHANGED
@@ -22,8 +22,6 @@ sys.path.append('/home/user/app/av_hubert/avhubert')
22
  print(sys.path)
23
  print(os.listdir())
24
 
25
- from fairseq import checkpoint_utils, options, tasks, utils
26
- from argparse import Namespace
27
 
28
 
29
 
@@ -46,13 +44,24 @@ from huggingface_hub import hf_hub_download
46
  import gradio as gr
47
 
48
  user_dir = "/home/user/app/av_hubert/avhubert"
 
 
 
49
  ckpt_path = hf_hub_download('vumichien/AV-HuBERT', 'model.pt')
50
  face_detector_path = "/home/user/app/mmod_human_face_detector.dat"
51
  face_predictor_path = "/home/user/app/shape_predictor_68_face_landmarks.dat"
52
  mean_face_path = "/home/user/app/20words_mean_face.npy"
53
  mouth_roi_path = "/home/user/app/roi.mp4"
 
 
 
54
  models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
55
- utils.import_user_module(Namespace(user_dir=user_dir))
 
 
 
 
 
56
 
57
  def detect_landmark(image, detector, predictor):
58
  gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
@@ -100,17 +109,7 @@ def predict(process_video):
100
  fo.write("".join(tsv_cont))
101
  with open(f"{data_dir}/test.wrd", "w") as fo:
102
  fo.write("".join(label_cont))
103
- modalities = ["video"]
104
- gen_subset = "test"
105
- gen_cfg = GenerationConfig(beam=20)
106
-
107
- models = [model.eval().cuda() if torch.cuda.is_available() else model.eval() for model in models]
108
- saved_cfg.task.modalities = modalities
109
- saved_cfg.task.data = data_dir
110
- saved_cfg.task.label_dir = data_dir
111
- task = tasks.setup_task(saved_cfg.task)
112
  task.load_dataset(gen_subset, task_cfg=saved_cfg.task)
113
- generator = task.build_generator(models, gen_cfg)
114
 
115
  def decode_fn(x):
116
  dictionary = task.target_dictionary
 
22
  print(sys.path)
23
  print(os.listdir())
24
 
 
 
25
 
26
 
27
 
 
44
  import gradio as gr
45
 
46
  user_dir = "/home/user/app/av_hubert/avhubert"
47
+ utils.import_user_module(Namespace(user_dir=user_dir))
48
+ data_dir = tempfile.mkdtemp()
49
+
50
  ckpt_path = hf_hub_download('vumichien/AV-HuBERT', 'model.pt')
51
  face_detector_path = "/home/user/app/mmod_human_face_detector.dat"
52
  face_predictor_path = "/home/user/app/shape_predictor_68_face_landmarks.dat"
53
  mean_face_path = "/home/user/app/20words_mean_face.npy"
54
  mouth_roi_path = "/home/user/app/roi.mp4"
55
+ modalities = ["video"]
56
+ gen_subset = "test"
57
+ gen_cfg = GenerationConfig(beam=20)
58
  models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
59
+ models = [model.eval().cuda() if torch.cuda.is_available() else model.eval() for model in models]
60
+ saved_cfg.task.modalities = modalities
61
+ saved_cfg.task.data = data_dir
62
+ saved_cfg.task.label_dir = data_dir
63
+ task = tasks.setup_task(saved_cfg.task)
64
+ generator = task.build_generator(models, gen_cfg)
65
 
66
  def detect_landmark(image, detector, predictor):
67
  gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
 
109
  fo.write("".join(tsv_cont))
110
  with open(f"{data_dir}/test.wrd", "w") as fo:
111
  fo.write("".join(label_cont))
 
 
 
 
 
 
 
 
 
112
  task.load_dataset(gen_subset, task_cfg=saved_cfg.task)
 
113
 
114
  def decode_fn(x):
115
  dictionary = task.target_dictionary