BoboiAzumi commited on
Commit
f13c42f
·
1 Parent(s): 2b1a928
Files changed (2) hide show
  1. process.py +33 -24
  2. requirements.txt +25 -8
process.py CHANGED
@@ -1,32 +1,28 @@
1
  import io
2
- import spaces
3
-
4
  import argparse
5
  import numpy as np
6
  import torch
7
  from decord import cpu, VideoReader, bridge
8
  from transformers import AutoModelForCausalLM, AutoTokenizer
9
-
10
 
11
  MODEL_PATH = "THUDM/cogvlm2-llama3-caption"
12
-
13
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
14
  TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[
15
  0] >= 8 else torch.float16
16
 
17
  parser = argparse.ArgumentParser(description="CogVLM2-Video CLI Demo")
18
- parser.add_argument('--quant', type=int, choices=[4, 8], help='Enable 4-bit or 8-bit precision loading', default=0)
19
  args = parser.parse_args([])
20
 
21
-
22
  def load_video(video_data, strategy='chat'):
23
  bridge.set_bridge('torch')
24
  mp4_stream = video_data
25
  num_frames = 24
26
  decord_vr = VideoReader(io.BytesIO(mp4_stream), ctx=cpu(0))
27
-
28
  frame_id_list = None
29
  total_frames = len(decord_vr)
 
30
  if strategy == 'base':
31
  clip_end_sec = 60
32
  clip_start_sec = 0
@@ -45,11 +41,18 @@ def load_video(video_data, strategy='chat'):
45
  frame_id_list.append(index)
46
  if len(frame_id_list) >= num_frames:
47
  break
48
-
49
  video_data = decord_vr.get_batch(frame_id_list)
50
  video_data = video_data.permute(3, 0, 1, 2)
51
  return video_data
52
 
 
 
 
 
 
 
 
53
 
54
  tokenizer = AutoTokenizer.from_pretrained(
55
  MODEL_PATH,
@@ -59,11 +62,14 @@ tokenizer = AutoTokenizer.from_pretrained(
59
  model = AutoModelForCausalLM.from_pretrained(
60
  MODEL_PATH,
61
  torch_dtype=TORCH_TYPE,
62
- trust_remote_code=True
63
- ).eval().to(DEVICE)
 
 
64
 
65
- @spaces.GPU
66
- def predict(prompt, video, temperature, strategy):
 
67
  history = []
68
  query = prompt
69
  inputs = model.build_conversation_input_ids(
@@ -73,31 +79,34 @@ def predict(prompt, video, temperature, strategy):
73
  history=history,
74
  template_version=strategy
75
  )
 
76
  inputs = {
77
- 'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),
78
- 'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),
79
- 'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'),
80
- 'images': [[inputs['images'][0].to('cuda').to(TORCH_TYPE)]],
81
  }
 
82
  gen_kwargs = {
83
  "max_new_tokens": 2048,
84
  "pad_token_id": 128002,
85
  "top_k": 1,
86
- "do_sample": True,
87
  "top_p": 0.1,
88
  "temperature": temperature,
89
  }
 
90
  with torch.no_grad():
91
  outputs = model.generate(**inputs, **gen_kwargs)
92
  outputs = outputs[:, inputs['input_ids'].shape[1]:]
93
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
94
  return response
95
 
96
-
97
- def inference(video, prompt):
98
  temperature = 0.1
99
- video = open(video, 'rb').read()
100
- strategy = 'base'
101
- video_data = load_video(video, strategy=strategy)
102
- response = predict(prompt, video_data, temperature, strategy)
103
- return response
 
 
1
  import io
 
 
2
  import argparse
3
  import numpy as np
4
  import torch
5
  from decord import cpu, VideoReader, bridge
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
+ from transformers import BitsAndBytesConfig
8
 
9
  MODEL_PATH = "THUDM/cogvlm2-llama3-caption"
 
10
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
11
  TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[
12
  0] >= 8 else torch.float16
13
 
14
  parser = argparse.ArgumentParser(description="CogVLM2-Video CLI Demo")
15
+ parser.add_argument('--quant', type=int, choices=[4, 8], help='Enable 4-bit or 8-bit precision loading', default=4)
16
  args = parser.parse_args([])
17
 
 
18
  def load_video(video_data, strategy='chat'):
19
  bridge.set_bridge('torch')
20
  mp4_stream = video_data
21
  num_frames = 24
22
  decord_vr = VideoReader(io.BytesIO(mp4_stream), ctx=cpu(0))
 
23
  frame_id_list = None
24
  total_frames = len(decord_vr)
25
+
26
  if strategy == 'base':
27
  clip_end_sec = 60
28
  clip_start_sec = 0
 
41
  frame_id_list.append(index)
42
  if len(frame_id_list) >= num_frames:
43
  break
44
+
45
  video_data = decord_vr.get_batch(frame_id_list)
46
  video_data = video_data.permute(3, 0, 1, 2)
47
  return video_data
48
 
49
+ # Configure quantization
50
+ quantization_config = BitsAndBytesConfig(
51
+ load_in_4bit=True,
52
+ bnb_4bit_compute_dtype=TORCH_TYPE,
53
+ bnb_4bit_use_double_quant=True,
54
+ bnb_4bit_quant_type="nf4"
55
+ )
56
 
57
  tokenizer = AutoTokenizer.from_pretrained(
58
  MODEL_PATH,
 
62
  model = AutoModelForCausalLM.from_pretrained(
63
  MODEL_PATH,
64
  torch_dtype=TORCH_TYPE,
65
+ trust_remote_code=True,
66
+ quantization_config=quantization_config,
67
+ device_map="auto"
68
+ ).eval()
69
 
70
+ def predict(prompt, video_data, temperature):
71
+ strategy = 'chat'
72
+ video = load_video(video_data, strategy=strategy)
73
  history = []
74
  query = prompt
75
  inputs = model.build_conversation_input_ids(
 
79
  history=history,
80
  template_version=strategy
81
  )
82
+
83
  inputs = {
84
+ 'input_ids': inputs['input_ids'].unsqueeze(0).to(DEVICE),
85
+ 'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to(DEVICE),
86
+ 'attention_mask': inputs['attention_mask'].unsqueeze(0).to(DEVICE),
87
+ 'images': [[inputs['images'][0].to(DEVICE).to(TORCH_TYPE)]],
88
  }
89
+
90
  gen_kwargs = {
91
  "max_new_tokens": 2048,
92
  "pad_token_id": 128002,
93
  "top_k": 1,
94
+ "do_sample": False,
95
  "top_p": 0.1,
96
  "temperature": temperature,
97
  }
98
+
99
  with torch.no_grad():
100
  outputs = model.generate(**inputs, **gen_kwargs)
101
  outputs = outputs[:, inputs['input_ids'].shape[1]:]
102
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
103
  return response
104
 
105
+ def test(video, prompt):
 
106
  temperature = 0.1
107
+ video_data = open(video, 'rb').read()
108
+ response = predict(prompt, video_data, temperature)
109
+ print(response)
110
+
111
+ if __name__ == '__main__':
112
+ test()
requirements.txt CHANGED
@@ -1,8 +1,25 @@
1
- argparse
2
- numpy
3
- decord
4
- spaces
5
- transformers==4.44.2
6
- einops==0.8.0
7
- torchvision==0.16.1
8
- pytorchvideo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ decord>=0.6.0
2
+ #根据https://download.pytorch.org/whl/torch/,python版本为[3.8,3.11]
3
+ torch==2.1.0
4
+ torchvision== 0.16.0
5
+ pytorchvideo==0.1.5
6
+ xformers
7
+ transformers==4.42.4
8
+ #git+https://github.com/huggingface/transformers.git
9
+ huggingface-hub>=0.23.0
10
+ pillow
11
+ chainlit>=1.0
12
+ pydantic>=2.7.1
13
+ timm>=0.9.16
14
+ openai>=1.30.1
15
+ loguru>=0.7.2
16
+ pydantic>=2.7.1
17
+ einops
18
+ sse-starlette>=2.1.0
19
+ flask
20
+ gunicorn
21
+ gevent
22
+ requests
23
+ gradio
24
+ accelerate
25
+ bitsandbytes>=0.39.0