THUdyh commited on
Commit
40fcdeb
1 Parent(s): 4de08c8

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -6,16 +6,18 @@ from PIL import Image
6
  import numpy as np
7
  import transformers
8
  from typing import Dict, Optional, Sequence, List
9
- import spaces
 
 
10
 
11
  import sys
 
12
  from oryx.conversation import conv_templates, SeparatorStyle
13
  from oryx.model.builder import load_pretrained_model
14
  from oryx.utils import disable_torch_init
15
  from oryx.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, process_anyres_video_genli
16
  from oryx.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
17
 
18
-
19
  model_path = "THUdyh/Oryx-7B"
20
  model_name = get_model_name_from_path(model_path)
21
  overwrite_config = {}
@@ -78,7 +80,7 @@ def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_im
78
  targets = torch.tensor(targets, dtype=torch.long)
79
  return input_ids
80
 
81
- @spaces.GPU(duration=120)
82
  def oryx_inference(video, text):
83
  vr = VideoReader(video, ctx=cpu(0))
84
  total_frame_num = len(vr)
 
6
  import numpy as np
7
  import transformers
8
  from typing import Dict, Optional, Sequence, List
9
+
10
+ import subprocess
11
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
12
 
13
  import sys
14
+ # sys.path.append('/mnt/lzy/oryx-demo')
15
  from oryx.conversation import conv_templates, SeparatorStyle
16
  from oryx.model.builder import load_pretrained_model
17
  from oryx.utils import disable_torch_init
18
  from oryx.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, process_anyres_video_genli
19
  from oryx.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
20
 
 
21
  model_path = "THUdyh/Oryx-7B"
22
  model_name = get_model_name_from_path(model_path)
23
  overwrite_config = {}
 
80
  targets = torch.tensor(targets, dtype=torch.long)
81
  return input_ids
82
 
83
+
84
  def oryx_inference(video, text):
85
  vr = VideoReader(video, ctx=cpu(0))
86
  total_frame_num = len(vr)