vilarin commited on
Commit
82b38de
1 Parent(s): ef3e7ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -16
app.py CHANGED
@@ -2,22 +2,23 @@ import torch
2
  from PIL import Image
3
  import gradio as gr
4
  import spaces
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  import os
7
  from threading import Thread
8
 
9
 
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
- MODEL_ID = "Qwen/Qwen2-7B-Instruct"
12
  MODELS = os.environ.get("MODELS")
13
  MODEL_NAME = MODELS.split("/")[-1]
 
14
 
15
  TITLE = "<h1><center>Qwen2-Chatbox</center></h1>"
16
 
17
  DESCRIPTION = f"""
18
  <h3>MODEL: <a href="https://hf.co/{MODELS}">{MODEL_NAME}</a></h3>
19
  <center>
20
- <p>Qwen is the large language model built by Alibaba Cloud.
21
  <br>
22
  Feel free to test without log.
23
  </p>
@@ -37,13 +38,15 @@ h3 {
37
  """
38
 
39
  model = AutoModelForCausalLM.from_pretrained(
40
- MODELS,
41
- torch_dtype=torch.float16,
42
- device_map="auto",
43
  )
44
- tokenizer = AutoTokenizer.from_pretrained(MODELS)
 
 
45
 
46
- @spaces.GPU
47
  def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
48
  print(f'message is - {message}')
49
  print(f'history is - {history}')
@@ -54,13 +57,16 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
54
 
55
  print(f"Conversation is -\n{conversation}")
56
 
57
- input_ids = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
58
- inputs = tokenizer(input_ids, return_tensors="pt").to(0)
 
 
 
59
 
60
- streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
61
 
62
  generate_kwargs = dict(
63
- inputs,
64
  streamer=streamer,
65
  top_k=top_k,
66
  top_p=top_p,
@@ -68,7 +74,8 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
68
  max_new_tokens=max_new_tokens,
69
  do_sample=True,
70
  temperature=temperature,
71
- eos_token_id = [151645, 151643],
 
72
  )
73
 
74
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
@@ -81,9 +88,9 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
81
 
82
 
83
 
84
- chatbot = gr.Chatbot(height=450)
85
 
86
- with gr.Blocks(css=CSS) as demo:
87
  gr.HTML(TITLE)
88
  gr.HTML(DESCRIPTION)
89
  gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
@@ -103,7 +110,7 @@ with gr.Blocks(css=CSS) as demo:
103
  ),
104
  gr.Slider(
105
  minimum=128,
106
- maximum=4096,
107
  step=1,
108
  value=1024,
109
  label="Max new tokens",
 
2
  from PIL import Image
3
  import gradio as gr
4
  import spaces
5
+ from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer,BitsAndBytesConfig
6
  import os
7
  from threading import Thread
8
 
9
 
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
+ MODEL_ID = "google/gemma-2-27b-it"
12
  MODELS = os.environ.get("MODELS")
13
  MODEL_NAME = MODELS.split("/")[-1]
14
+ MAX_INPUT_TOKEN_LENGTH = int(os.environ.get("MAX_INPUT_TOKEN_LENGTH", "4096"))
15
 
16
  TITLE = "<h1><center>Qwen2-Chatbox</center></h1>"
17
 
18
  DESCRIPTION = f"""
19
  <h3>MODEL: <a href="https://hf.co/{MODELS}">{MODEL_NAME}</a></h3>
20
  <center>
21
+ <p>Gemma is the large language model built by Google.
22
  <br>
23
  Feel free to test without log.
24
  </p>
 
38
  """
39
 
40
  model = AutoModelForCausalLM.from_pretrained(
41
+ MODELS,
42
+ device_map="auto",
43
+ quantization_config=BitsAndBytesConfig(load_in_4bit=True)
44
  )
45
+ tokenizer = GemmaTokenizerFast.from_pretrained(MODELS)
46
+ model.config.sliding_window = 4096
47
+ model.eval()
48
 
49
+ @spaces.GPU(duration=90)
50
  def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
51
  print(f'message is - {message}')
52
  print(f'history is - {history}')
 
57
 
58
  print(f"Conversation is -\n{conversation}")
59
 
60
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
61
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
62
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
63
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
64
+ input_ids = input_ids.to(0)
65
 
66
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
67
 
68
  generate_kwargs = dict(
69
+ {"input_ids": input_ids},
70
  streamer=streamer,
71
  top_k=top_k,
72
  top_p=top_p,
 
74
  max_new_tokens=max_new_tokens,
75
  do_sample=True,
76
  temperature=temperature,
77
+ num_beams=1,
78
+ repetition_penalty=repetition_penalty,
79
  )
80
 
81
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
 
88
 
89
 
90
 
91
+ chatbot = gr.Chatbot(height=600)
92
 
93
+ with gr.Blocks(css=CSS, theme="soft") as demo:
94
  gr.HTML(TITLE)
95
  gr.HTML(DESCRIPTION)
96
  gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
 
110
  ),
111
  gr.Slider(
112
  minimum=128,
113
+ maximum=2048,
114
  step=1,
115
  value=1024,
116
  label="Max new tokens",