yuhuili commited on
Commit
f2ce589
1 Parent(s): 75b08df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -11
app.py CHANGED
@@ -11,7 +11,7 @@ except:
11
  import torch
12
  from fastchat.model import get_conversation_template
13
  import re
14
- from transformers import LlamaForCausalLM,AutoTokenizer
15
 
16
  def truncate_list(lst, num):
17
  if num not in lst:
@@ -73,7 +73,7 @@ def highlight_text(text, text_list,color="black"):
73
 
74
  return result
75
 
76
- @spaces.GPU(duration=30)
77
  def warmup(model):
78
  model.cuda()
79
  conv = get_conversation_template(args.model_type)
@@ -90,12 +90,13 @@ def warmup(model):
90
  prompt = conv.get_prompt()
91
  if args.model_type == "llama-2-chat":
92
  prompt += " "
93
- input_ids = tokenizer([prompt]).input_ids
94
  input_ids = torch.as_tensor(input_ids).to(model.base_model.device)
95
- outs=model.generate(input_ids)
96
- print(outs)
97
- @spaces.GPU(duration=30)
98
  def bot(history, temperature, top_p, use_EaInfer, highlight_EaInfer,session_state,):
 
99
  if not history:
100
  return history, "0.00 tokens/s", "0.00", session_state
101
  pure_history = session_state.get("pure_history", [])
@@ -270,17 +271,17 @@ parser.add_argument(
270
  args = parser.parse_args()
271
  a=torch.tensor(1).cuda()
272
  print(a)
273
- model = LlamaForCausalLM.from_pretrained(
274
- args.base_model_path,
 
 
275
  torch_dtype=torch.float16,
276
  low_cpu_mem_usage=True,
277
  load_in_4bit=args.load_in_4bit,
278
  load_in_8bit=args.load_in_8bit,
279
  device_map="auto",
280
  )
281
-
282
  model.eval()
283
- tokenizer=AutoTokenizer.from_pretrained(args.base_model_path)
284
  warmup(model)
285
 
286
  custom_css = """
@@ -327,4 +328,4 @@ with gr.Blocks(css=custom_css) as demo:
327
  )
328
  stop_button.click(fn=None, inputs=None, outputs=None, cancels=[send_event,regenerate_event,enter_event])
329
  demo.queue()
330
- demo.launch(share=True)
 
11
  import torch
12
  from fastchat.model import get_conversation_template
13
  import re
14
+
15
 
16
  def truncate_list(lst, num):
17
  if num not in lst:
 
73
 
74
  return result
75
 
76
+ @spaces.GPU(duration=60)
77
  def warmup(model):
78
  model.cuda()
79
  conv = get_conversation_template(args.model_type)
 
90
  prompt = conv.get_prompt()
91
  if args.model_type == "llama-2-chat":
92
  prompt += " "
93
+ input_ids = model.tokenizer([prompt]).input_ids
94
  input_ids = torch.as_tensor(input_ids).to(model.base_model.device)
95
+ for output_ids in model.ea_generate(input_ids):
96
+ ol=output_ids.shape[1]
97
+ @spaces.GPU(duration=60)
98
  def bot(history, temperature, top_p, use_EaInfer, highlight_EaInfer,session_state,):
99
+ model.cuda()
100
  if not history:
101
  return history, "0.00 tokens/s", "0.00", session_state
102
  pure_history = session_state.get("pure_history", [])
 
271
  args = parser.parse_args()
272
  a=torch.tensor(1).cuda()
273
  print(a)
274
+ model = EaModel.from_pretrained(
275
+ base_model_path=args.base_model_path,
276
+ ea_model_path=args.ea_model_path,
277
+ total_token=args.total_token,
278
  torch_dtype=torch.float16,
279
  low_cpu_mem_usage=True,
280
  load_in_4bit=args.load_in_4bit,
281
  load_in_8bit=args.load_in_8bit,
282
  device_map="auto",
283
  )
 
284
  model.eval()
 
285
  warmup(model)
286
 
287
  custom_css = """
 
328
  )
329
  stop_button.click(fn=None, inputs=None, outputs=None, cancels=[send_event,regenerate_event,enter_event])
330
  demo.queue()
331
+ demo.launch()