Shoubin commited on
Commit
0ea72e5
1 Parent(s): 3cd74d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -25
app.py CHANGED
@@ -34,23 +34,23 @@ image_size = img_size
34
  transform = transforms.Compose([ToUint8(), ToTHWC(), transforms_video.ToTensorVideo(), normalize])
35
 
36
  print('model loading')
37
- sevila = SeViLA(
38
- img_size=img_size,
39
- drop_path_rate=drop_path_rate,
40
- use_grad_checkpoint=use_grad_checkpoint,
41
- vit_precision=vit_precision,
42
- freeze_vit=freeze_vit,
43
- num_query_token=num_query_token,
44
- t5_model=t5_model,
45
- prompt=prompt,
46
- max_txt_len=max_txt_len,
47
- apply_lemmatizer=apply_lemmatizer,
48
- frame_num=4,
49
- answer_num=answer_num,
50
- task=task,
51
- )
52
 
53
- sevila.load_checkpoint(url_or_filename='https://huggingface.co/Shoubin/SeViLA/resolve/main/sevila_pretrained.pth')
54
  print('model loaded')
55
 
56
  ANS_MAPPING = {0 : 'A', 1 : 'B', 2 : 'C', 3 : 'D', 4 : 'E'}
@@ -68,11 +68,11 @@ def sevila_demo(video,
68
  else:
69
  device = 'cpu'
70
 
71
- global sevila
72
- if device == "cpu":
73
- sevila = sevila.float()
74
- else:
75
- sevila = sevila.to(int(device))
76
 
77
  vpath = video
78
  raw_clip, indice, fps, vlen = load_video_demo(
@@ -98,11 +98,11 @@ def sevila_demo(video,
98
  text_input_qa = 'Question: ' + question + ' ' + options + ' ' + QA_prompt
99
  text_input_loc = 'Question: ' + question + ' ' + options + ' ' + LOC_propmpt
100
 
101
- out = sevila.generate_demo(clip, text_input_qa, text_input_loc, int(keyframe_num))
102
  # print(out)
103
- answer_id = out['output_text'][0]
104
  answer = option_dict[answer_id]
105
- select_index = out['frame_idx'][0]
106
  # images = []
107
  keyframes = []
108
  timestamps =[]
@@ -170,7 +170,7 @@ with gr.Blocks(title="SeViLA demo") as demo:
170
  keyframe_num = gr.Textbox(placeholder=4, label='# Keyframe')
171
  # device = gr.Textbox(placeholder=0, label='Device')
172
  gen_btn = gr.Button(value='Locate and Answer!')
173
- with gr.Column(scale=1, min_width=100):
174
  keyframes = gr.Gallery(
175
  label="Keyframes", show_label=False, elem_id="gallery", max_width=100, max_height=100,
176
  ).style(columns=[4], rows=[1], object_fit="contain", height='auto')
 
34
  transform = transforms.Compose([ToUint8(), ToTHWC(), transforms_video.ToTensorVideo(), normalize])
35
 
36
  print('model loading')
37
+ # sevila = SeViLA(
38
+ # img_size=img_size,
39
+ # drop_path_rate=drop_path_rate,
40
+ # use_grad_checkpoint=use_grad_checkpoint,
41
+ # vit_precision=vit_precision,
42
+ # freeze_vit=freeze_vit,
43
+ # num_query_token=num_query_token,
44
+ # t5_model=t5_model,
45
+ # prompt=prompt,
46
+ # max_txt_len=max_txt_len,
47
+ # apply_lemmatizer=apply_lemmatizer,
48
+ # frame_num=4,
49
+ # answer_num=answer_num,
50
+ # task=task,
51
+ # )
52
 
53
+ # sevila.load_checkpoint(url_or_filename='https://huggingface.co/Shoubin/SeViLA/resolve/main/sevila_pretrained.pth')
54
  print('model loaded')
55
 
56
  ANS_MAPPING = {0 : 'A', 1 : 'B', 2 : 'C', 3 : 'D', 4 : 'E'}
 
68
  else:
69
  device = 'cpu'
70
 
71
+ # global sevila
72
+ # if device == "cpu":
73
+ # sevila = sevila.float()
74
+ # else:
75
+ # sevila = sevila.to(int(device))
76
 
77
  vpath = video
78
  raw_clip, indice, fps, vlen = load_video_demo(
 
98
  text_input_qa = 'Question: ' + question + ' ' + options + ' ' + QA_prompt
99
  text_input_loc = 'Question: ' + question + ' ' + options + ' ' + LOC_propmpt
100
 
101
+ # out = sevila.generate_demo(clip, text_input_qa, text_input_loc, int(keyframe_num))
102
  # print(out)
103
+ answer_id = 0 #out['output_text'][0]
104
  answer = option_dict[answer_id]
105
+ select_index = [1,2,3,4]#out['frame_idx'][0]
106
  # images = []
107
  keyframes = []
108
  timestamps =[]
 
170
  keyframe_num = gr.Textbox(placeholder=4, label='# Keyframe')
171
  # device = gr.Textbox(placeholder=0, label='Device')
172
  gen_btn = gr.Button(value='Locate and Answer!')
173
+ with gr.Column(scale=1, min_width=600):
174
  keyframes = gr.Gallery(
175
  label="Keyframes", show_label=False, elem_id="gallery", max_width=100, max_height=100,
176
  ).style(columns=[4], rows=[1], object_fit="contain", height='auto')