LongLe3102000's picture
Update app.py
62720f5 verified
raw
history blame
No virus
2.78 kB
import gradio as gr
from llama_cpp import Llama
css = """
.message-row {
justify-content: space-evenly !important;
}
.message-bubble-border {
border-radius: 6px !important;
}
.dark.message-bubble-border {
border-color: #343140 !important;
}
.dark.user {
background: #1e1c26 !important;
}
.dark.assistant.dark, .dark.pending.dark {
background: #16141c !important;
}
"""
def respond(encoded_smiles, max_tokens, temperature, top_p, top_k):
try:
# Load the Llama model
model_name = "model.gguf"
llm = Llama(model_name) # Initialize Llama with the model file
# Tokenize the input
input_ids = llm.tokenize(encoded_smiles) # Encode input to token IDs
# Set generation settings
settings = {
"max_new_tokens": int(max_tokens),
"temperature": float(temperature),
"top_p": float(top_p),
"top_k": int(top_k),
"do_sample": True,
}
# Generate the output
outputs = llm.generate(input_ids, **settings)
# Decode the output tokens to text
output_text = llm.decode(outputs[0])
# Extract the predicted selfies from the output text
first_inst_index = output_text.find("[/INST]")
second_inst_index = output_text.find("[/IN", first_inst_index + len("[/INST]") + 1)
predicted_selfies = output_text[first_inst_index + len("[/INST]"): second_inst_index].strip()
return {'input': encoded_smiles, 'predict': predicted_selfies}
except Exception as e:
return {'error': str(e)}
demo = gr.Interface(
fn=respond,
inputs=[
gr.Textbox(label="Encoded SMILES"),
gr.Slider(minimum=1, maximum=2048, step=1, label="Max tokens", default=512),
gr.Slider(minimum=0.1, maximum=4.0, step=0.1, label="Temperature", default=1.0),
gr.Slider(minimum=0.1, maximum=1.0, step=0.05, label="Top-p", default=1.0),
gr.Slider(minimum=0, maximum=100, step=1, label="Top-k", default=50)
],
outputs=gr.JSON(label="Results"),
theme=gr.themes.Soft(primary_hue="violet", secondary_hue="violet", neutral_hue="gray", font=[gr.themes.GoogleFont("Exo"), "ui-sans-serif", "system-ui", "sans-serif"]).set(
body_background_fill_dark="#16141c",
block_background_fill_dark="#16141c",
block_border_width="1px",
block_title_background_fill_dark="#1e1c26",
input_background_fill_dark="#292733",
button_secondary_background_fill_dark="#24212b",
border_color_primary_dark="#343140",
background_fill_secondary_dark="#16141c",
color_accent_soft_dark="transparent"
),
css=css,
description="Retrosynthesis chatbot",
)
if __name__ == "__main__":
demo.launch()