LongLe3102000 commited on
Commit
072c58d
1 Parent(s): ded59be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -53
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import gradio as gr
2
- import spaces
3
  import selfies as sf
4
  from llama_cpp import Llama
5
  from llama_cpp_agent import LlamaCppAgent
@@ -26,16 +25,7 @@ css = """
26
  }
27
  """
28
 
29
- @spaces.GPU(duration=120)
30
- def respond(
31
- message,
32
- history: list[tuple[str, str]],
33
- max_tokens,
34
- temperature,
35
- top_p,
36
- top_k
37
- ):
38
-
39
  model_name = "model.gguf"
40
  llm = Llama(model_name)
41
  provider = LlamaCppPythonProvider(llm)
@@ -52,46 +42,23 @@ def respond(
52
  settings.top_k = top_k
53
  settings.top_p = top_p
54
  settings.max_tokens = max_tokens
55
- settings.stream = True
56
 
57
- messages = BasicChatHistory()
58
-
59
- for msn in history:
60
- user = {
61
- 'role': Roles.user,
62
- 'content': msn[0]
63
- }
64
- assistant = {
65
- 'role': Roles.assistant,
66
- 'content': msn[1]
67
- }
68
- messages.add_message(user)
69
- messages.add_message(assistant)
70
 
71
- stream = agent.get_chat_response(
72
- sf.encoder(message),
73
- llm_sampling_settings=settings,
74
- chat_history=messages,
75
- returns_streaming_generator=True,
76
- print_output=False
77
- )
78
 
79
- outputs = ""
80
- for output in stream:
81
- outputs += output
82
- yield outputs
83
-
84
- PLACEHOLDER = """
85
- <div class="message-bubble-border" style="display:flex; max-width: 600px; border-radius: 8px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); backdrop-filter: blur(10px);">
86
- <div style="padding: .5rem 1.5rem;">
87
- <h2 style="text-align: left; font-size: 1.5rem; font-weight: 800; margin-bottom: 0.5rem;">Retrosynthesis Chatbot</h2>
88
- </div>
89
- </div>
90
- """
91
 
92
- demo = gr.ChatInterface(
93
- respond,
94
- additional_inputs=[
 
95
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max tokens"),
96
  gr.Slider(minimum=0.1, maximum=4.0, value=1.0, step=0.1, label="Temperature"),
97
  gr.Slider(
@@ -109,6 +76,7 @@ demo = gr.ChatInterface(
109
  label="Top-k",
110
  )
111
  ],
 
112
  theme=gr.themes.Soft(primary_hue="violet", secondary_hue="violet", neutral_hue="gray", font=[gr.themes.GoogleFont("Exo"), "ui-sans-serif", "system-ui", "sans-serif"]).set(
113
  body_background_fill_dark="#16141c",
114
  block_background_fill_dark="#16141c",
@@ -121,13 +89,8 @@ demo = gr.ChatInterface(
121
  color_accent_soft_dark="transparent"
122
  ),
123
  css=css,
124
- retry_btn="Retry",
125
- undo_btn="Undo",
126
- clear_btn="Clear",
127
- submit_btn="Send",
128
  description="Retrosynthesis chatbot",
129
- chatbot=gr.Chatbot(scale=1, placeholder=PLACEHOLDER)
130
  )
131
 
132
  if __name__ == "__main__":
133
- demo.launch()
 
1
  import gradio as gr
 
2
  import selfies as sf
3
  from llama_cpp import Llama
4
  from llama_cpp_agent import LlamaCppAgent
 
25
  }
26
  """
27
 
28
+ def respond(encoded_smiles, max_tokens, temperature, top_p, top_k):
 
 
 
 
 
 
 
 
 
29
  model_name = "model.gguf"
30
  llm = Llama(model_name)
31
  provider = LlamaCppPythonProvider(llm)
 
42
  settings.top_k = top_k
43
  settings.top_p = top_p
44
  settings.max_tokens = max_tokens
45
+ settings.stream = False
46
 
47
+ prompt = f"{encoded_smiles}"
48
+ input_ids = agent.tokenizer(prompt, return_tensors='pt', truncation=False).input_ids.cuda()
49
+ outputs = agent.model.generate(input_ids=input_ids, max_new_tokens=512)
50
+ output1 = agent.tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0][len(prompt):]
 
 
 
 
 
 
 
 
 
51
 
52
+ first_inst_index = output1.find("[/INST]")
53
+ second_inst_index = output1.find("[/IN", first_inst_index + len("[/INST]") + 1)
54
+ predicted_selfies = output1[first_inst_index + len("[/INST]") : second_inst_index].strip()
 
 
 
 
55
 
56
+ return {'input': encoded_smiles, 'predict': predicted_selfies}
 
 
 
 
 
 
 
 
 
 
 
57
 
58
+ demo = gr.Interface(
59
+ fn=respond,
60
+ inputs=[
61
+ gr.Textbox(label="Encoded SMILES"),
62
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max tokens"),
63
  gr.Slider(minimum=0.1, maximum=4.0, value=1.0, step=0.1, label="Temperature"),
64
  gr.Slider(
 
76
  label="Top-k",
77
  )
78
  ],
79
+ outputs=gr.JSON(label="Results"),
80
  theme=gr.themes.Soft(primary_hue="violet", secondary_hue="violet", neutral_hue="gray", font=[gr.themes.GoogleFont("Exo"), "ui-sans-serif", "system-ui", "sans-serif"]).set(
81
  body_background_fill_dark="#16141c",
82
  block_background_fill_dark="#16141c",
 
89
  color_accent_soft_dark="transparent"
90
  ),
91
  css=css,
 
 
 
 
92
  description="Retrosynthesis chatbot",
 
93
  )
94
 
95
  if __name__ == "__main__":
96
+ demo.launch()