phuongnv's picture
Update app.py
a890222 verified
raw history blame
No virus
4.58 kB
import gradio as gr
import spaces
import selfies as sf
from llama_cpp import Llama
from llama_cpp_agent import LlamaCppAgent
from llama_cpp_agent.providers import LlamaCppPythonProvider
from llama_cpp_agent.chat_history import BasicChatHistory
from llama_cpp_agent.chat_history.messages import Roles
css = """
.message-row {
justify-content: space-evenly !important;
}
.message-bubble-border {
border-radius: 6px !important;
}
.dark.message-bubble-border {
border-color: #343140 !important;
}
.dark.user {
background: #1e1c26 !important;
}
.dark.assistant.dark, .dark.pending.dark {
background: #16141c !important;
}
"""
def get_messages_formatter_type(model_name):
from llama_cpp_agent import MessagesFormatterType
return MessagesFormatterType.CHATML
@spaces.GPU(duration=120)
def respond(
message,
history: list[tuple[str, str]],
max_tokens,
temperature,
top_p,
top_k,
model,
):
chat_template = get_messages_formatter_type(model)
llm = Llama(model_path="model.guff")
provider = LlamaCppPythonProvider(llm)
agent = LlamaCppAgent(
provider,
predefined_messages_formatter_type=chat_template,
debug_output=True
)
settings = provider.get_provider_default_settings()
settings.temperature = temperature
settings.top_k = top_k
settings.top_p = top_p
settings.max_tokens = max_tokens
settings.stream = True
settings.num_beams = 10
settings.num_return_sequences=10
messages = BasicChatHistory()
for msn in history:
user = {
'role': Roles.user,
'content': msn[0]
}
assistant = {
'role': Roles.assistant,
'content': msn[1]
}
messages.add_message(user)
messages.add_message(assistant)
stream = agent.get_chat_response(
message,
llm_sampling_settings=settings,
chat_history=messages,
returns_streaming_generator=True,
print_output=False
)
outputs = set() # Use a set to store unique outputs
unique_responses = []
prompt_length = len(message) # Assuming `message` is the prompt
for index, output in enumerate(stream, start=1):
if output not in outputs:
outputs.add(output)
# Post-process the output
output1 = output[prompt_length:]
first_inst_index = output1.find("[/INST]")
second_inst_index = output1.find("[/IN", first_inst_index + len("[/INST]") + 1)
predicted_selfies = output1[first_inst_index + len("[/INST]") : second_inst_index].strip()
predicted_smiles = sf.decoder(predicted_selfies)
unique_responses.append(f"Predict {index}: {predicted_smiles}")
yield "\n".join(unique_responses)
PLACEHOLDER = """
<div class="message-bubble-border" style="display:flex; max-width: 600px; border-radius: 8px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); backdrop-filter: blur(10px);">
<div style="padding: .5rem 1.5rem;">
<h2 style="text-align: left; font-size: 1.5rem; font-weight: 800; margin-bottom: 0.5rem;">Retrosynthesis Chatbot</h2>
</div>
</div>
"""
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=1.0, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=1.0,
step=0.05,
label="Top-p",
),
gr.Slider(
minimum=0,
maximum=100,
value=50,
step=1,
label="Top-k",
)
],
theme=gr.themes.Soft(primary_hue="violet", secondary_hue="violet", neutral_hue="gray", font=[gr.themes.GoogleFont("Exo"), "ui-sans-serif", "system-ui", "sans-serif"]).set(
body_background_fill_dark="#16141c",
block_background_fill_dark="#16141c",
block_border_width="1px",
block_title_background_fill_dark="#1e1c26",
input_background_fill_dark="#292733",
button_secondary_background_fill_dark="#24212b",
border_color_primary_dark="#343140",
background_fill_secondary_dark="#16141c",
color_accent_soft_dark="transparent"
),
css=css,
retry_btn="Retry",
undo_btn="Undo",
clear_btn="Clear",
submit_btn="Send",
description="Retrosynthesis chatbot",
chatbot=gr.Chatbot(scale=1, placeholder=PLACEHOLDER)
)
if __name__ == "__main__":
demo.launch()