chansung commited on
Commit
fe3e540
1 Parent(s): 89c0f3c

Update vid2persona/gen/local_openllm.py

Browse files
Files changed (1) hide show
  1. vid2persona/gen/local_openllm.py +7 -17
vid2persona/gen/local_openllm.py CHANGED
@@ -1,36 +1,26 @@
1
- # import spaces
2
 
3
  import torch
4
  from threading import Thread
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
  from transformers import TextIteratorStreamer
7
 
8
- model = None
9
- tokenizer = None
10
 
11
- # @spaces.GPU
12
  def send_message(
13
  messages: list,
14
  model_id: str,
15
  max_input_token_length: int,
16
  parameters: dict
17
  ):
18
- global tokenizer
19
- global model
20
-
21
- if tokenizer is None:
22
- tokenizer = AutoTokenizer.from_pretrained(model_id)
23
- tokenizer.use_default_system_prompt = False
24
- if model is None:
25
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
26
-
27
- input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt")
28
  if input_ids.shape[1] > max_input_token_length:
29
  input_ids = input_ids[:, -max_input_token_length:]
30
  print(f"Trimmed input from conversation as it was longer than {max_input_token_length} tokens.")
31
- input_ids = input_ids.to(model.device)
32
 
33
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
34
  generate_kwargs = dict(
35
  {"input_ids": input_ids},
36
  streamer=streamer,
@@ -38,7 +28,7 @@ def send_message(
38
  num_beams=1,
39
  **parameters
40
  )
41
- t = Thread(target=model.generate, kwargs=generate_kwargs)
42
  t.start()
43
 
44
  for text in streamer:
 
1
+ import spaces
2
 
3
  import torch
4
  from threading import Thread
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
  from transformers import TextIteratorStreamer
7
 
8
+ from vid2persona import init
 
9
 
10
+ @spaces.GPU
11
  def send_message(
12
  messages: list,
13
  model_id: str,
14
  max_input_token_length: int,
15
  parameters: dict
16
  ):
17
+ input_ids = init.tokenizer.apply_chat_template(messages, return_tensors="pt")
 
 
 
 
 
 
 
 
 
18
  if input_ids.shape[1] > max_input_token_length:
19
  input_ids = input_ids[:, -max_input_token_length:]
20
  print(f"Trimmed input from conversation as it was longer than {max_input_token_length} tokens.")
21
+ input_ids = input_ids.to(init.model.device)
22
 
23
+ streamer = TextIteratorStreamer(init.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
24
  generate_kwargs = dict(
25
  {"input_ids": input_ids},
26
  streamer=streamer,
 
28
  num_beams=1,
29
  **parameters
30
  )
31
+ t = Thread(target=init.model.generate, kwargs=generate_kwargs)
32
  t.start()
33
 
34
  for text in streamer: