From be0c9fd57b830c84c2a8ff3818a7e39fa24fa7ca Mon Sep 17 00:00:00 2001 From: Kazuya Ujihara Date: Sat, 15 Jun 2024 23:50:59 +0900 Subject: [PATCH 1/3] reverse the sign of xi --- molpal/acquirer/metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/molpal/acquirer/metrics.py b/molpal/acquirer/metrics.py index afa067d..f6c0f12 100644 --- a/molpal/acquirer/metrics.py +++ b/molpal/acquirer/metrics.py @@ -227,7 +227,7 @@ def ei(Y_mean: np.ndarray, Y_var: np.ndarray, current_max: float, xi: float = 0. E_imp : np.ndarray the expected improvement acquisition scores """ - I = Y_mean - current_max + xi + I = Y_mean - current_max - xi Y_sd = np.sqrt(Y_var) with np.errstate(divide="ignore", invalid="ignore"): Z = I / Y_sd @@ -255,7 +255,7 @@ def pi(Y_mean: np.ndarray, Y_var: np.ndarray, current_max: float, xi: float = 0. P_imp : np.ndarray the probability of improvement acquisition scores """ - I = Y_mean - current_max + xi + I = Y_mean - current_max - xi with np.errstate(divide="ignore"): Z = I / np.sqrt(Y_var) P_imp = norm.cdf(Z) From b6303a4a7ca8c03e728fe866baf0b1ab39164582 Mon Sep 17 00:00:00 2001 From: Kazuya Ujihara Date: Sat, 15 Jun 2024 23:52:00 +0900 Subject: [PATCH 2/3] update test --- tests/test_metrics.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 4d9c869..5ed58ce 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -82,14 +82,14 @@ def test_thompson(Y_mean, Y_var, stochastic): def test_ei(Y_mean: np.ndarray, Y_var_0: np.ndarray, xi, curr_max): U = metrics.ei(Y_mean, Y_var_0, curr_max, xi) - np.testing.assert_array_less(0, U[Y_mean + xi > curr_max]) - np.testing.assert_array_less(U[Y_mean + xi <= curr_max], 0) + np.testing.assert_array_less(0, U[Y_mean - xi > curr_max]) + np.testing.assert_array_less(U[Y_mean - xi <= curr_max], 0) def test_pi(Y_mean: np.ndarray, Y_var_0: np.ndarray, xi, curr_max): U = metrics.pi(Y_mean, Y_var_0, curr_max, xi) - np.testing.assert_array_equal(1, U[Y_mean + xi > curr_max]) - np.testing.assert_array_equal(0, U[Y_mean + xi <= curr_max]) + np.testing.assert_array_equal(1, U[Y_mean - xi > curr_max]) + np.testing.assert_array_equal(0, U[Y_mean - xi <= curr_max]) def test_threshold(Y_mean: np.ndarray, threshold: float): U = metrics.threshold(Y_mean, threshold) From 0479e071e0799674280e935be163306e4a3553e9 Mon Sep 17 00:00:00 2001 From: Kazuya Ujihara Date: Sat, 15 Jun 2024 23:56:12 +0900 Subject: [PATCH 3/3] flake8 --- molpal/models/chemprop/data/utils.py | 2 +- molpal/models/chemprop/features/featurization.py | 2 +- molpal/models/chemprop/utils.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/molpal/models/chemprop/data/utils.py b/molpal/models/chemprop/data/utils.py index 6f3ca1b..aa2ef0c 100644 --- a/molpal/models/chemprop/data/utils.py +++ b/molpal/models/chemprop/data/utils.py @@ -25,7 +25,7 @@ def preprocess_smiles_columns( :return: The preprocessed version of :code:`smiles_column` which is guaranteed to be a list. """ smiles_columns = smiles_columns if smiles_columns is not None else [None] - smiles_columns = [smiles_columns] if type(smiles_columns) != list else smiles_columns + smiles_columns = [smiles_columns] if not isinstance(smiles_columns, list) else smiles_columns return smiles_columns diff --git a/molpal/models/chemprop/features/featurization.py b/molpal/models/chemprop/features/featurization.py index fad1d7e..770c1e5 100644 --- a/molpal/models/chemprop/features/featurization.py +++ b/molpal/models/chemprop/features/featurization.py @@ -141,7 +141,7 @@ def __init__(self, mol: Union[str, Chem.Mol], atom_descriptors: np.ndarray = Non :param mol: A SMILES or an RDKit molecule. """ # Convert SMILES to RDKit molecule if necessary - if type(mol) == str: + if isinstance(mol, str): mol = Chem.MolFromSmiles(mol) self.n_atoms = 0 # number of atoms diff --git a/molpal/models/chemprop/utils.py b/molpal/models/chemprop/utils.py index 6a93103..1392c7f 100644 --- a/molpal/models/chemprop/utils.py +++ b/molpal/models/chemprop/utils.py @@ -116,7 +116,7 @@ def accuracy( :param threshold: The threshold above which a prediction is a 1 and below which (inclusive) a prediction is a 0. :return: The computed accuracy. """ - if type(preds[0]) == list: # multiclass + if isinstance(preds[0], list): # multiclass hard_preds = [p.index(max(p)) for p in preds] else: hard_preds = [1 if p > threshold else 0 for p in preds] # binary prediction