mgyigit commited on
Commit
ce56756
1 Parent(s): 1dd4981

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +7 -3
gradio_app.py CHANGED
@@ -50,7 +50,7 @@ def function(model_name: str, num_molecules: int, seed_num: int) -> tuple[PIL.Im
50
  '''
51
 
52
  config = model_configs[model_name]
53
- config.inference_sample_num = num_molecules
54
  config.seed = seed_num
55
 
56
  inferer = Inference(config)
@@ -69,8 +69,12 @@ def function(model_name: str, num_molecules: int, seed_num: int) -> tuple[PIL.Im
69
 
70
  generated_molecule_list = inference_drugs.split("\n")
71
 
72
- rng = random.Random(seed)
73
- selected_molecules = rng.choices(generated_molecule_list,k=12)
 
 
 
 
74
  selected_molecules = [Chem.MolFromSmiles(mol) for mol in selected_molecules]
75
 
76
  drawOptions = Draw.rdMolDraw2D.MolDrawOptions()
 
50
  '''
51
 
52
  config = model_configs[model_name]
53
+ config.sample_num = num_molecules
54
  config.seed = seed_num
55
 
56
  inferer = Inference(config)
 
69
 
70
  generated_molecule_list = inference_drugs.split("\n")
71
 
72
+ rng = random.Random(config.seed)
73
+ if num_molecules > 12:
74
+ selected_molecules = rng.choices(generated_molecule_list, k=12)
75
+ else:
76
+ selected_molecules = rng.choices(generated_molecule_list, k=num_molecules)
77
+
78
  selected_molecules = [Chem.MolFromSmiles(mol) for mol in selected_molecules]
79
 
80
  drawOptions = Draw.rdMolDraw2D.MolDrawOptions()