ClownRat commited on
Commit
0c50b58
1 Parent(s): 75a4b32

Update demo.

Browse files
Files changed (1) hide show
  1. app.py +3 -4
app.py CHANGED
@@ -61,24 +61,23 @@ The service is a research preview intended for non-commercial use only, subject
61
 
62
 
63
  class Chat:
64
- def __init__(self, model_path, conv_mode, model_base=None, load_8bit=False, load_4bit=False, device='cuda'):
65
  # disable_torch_init()
66
  model_name = get_model_name_from_path(model_path)
67
  self.tokenizer, self.model, processor, context_len = load_pretrained_model(
68
  model_path, model_base, model_name,
69
  load_8bit, load_4bit,
70
- device=device,
71
  offload_folder="save_folder")
72
  self.processor = processor
73
  self.conv_mode = conv_mode
74
  self.conv = conv_templates[conv_mode].copy()
75
- self.device = self.model.device
76
 
77
  def get_prompt(self, qs, state):
78
  state.append_message(state.roles[0], qs)
79
  state.append_message(state.roles[1], None)
80
  return state
81
 
 
82
  @torch.inference_mode()
83
  def generate(self, tensor: list, modals: list, prompt: str, first_run: bool, state):
84
  # TODO: support multiple turns of conversation.
@@ -92,7 +91,7 @@ class Chat:
92
  prompt = state.get_prompt()
93
  # print('\n\n\n')
94
  # print(prompt)
95
- input_ids = tokenizer_MMODAL_token(prompt, tokenizer, MMODAL_TOKEN_INDEX[modals[0]], return_tensors='pt').unsqueeze(0).to(self.device)
96
 
97
  # 3. generate response according to visual signals and prompts.
98
  stop_str = self.conv.sep if self.conv.sep_style in [SeparatorStyle.SINGLE] else self.conv.sep2
 
61
 
62
 
63
  class Chat:
64
+ def __init__(self, model_path, conv_mode, model_base=None, load_8bit=False, load_4bit=False):
65
  # disable_torch_init()
66
  model_name = get_model_name_from_path(model_path)
67
  self.tokenizer, self.model, processor, context_len = load_pretrained_model(
68
  model_path, model_base, model_name,
69
  load_8bit, load_4bit,
 
70
  offload_folder="save_folder")
71
  self.processor = processor
72
  self.conv_mode = conv_mode
73
  self.conv = conv_templates[conv_mode].copy()
 
74
 
75
  def get_prompt(self, qs, state):
76
  state.append_message(state.roles[0], qs)
77
  state.append_message(state.roles[1], None)
78
  return state
79
 
80
+ @spaces.GPU(duration=120)
81
  @torch.inference_mode()
82
  def generate(self, tensor: list, modals: list, prompt: str, first_run: bool, state):
83
  # TODO: support multiple turns of conversation.
 
91
  prompt = state.get_prompt()
92
  # print('\n\n\n')
93
  # print(prompt)
94
+ input_ids = tokenizer_MMODAL_token(prompt, tokenizer, MMODAL_TOKEN_INDEX[modals[0]], return_tensors='pt').unsqueeze(0).to(self.model.device)
95
 
96
  # 3. generate response according to visual signals and prompts.
97
  stop_str = self.conv.sep if self.conv.sep_style in [SeparatorStyle.SINGLE] else self.conv.sep2