File size: 2,770 Bytes
2f97117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
072c58d
10b4a62
 
 
62720f5
10b4a62
386be1d
62720f5
386be1d
10b4a62
 
386be1d
 
 
 
10b4a62
 
2f97117
10b4a62
386be1d
2f97117
10b4a62
386be1d
2f97117
10b4a62
 
 
 
 
 
 
 
2f97117
072c58d
 
 
 
adc9c0c
 
 
 
2f97117
072c58d
2f97117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
072c58d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
from llama_cpp import Llama

css = """
.message-row {
    justify-content: space-evenly !important;
}
.message-bubble-border {
    border-radius: 6px !important;
}
.dark.message-bubble-border {
    border-color: #343140 !important;
}
.dark.user {
    background: #1e1c26 !important;
}
.dark.assistant.dark, .dark.pending.dark {
    background: #16141c !important;
}
"""

def respond(encoded_smiles, max_tokens, temperature, top_p, top_k):
    try:
        # Load the Llama model
        model_name = "model.gguf"
        llm = Llama(model_name)  # Initialize Llama with the model file
        
        # Tokenize the input
        input_ids = llm.tokenize(encoded_smiles)  # Encode input to token IDs

        # Set generation settings
        settings = {
            "max_new_tokens": int(max_tokens),
            "temperature": float(temperature),
            "top_p": float(top_p),
            "top_k": int(top_k),
            "do_sample": True,
        }

        # Generate the output
        outputs = llm.generate(input_ids, **settings)

        # Decode the output tokens to text
        output_text = llm.decode(outputs[0])

        # Extract the predicted selfies from the output text
        first_inst_index = output_text.find("[/INST]")
        second_inst_index = output_text.find("[/IN", first_inst_index + len("[/INST]") + 1)
        predicted_selfies = output_text[first_inst_index + len("[/INST]"): second_inst_index].strip()

        return {'input': encoded_smiles, 'predict': predicted_selfies}
    except Exception as e:
        return {'error': str(e)}

demo = gr.Interface(
    fn=respond,
    inputs=[
        gr.Textbox(label="Encoded SMILES"),
        gr.Slider(minimum=1, maximum=2048, step=1, label="Max tokens", value=512),
        gr.Slider(minimum=0.1, maximum=4.0, step=0.1, label="Temperature", value=1.0),
        gr.Slider(minimum=0.1, maximum=1.0, step=0.05, label="Top-p", value=1.0),
        gr.Slider(minimum=0, maximum=100, step=1, label="Top-k", value=50)
    ],
    outputs=gr.JSON(label="Results"),
    theme=gr.themes.Soft(primary_hue="violet", secondary_hue="violet", neutral_hue="gray", font=[gr.themes.GoogleFont("Exo"), "ui-sans-serif", "system-ui", "sans-serif"]).set(
        body_background_fill_dark="#16141c",
        block_background_fill_dark="#16141c",
        block_border_width="1px",
        block_title_background_fill_dark="#1e1c26",
        input_background_fill_dark="#292733",
        button_secondary_background_fill_dark="#24212b",
        border_color_primary_dark="#343140",
        background_fill_secondary_dark="#16141c",
        color_accent_soft_dark="transparent"
    ),
    css=css,
    description="Retrosynthesis chatbot",
)

if __name__ == "__main__":
    demo.launch()