Spaces:
Sleeping
Sleeping
import gradio as gr | |
from llama_cpp import Llama | |
css = """ | |
.message-row { | |
justify-content: space-evenly !important; | |
} | |
.message-bubble-border { | |
border-radius: 6px !important; | |
} | |
.dark.message-bubble-border { | |
border-color: #343140 !important; | |
} | |
.dark.user { | |
background: #1e1c26 !important; | |
} | |
.dark.assistant.dark, .dark.pending.dark { | |
background: #16141c !important; | |
} | |
""" | |
def respond(encoded_smiles, max_tokens, temperature, top_p, top_k): | |
try: | |
# Load the Llama model | |
model_name = "model.gguf" | |
llm = Llama(model_name) # Initialize Llama with the model file | |
# Tokenize the input | |
input_ids = llm.tokenize(encoded_smiles) # Encode input to token IDs | |
# Set generation settings | |
settings = { | |
"max_new_tokens": int(max_tokens), | |
"temperature": float(temperature), | |
"top_p": float(top_p), | |
"top_k": int(top_k), | |
"do_sample": True, | |
} | |
# Generate the output | |
outputs = llm.generate(input_ids, **settings) | |
# Decode the output tokens to text | |
output_text = llm.decode(outputs[0]) | |
# Extract the predicted selfies from the output text | |
first_inst_index = output_text.find("[/INST]") | |
second_inst_index = output_text.find("[/IN", first_inst_index + len("[/INST]") + 1) | |
predicted_selfies = output_text[first_inst_index + len("[/INST]"): second_inst_index].strip() | |
return {'input': encoded_smiles, 'predict': predicted_selfies} | |
except Exception as e: | |
return {'error': str(e)} | |
demo = gr.Interface( | |
fn=respond, | |
inputs=[ | |
gr.Textbox(label="Encoded SMILES"), | |
gr.Slider(minimum=1, maximum=2048, step=1, label="Max tokens", value=512), | |
gr.Slider(minimum=0.1, maximum=4.0, step=0.1, label="Temperature", value=1.0), | |
gr.Slider(minimum=0.1, maximum=1.0, step=0.05, label="Top-p", value=1.0), | |
gr.Slider(minimum=0, maximum=100, step=1, label="Top-k", value=50) | |
], | |
outputs=gr.JSON(label="Results"), | |
theme=gr.themes.Soft(primary_hue="violet", secondary_hue="violet", neutral_hue="gray", font=[gr.themes.GoogleFont("Exo"), "ui-sans-serif", "system-ui", "sans-serif"]).set( | |
body_background_fill_dark="#16141c", | |
block_background_fill_dark="#16141c", | |
block_border_width="1px", | |
block_title_background_fill_dark="#1e1c26", | |
input_background_fill_dark="#292733", | |
button_secondary_background_fill_dark="#24212b", | |
border_color_primary_dark="#343140", | |
background_fill_secondary_dark="#16141c", | |
color_accent_soft_dark="transparent" | |
), | |
css=css, | |
description="Retrosynthesis chatbot", | |
) | |
if __name__ == "__main__": | |
demo.launch() | |