alimotahharynia commited on
Commit
f92ab61
1 Parent(s): c7b2596

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -45
app.py CHANGED
@@ -49,6 +49,12 @@ def load_uniprot_dataset(dataset_name, dataset_key):
49
  logging.error(f"Error loading dataset: {e}")
50
  raise RuntimeError(f"Failed to load dataset: {e}")
51
 
 
 
 
 
 
 
52
  # SMILES Generator
53
  class SMILESGenerator:
54
  def __init__(self, model, tokenizer, uniprot_to_sequence):
@@ -102,61 +108,40 @@ class SMILESGenerator:
102
  def generate_smiles_gradio(sequence_input=None, uniprot_id=None, num_generated=10):
103
  results = {}
104
 
105
- # Process sequence inputs and include UniProt ID if found
106
  if sequence_input:
107
  sequences = [seq.strip() for seq in sequence_input.split(",") if seq.strip()]
108
  for seq in sequences:
109
  try:
110
- # Find the corresponding UniProt ID for the sequence
111
- uniprot_id_for_seq = [uid for uid, s in uniprot_to_sequence.items() if s == seq]
112
- uniprot_id_for_seq = uniprot_id_for_seq[0] if uniprot_id_for_seq else "N/A"
113
-
114
- # Generate SMILES for the sequence
115
  smiles = generator.generate_smiles(seq, num_generated)
116
- results[uniprot_id_for_seq] = {
117
- "sequence": seq,
118
- "smiles": smiles
119
- }
120
  except Exception as e:
121
- results["N/A"] = {"sequence": seq, "error": f"Error generating SMILES: {str(e)}"}
122
 
123
- # Process UniProt ID inputs and include sequence if found
124
  if uniprot_id:
125
  uniprot_ids = [uid.strip() for uid in uniprot_id.split(",") if uid.strip()]
126
  for uid in uniprot_ids:
127
- sequence = uniprot_to_sequence.get(uid, "N/A")
128
  try:
129
- # Generate SMILES for the sequence found
130
- if sequence != "N/A":
131
  smiles = generator.generate_smiles(sequence, num_generated)
132
- results[uid] = {
133
- "sequence": sequence,
134
- "smiles": smiles
135
- }
136
  else:
137
- results[uid] = {
138
- "sequence": "N/A",
139
- "error": f"UniProt ID {uid} not found in the dataset."
140
- }
141
  except Exception as e:
142
  results[uid] = {"sequence": "N/A", "error": f"Error generating SMILES: {str(e)}"}
143
 
144
- # Check if no results were generated
145
  if not results:
146
- return {"error": "No SMILES generated. Please try again with different inputs."}
147
 
148
- # Save results to a file
149
  file_path = save_smiles_to_file(results)
150
  return results, file_path
151
 
152
 
153
- def save_smiles_to_file(results):
154
- file_path = os.path.join(tempfile.gettempdir(), "generated_smiles.json")
155
- with open(file_path, "w") as f:
156
- json.dump(results, f, indent=4)
157
- return file_path
158
-
159
-
160
  # Main initialization and Gradio setup
161
  if __name__ == "__main__":
162
  setup_logging()
@@ -164,34 +149,142 @@ if __name__ == "__main__":
164
  dataset_name = "alimotahharynia/approved_drug_target"
165
  dataset_key = "uniprot_sequence"
166
 
167
- # Load model, tokenizer, and dataset
168
  model, tokenizer = load_model_and_tokenizer(model_name)
169
  uniprot_to_sequence = load_uniprot_dataset(dataset_name, dataset_key)
170
 
171
- # SMILESGenerator
172
  generator = SMILESGenerator(model, tokenizer, uniprot_to_sequence)
173
 
174
- # Gradio interface
175
- with gr.Blocks() as iface:
176
- gr.Markdown("## DrugGen interface")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  with gr.Row():
178
  sequence_input = gr.Textbox(
179
- label="Input Protein Sequences",
180
- placeholder="Enter protein sequences separated by commas..."
 
181
  )
182
  uniprot_id_input = gr.Textbox(
183
  label="UniProt IDs",
184
- placeholder="Enter UniProt IDs separated by commas..."
 
185
  )
186
- num_generated_slider = gr.Slider(minimum=1, maximum=100, step=1, value=10, label="Number of Unique SMILES to Generate")
 
 
 
 
 
 
 
 
187
  output = gr.JSON(label="Generated SMILES")
188
- file_output = gr.File(label="Download output as .json")
 
 
189
 
190
- generate_button = gr.Button("Generate SMILES")
191
  generate_button.click(
192
  generate_smiles_gradio,
193
  inputs=[sequence_input, uniprot_id_input, num_generated_slider],
194
  outputs=[output, file_output]
195
  )
196
 
197
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  logging.error(f"Error loading dataset: {e}")
50
  raise RuntimeError(f"Failed to load dataset: {e}")
51
 
52
+ def save_smiles_to_file(results):
53
+ file_path = os.path.join(tempfile.gettempdir(), "generated_smiles.json")
54
+ with open(file_path, "w") as f:
55
+ json.dump(results, f, indent=4)
56
+ return file_path
57
+
58
  # SMILES Generator
59
  class SMILESGenerator:
60
  def __init__(self, model, tokenizer, uniprot_to_sequence):
 
108
  def generate_smiles_gradio(sequence_input=None, uniprot_id=None, num_generated=10):
109
  results = {}
110
 
111
+ # Process protein sequences
112
  if sequence_input:
113
  sequences = [seq.strip() for seq in sequence_input.split(",") if seq.strip()]
114
  for seq in sequences:
115
  try:
116
+ # Always attempt to generate SMILES from the sequence (regardless of validity)
 
 
 
 
117
  smiles = generator.generate_smiles(seq, num_generated)
118
+ results[seq] = {"sequence": seq, "smiles": smiles}
 
 
 
119
  except Exception as e:
120
+ results[seq] = {"sequence": seq, "error": f"Error generating SMILES: {str(e)}"}
121
 
122
+ # Process UniProt IDs
123
  if uniprot_id:
124
  uniprot_ids = [uid.strip() for uid in uniprot_id.split(",") if uid.strip()]
125
  for uid in uniprot_ids:
126
+ sequence = uniprot_to_sequence.get(uid, None) # None if not found
127
  try:
128
+ if sequence:
 
129
  smiles = generator.generate_smiles(sequence, num_generated)
130
+ results[uid] = {"sequence": sequence, "smiles": smiles}
 
 
 
131
  else:
132
+ # UniProt ID not found
133
+ results[uid] = {"sequence": "N/A", "error": f"UniProt ID {uid} not found in dataset."}
 
 
134
  except Exception as e:
135
  results[uid] = {"sequence": "N/A", "error": f"Error generating SMILES: {str(e)}"}
136
 
 
137
  if not results:
138
+ return {"error": "No valid input provided. Please try again with different sequences or UniProt IDs."}
139
 
140
+ # Save
141
  file_path = save_smiles_to_file(results)
142
  return results, file_path
143
 
144
 
 
 
 
 
 
 
 
145
  # Main initialization and Gradio setup
146
  if __name__ == "__main__":
147
  setup_logging()
 
149
  dataset_name = "alimotahharynia/approved_drug_target"
150
  dataset_key = "uniprot_sequence"
151
 
 
152
  model, tokenizer = load_model_and_tokenizer(model_name)
153
  uniprot_to_sequence = load_uniprot_dataset(dataset_name, dataset_key)
154
 
 
155
  generator = SMILESGenerator(model, tokenizer, uniprot_to_sequence)
156
 
157
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="teal")) as iface:
158
+ custom_css = """
159
+ body {
160
+ font-family: 'Roboto', sans-serif;
161
+ background-color: #fafafa;
162
+ color: #333;
163
+ font-size: 16px;
164
+ }
165
+
166
+ #app-title {
167
+ text-align: center;
168
+ font-size: 36px;
169
+ font-weight: 700;
170
+ color: #2c3e50;
171
+ margin-bottom: 20px;
172
+ }
173
+
174
+ #description {
175
+ font-size: 18px;
176
+ margin-bottom: 40px;
177
+ text-align: center;
178
+ color: #555;
179
+ }
180
+
181
+ .gr-button {
182
+ padding: 12px 24px;
183
+ font-weight: bold;
184
+ background-color: #007bff;
185
+ color: white;
186
+ border-radius: 8px;
187
+ border: none;
188
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
189
+ transition: all 0.3s ease;
190
+ }
191
+
192
+ .gr-button:hover {
193
+ background-color: #0056b3;
194
+ transform: translateY(-2px);
195
+ box-shadow: 0 6px 10px rgba(0, 0, 0, 0.2);
196
+ }
197
+
198
+ .gr-input:focus {
199
+ border-color: #007bff;
200
+ box-shadow: 0 0 8px rgba(0, 123, 255, 0.3);
201
+ }
202
+
203
+ .gr-output {
204
+ background-color: #ffffff;
205
+ border-radius: 10px;
206
+ padding: 20px;
207
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
208
+ margin-bottom: 20px;
209
+ }
210
+
211
+ .error-message {
212
+ background-color: #f8d7da;
213
+ border-color: #f5c6cb;
214
+ color: #721c24;
215
+ padding: 15px;
216
+ border-radius: 8px;
217
+ }
218
+
219
+ .success-message {
220
+ background-color: #d4edda;
221
+ border-color: #c3e6cb;
222
+ color: #155724;
223
+ padding: 15px;
224
+ border-radius: 8px;
225
+ }
226
+
227
+ .gr-row {
228
+ margin-bottom: 20px;
229
+ }
230
+ """
231
+
232
+ iface.css = custom_css
233
+ gr.Markdown("## GPT-2 Drug Generator", elem_id="app-title")
234
+ gr.Markdown(
235
+ "Generate **drug-like SMILES structures** from protein sequences or UniProt IDs. "
236
+ "Input data, specify parameters, and download the results.",
237
+ elem_id="description"
238
+ )
239
+
240
  with gr.Row():
241
  sequence_input = gr.Textbox(
242
+ label="Protein Sequences",
243
+ placeholder="Enter sequences separated by commas (e.g., MGAASGRRGP, MGETLGDSPI, ...)",
244
+ lines=3,
245
  )
246
  uniprot_id_input = gr.Textbox(
247
  label="UniProt IDs",
248
+ placeholder="Enter UniProt IDs separated by commas (e.g., P12821, P37231, ...)",
249
+ lines=1,
250
  )
251
+
252
+ num_generated_slider = gr.Slider(
253
+ minimum=1,
254
+ maximum=100,
255
+ step=1,
256
+ value=10,
257
+ label="Number of Unique SMILES to Generate",
258
+ )
259
+
260
  output = gr.JSON(label="Generated SMILES")
261
+ file_output = gr.File(label="Download Results as JSON")
262
+
263
+ generate_button = gr.Button("Generate SMILES", elem_id="generate-button")
264
 
 
265
  generate_button.click(
266
  generate_smiles_gradio,
267
  inputs=[sequence_input, uniprot_id_input, num_generated_slider],
268
  outputs=[output, file_output]
269
  )
270
 
271
+ gr.Markdown("""
272
+ ### How to Cite:
273
+ If you use this tool in your research, please cite the following work:
274
+
275
+ ```bibtex
276
+ @misc{sheikholeslami2024druggenadvancingdrugdiscovery,
277
+ title={DrugGen: Advancing Drug Discovery with Large Language Models and Reinforcement Learning Feedback},
278
+ author={Mahsa Sheikholeslami and Navid Mazrouei and Yousof Gheisari and Afshin Fasihi and Matin Irajpour and Ali Motahharynia},
279
+ year={2024},
280
+ eprint={2411.14157},
281
+ archivePrefix={arXiv},
282
+ primaryClass={q-bio.QM},
283
+ url={https://arxiv.org/abs/2411.14157},
284
+ }
285
+ ```
286
+
287
+ This will help us maintain the tool and support future development!
288
+ """)
289
+
290
+ iface.launch(allowed_paths=["/tmp"])