Spaces:
Sleeping
Sleeping
File size: 4,576 Bytes
2f97117 cdc6d21 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import gradio as gr
import spaces
import selfies as sf
from llama_cpp import Llama
from llama_cpp_agent import LlamaCppAgent
from llama_cpp_agent.providers import LlamaCppPythonProvider
from llama_cpp_agent.chat_history import BasicChatHistory
from llama_cpp_agent.chat_history.messages import Roles
css = """
.message-row {
justify-content: space-evenly !important;
}
.message-bubble-border {
border-radius: 6px !important;
}
.dark.message-bubble-border {
border-color: #343140 !important;
}
.dark.user {
background: #1e1c26 !important;
}
.dark.assistant.dark, .dark.pending.dark {
background: #16141c !important;
}
"""
def get_messages_formatter_type(model_name):
from llama_cpp_agent import MessagesFormatterType
return MessagesFormatterType.CHATML
@spaces.GPU(duration=120)
def respond(
message,
history: list[tuple[str, str]],
max_tokens,
temperature,
top_p,
top_k,
model,
):
chat_template = get_messages_formatter_type(model)
llm = Llama(model_path="model.guff")
provider = LlamaCppPythonProvider(llm)
agent = LlamaCppAgent(
provider,
predefined_messages_formatter_type=chat_template,
debug_output=True
)
settings = provider.get_provider_default_settings()
settings.temperature = temperature
settings.top_k = top_k
settings.top_p = top_p
settings.max_tokens = max_tokens
settings.stream = True
settings.num_beams = 10 # Enable beam search with 10 beams
messages = BasicChatHistory()
for msn in history:
user = {
'role': Roles.user,
'content': msn[0]
}
assistant = {
'role': Roles.assistant,
'content': msn[1]
}
messages.add_message(user)
messages.add_message(assistant)
stream = agent.get_chat_response(
message,
llm_sampling_settings=settings,
chat_history=messages,
returns_streaming_generator=True,
print_output=False
)
outputs = set() # Use a set to store unique outputs
unique_responses = []
prompt_length = len(message) # Assuming `message` is the prompt
for index, output in enumerate(stream, start=1):
if output not in outputs:
outputs.add(output)
# Post-process the output
output1 = output[prompt_length:]
first_inst_index = output1.find("[/INST]")
second_inst_index = output1.find("[/IN", first_inst_index + len("[/INST]") + 1)
predicted_selfies = output1[first_inst_index + len("[/INST]") : second_inst_index].strip()
predicted_smiles = sf.decoder(predicted_selfies)
unique_responses.append(f"Predict {index}: {predicted_smiles}")
yield "\n".join(unique_responses)
PLACEHOLDER = """
<div class="message-bubble-border" style="display:flex; max-width: 600px; border-radius: 8px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); backdrop-filter: blur(10px);">
<div style="padding: .5rem 1.5rem;">
<h2 style="text-align: left; font-size: 1.5rem; font-weight: 700; margin-bottom: 0.5rem;">Retrosynthesis Chatbot</h2>
</div>
</div>
"""
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=1.0, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=1.0,
step=0.05,
label="Top-p",
),
gr.Slider(
minimum=0,
maximum=100,
value=50,
step=1,
label="Top-k",
)
],
theme=gr.themes.Soft(primary_hue="violet", secondary_hue="violet", neutral_hue="gray", font=[gr.themes.GoogleFont("Exo"), "ui-sans-serif", "system-ui", "sans-serif"]).set(
body_background_fill_dark="#16141c",
block_background_fill_dark="#16141c",
block_border_width="1px",
block_title_background_fill_dark="#1e1c26",
input_background_fill_dark="#292733",
button_secondary_background_fill_dark="#24212b",
border_color_primary_dark="#343140",
background_fill_secondary_dark="#16141c",
color_accent_soft_dark="transparent"
),
css=css,
retry_btn="Retry",
undo_btn="Undo",
clear_btn="Clear",
submit_btn="Send",
description="Retrosynthesis chatbot",
chatbot=gr.Chatbot(scale=1, placeholder=PLACEHOLDER)
)
if __name__ == "__main__":
demo.launch() |