Azure99 commited on
Commit
7054442
·
verified ·
1 Parent(s): f8566ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -17
app.py CHANGED
@@ -1,17 +1,14 @@
1
- import time
2
 
3
  import gradio as gr
4
  import spaces
5
- from threading import Thread
6
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
7
  import torch
 
8
 
9
  MAX_INPUT_LIMIT = 3584
10
-
11
  MODEL_NAME = "Azure99/blossom-v5.1-9b"
12
 
13
  model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, device_map="auto")
14
-
15
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
16
 
17
  GENERATE_CONFIG = dict(
@@ -22,7 +19,6 @@ GENERATE_CONFIG = dict(
22
  repetition_penalty=1.05
23
  )
24
 
25
-
26
  def get_input_ids(inst, history):
27
  prefix = ("A chat between a human and an artificial intelligence bot. "
28
  "The bot gives helpful, detailed, and polite answers to the human's questions.")
@@ -46,27 +42,17 @@ def chat(inst, history):
46
  with torch.no_grad():
47
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
48
  input_ids = get_input_ids(inst, history)
49
- print(len(input_ids))
50
  if len(input_ids) > MAX_INPUT_LIMIT:
51
  yield "The input is too long, please clear the history."
52
  return
53
  generation_kwargs = dict(input_ids=torch.tensor([input_ids]).to(model.device), do_sample=True,
54
  streamer=streamer, **GENERATE_CONFIG)
55
  Thread(target=model.generate, kwargs=generation_kwargs).start()
56
-
57
- # stop watch
58
- start = time.time()
59
  outputs = ""
60
  for new_text in streamer:
61
  outputs += new_text
62
  yield outputs
63
- total_time = time.time() - start
64
- output_token_len = len(tokenizer.encode(outputs, add_special_tokens=False))
65
- speed = output_token_len / total_time
66
- print("----------")
67
- print(history)
68
- print([inst, outputs])
69
- print(f"Speed: {speed:.2f} tokens/s")
70
 
71
 
72
  gr.ChatInterface(chat,
 
1
+ from threading import Thread
2
 
3
  import gradio as gr
4
  import spaces
 
 
5
  import torch
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
7
 
8
  MAX_INPUT_LIMIT = 3584
 
9
  MODEL_NAME = "Azure99/blossom-v5.1-9b"
10
 
11
  model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, device_map="auto")
 
12
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
13
 
14
  GENERATE_CONFIG = dict(
 
19
  repetition_penalty=1.05
20
  )
21
 
 
22
  def get_input_ids(inst, history):
23
  prefix = ("A chat between a human and an artificial intelligence bot. "
24
  "The bot gives helpful, detailed, and polite answers to the human's questions.")
 
42
  with torch.no_grad():
43
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
44
  input_ids = get_input_ids(inst, history)
 
45
  if len(input_ids) > MAX_INPUT_LIMIT:
46
  yield "The input is too long, please clear the history."
47
  return
48
  generation_kwargs = dict(input_ids=torch.tensor([input_ids]).to(model.device), do_sample=True,
49
  streamer=streamer, **GENERATE_CONFIG)
50
  Thread(target=model.generate, kwargs=generation_kwargs).start()
51
+
 
 
52
  outputs = ""
53
  for new_text in streamer:
54
  outputs += new_text
55
  yield outputs
 
 
 
 
 
 
 
56
 
57
 
58
  gr.ChatInterface(chat,