diff --git a/molpal/acquirer/metrics.py b/molpal/acquirer/metrics.py index afa067d..f6c0f12 100644 --- a/molpal/acquirer/metrics.py +++ b/molpal/acquirer/metrics.py @@ -227,7 +227,7 @@ def ei(Y_mean: np.ndarray, Y_var: np.ndarray, current_max: float, xi: float = 0. E_imp : np.ndarray the expected improvement acquisition scores """ - I = Y_mean - current_max + xi + I = Y_mean - current_max - xi Y_sd = np.sqrt(Y_var) with np.errstate(divide="ignore", invalid="ignore"): Z = I / Y_sd @@ -255,7 +255,7 @@ def pi(Y_mean: np.ndarray, Y_var: np.ndarray, current_max: float, xi: float = 0. P_imp : np.ndarray the probability of improvement acquisition scores """ - I = Y_mean - current_max + xi + I = Y_mean - current_max - xi with np.errstate(divide="ignore"): Z = I / np.sqrt(Y_var) P_imp = norm.cdf(Z) diff --git a/molpal/models/chemprop/data/utils.py b/molpal/models/chemprop/data/utils.py index 6f3ca1b..aa2ef0c 100644 --- a/molpal/models/chemprop/data/utils.py +++ b/molpal/models/chemprop/data/utils.py @@ -25,7 +25,7 @@ def preprocess_smiles_columns( :return: The preprocessed version of :code:`smiles_column` which is guaranteed to be a list. """ smiles_columns = smiles_columns if smiles_columns is not None else [None] - smiles_columns = [smiles_columns] if type(smiles_columns) != list else smiles_columns + smiles_columns = [smiles_columns] if not isinstance(smiles_columns, list) else smiles_columns return smiles_columns diff --git a/molpal/models/chemprop/features/featurization.py b/molpal/models/chemprop/features/featurization.py index fad1d7e..770c1e5 100644 --- a/molpal/models/chemprop/features/featurization.py +++ b/molpal/models/chemprop/features/featurization.py @@ -141,7 +141,7 @@ def __init__(self, mol: Union[str, Chem.Mol], atom_descriptors: np.ndarray = Non :param mol: A SMILES or an RDKit molecule. """ # Convert SMILES to RDKit molecule if necessary - if type(mol) == str: + if isinstance(mol, str): mol = Chem.MolFromSmiles(mol) self.n_atoms = 0 # number of atoms diff --git a/molpal/models/chemprop/utils.py b/molpal/models/chemprop/utils.py index 6a93103..1392c7f 100644 --- a/molpal/models/chemprop/utils.py +++ b/molpal/models/chemprop/utils.py @@ -116,7 +116,7 @@ def accuracy( :param threshold: The threshold above which a prediction is a 1 and below which (inclusive) a prediction is a 0. :return: The computed accuracy. """ - if type(preds[0]) == list: # multiclass + if isinstance(preds[0], list): # multiclass hard_preds = [p.index(max(p)) for p in preds] else: hard_preds = [1 if p > threshold else 0 for p in preds] # binary prediction diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 4d9c869..5ed58ce 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -82,14 +82,14 @@ def test_thompson(Y_mean, Y_var, stochastic): def test_ei(Y_mean: np.ndarray, Y_var_0: np.ndarray, xi, curr_max): U = metrics.ei(Y_mean, Y_var_0, curr_max, xi) - np.testing.assert_array_less(0, U[Y_mean + xi > curr_max]) - np.testing.assert_array_less(U[Y_mean + xi <= curr_max], 0) + np.testing.assert_array_less(0, U[Y_mean - xi > curr_max]) + np.testing.assert_array_less(U[Y_mean - xi <= curr_max], 0) def test_pi(Y_mean: np.ndarray, Y_var_0: np.ndarray, xi, curr_max): U = metrics.pi(Y_mean, Y_var_0, curr_max, xi) - np.testing.assert_array_equal(1, U[Y_mean + xi > curr_max]) - np.testing.assert_array_equal(0, U[Y_mean + xi <= curr_max]) + np.testing.assert_array_equal(1, U[Y_mean - xi > curr_max]) + np.testing.assert_array_equal(0, U[Y_mean - xi <= curr_max]) def test_threshold(Y_mean: np.ndarray, threshold: float): U = metrics.threshold(Y_mean, threshold)