leonardlin commited on
Commit
22b3942
1 Parent(s): 82777d3

better chat rebuild, added system prompt, load_in_4bit

Browse files
Files changed (1) hide show
  1. app.py +35 -10
app.py CHANGED
@@ -3,40 +3,61 @@
3
  import gradio as gr
4
  import logging
5
  import html
 
6
  import time
7
  import torch
8
  from threading import Thread
9
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
 
11
  # Model
12
  model_name = "augmxnt/shisa-7b-v1"
13
 
14
  # UI Settings
15
  title = "Shisa 7B"
16
- description = "Test out Shisa 7B in either English or Japanese."
17
  placeholder = "Type Here / ここに入力してください"
18
  examples = [
19
- "What's the best ramen in Tokyo?",
20
- "あなたは熱狂的なポケモンファンです。",
21
- "東京でおすすめのラーメン屋ってどこ?",
 
22
  ]
23
 
24
  # LLM Settings
25
- system_prompt = 'あなたは役に立つアシスタントです。'
26
- chat_history = [{"role": "system", "content": system_prompt}]
 
 
27
  tokenizer = AutoTokenizer.from_pretrained(model_name)
28
  model = AutoModelForCausalLM.from_pretrained(
29
  model_name,
30
  torch_dtype=torch.bfloat16,
31
  device_map="auto",
32
- load_in_8bit=True,
33
- # load_in_4bit=True
 
 
 
 
 
34
  )
35
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
36
 
37
- def chat(message, history):
 
 
 
 
 
 
 
 
 
 
38
  chat_history.append({"role": "user", "content": message})
 
39
  input_ids = tokenizer.apply_chat_template(chat_history, add_generation_prompt=True, return_tensors="pt")
 
40
  # for multi-gpu, find the device of the first parameter of the model
41
  first_param_device = next(model.parameters()).device
42
  input_ids = input_ids.to(first_param_device)
@@ -50,6 +71,7 @@ def chat(message, history):
50
  repetition_penalty=1.15,
51
  top_p=0.95,
52
  eos_token_id=tokenizer.eos_token_id,
 
53
  )
54
  # https://www.gradio.app/main/guides/creating-a-chatbot-fast#example-using-a-local-open-source-llm-with-hugging-face
55
  t = Thread(target=model.generate, kwargs=generate_kwargs)
@@ -71,6 +93,9 @@ chat_interface = gr.ChatInterface(
71
  cache_examples=False,
72
  undo_btn="Delete Previous",
73
  clear_btn="Clear",
 
 
 
74
  )
75
 
76
  # https://huggingface.co/spaces/ysharma/Explore_llamav2_with_TGI/blob/main/app.py#L219 - we use this with construction b/c Gradio barfs on autoreload otherwise
 
3
  import gradio as gr
4
  import logging
5
  import html
6
+ from pprint import pprint
7
  import time
8
  import torch
9
  from threading import Thread
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
11
 
12
  # Model
13
  model_name = "augmxnt/shisa-7b-v1"
14
 
15
  # UI Settings
16
  title = "Shisa 7B"
17
+ description = "Test out Shisa 7B in either English or Japanese. If you aren't getting the right language outputs, you can try changing the system prompt to the appropriate language. Note, we are running `load_in_4bit` to fit in 16GB of VRAM"
18
  placeholder = "Type Here / ここに入力してください"
19
  examples = [
20
+ ["What are the best slices of pizza in New York City?"],
21
+ ['How do I program a simple "hello world" in Python?'],
22
+ ["東京でおすすめのラーメン屋ってどこ?"],
23
+ ["Pythonでシンプルな「ハローワールド」をプログラムするにはどうすればいいですか?"],
24
  ]
25
 
26
  # LLM Settings
27
+ # Initial
28
+ system_prompt = 'You are a helpful, bilingual assistant. Reply in the same language as the user.'
29
+ default_prompt = system_prompt
30
+
31
  tokenizer = AutoTokenizer.from_pretrained(model_name)
32
  model = AutoModelForCausalLM.from_pretrained(
33
  model_name,
34
  torch_dtype=torch.bfloat16,
35
  device_map="auto",
36
+ # load_in_8bit=True,
37
+ quantization_config = BitsAndBytesConfig(
38
+ load_in_4bit=True,
39
+ bnb_4bit_quant_type='nf4',
40
+ bnb_4bit_use_double_quant=True,
41
+ bnb_4bit_compute_dtype=torch.bfloat16
42
+ ),
43
  )
44
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
45
 
46
+ def chat(message, history, system_prompt):
47
+ print('---')
48
+ pprint(history)
49
+ if not system_prompt:
50
+ system_prompt = default_prompt
51
+
52
+ # Let's just rebuild every time it's easier
53
+ chat_history = [{"role": "system", "content": system_prompt}]
54
+ for h in history:
55
+ chat_history.append({"role": "user", "content": h[0]})
56
+ chat_history.append({"role": "assistant", "content": h[1]})
57
  chat_history.append({"role": "user", "content": message})
58
+
59
  input_ids = tokenizer.apply_chat_template(chat_history, add_generation_prompt=True, return_tensors="pt")
60
+
61
  # for multi-gpu, find the device of the first parameter of the model
62
  first_param_device = next(model.parameters()).device
63
  input_ids = input_ids.to(first_param_device)
 
71
  repetition_penalty=1.15,
72
  top_p=0.95,
73
  eos_token_id=tokenizer.eos_token_id,
74
+ pad_token_id=tokenizer.eos_token_id,
75
  )
76
  # https://www.gradio.app/main/guides/creating-a-chatbot-fast#example-using-a-local-open-source-llm-with-hugging-face
77
  t = Thread(target=model.generate, kwargs=generate_kwargs)
 
93
  cache_examples=False,
94
  undo_btn="Delete Previous",
95
  clear_btn="Clear",
96
+ additional_inputs=[
97
+ gr.Textbox(system_prompt, label="System Prompt (Change the language of the prompt for better replies)"),
98
+ ],
99
  )
100
 
101
  # https://huggingface.co/spaces/ysharma/Explore_llamav2_with_TGI/blob/main/app.py#L219 - we use this with construction b/c Gradio barfs on autoreload otherwise