aixsatoshi commited on
Commit
571bf3a
1 Parent(s): 2ac2a22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -122
app.py CHANGED
@@ -9,142 +9,144 @@ from llama_cpp_agent.chat_history.messages import Roles
9
  import gradio as gr
10
  from huggingface_hub import hf_hub_download
11
 
 
12
  hf_hub_download(
13
- repo_id="bartowski/gemma-2-9b-it-GGUF",
14
- filename="gemma-2-9b-it-Q5_K_M.gguf",
15
- local_dir="./models"
16
  )
17
 
18
-
19
-
20
  hf_hub_download(
21
- repo_id="bartowski/gemma-2-27b-it-GGUF",
22
- filename="gemma-2-27b-it-Q5_K_M.gguf",
23
- local_dir="./models"
24
  )
25
 
26
-
27
  @spaces.GPU(duration=120)
28
  def respond(
29
- message,
30
- history: list[tuple[str, str]],
31
- model,
32
- system_message,
33
- max_tokens,
34
- temperature,
35
- top_p,
36
- top_k,
37
- repeat_penalty,
38
  ):
39
- chat_template = MessagesFormatterType.GEMMA_2
40
 
41
- llm = Llama(
42
- model_path=f"models/{model}",
43
- flash_attn=True,
44
- n_gpu_layers=81,
45
- n_batch=1024,
46
- n_ctx=8192,
47
- )
48
- provider = LlamaCppPythonProvider(llm)
49
 
50
- agent = LlamaCppAgent(
51
- provider,
52
- system_prompt=f"{system_message}",
53
- predefined_messages_formatter_type=chat_template,
54
- debug_output=True
55
- )
56
-
57
- settings = provider.get_provider_default_settings()
58
- settings.temperature = temperature
59
- settings.top_k = top_k
60
- settings.top_p = top_p
61
- settings.max_tokens = max_tokens
62
- settings.repeat_penalty = repeat_penalty
63
- settings.stream = True
64
 
65
- messages = BasicChatHistory()
 
 
 
 
 
 
66
 
67
- for msn in history:
68
- user = {
69
- 'role': Roles.user,
70
- 'content': msn[0]
71
- }
72
- assistant = {
73
- 'role': Roles.assistant,
74
- 'content': msn[1]
75
- }
76
- messages.add_message(user)
77
- messages.add_message(assistant)
78
-
79
- stream = agent.get_chat_response(
80
- message,
81
- llm_sampling_settings=settings,
82
- chat_history=messages,
83
- returns_streaming_generator=True,
84
- print_output=False
85
- )
86
-
87
- outputs = ""
88
- for output in stream:
89
- outputs += output
90
- yield outputs
91
 
92
- description = """<p align="center">Defaults to 27B it (you can switch to 9b it from additional inputs)</p>
93
- <p><center>
94
- <a href="https://huggingface.co/google/gemma-2-27b-it" target="_blank">[27B it Model]</a>
95
- <a href="https://huggingface.co/google/gemma-2-9b-it" target="_blank">[9B it Model]</a>
96
- <a href="https://huggingface.co/bartowski/gemma-2-27b-it-GGUF" target="_blank">[27B it Model GGUF]</a>
97
- <a href="https://huggingface.co/bartowski/gemma-2-9b-it-GGUF" target="_blank">[9B it Model GGUF]</a>
98
- </center></p>
99
- """
 
 
 
100
 
101
- demo = gr.ChatInterface(
102
- respond,
103
- additional_inputs=[
104
- gr.Dropdown([
105
- 'gemma-2-9b-it-Q5_K_M.gguf',
106
- 'gemma-2-27b-it-Q5_K_M.gguf'
107
- ],
108
- value="gemma-2-27b-it-Q5_K_M.gguf",
109
- label="Model"
110
- ),
111
- gr.Textbox(value="You are a helpful assistant.", label="System message"),
112
- gr.Slider(minimum=1, maximum=4096, value=2048, step=1, label="Max tokens"),
113
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
114
- gr.Slider(
115
- minimum=0.1,
116
- maximum=1.0,
117
- value=0.95,
118
- step=0.05,
119
- label="Top-p",
120
- ),
121
- gr.Slider(
122
- minimum=0,
123
- maximum=100,
124
- value=40,
125
- step=1,
126
- label="Top-k",
127
- ),
128
- gr.Slider(
129
- minimum=0.0,
130
- maximum=2.0,
131
- value=1.1,
132
- step=0.1,
133
- label="Repetition penalty",
134
- ),
135
- ],
136
- retry_btn="Retry",
137
- undo_btn="Undo",
138
- clear_btn="Clear",
139
- submit_btn="Send",
140
- title="Chat with Gemma 2 using llama.cpp",
141
- description=description,
142
- chatbot=gr.Chatbot(
143
- scale=1,
144
- likeable=False,
145
- show_copy_button=True
146
- )
147
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
  if __name__ == "__main__":
150
- demo.launch()
 
9
  import gradio as gr
10
  from huggingface_hub import hf_hub_download
11
 
12
+ # モデルのダウンロード
13
  hf_hub_download(
14
+ repo_id="bartowski/gemma-2-9b-it-GGUF",
15
+ filename="gemma-2-9b-it-Q5_K_M.gguf",
16
+ local_dir="./models"
17
  )
18
 
 
 
19
  hf_hub_download(
20
+ repo_id="bartowski/gemma-2-27b-it-GGUF",
21
+ filename="gemma-2-27b-it-Q5_K_M.gguf",
22
+ local_dir="./models"
23
  )
24
 
25
+ # 推論関数
26
  @spaces.GPU(duration=120)
27
  def respond(
28
+ message,
29
+ history: list[tuple[str, str]],
30
+ model,
31
+ system_message,
32
+ max_tokens,
33
+ temperature,
34
+ top_p,
35
+ top_k,
36
+ repeat_penalty,
37
  ):
38
+ chat_template = MessagesFormatterType.GEMMA_2
39
 
40
+ llm = Llama(
41
+ model_path=f"models/{model}",
42
+ flash_attn=True,
43
+ n_gpu_layers=81,
44
+ n_batch=1024,
45
+ n_ctx=8192,
46
+ )
47
+ provider = LlamaCppPythonProvider(llm)
48
 
49
+ agent = LlamaCppAgent(
50
+ provider,
51
+ system_prompt=f"{system_message}",
52
+ predefined_messages_formatter_type=chat_template,
53
+ debug_output=True
54
+ )
 
 
 
 
 
 
 
 
55
 
56
+ settings = provider.get_provider_default_settings()
57
+ settings.temperature = temperature
58
+ settings.top_k = top_k
59
+ settings.top_p = top_p
60
+ settings.max_tokens = max_tokens
61
+ settings.repeat_penalty = repeat_penalty
62
+ settings.stream = True
63
 
64
+ messages = BasicChatHistory()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ for msn in history:
67
+ user = {
68
+ 'role': Roles.user,
69
+ 'content': msn[0]
70
+ }
71
+ assistant = {
72
+ 'role': Roles.assistant,
73
+ 'content': msn[1]
74
+ }
75
+ messages.add_message(user)
76
+ messages.add_message(assistant)
77
 
78
+ stream = agent.get_chat_response(
79
+ message,
80
+ llm_sampling_settings=settings,
81
+ chat_history=messages,
82
+ returns_streaming_generator=True,
83
+ print_output=False
84
+ )
85
+
86
+ outputs = ""
87
+ for output in stream:
88
+ outputs += output
89
+ yield outputs
90
+
91
+ # Gradioのインターフェースを作成
92
+ def create_interface(model_name, description):
93
+ return gr.ChatInterface(
94
+ respond,
95
+ additional_inputs=[
96
+ gr.Textbox(value=model_name, label="Model", interactive=False),
97
+ gr.Textbox(value="You are a helpful assistant.", label="System message"),
98
+ gr.Slider(minimum=1, maximum=4096, value=2048, step=1, label="Max tokens"),
99
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
100
+ gr.Slider(
101
+ minimum=0.1,
102
+ maximum=1.0,
103
+ value=0.95,
104
+ step=0.05,
105
+ label="Top-p",
106
+ ),
107
+ gr.Slider(
108
+ minimum=0,
109
+ maximum=100,
110
+ value=40,
111
+ step=1,
112
+ label="Top-k",
113
+ ),
114
+ gr.Slider(
115
+ minimum=0.0,
116
+ maximum=2.0,
117
+ value=1.1,
118
+ step=0.1,
119
+ label="Repetition penalty",
120
+ ),
121
+ ],
122
+ retry_btn="Retry",
123
+ undo_btn="Undo",
124
+ clear_btn="Clear",
125
+ submit_btn="Send",
126
+ title=f"Chat with Gemma 2 using llama.cpp - {model_name}",
127
+ description=description,
128
+ chatbot=gr.Chatbot(
129
+ scale=1,
130
+ likeable=False,
131
+ show_copy_button=True
132
+ )
133
+ )
134
+
135
+ # 各モデルのインターフェース
136
+ description_9b = """<p align="center">Gemma-2 9B it Model</p>"""
137
+ description_27b = """<p align="center">Gemma-2 27B it Model</p>"""
138
+
139
+ interface_9b = create_interface('gemma-2-9b-it-Q5_K_M.gguf', description_9b)
140
+ interface_27b = create_interface('gemma-2-27b-it-Q5_K_M.gguf', description_27b)
141
+
142
+ # Gradio Blocksで2つのインターフェースを並べて表示
143
+ with gr.Blocks() as demo:
144
+ #gr.Markdown("# Compare Gemma-2 9B and 27B Models")
145
+ with gr.Row():
146
+ with gr.Column():
147
+ interface_9b.render()
148
+ with gr.Column():
149
+ interface_27b.render()
150
 
151
  if __name__ == "__main__":
152
+ demo.launch()