phuongnv commited on
Commit
2f97117
1 Parent(s): 372e8fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -150
app.py CHANGED
@@ -1,151 +1,150 @@
1
- import gradio as gr
2
- import spaces
3
- import selfies as sf
4
- from llama_cpp import Llama
5
- from llama_cpp_agent import LlamaCppAgent
6
- from llama_cpp_agent.providers import LlamaCppPythonProvider
7
- from llama_cpp_agent.chat_history import BasicChatHistory
8
- from llama_cpp_agent.chat_history.messages import Roles
9
-
10
- css = """
11
- .message-row {
12
- justify-content: space-evenly !important;
13
- }
14
- .message-bubble-border {
15
- border-radius: 6px !important;
16
- }
17
- .dark.message-bubble-border {
18
- border-color: #343140 !important;
19
- }
20
- .dark.user {
21
- background: #1e1c26 !important;
22
- }
23
- .dark.assistant.dark, .dark.pending.dark {
24
- background: #16141c !important;
25
- }
26
- """
27
-
28
- def get_messages_formatter_type(model_name):
29
- from llama_cpp_agent import MessagesFormatterType
30
- return MessagesFormatterType.CHATML
31
-
32
- @spaces.GPU(duration=120)
33
- def respond(
34
- message,
35
- history: list[tuple[str, str]],
36
- max_tokens,
37
- temperature,
38
- top_p,
39
- top_k,
40
- model,
41
- ):
42
- chat_template = get_messages_formatter_type(model)
43
-
44
- llm = Llama(model_path="model.guff")
45
- provider = LlamaCppPythonProvider(llm)
46
-
47
- agent = LlamaCppAgent(
48
- provider,
49
- predefined_messages_formatter_type=chat_template,
50
- debug_output=True
51
- )
52
-
53
- settings = provider.get_provider_default_settings()
54
- settings.temperature = temperature
55
- settings.top_k = top_k
56
- settings.top_p = top_p
57
- settings.max_tokens = max_tokens
58
- settings.stream = True
59
- settings.num_beams = 10 # Enable beam search with 10 beams
60
-
61
- messages = BasicChatHistory()
62
-
63
- for msn in history:
64
- user = {
65
- 'role': Roles.user,
66
- 'content': msn[0]
67
- }
68
- assistant = {
69
- 'role': Roles.assistant,
70
- 'content': msn[1]
71
- }
72
- messages.add_message(user)
73
- messages.add_message(assistant)
74
-
75
- stream = agent.get_chat_response(
76
- message,
77
- llm_sampling_settings=settings,
78
- chat_history=messages,
79
- returns_streaming_generator=True,
80
- print_output=False
81
- )
82
-
83
- outputs = set() # Use a set to store unique outputs
84
- unique_responses = []
85
- prompt_length = len(message) # Assuming `message` is the prompt
86
-
87
- for index, output in enumerate(stream, start=1):
88
- if output not in outputs:
89
- outputs.add(output)
90
-
91
- # Post-process the output
92
- output1 = output[prompt_length:]
93
- first_inst_index = output1.find("[/INST]")
94
- second_inst_index = output1.find("[/IN", first_inst_index + len("[/INST]") + 1)
95
- predicted_selfies = output1[first_inst_index + len("[/INST]") : second_inst_index].strip()
96
- predicted_smiles = sf.decoder(predicted_selfies)
97
- unique_responses.append(f"Predict {index}: {predicted_smiles}")
98
- yield "\n".join(unique_responses)
99
-
100
-
101
- PLACEHOLDER = """
102
- <div class="message-bubble-border" style="display:flex; max-width: 600px; border-radius: 8px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); backdrop-filter: blur(10px);">
103
- <div style="padding: .5rem 1.5rem;">
104
- <h2 style="text-align: left; font-size: 1.5rem; font-weight: 700; margin-bottom: 0.5rem;">Chat with CausalLM 35B long (Q6K GGUF)</h2>
105
- <p style="text-align: left; font-size: 16px; line-height: 1.5; margin-bottom: 15px;">You can try different models from CausalLM here.<br>Running on NVIDIA A100-SXM4-80GB MIG 3g.40gb with Zero-GPU from Hugging Face.</p>
106
- </div>
107
- </div>
108
- """
109
-
110
- demo = gr.ChatInterface(
111
- respond,
112
- additional_inputs=[
113
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max tokens"),
114
- gr.Slider(minimum=0.1, maximum=4.0, value=1.0, step=0.1, label="Temperature"),
115
- gr.Slider(
116
- minimum=0.1,
117
- maximum=1.0,
118
- value=1.0,
119
- step=0.05,
120
- label="Top-p",
121
- ),
122
- gr.Slider(
123
- minimum=0,
124
- maximum=100,
125
- value=50,
126
- step=1,
127
- label="Top-k",
128
- )
129
- ],
130
- theme=gr.themes.Soft(primary_hue="violet", secondary_hue="violet", neutral_hue="gray", font=[gr.themes.GoogleFont("Exo"), "ui-sans-serif", "system-ui", "sans-serif"]).set(
131
- body_background_fill_dark="#16141c",
132
- block_background_fill_dark="#16141c",
133
- block_border_width="1px",
134
- block_title_background_fill_dark="#1e1c26",
135
- input_background_fill_dark="#292733",
136
- button_secondary_background_fill_dark="#24212b",
137
- border_color_primary_dark="#343140",
138
- background_fill_secondary_dark="#16141c",
139
- color_accent_soft_dark="transparent"
140
- ),
141
- css=css,
142
- retry_btn="Retry",
143
- undo_btn="Undo",
144
- clear_btn="Clear",
145
- submit_btn="Send",
146
- description="Retrosynthesis chatbot",
147
- chatbot=gr.Chatbot(scale=1, placeholder=PLACEHOLDER)
148
- )
149
-
150
- if __name__ == "__main__":
151
  demo.launch()
 
1
+ import gradio as gr
2
+ import spaces
3
+ import selfies as sf
4
+ from llama_cpp import Llama
5
+ from llama_cpp_agent import LlamaCppAgent
6
+ from llama_cpp_agent.providers import LlamaCppPythonProvider
7
+ from llama_cpp_agent.chat_history import BasicChatHistory
8
+ from llama_cpp_agent.chat_history.messages import Roles
9
+
10
+ css = """
11
+ .message-row {
12
+ justify-content: space-evenly !important;
13
+ }
14
+ .message-bubble-border {
15
+ border-radius: 6px !important;
16
+ }
17
+ .dark.message-bubble-border {
18
+ border-color: #343140 !important;
19
+ }
20
+ .dark.user {
21
+ background: #1e1c26 !important;
22
+ }
23
+ .dark.assistant.dark, .dark.pending.dark {
24
+ background: #16141c !important;
25
+ }
26
+ """
27
+
28
+ def get_messages_formatter_type(model_name):
29
+ from llama_cpp_agent import MessagesFormatterType
30
+ return MessagesFormatterType.CHATML
31
+
32
+ @spaces.GPU(duration=120)
33
+ def respond(
34
+ message,
35
+ history: list[tuple[str, str]],
36
+ max_tokens,
37
+ temperature,
38
+ top_p,
39
+ top_k,
40
+ model,
41
+ ):
42
+ chat_template = get_messages_formatter_type(model)
43
+
44
+ llm = Llama(model_path="model.guff")
45
+ provider = LlamaCppPythonProvider(llm)
46
+
47
+ agent = LlamaCppAgent(
48
+ provider,
49
+ predefined_messages_formatter_type=chat_template,
50
+ debug_output=True
51
+ )
52
+
53
+ settings = provider.get_provider_default_settings()
54
+ settings.temperature = temperature
55
+ settings.top_k = top_k
56
+ settings.top_p = top_p
57
+ settings.max_tokens = max_tokens
58
+ settings.stream = True
59
+ settings.num_beams = 10 # Enable beam search with 10 beams
60
+
61
+ messages = BasicChatHistory()
62
+
63
+ for msn in history:
64
+ user = {
65
+ 'role': Roles.user,
66
+ 'content': msn[0]
67
+ }
68
+ assistant = {
69
+ 'role': Roles.assistant,
70
+ 'content': msn[1]
71
+ }
72
+ messages.add_message(user)
73
+ messages.add_message(assistant)
74
+
75
+ stream = agent.get_chat_response(
76
+ message,
77
+ llm_sampling_settings=settings,
78
+ chat_history=messages,
79
+ returns_streaming_generator=True,
80
+ print_output=False
81
+ )
82
+
83
+ outputs = set() # Use a set to store unique outputs
84
+ unique_responses = []
85
+ prompt_length = len(message) # Assuming `message` is the prompt
86
+
87
+ for index, output in enumerate(stream, start=1):
88
+ if output not in outputs:
89
+ outputs.add(output)
90
+
91
+ # Post-process the output
92
+ output1 = output[prompt_length:]
93
+ first_inst_index = output1.find("[/INST]")
94
+ second_inst_index = output1.find("[/IN", first_inst_index + len("[/INST]") + 1)
95
+ predicted_selfies = output1[first_inst_index + len("[/INST]") : second_inst_index].strip()
96
+ predicted_smiles = sf.decoder(predicted_selfies)
97
+ unique_responses.append(f"Predict {index}: {predicted_smiles}")
98
+ yield "\n".join(unique_responses)
99
+
100
+
101
+ PLACEHOLDER = """
102
+ <div class="message-bubble-border" style="display:flex; max-width: 600px; border-radius: 8px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); backdrop-filter: blur(10px);">
103
+ <div style="padding: .5rem 1.5rem;">
104
+ <h2 style="text-align: left; font-size: 1.5rem; font-weight: 700; margin-bottom: 0.5rem;">Retrosynthesis Chatbot</h2>
105
+ </div>
106
+ </div>
107
+ """
108
+
109
+ demo = gr.ChatInterface(
110
+ respond,
111
+ additional_inputs=[
112
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max tokens"),
113
+ gr.Slider(minimum=0.1, maximum=4.0, value=1.0, step=0.1, label="Temperature"),
114
+ gr.Slider(
115
+ minimum=0.1,
116
+ maximum=1.0,
117
+ value=1.0,
118
+ step=0.05,
119
+ label="Top-p",
120
+ ),
121
+ gr.Slider(
122
+ minimum=0,
123
+ maximum=100,
124
+ value=50,
125
+ step=1,
126
+ label="Top-k",
127
+ )
128
+ ],
129
+ theme=gr.themes.Soft(primary_hue="violet", secondary_hue="violet", neutral_hue="gray", font=[gr.themes.GoogleFont("Exo"), "ui-sans-serif", "system-ui", "sans-serif"]).set(
130
+ body_background_fill_dark="#16141c",
131
+ block_background_fill_dark="#16141c",
132
+ block_border_width="1px",
133
+ block_title_background_fill_dark="#1e1c26",
134
+ input_background_fill_dark="#292733",
135
+ button_secondary_background_fill_dark="#24212b",
136
+ border_color_primary_dark="#343140",
137
+ background_fill_secondary_dark="#16141c",
138
+ color_accent_soft_dark="transparent"
139
+ ),
140
+ css=css,
141
+ retry_btn="Retry",
142
+ undo_btn="Undo",
143
+ clear_btn="Clear",
144
+ submit_btn="Send",
145
+ description="Retrosynthesis chatbot",
146
+ chatbot=gr.Chatbot(scale=1, placeholder=PLACEHOLDER)
147
+ )
148
+
149
+ if __name__ == "__main__":
 
150
  demo.launch()