From 5458cc9888b620404615dc89f9fadc65f02f427f Mon Sep 17 00:00:00 2001 From: jimnel Date: Wed, 2 Apr 2025 13:12:12 +0100 Subject: [PATCH 1/5] add get and del tests --- qse/qbits.py | 84 ++++++++++++++------------------------------- tests/qbits_test.py | 40 +++++++++++++++++++++ 2 files changed, 66 insertions(+), 58 deletions(-) diff --git a/qse/qbits.py b/qse/qbits.py index 8a19515..8812602 100644 --- a/qse/qbits.py +++ b/qse/qbits.py @@ -699,93 +699,61 @@ def __iter__(self): for i in range(len(self)): yield self[i] - def __getitem__(self, i): - """Return a subset of the qbits. + def __getitem__(self, indicies): + """ + Return a subset of the qbits. - i -- scalar integer, list of integers, or slice object - describing which qbits to return. + Parameters + ---------- + indicies : int | list | slice + The indicies to be returned. - If i is a scalar, return an Qbit object. If i is a list or a - slice, return an Qbits object with the same cell, pbc, and - other associated info as the original Qbits object. The - indices of the constraints will be shuffled so that they match - the indexing in the subset returned. + Returns + ------- + If indicies is a scalar, return an Qbit object. + If indicies is a list or a slice, return an Qbits object with the same cell, pbc, and + other associated info as the original Qbits object. """ - if isinstance(i, numbers.Integral): + if isinstance(indicies, numbers.Integral): nqbits = len(self) - if i < -nqbits or i >= nqbits: + if indicies < -nqbits or indicies >= nqbits: raise IndexError("Index out of range.") + return Qbit(qbits=self, index=indicies) - return Qbit(qbits=self, index=i) - elif not isinstance(i, slice): - i = np.array(i) + if not isinstance(indicies, slice): + indicies = np.array(indicies) # if i is a mask - if i.dtype == bool: - if len(i) != len(self): + if indicies.dtype == bool: + if len(indicies) != len(self): raise IndexError( "Length of mask {} must equal " - "number of qbits {}".format(len(i), len(self)) + "number of qbits {}".format(len(indicies), len(self)) ) - i = np.arange(len(self))[i] - - import copy - - conadd = [] - # Constraints need to be deepcopied, but only the relevant ones. - for con in copy.deepcopy(self.constraints): - try: - con.index_shuffle(self, i) - except (IndexError, NotImplementedError): - pass - else: - conadd.append(con) + indicies = np.arange(len(self))[indicies] qbits = self.__class__( cell=self.cell, pbc=self.pbc, info=self.info, - # should be communicated to the slice as well celldisp=self._celldisp, ) - # TODO: Do we need to shuffle indices in adsorbate_info too? qbits.arrays = {} for name, a in self.arrays.items(): - qbits.arrays[name] = a[i].copy() + qbits.arrays[name] = a[indicies].copy() - qbits.constraints = conadd return qbits - def __delitem__(self, i): - from qse.constraints import FixQbits - - for c in self._constraints: - if not isinstance(c, FixQbits): - raise RuntimeError( - "Remove constraint using set_constraint() " "before deleting qbits." - ) - - if isinstance(i, list) and len(i) > 0: + def __delitem__(self, indicies): + if isinstance(indicies, list) and len(indicies) > 0: # Make sure a list of booleans will work correctly and not be # interpreted at 0 and 1 indices. - i = np.array(i) - - if len(self._constraints) > 0: - n = len(self) - i = np.arange(n)[i] - if isinstance(i, int): - i = [i] - constraints = [] - for c in self._constraints: - c = c.delete_qbits(i, n) - if c is not None: - constraints.append(c) - self.constraints = constraints + indicies = np.array(indicies) mask = np.ones(len(self), bool) - mask[i] = False + mask[indicies] = False for name, a in self.arrays.items(): self.arrays[name] = a[mask] diff --git a/tests/qbits_test.py b/tests/qbits_test.py index 0cba5d6..4404627 100644 --- a/tests/qbits_test.py +++ b/tests/qbits_test.py @@ -127,3 +127,43 @@ def test_translate(nqbits, type_of_disp): else: qbits.translate(disp) assert np.allclose(qbits.get_positions(), positions + disp) + + +def test_get_item(): + """Test __getitem__""" + positions = np.random.rand(4, 3) + qbits = qse.Qbits(positions=positions) + + # test int + for indicies in [0, 2]: + assert isinstance(qbits[indicies], qse.Qbit) + assert np.allclose(qbits[indicies].position, positions[indicies]) + + # test list + for indicies in [[0, 2], [1, 3, 2]]: + assert isinstance(qbits[indicies], qse.Qbits) + assert np.allclose(qbits[indicies].get_positions(), positions[indicies]) + + # test slice + assert isinstance(qbits[1:3], qse.Qbits) + assert np.allclose(qbits[1:3].get_positions(), positions[1:3]) + + +@pytest.mark.parametrize("indicies", [1, [0, 1, 3]]) +def test_get_item(indicies): + """Test __delitem__""" + nqbits = 4 + positions = np.random.rand(nqbits, 3) + qbits = qse.Qbits(positions=positions) + + del qbits[indicies] + + if isinstance(indicies, int): + indicies = [indicies] + + assert isinstance(qbits, qse.Qbits) + assert len(qbits) == nqbits - len(indicies) + assert np.allclose( + qbits.get_positions(), + positions[[i for i in range(nqbits) if i not in indicies]], + ) From 20fe84ecfc63ed1b5271b57c9810aef11d8bab4c Mon Sep 17 00:00:00 2001 From: James Nelson Date: Mon, 7 Apr 2025 15:32:07 +0100 Subject: [PATCH 2/5] Update tests/qbits_test.py --- tests/qbits_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/qbits_test.py b/tests/qbits_test.py index 4404627..723b153 100644 --- a/tests/qbits_test.py +++ b/tests/qbits_test.py @@ -150,7 +150,7 @@ def test_get_item(): @pytest.mark.parametrize("indicies", [1, [0, 1, 3]]) -def test_get_item(indicies): +def test_del_item(indicies): """Test __delitem__""" nqbits = 4 positions = np.random.rand(nqbits, 3) From f4fa8f34663dff2288e46e32d61140bb21e02fb1 Mon Sep 17 00:00:00 2001 From: jimnel Date: Tue, 8 Apr 2025 09:57:16 +0100 Subject: [PATCH 3/5] add doc --- qse/qbits.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/qse/qbits.py b/qse/qbits.py index 8e9f521..c1a738e 100644 --- a/qse/qbits.py +++ b/qse/qbits.py @@ -704,10 +704,10 @@ def __getitem__(self, indicies): Returns ------- - If indicies is a scalar, return an Qbit object. - - If indicies is a list or a slice, return an Qbits object with the same cell, pbc, and - other associated info as the original Qbits object. + Qbit | Qbits. + If indicies is a scalar a Qbit object is returned. If indicies + is a list or a slice, a Qbits object with the same cell, pbc, and + other associated info as the original Qbits object is returned. """ if isinstance(indicies, numbers.Integral): @@ -741,6 +741,14 @@ def __getitem__(self, indicies): return qbits def __delitem__(self, indicies): + """ + Delete a subset of the qbits. + + Parameters + ---------- + indicies : int | list + The indicies to be deleted. + """ if isinstance(indicies, list) and len(indicies) > 0: # Make sure a list of booleans will work correctly and not be # interpreted at 0 and 1 indices. @@ -792,8 +800,6 @@ def __imul__(self, m): def draw(self, ax=None, radius=None): _draw(self, ax=ax, radius=radius) - # - def repeat(self, rep): """Create new repeated qbits object. From b8aa9e8606421b4399cb93b0939ef3001f6856c2 Mon Sep 17 00:00:00 2001 From: jimnel Date: Tue, 8 Apr 2025 13:26:35 +0100 Subject: [PATCH 4/5] indicies -> indices --- qse/qbits.py | 44 ++++++++++++++++++++++---------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/qse/qbits.py b/qse/qbits.py index c1a738e..e223410 100644 --- a/qse/qbits.py +++ b/qse/qbits.py @@ -693,39 +693,39 @@ def __iter__(self): for i in range(len(self)): yield self[i] - def __getitem__(self, indicies): + def __getitem__(self, indices): """ Return a subset of the qbits. Parameters ---------- - indicies : int | list | slice - The indicies to be returned. + indices : int | list | slice + The indices to be returned. Returns ------- Qbit | Qbits. - If indicies is a scalar a Qbit object is returned. If indicies + If indices is a scalar a Qbit object is returned. If indices is a list or a slice, a Qbits object with the same cell, pbc, and other associated info as the original Qbits object is returned. """ - if isinstance(indicies, numbers.Integral): + if isinstance(indices, numbers.Integral): nqbits = len(self) - if indicies < -nqbits or indicies >= nqbits: + if indices < -nqbits or indices >= nqbits: raise IndexError("Index out of range.") - return Qbit(qbits=self, index=indicies) + return Qbit(qbits=self, index=indices) - if not isinstance(indicies, slice): - indicies = np.array(indicies) - # if i is a mask - if indicies.dtype == bool: - if len(indicies) != len(self): + if not isinstance(indices, slice): + indices = np.array(indices) + # if indices is a mask. + if indices.dtype == bool: + if len(indices) != len(self): raise IndexError( - "Length of mask {} must equal " - "number of qbits {}".format(len(indicies), len(self)) + f"Length of mask {len(indices)} must equal " + f"number of qbits {len(self)}" ) - indicies = np.arange(len(self))[indicies] + indices = np.arange(len(self))[indices] qbits = self.__class__( cell=self.cell, @@ -736,26 +736,26 @@ def __getitem__(self, indicies): qbits.arrays = {} for name, a in self.arrays.items(): - qbits.arrays[name] = a[indicies].copy() + qbits.arrays[name] = a[indices].copy() return qbits - def __delitem__(self, indicies): + def __delitem__(self, indices): """ Delete a subset of the qbits. Parameters ---------- - indicies : int | list - The indicies to be deleted. + indices : int | list + The indices to be deleted. """ - if isinstance(indicies, list) and len(indicies) > 0: + if isinstance(indices, list) and len(indices) > 0: # Make sure a list of booleans will work correctly and not be # interpreted at 0 and 1 indices. - indicies = np.array(indicies) + indices = np.array(indices) mask = np.ones(len(self), bool) - mask[indicies] = False + mask[indices] = False for name, a in self.arrays.items(): self.arrays[name] = a[mask] From 9da5fc137e7c6f92baebf9fa68cbebcf2e541a9e Mon Sep 17 00:00:00 2001 From: jimnel Date: Tue, 8 Apr 2025 13:27:44 +0100 Subject: [PATCH 5/5] indicies -> indices --- tests/qbits_test.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/qbits_test.py b/tests/qbits_test.py index 723b153..7a30bd0 100644 --- a/tests/qbits_test.py +++ b/tests/qbits_test.py @@ -135,35 +135,35 @@ def test_get_item(): qbits = qse.Qbits(positions=positions) # test int - for indicies in [0, 2]: - assert isinstance(qbits[indicies], qse.Qbit) - assert np.allclose(qbits[indicies].position, positions[indicies]) + for indices in [0, 2]: + assert isinstance(qbits[indices], qse.Qbit) + assert np.allclose(qbits[indices].position, positions[indices]) # test list - for indicies in [[0, 2], [1, 3, 2]]: - assert isinstance(qbits[indicies], qse.Qbits) - assert np.allclose(qbits[indicies].get_positions(), positions[indicies]) + for indices in [[0, 2], [1, 3, 2]]: + assert isinstance(qbits[indices], qse.Qbits) + assert np.allclose(qbits[indices].get_positions(), positions[indices]) # test slice assert isinstance(qbits[1:3], qse.Qbits) assert np.allclose(qbits[1:3].get_positions(), positions[1:3]) -@pytest.mark.parametrize("indicies", [1, [0, 1, 3]]) -def test_del_item(indicies): +@pytest.mark.parametrize("indices", [1, [0, 1, 3]]) +def test_del_item(indices): """Test __delitem__""" nqbits = 4 positions = np.random.rand(nqbits, 3) qbits = qse.Qbits(positions=positions) - del qbits[indicies] + del qbits[indices] - if isinstance(indicies, int): - indicies = [indicies] + if isinstance(indices, int): + indices = [indices] assert isinstance(qbits, qse.Qbits) - assert len(qbits) == nqbits - len(indicies) + assert len(qbits) == nqbits - len(indices) assert np.allclose( qbits.get_positions(), - positions[[i for i in range(nqbits) if i not in indicies]], + positions[[i for i in range(nqbits) if i not in indices]], )