Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions molpal/acquirer/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def ei(Y_mean: np.ndarray, Y_var: np.ndarray, current_max: float, xi: float = 0.
E_imp : np.ndarray
the expected improvement acquisition scores
"""
I = Y_mean - current_max + xi
I = Y_mean - current_max - xi
Y_sd = np.sqrt(Y_var)
with np.errstate(divide="ignore", invalid="ignore"):
Z = I / Y_sd
Expand Down Expand Up @@ -255,7 +255,7 @@ def pi(Y_mean: np.ndarray, Y_var: np.ndarray, current_max: float, xi: float = 0.
P_imp : np.ndarray
the probability of improvement acquisition scores
"""
I = Y_mean - current_max + xi
I = Y_mean - current_max - xi
with np.errstate(divide="ignore"):
Z = I / np.sqrt(Y_var)
P_imp = norm.cdf(Z)
Expand Down
2 changes: 1 addition & 1 deletion molpal/models/chemprop/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def preprocess_smiles_columns(
:return: The preprocessed version of :code:`smiles_column` which is guaranteed to be a list.
"""
smiles_columns = smiles_columns if smiles_columns is not None else [None]
smiles_columns = [smiles_columns] if type(smiles_columns) != list else smiles_columns
smiles_columns = [smiles_columns] if not isinstance(smiles_columns, list) else smiles_columns

return smiles_columns

Expand Down
2 changes: 1 addition & 1 deletion molpal/models/chemprop/features/featurization.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def __init__(self, mol: Union[str, Chem.Mol], atom_descriptors: np.ndarray = Non
:param mol: A SMILES or an RDKit molecule.
"""
# Convert SMILES to RDKit molecule if necessary
if type(mol) == str:
if isinstance(mol, str):
mol = Chem.MolFromSmiles(mol)

self.n_atoms = 0 # number of atoms
Expand Down
2 changes: 1 addition & 1 deletion molpal/models/chemprop/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def accuracy(
:param threshold: The threshold above which a prediction is a 1 and below which (inclusive) a prediction is a 0.
:return: The computed accuracy.
"""
if type(preds[0]) == list: # multiclass
if isinstance(preds[0], list): # multiclass
hard_preds = [p.index(max(p)) for p in preds]
else:
hard_preds = [1 if p > threshold else 0 for p in preds] # binary prediction
Expand Down
8 changes: 4 additions & 4 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,14 @@ def test_thompson(Y_mean, Y_var, stochastic):
def test_ei(Y_mean: np.ndarray, Y_var_0: np.ndarray, xi, curr_max):
U = metrics.ei(Y_mean, Y_var_0, curr_max, xi)

np.testing.assert_array_less(0, U[Y_mean + xi > curr_max])
np.testing.assert_array_less(U[Y_mean + xi <= curr_max], 0)
np.testing.assert_array_less(0, U[Y_mean - xi > curr_max])
np.testing.assert_array_less(U[Y_mean - xi <= curr_max], 0)

def test_pi(Y_mean: np.ndarray, Y_var_0: np.ndarray, xi, curr_max):
U = metrics.pi(Y_mean, Y_var_0, curr_max, xi)

np.testing.assert_array_equal(1, U[Y_mean + xi > curr_max])
np.testing.assert_array_equal(0, U[Y_mean + xi <= curr_max])
np.testing.assert_array_equal(1, U[Y_mean - xi > curr_max])
np.testing.assert_array_equal(0, U[Y_mean - xi <= curr_max])

def test_threshold(Y_mean: np.ndarray, threshold: float):
U = metrics.threshold(Y_mean, threshold)
Expand Down