mkshing commited on
Commit
4e2136a
•
1 Parent(s): 5753b4f
Files changed (1) hide show
  1. app.py +85 -47
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import itertools
2
 
3
  import torch
@@ -7,66 +8,103 @@ import gradio as gr
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
  print(f"device: {device}")
9
 
10
- tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt-neox-3.6b-instruction-sft", use_fast=False)
11
- model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt-neox-3.6b-instruction-sft", device_map="auto", torch_dtype=torch.float16)
 
 
 
 
 
 
12
  model = model.to(device)
13
 
14
 
15
  @torch.no_grad()
16
  def inference_func(prompt, max_new_tokens=128, temperature=0.7):
17
- token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
18
- output_ids = model.generate(
19
- token_ids.to(model.device),
20
- do_sample=True,
21
- max_new_tokens=max_new_tokens,
22
- temperature=temperature,
23
- pad_token_id=tokenizer.pad_token_id,
24
- bos_token_id=tokenizer.bos_token_id,
25
- eos_token_id=tokenizer.eos_token_id
26
- )
27
- output = tokenizer.decode(output_ids.tolist()[0][token_ids.size(1):], skip_special_tokens=True)
28
- output = output.replace("<NL>", "\n")
29
- return output
 
 
30
 
31
 
32
  def make_prompt(message, chat_history, max_context_size: int = 10):
33
- contexts = chat_history + [[message, ""]]
34
- contexts = list(itertools.chain.from_iterable(contexts))
35
- if max_context_size > 0:
36
- context_size = max_context_size - 1
37
- else:
38
- context_size = 100000
39
- contexts = contexts[-context_size:]
40
- prompt = []
41
- for idx, context in enumerate(reversed(contexts)):
42
- if idx % 2 == 0:
43
- prompt = [f"システム: {context}"] + prompt
44
  else:
45
- prompt = [f"ユーザー: {context}"] + prompt
46
- prompt = "<NL>".join(prompt)
47
- return prompt
 
 
 
 
 
 
 
 
48
 
49
  def interact_func(message, chat_history, max_context_size, max_new_tokens, temperature):
50
- prompt = make_prompt(message, chat_history, max_context_size)
51
- print(f"prompt: {prompt}")
52
- generated = inference_func(prompt, max_new_tokens, temperature)
53
- print(f"generated: {generated}")
54
- chat_history.append((message, generated))
55
- return "", chat_history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
 
58
  with gr.Blocks() as demo:
59
- gr.Markdown("# Chat with `rinna/japanese-gpt-neox-3.6b-instruction-sft`")
60
- with gr.Accordion("Configs", open=False):
61
- # max_context_size = the number of turns * 2
62
- max_context_size = gr.Number(value=10, label="max_context_size", precision=0)
63
- max_new_tokens = gr.Number(value=128, label="max_new_tokens", precision=0)
64
- temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.1, label="temperature")
65
- chatbot = gr.Chatbot()
66
- msg = gr.Textbox()
67
- clear = gr.Button("Clear")
68
- msg.submit(interact_func, [msg, chatbot, max_context_size, max_new_tokens, temperature], [msg, chatbot])
69
- clear.click(lambda: None, None, chatbot, queue=False)
 
 
 
 
 
 
 
 
70
 
71
  if __name__ == "__main__":
72
- demo.launch(debug=True)
 
1
+ import os
2
  import itertools
3
 
4
  import torch
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
  print(f"device: {device}")
10
 
11
+ tokenizer = AutoTokenizer.from_pretrained(
12
+ "rinna/japanese-gpt-neox-3.6b-instruction-sft", use_fast=False
13
+ )
14
+ model = AutoModelForCausalLM.from_pretrained(
15
+ "rinna/japanese-gpt-neox-3.6b-instruction-sft",
16
+ device_map="auto",
17
+ torch_dtype=torch.float16,
18
+ )
19
  model = model.to(device)
20
 
21
 
22
  @torch.no_grad()
23
  def inference_func(prompt, max_new_tokens=128, temperature=0.7):
24
+ token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
25
+ output_ids = model.generate(
26
+ token_ids.to(model.device),
27
+ do_sample=True,
28
+ max_new_tokens=max_new_tokens,
29
+ temperature=temperature,
30
+ pad_token_id=tokenizer.pad_token_id,
31
+ bos_token_id=tokenizer.bos_token_id,
32
+ eos_token_id=tokenizer.eos_token_id,
33
+ )
34
+ output = tokenizer.decode(
35
+ output_ids.tolist()[0][token_ids.size(1) :], skip_special_tokens=True
36
+ )
37
+ output = output.replace("<NL>", "\n")
38
+ return output
39
 
40
 
41
  def make_prompt(message, chat_history, max_context_size: int = 10):
42
+ contexts = chat_history + [[message, ""]]
43
+ contexts = list(itertools.chain.from_iterable(contexts))
44
+ if max_context_size > 0:
45
+ context_size = max_context_size - 1
 
 
 
 
 
 
 
46
  else:
47
+ context_size = 100000
48
+ contexts = contexts[-context_size:]
49
+ prompt = []
50
+ for idx, context in enumerate(reversed(contexts)):
51
+ if idx % 2 == 0:
52
+ prompt = [f"システム: {context}"] + prompt
53
+ else:
54
+ prompt = [f"ユーザー: {context}"] + prompt
55
+ prompt = "<NL>".join(prompt)
56
+ return prompt
57
+
58
 
59
  def interact_func(message, chat_history, max_context_size, max_new_tokens, temperature):
60
+ prompt = make_prompt(message, chat_history, max_context_size)
61
+ print(f"prompt: {prompt}")
62
+ generated = inference_func(prompt, max_new_tokens, temperature)
63
+ print(f"generated: {generated}")
64
+ chat_history.append((message, generated))
65
+ return "", chat_history
66
+
67
+
68
+ ORIGINAL_SPACE_ID = "mkshing/rinna-japanese-gpt-neox-3.6b-instruction-sft"
69
+ SPACE_ID = os.getenv("SPACE_ID", ORIGINAL_SPACE_ID)
70
+ SHARED_UI_WARNING = f"""# Attention - This Space doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.
71
+ <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></center>
72
+ """
73
+
74
+ if os.getenv("SYSTEM") == "spaces" and SPACE_ID != ORIGINAL_SPACE_ID:
75
+ SETTINGS = (
76
+ f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'
77
+ )
78
+ else:
79
+ SETTINGS = "Settings"
80
+ CUDA_NOT_AVAILABLE_WARNING = f"""# Attention - Running on CPU.
81
+ <center>
82
+ You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
83
+ "T4 small" is sufficient to run this demo.
84
+ </center>
85
+ """
86
 
87
 
88
  with gr.Blocks() as demo:
89
+ gr.Markdown("""# Chat with `rinna/japanese-gpt-neox-3.6b-instruction-sft`
90
+ <a href=\"https://colab.research.google.com/github/mkshing/notebooks/blob/main/rinna_japanese_gpt_neox_3_6b_instruction_sft.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>
91
+
92
+ This demo is a chat UI for [rinna/japanese-gpt-neox-3.6b-instruction-sft](https://huggingface.co/rinna/japanese-gpt-neox-3.6b-instruction-sft).
93
+ """)
94
+ with gr.Accordion("Configs", open=False):
95
+ # max_context_size = the number of turns * 2
96
+ max_context_size = gr.Number(value=10, label="max_context_size", precision=0)
97
+ max_new_tokens = gr.Number(value=128, label="max_new_tokens", precision=0)
98
+ temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.1, label="temperature")
99
+ chatbot = gr.Chatbot()
100
+ msg = gr.Textbox()
101
+ clear = gr.Button("Clear")
102
+ msg.submit(
103
+ interact_func,
104
+ [msg, chatbot, max_context_size, max_new_tokens, temperature],
105
+ [msg, chatbot],
106
+ )
107
+ clear.click(lambda: None, None, chatbot, queue=False)
108
 
109
  if __name__ == "__main__":
110
+ demo.launch(debug=True)