Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 37 additions & 63 deletions qse/qbits.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,93 +693,69 @@ def __iter__(self):
for i in range(len(self)):
yield self[i]

def __getitem__(self, i):
"""Return a subset of the qbits.

i -- scalar integer, list of integers, or slice object
describing which qbits to return.

If i is a scalar, return an Qbit object. If i is a list or a
slice, return an Qbits object with the same cell, pbc, and
other associated info as the original Qbits object. The
indices of the constraints will be shuffled so that they match
the indexing in the subset returned.
def __getitem__(self, indices):
"""
Return a subset of the qbits.

Parameters
----------
indices : int | list | slice
The indices to be returned.

Returns
-------
Qbit | Qbits.
If indices is a scalar a Qbit object is returned. If indices
is a list or a slice, a Qbits object with the same cell, pbc, and
other associated info as the original Qbits object is returned.
"""

if isinstance(i, numbers.Integral):
if isinstance(indices, numbers.Integral):
nqbits = len(self)
if i < -nqbits or i >= nqbits:
if indices < -nqbits or indices >= nqbits:
raise IndexError("Index out of range.")
return Qbit(qbits=self, index=indices)

return Qbit(qbits=self, index=i)
elif not isinstance(i, slice):
i = np.array(i)
# if i is a mask
if i.dtype == bool:
if len(i) != len(self):
if not isinstance(indices, slice):
indices = np.array(indices)
# if indices is a mask.
if indices.dtype == bool:
if len(indices) != len(self):
raise IndexError(
"Length of mask {} must equal "
"number of qbits {}".format(len(i), len(self))
f"Length of mask {len(indices)} must equal "
f"number of qbits {len(self)}"
)
i = np.arange(len(self))[i]

import copy

conadd = []
# Constraints need to be deepcopied, but only the relevant ones.
for con in copy.deepcopy(self.constraints):
try:
con.index_shuffle(self, i)
except (IndexError, NotImplementedError):
pass
else:
conadd.append(con)
indices = np.arange(len(self))[indices]

qbits = self.__class__(
cell=self.cell,
pbc=self.pbc,
info=self.info,
# should be communicated to the slice as well
celldisp=self._celldisp,
)
# TODO: Do we need to shuffle indices in adsorbate_info too?

qbits.arrays = {}
for name, a in self.arrays.items():
qbits.arrays[name] = a[i].copy()
qbits.arrays[name] = a[indices].copy()

qbits.constraints = conadd
return qbits

def __delitem__(self, i):
from qse.constraints import FixQbits

for c in self._constraints:
if not isinstance(c, FixQbits):
raise RuntimeError(
"Remove constraint using set_constraint() " "before deleting qbits."
)
def __delitem__(self, indices):
"""
Delete a subset of the qbits.

if isinstance(i, list) and len(i) > 0:
Parameters
----------
indices : int | list
The indices to be deleted.
"""
if isinstance(indices, list) and len(indices) > 0:
# Make sure a list of booleans will work correctly and not be
# interpreted at 0 and 1 indices.
i = np.array(i)

if len(self._constraints) > 0:
n = len(self)
i = np.arange(n)[i]
if isinstance(i, int):
i = [i]
constraints = []
for c in self._constraints:
c = c.delete_qbits(i, n)
if c is not None:
constraints.append(c)
self.constraints = constraints
indices = np.array(indices)

mask = np.ones(len(self), bool)
mask[i] = False
mask[indices] = False
for name, a in self.arrays.items():
self.arrays[name] = a[mask]

Expand Down Expand Up @@ -824,8 +800,6 @@ def __imul__(self, m):
def draw(self, ax=None, radius=None):
_draw(self, ax=ax, radius=radius)

#

def repeat(self, rep):
"""Create new repeated qbits object.

Expand Down
40 changes: 40 additions & 0 deletions tests/qbits_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,43 @@ def test_translate(nqbits, type_of_disp):
else:
qbits.translate(disp)
assert np.allclose(qbits.get_positions(), positions + disp)


def test_get_item():
"""Test __getitem__"""
positions = np.random.rand(4, 3)
qbits = qse.Qbits(positions=positions)

# test int
for indices in [0, 2]:
assert isinstance(qbits[indices], qse.Qbit)
assert np.allclose(qbits[indices].position, positions[indices])

# test list
for indices in [[0, 2], [1, 3, 2]]:
assert isinstance(qbits[indices], qse.Qbits)
assert np.allclose(qbits[indices].get_positions(), positions[indices])

# test slice
assert isinstance(qbits[1:3], qse.Qbits)
assert np.allclose(qbits[1:3].get_positions(), positions[1:3])


@pytest.mark.parametrize("indices", [1, [0, 1, 3]])
def test_del_item(indices):
"""Test __delitem__"""
nqbits = 4
positions = np.random.rand(nqbits, 3)
qbits = qse.Qbits(positions=positions)

del qbits[indices]

if isinstance(indices, int):
indices = [indices]

assert isinstance(qbits, qse.Qbits)
assert len(qbits) == nqbits - len(indices)
assert np.allclose(
qbits.get_positions(),
positions[[i for i in range(nqbits) if i not in indices]],
)