LongLe3102000 commited on
Commit
10b4a62
1 Parent(s): aa0348e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -44
app.py CHANGED
@@ -1,11 +1,5 @@
1
  import gradio as gr
2
- import selfies as sf
3
  from llama_cpp import Llama
4
- from llama_cpp_agent import LlamaCppAgent
5
- from llama_cpp_agent.providers import LlamaCppPythonProvider
6
- from llama_cpp_agent.chat_history import BasicChatHistory
7
- from llama_cpp_agent.chat_history.messages import Roles
8
- from llama_cpp_agent import MessagesFormatterType
9
 
10
  css = """
11
  .message-row {
@@ -26,34 +20,37 @@ css = """
26
  """
27
 
28
  def respond(encoded_smiles, max_tokens, temperature, top_p, top_k):
29
- model_name = "model.gguf"
30
- llm = Llama(model_name)
31
- provider = LlamaCppPythonProvider(llm)
32
- chat_template = MessagesFormatterType.CHATML
33
-
34
- agent = LlamaCppAgent(
35
- provider,
36
- predefined_messages_formatter_type=chat_template,
37
- debug_output=True
38
- )
 
 
 
39
 
40
- settings = provider.get_provider_default_settings()
41
- settings.temperature = temperature
42
- settings.top_k = top_k
43
- settings.top_p = top_p
44
- settings.max_tokens = max_tokens
45
- settings.stream = False
46
 
47
- prompt = f"{encoded_smiles}"
48
- input_ids = agent.tokenizer(prompt, return_tensors='pt', truncation=False).input_ids.cuda()
49
- outputs = agent.llm.generate(input_ids=input_ids)
50
- output1 = agent.tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0][len(prompt):]
51
 
52
- first_inst_index = output1.find("[/INST]")
53
- second_inst_index = output1.find("[/IN", first_inst_index + len("[/INST]") + 1)
54
- predicted_selfies = output1[first_inst_index + len("[/INST]") : second_inst_index].strip()
55
 
56
- return {'input': encoded_smiles, 'predict': predicted_selfies}
 
 
 
 
 
 
 
57
 
58
  demo = gr.Interface(
59
  fn=respond,
@@ -61,20 +58,8 @@ demo = gr.Interface(
61
  gr.Textbox(label="Encoded SMILES"),
62
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max tokens"),
63
  gr.Slider(minimum=0.1, maximum=4.0, value=1.0, step=0.1, label="Temperature"),
64
- gr.Slider(
65
- minimum=0.1,
66
- maximum=1.0,
67
- value=1.0,
68
- step=0.05,
69
- label="Top-p",
70
- ),
71
- gr.Slider(
72
- minimum=0,
73
- maximum=100,
74
- value=50,
75
- step=1,
76
- label="Top-k",
77
- )
78
  ],
79
  outputs=gr.JSON(label="Results"),
80
  theme=gr.themes.Soft(primary_hue="violet", secondary_hue="violet", neutral_hue="gray", font=[gr.themes.GoogleFont("Exo"), "ui-sans-serif", "system-ui", "sans-serif"]).set(
 
1
  import gradio as gr
 
2
  from llama_cpp import Llama
 
 
 
 
 
3
 
4
  css = """
5
  .message-row {
 
20
  """
21
 
22
  def respond(encoded_smiles, max_tokens, temperature, top_p, top_k):
23
+ try:
24
+ # Load the Llama model
25
+ model_name = "model.gguf"
26
+ llm = Llama(model_name) # Khởi tạo đối tượng Llama với tệp mô hình
27
+
28
+ # Set generation settings
29
+ settings = {
30
+ "max_new_tokens": max_tokens,
31
+ "temperature": temperature,
32
+ "top_p": top_p,
33
+ "top_k": top_k,
34
+ "do_sample": True,
35
+ }
36
 
37
+ # Tokenize the input
38
+ input_ids = llm.tokenizer(encoded_smiles, return_tensors='pt').input_ids
 
 
 
 
39
 
40
+ # Generate the output
41
+ outputs = llm.generate(input_ids=input_ids, **settings)
 
 
42
 
43
+ # Decode the output tokens to text
44
+ output_text = llm.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
45
 
46
+ # Extract the predicted selfies from the output text
47
+ first_inst_index = output_text.find("[/INST]")
48
+ second_inst_index = output_text.find("[/IN", first_inst_index + len("[/INST]") + 1)
49
+ predicted_selfies = output_text[first_inst_index + len("[/INST]"): second_inst_index].strip()
50
+
51
+ return {'input': encoded_smiles, 'predict': predicted_selfies}
52
+ except Exception as e:
53
+ return {'error': str(e)}
54
 
55
  demo = gr.Interface(
56
  fn=respond,
 
58
  gr.Textbox(label="Encoded SMILES"),
59
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max tokens"),
60
  gr.Slider(minimum=0.1, maximum=4.0, value=1.0, step=0.1, label="Temperature"),
61
+ gr.Slider(minimum=0.1, maximum=1.0, value=1.0, step=0.05, label="Top-p"),
62
+ gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-k")
 
 
 
 
 
 
 
 
 
 
 
 
63
  ],
64
  outputs=gr.JSON(label="Results"),
65
  theme=gr.themes.Soft(primary_hue="violet", secondary_hue="violet", neutral_hue="gray", font=[gr.themes.GoogleFont("Exo"), "ui-sans-serif", "system-ui", "sans-serif"]).set(