From 639ce0b1df0ce828050b77f3971be32c1c1c2171 Mon Sep 17 00:00:00 2001 From: syntonym Date: Thu, 25 Sep 2025 17:43:14 +0800 Subject: [PATCH 1/2] Implement __setitem__, __getitem__ and __delitem__ --- src/container/list.rs | 37 ++++++++++++++++++++++++++++++++++- src/container/map.rs | 26 +++++++++++++++++++++++- src/container/movable_list.rs | 35 +++++++++++++++++++++++++++++++++ tests/test_map.py | 4 +++- 4 files changed, 99 insertions(+), 3 deletions(-) diff --git a/src/container/list.rs b/src/container/list.rs index 76246a3..beb53cd 100644 --- a/src/container/list.rs +++ b/src/container/list.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use loro::{ContainerTrait, LoroList as LoroListInner}; use pyo3::prelude::*; +use pyo3::{BoundObject, types::PySlice, exceptions::PyIndexError}; use crate::{ doc::LoroDoc, @@ -12,7 +13,13 @@ use crate::{ use super::{Container, Cursor, Side}; -#[pyclass(frozen)] +#[derive(FromPyObject)] +enum SliceOrInt<'py> { + Slice(Bound<'py, PySlice>), + Int(usize), +} + +#[pyclass(frozen, sequence)] #[derive(Debug, Clone, Default)] pub struct LoroList(pub LoroListInner); @@ -106,6 +113,34 @@ impl LoroList { }) } + pub fn __getitem__<'py>(&self, py: Python<'py>, index: SliceOrInt<'py>) -> PyResult> { + match index { + SliceOrInt::Slice(slice) => { + let indices = slice.indices(self.0.len() as isize)?; + let mut i = indices.start; + let mut list: Vec = Vec::with_capacity(indices.slicelength); + + for _ in 0..indices.slicelength { + list.push(self.0.get(i as usize).ok_or(PyIndexError::new_err("index out of range"))?.into()); + i += indices.step; + } + list.into_pyobject(py) + }, + SliceOrInt::Int(idx) => { + let value: ValueOrContainer = self.0.get(usize::try_from(idx)?).ok_or(PyIndexError::new_err("index out of range"))?.into(); + Ok(value.into_pyobject(py)?.into_any().into_bound()) + } + } + } + + pub fn __setitem__(&self, index: usize, v: LoroValue) -> PyLoroResult<()> { + self.insert(index, v) + } + + pub fn __delitem__(&self, index: usize) -> PyLoroResult<()> { + self.delete(index, 1) + } + /// Get the length of the list. #[inline] pub fn __len__(&self) -> usize { diff --git a/src/container/map.rs b/src/container/map.rs index d41d5f4..515d4dd 100644 --- a/src/container/map.rs +++ b/src/container/map.rs @@ -17,7 +17,7 @@ pub fn register_class(m: &Bound<'_, PyModule>) -> PyResult<()> { Ok(()) } -#[pyclass(frozen)] +#[pyclass(frozen, mapping)] #[derive(Debug, Clone, Default)] pub struct LoroMap(pub LoroMapInner); @@ -65,6 +65,30 @@ impl LoroMap { self.0.len() } + pub fn __contains__(&self, key: &str) -> bool { + match self.0.get(key) { + Some(_) => true, + None => false + } + } + + pub fn __getitem__(&self, key: &str) -> Option { + self.get(key) + } + + pub fn __setitem__(&self, key: &str, value: LoroValue) -> PyLoroResult<()> { + self.insert(key, value) + } + + pub fn __delitem__(&self, key: &str) -> PyLoroResult<()> { + self.0.delete(key)?; + Ok(()) + } + + pub fn __iter__(&self) -> Vec { + self.0.keys().map(|k| k.to_string()).collect() + } + /// Get the ID of the map. #[getter] pub fn id(&self) -> ContainerID { diff --git a/src/container/movable_list.rs b/src/container/movable_list.rs index b385d95..6655833 100644 --- a/src/container/movable_list.rs +++ b/src/container/movable_list.rs @@ -8,6 +8,7 @@ use crate::{ }; use loro::{ContainerTrait, LoroMovableList as LoroMovableListInner, PeerID}; use pyo3::prelude::*; +use pyo3::{BoundObject, types::PySlice, exceptions::PyIndexError}; use super::{Container, Cursor, Side}; @@ -15,6 +16,12 @@ use super::{Container, Cursor, Side}; #[derive(Debug, Clone, Default)] pub struct LoroMovableList(pub LoroMovableListInner); +#[derive(FromPyObject)] +enum SliceOrInt<'py> { + Slice(Bound<'py, PySlice>), + Int(usize), +} + #[pymethods] impl LoroMovableList { /// Create a new container that is detached from the document. @@ -63,6 +70,34 @@ impl LoroMovableList { self.0.len() } + pub fn __getitem__<'py>(&self, py: Python<'py>, index: SliceOrInt<'py>) -> PyResult> { + match index { + SliceOrInt::Slice(slice) => { + let indices = slice.indices(self.0.len() as isize)?; + let mut i = indices.start; + let mut list: Vec = Vec::with_capacity(indices.slicelength); + + for _ in 0..indices.slicelength { + list.push(self.0.get(i as usize).ok_or(PyIndexError::new_err("index out of range"))?.into()); + i += indices.step; + } + list.into_pyobject(py) + }, + SliceOrInt::Int(idx) => { + let value: ValueOrContainer = self.0.get(usize::try_from(idx)?).ok_or(PyIndexError::new_err("index out of range"))?.into(); + Ok(value.into_pyobject(py)?.into_any().into_bound()) + } + } + } + + pub fn __setitem__(&self, index: usize, v: LoroValue) -> PyLoroResult<()> { + self.insert(index, v) + } + + pub fn __delitem__(&self, index: usize) -> PyLoroResult<()> { + self.delete(index, 1) + } + /// Whether the list is empty. pub fn is_empty(&self) -> bool { self.__len__() == 0 diff --git a/tests/test_map.py b/tests/test_map.py index dadf9ea..18048bd 100644 --- a/tests/test_map.py +++ b/tests/test_map.py @@ -9,4 +9,6 @@ def test_map(): doc.commit() assert doc.get_deep_value() == { "map": {"key": "value", "key2": ["value2"]}, - } \ No newline at end of file + } + map["key2"] = "value2" + assert map["key2"].value == "value2" From 3d3bddf85b7196593a1641bd74484dd844c3bcd9 Mon Sep 17 00:00:00 2001 From: Leon Zhao Date: Thu, 25 Sep 2025 17:43:15 +0800 Subject: [PATCH 2/2] feat: more magic func for pythonic --- loro.pyi | 60 +++++++++++ src/container.rs | 1 + src/container/counter.rs | 56 ++++++++++- src/container/list.rs | 110 +++++++++++++++++--- src/container/map.rs | 24 +++-- src/container/movable_list.rs | 111 +++++++++++++++++---- src/container/text.rs | 86 +++++++++++++++- src/container/tree.rs | 4 + src/container/utils.rs | 38 +++++++ tests/test_magic_methods.py | 182 ++++++++++++++++++++++++++++++++++ 10 files changed, 627 insertions(+), 45 deletions(-) create mode 100644 src/container/utils.rs create mode 100644 tests/test_magic_methods.py diff --git a/loro.pyi b/loro.pyi index 96e2903..ba66fa3 100644 --- a/loro.pyi +++ b/loro.pyi @@ -185,6 +185,14 @@ class ImportStatus: class LoroCounter: id: ContainerID value: float + def __float__(self) -> float: ... + def __int__(self) -> int: ... + def __add__(self, other: typing.Any) -> float: ... + def __radd__(self, other: typing.Any) -> float: ... + def __sub__(self, other: typing.Any) -> float: ... + def __rsub__(self, other: typing.Any) -> float: ... + def __neg__(self) -> float: ... + def __abs__(self) -> float: ... def __new__( cls, ): ... @@ -1051,6 +1059,24 @@ class LoroList: """ ... + @typing.overload + def __getitem__(self, index: int) -> ValueOrContainer: ... + + @typing.overload + def __getitem__(self, index: slice) -> list[ValueOrContainer]: ... + + @typing.overload + def __setitem__(self, index: int, value: LoroValue) -> None: ... + + @typing.overload + def __setitem__(self, index: slice, value: typing.Iterable[LoroValue]) -> None: ... + + @typing.overload + def __delitem__(self, index: int) -> None: ... + + @typing.overload + def __delitem__(self, index: slice) -> None: ... + def insert_container(self, pos: int, child: Container) -> Container: r""" Insert a container with the given type at the given index. @@ -1193,6 +1219,15 @@ class LoroMap: """ ... + def __contains__(self, key: str) -> bool: ... + + @typing.overload + def __getitem__(self, key: str) -> ValueOrContainer: ... + + def __setitem__(self, key: str, value: LoroValue) -> None: ... + + def __delitem__(self, key: str) -> None: ... + def is_empty(self) -> bool: r""" Whether the map is empty. @@ -1264,6 +1299,8 @@ class LoroMap: """ ... + def items(self) -> list[tuple[str, ValueOrContainer]]: ... + def get_last_editor(self, key: str) -> typing.Optional[int]: r""" Get the peer id of the last editor on the given entry @@ -1468,6 +1505,18 @@ class LoroMovableList: """ ... + @typing.overload + def __setitem__(self, index: int, value: LoroValue) -> None: ... + + @typing.overload + def __setitem__(self, index: slice, value: typing.Iterable[LoroValue]) -> None: ... + + @typing.overload + def __delitem__(self, index: int) -> None: ... + + @typing.overload + def __delitem__(self, index: slice) -> None: ... + def get_creator_at(self, pos: int) -> typing.Optional[int]: r""" Get the creator of the list item at the given position. @@ -1514,6 +1563,16 @@ class LoroText: len_utf8: int len_unicode: int len_utf16: int + def __str__(self) -> str: ... + def __repr__(self) -> str: ... + def __len__(self) -> int: ... + @typing.overload + def __getitem__(self, index: int) -> str: ... + @typing.overload + def __getitem__(self, index: slice) -> str: ... + def __contains__(self, item: typing.Any) -> bool: ... + def __add__(self, other: typing.Union[str, "LoroText"]) -> str: ... + def __radd__(self, other: typing.Union[str, "LoroText"]) -> str: ... def __new__( cls, ): ... @@ -1728,6 +1787,7 @@ class LoroTree: is_attached: bool roots: list[TreeID] id: ContainerID + def __contains__(self, target: TreeID) -> bool: ... def __new__( cls, ): ... diff --git a/src/container.rs b/src/container.rs index 0ad87de..7da7717 100644 --- a/src/container.rs +++ b/src/container.rs @@ -7,6 +7,7 @@ mod movable_list; mod text; mod tree; mod unknown; +pub mod utils; pub use counter::LoroCounter; pub use list::LoroList; pub use map::LoroMap; diff --git a/src/container/counter.rs b/src/container/counter.rs index 7620657..3c0de87 100644 --- a/src/container/counter.rs +++ b/src/container/counter.rs @@ -7,7 +7,19 @@ use crate::{ value::ContainerID, }; use loro::{ContainerTrait, LoroCounter as LoroCounterInner}; -use pyo3::prelude::*; +use pyo3::{exceptions::PyTypeError, prelude::*, Bound, PyRef}; + +impl LoroCounter { + fn coerce_to_f64(other: &Bound<'_, PyAny>) -> PyLoroResult { + if let Ok(value) = other.extract::() { + Ok(value) + } else if let Ok(counter) = other.extract::>() { + Ok(counter.get_value()) + } else { + Err(PyTypeError::new_err("expected a number or LoroCounter").into()) + } + } +} #[pyclass(frozen)] #[derive(Debug, Clone, Default)] @@ -46,6 +58,48 @@ impl LoroCounter { self.0.get_value() } + pub fn __float__(&self) -> f64 { + self.0.get_value() + } + + pub fn __int__(&self) -> PyLoroResult { + let value = self.0.get_value(); + if !value.is_finite() { + return Err(PyTypeError::new_err("cannot convert non-finite counter to int").into()); + } + if value < i64::MIN as f64 || value > i64::MAX as f64 { + return Err(PyTypeError::new_err("counter value out of range for int").into()); + } + Ok(value.trunc() as i64) + } + + pub fn __add__(&self, other: Bound<'_, PyAny>) -> PyLoroResult { + let delta = Self::coerce_to_f64(&other)?; + Ok(self.0.get_value() + delta) + } + + pub fn __radd__(&self, other: Bound<'_, PyAny>) -> PyLoroResult { + self.__add__(other) + } + + pub fn __sub__(&self, other: Bound<'_, PyAny>) -> PyLoroResult { + let delta = Self::coerce_to_f64(&other)?; + Ok(self.0.get_value() - delta) + } + + pub fn __rsub__(&self, other: Bound<'_, PyAny>) -> PyLoroResult { + let delta = Self::coerce_to_f64(&other)?; + Ok(delta - self.0.get_value()) + } + + pub fn __neg__(&self) -> f64 { + -self.0.get_value() + } + + pub fn __abs__(&self) -> f64 { + self.0.get_value().abs() + } + pub fn doc(&self) -> Option { self.0.doc().map(|doc| doc.into()) } diff --git a/src/container/list.rs b/src/container/list.rs index beb53cd..be95c91 100644 --- a/src/container/list.rs +++ b/src/container/list.rs @@ -2,23 +2,21 @@ use std::sync::Arc; use loro::{ContainerTrait, LoroList as LoroListInner}; use pyo3::prelude::*; -use pyo3::{BoundObject, types::PySlice, exceptions::PyIndexError}; +use pyo3::{ + exceptions::{PyIndexError, PyValueError}, + BoundObject, +}; +use crate::container::utils::{py_any_to_loro_values, slice_indices_positions, SliceOrInt}; use crate::{ doc::LoroDoc, - err::PyLoroResult, + err::{PyLoroError, PyLoroResult}, event::{DiffEvent, Subscription}, value::{ContainerID, LoroValue, ValueOrContainer, ID}, }; use super::{Container, Cursor, Side}; -#[derive(FromPyObject)] -enum SliceOrInt<'py> { - Slice(Bound<'py, PySlice>), - Int(usize), -} - #[pyclass(frozen, sequence)] #[derive(Debug, Clone, Default)] pub struct LoroList(pub LoroListInner); @@ -113,7 +111,11 @@ impl LoroList { }) } - pub fn __getitem__<'py>(&self, py: Python<'py>, index: SliceOrInt<'py>) -> PyResult> { + pub fn __getitem__<'py>( + &self, + py: Python<'py>, + index: SliceOrInt<'py>, + ) -> PyResult> { match index { SliceOrInt::Slice(slice) => { let indices = slice.indices(self.0.len() as isize)?; @@ -121,24 +123,100 @@ impl LoroList { let mut list: Vec = Vec::with_capacity(indices.slicelength); for _ in 0..indices.slicelength { - list.push(self.0.get(i as usize).ok_or(PyIndexError::new_err("index out of range"))?.into()); + list.push( + self.0 + .get(i as usize) + .ok_or(PyIndexError::new_err("index out of range"))? + .into(), + ); i += indices.step; } list.into_pyobject(py) - }, + } SliceOrInt::Int(idx) => { - let value: ValueOrContainer = self.0.get(usize::try_from(idx)?).ok_or(PyIndexError::new_err("index out of range"))?.into(); + let value: ValueOrContainer = self + .0 + .get(usize::try_from(idx)?) + .ok_or(PyIndexError::new_err("index out of range"))? + .into(); Ok(value.into_pyobject(py)?.into_any().into_bound()) } } } - pub fn __setitem__(&self, index: usize, v: LoroValue) -> PyLoroResult<()> { - self.insert(index, v) + pub fn __setitem__<'py>( + &self, + index: SliceOrInt<'py>, + value: Bound<'py, PyAny>, + ) -> PyLoroResult<()> { + match index { + SliceOrInt::Int(idx) => { + let extracted: LoroValue = value.extract().map_err(PyLoroError::from)?; + self.0.delete(idx, 1).map_err(PyLoroError::from)?; + self.0.insert(idx, extracted.0).map_err(PyLoroError::from)?; + Ok(()) + } + SliceOrInt::Slice(slice) => { + let len = self.__len__() as isize; + let indices = slice.indices(len).map_err(PyLoroError::from)?; + + let values = py_any_to_loro_values(&value).map_err(PyLoroError::from)?; + + if indices.step == 1 { + self.0 + .delete(indices.start as usize, indices.slicelength) + .map_err(PyLoroError::from)?; + + let mut pos = indices.start as usize; + for v in values { + self.0.insert(pos, v).map_err(PyLoroError::from)?; + pos += 1; + } + Ok(()) + } else { + if values.len() != indices.slicelength { + return Err(PyValueError::new_err(format!( + "attempt to assign sequence of size {} to extended slice of size {}", + values.len(), + indices.slicelength + )) + .into()); + } + + let positions = slice_indices_positions(&indices); + for (pos, v) in positions.into_iter().zip(values.into_iter()) { + self.0.delete(pos, 1).map_err(PyLoroError::from)?; + self.0.insert(pos, v).map_err(PyLoroError::from)?; + } + Ok(()) + } + } + } } - pub fn __delitem__(&self, index: usize) -> PyLoroResult<()> { - self.delete(index, 1) + pub fn __delitem__<'py>(&self, index: SliceOrInt<'py>) -> PyLoroResult<()> { + match index { + SliceOrInt::Int(idx) => self.delete(idx, 1), + SliceOrInt::Slice(slice) => { + let len = self.__len__() as isize; + let indices = slice.indices(len).map_err(PyLoroError::from)?; + + if indices.slicelength == 0 { + return Ok(()); + } + + if indices.step == 1 { + self.delete(indices.start as usize, indices.slicelength) + } else { + let mut positions = slice_indices_positions(&indices); + positions.sort_unstable(); + for pos in positions.into_iter().rev() { + self.delete(pos, 1)?; + } + Ok(()) + } + } + } } /// Get the length of the list. diff --git a/src/container/map.rs b/src/container/map.rs index 515d4dd..ec2c3ac 100644 --- a/src/container/map.rs +++ b/src/container/map.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use loro::{ContainerTrait, LoroMap as LoroMapInner, PeerID}; -use pyo3::prelude::*; +use pyo3::{exceptions::PyKeyError, prelude::*, PyErr}; use crate::{ doc::LoroDoc, @@ -66,27 +66,31 @@ impl LoroMap { } pub fn __contains__(&self, key: &str) -> bool { - match self.0.get(key) { - Some(_) => true, - None => false - } + self.0.get(key).is_some() } - pub fn __getitem__(&self, key: &str) -> Option { + pub fn __getitem__(&self, key: &str) -> PyResult { self.get(key) + .ok_or_else(|| PyKeyError::new_err(format!("Key {key} not found"))) } pub fn __setitem__(&self, key: &str, value: LoroValue) -> PyLoroResult<()> { self.insert(key, value) } - pub fn __delitem__(&self, key: &str) -> PyLoroResult<()> { - self.0.delete(key)?; + pub fn __delitem__(&self, key: &str) -> PyResult<()> { + if !self.__contains__(key) { + return Err(PyKeyError::new_err(format!("Key {key} not found"))); + } + self.delete(key).map_err(PyErr::from)?; Ok(()) } - pub fn __iter__(&self) -> Vec { - self.0.keys().map(|k| k.to_string()).collect() + pub fn items(&self) -> Vec<(String, ValueOrContainer)> { + self.0 + .keys() + .filter_map(|k| self.get(&k).map(|v| (k.to_string(), v))) + .collect() } /// Get the ID of the map. diff --git a/src/container/movable_list.rs b/src/container/movable_list.rs index 6655833..0cdf170 100644 --- a/src/container/movable_list.rs +++ b/src/container/movable_list.rs @@ -1,27 +1,25 @@ use std::sync::Arc; use crate::{ + container::utils::{py_any_to_loro_values, slice_indices_positions, SliceOrInt}, doc::LoroDoc, - err::PyLoroResult, + err::{PyLoroError, PyLoroResult}, event::{DiffEvent, Subscription}, value::{ContainerID, LoroValue, ValueOrContainer}, }; use loro::{ContainerTrait, LoroMovableList as LoroMovableListInner, PeerID}; use pyo3::prelude::*; -use pyo3::{BoundObject, types::PySlice, exceptions::PyIndexError}; +use pyo3::{ + exceptions::{PyIndexError, PyValueError}, + BoundObject, +}; use super::{Container, Cursor, Side}; -#[pyclass(frozen)] +#[pyclass(frozen, sequence)] #[derive(Debug, Clone, Default)] pub struct LoroMovableList(pub LoroMovableListInner); -#[derive(FromPyObject)] -enum SliceOrInt<'py> { - Slice(Bound<'py, PySlice>), - Int(usize), -} - #[pymethods] impl LoroMovableList { /// Create a new container that is detached from the document. @@ -70,7 +68,11 @@ impl LoroMovableList { self.0.len() } - pub fn __getitem__<'py>(&self, py: Python<'py>, index: SliceOrInt<'py>) -> PyResult> { + pub fn __getitem__<'py>( + &self, + py: Python<'py>, + index: SliceOrInt<'py>, + ) -> PyResult> { match index { SliceOrInt::Slice(slice) => { let indices = slice.indices(self.0.len() as isize)?; @@ -78,24 +80,99 @@ impl LoroMovableList { let mut list: Vec = Vec::with_capacity(indices.slicelength); for _ in 0..indices.slicelength { - list.push(self.0.get(i as usize).ok_or(PyIndexError::new_err("index out of range"))?.into()); + list.push( + self.0 + .get(i as usize) + .ok_or(PyIndexError::new_err("index out of range"))? + .into(), + ); i += indices.step; } list.into_pyobject(py) - }, + } SliceOrInt::Int(idx) => { - let value: ValueOrContainer = self.0.get(usize::try_from(idx)?).ok_or(PyIndexError::new_err("index out of range"))?.into(); + let value: ValueOrContainer = self + .0 + .get(usize::try_from(idx)?) + .ok_or(PyIndexError::new_err("index out of range"))? + .into(); Ok(value.into_pyobject(py)?.into_any().into_bound()) } } } - pub fn __setitem__(&self, index: usize, v: LoroValue) -> PyLoroResult<()> { - self.insert(index, v) + pub fn __setitem__<'py>( + &self, + index: SliceOrInt<'py>, + value: Bound<'py, PyAny>, + ) -> PyLoroResult<()> { + match index { + SliceOrInt::Int(idx) => { + let extracted: LoroValue = value.extract().map_err(PyLoroError::from)?; + self.0.set(idx, extracted.0).map_err(PyLoroError::from)?; + Ok(()) + } + SliceOrInt::Slice(slice) => { + let len = self.__len__() as isize; + let indices = slice.indices(len).map_err(PyLoroError::from)?; + + let values = py_any_to_loro_values(&value).map_err(PyLoroError::from)?; + + if indices.step == 1 { + self.0 + .delete(indices.start as usize, indices.slicelength) + .map_err(PyLoroError::from)?; + + let mut pos = indices.start as usize; + for v in values { + self.0.insert(pos, v).map_err(PyLoroError::from)?; + pos += 1; + } + Ok(()) + } else { + if values.len() != indices.slicelength { + return Err(PyValueError::new_err(format!( + "attempt to assign sequence of size {} to extended slice of size {}", + values.len(), + indices.slicelength + )) + .into()); + } + + let mut current = indices.start; + for v in values { + self.0.set(current as usize, v).map_err(PyLoroError::from)?; + current += indices.step; + } + Ok(()) + } + } + } } - pub fn __delitem__(&self, index: usize) -> PyLoroResult<()> { - self.delete(index, 1) + pub fn __delitem__<'py>(&self, index: SliceOrInt<'py>) -> PyLoroResult<()> { + match index { + SliceOrInt::Int(idx) => self.delete(idx, 1), + SliceOrInt::Slice(slice) => { + let len = self.__len__() as isize; + let indices = slice.indices(len).map_err(PyLoroError::from)?; + + if indices.slicelength == 0 { + return Ok(()); + } + + if indices.step == 1 { + self.delete(indices.start as usize, indices.slicelength) + } else { + let mut positions = slice_indices_positions(&indices); + positions.sort_unstable(); + for pos in positions.into_iter().rev() { + self.delete(pos, 1)?; + } + Ok(()) + } + } + } } /// Whether the list is empty. diff --git a/src/container/text.rs b/src/container/text.rs index cb96b5b..6da77e6 100644 --- a/src/container/text.rs +++ b/src/container/text.rs @@ -1,5 +1,10 @@ use loro::{ContainerTrait, LoroText as LoroTextInner, PeerID}; -use pyo3::{exceptions::PyValueError, prelude::*, types::PyBytes}; +use pyo3::{ + exceptions::{PyIndexError, PyTypeError, PyValueError}, + prelude::*, + types::{PyBytes, PySlice, PyString}, + Bound, PyErr, PyRef, +}; use std::{fmt::Display, sync::Arc}; use crate::{ @@ -40,6 +45,14 @@ impl LoroText { self.0.is_attached() } + pub fn __str__(&self) -> String { + self.0.to_string() + } + + pub fn __repr__(&self) -> String { + format!("LoroText({:?})", self.0.to_string()) + } + /// Get the [ContainerID] of the text container. #[getter] pub fn id(&self) -> ContainerID { @@ -104,6 +117,77 @@ impl LoroText { self.0.is_empty() } + pub fn __len__(&self) -> usize { + self.len_unicode() + } + + pub fn __getitem__<'py>( + &self, + py: Python<'py>, + key: Bound<'py, PyAny>, + ) -> PyResult> { + if let Ok(index) = key.extract::() { + let len = self.len_unicode() as isize; + let mut idx = index; + if idx < 0 { + idx += len; + } + if idx < 0 || idx >= len { + return Err(PyIndexError::new_err("string index out of range")); + } + let ch = self.char_at(idx as usize).map_err(PyErr::from)?; + let mut buf = [0u8; 4]; + let as_str = ch.encode_utf8(&mut buf); + Ok(PyString::new(py, as_str).into_any()) + } else if let Ok(slice) = key.downcast::() { + let text = self.0.to_string(); + let py_str = PyString::new(py, &text); + py_str.get_item(slice) + } else { + Err(PyTypeError::new_err( + "text indices must be integers or slices", + )) + } + } + + pub fn __contains__(&self, item: Bound<'_, PyAny>) -> PyResult { + let text = self.0.to_string(); + if let Ok(substr) = item.extract::<&str>() { + Ok(text.contains(substr)) + } else if let Ok(other) = item.extract::>() { + Ok(text.contains(&other.0.to_string())) + } else { + Ok(false) + } + } + + pub fn __add__(&self, other: Bound<'_, PyAny>) -> PyResult { + let mut result = self.0.to_string(); + if let Ok(substr) = other.extract::<&str>() { + result.push_str(substr); + Ok(result) + } else if let Ok(other_text) = other.extract::>() { + result.push_str(&other_text.0.to_string()); + Ok(result) + } else { + Err(PyTypeError::new_err("can only concatenate str or LoroText")) + } + } + + pub fn __radd__(&self, other: Bound<'_, PyAny>) -> PyResult { + if let Ok(prefix) = other.extract::<&str>() { + Ok(format!("{}{}", prefix, self.0.to_string())) + } else if let Ok(other_text) = other.extract::>() { + Ok(format!( + "{}{}", + other_text.0.to_string(), + self.0.to_string() + )) + } else { + Err(PyTypeError::new_err("can only concatenate str or LoroText")) + } + } + /// Get the length of the text container in UTF-8. #[getter] pub fn len_utf8(&self) -> usize { diff --git a/src/container/tree.rs b/src/container/tree.rs index 3b7c0db..3ff71ba 100644 --- a/src/container/tree.rs +++ b/src/container/tree.rs @@ -43,6 +43,10 @@ impl LoroTree { self.0.is_attached() } + pub fn __contains__(&self, target: TreeID) -> bool { + self.contains(target) + } + /// Create a new tree node and return the [`TreeID`]. /// /// If the `parent` is `None`, the created node is the root of a tree. diff --git a/src/container/utils.rs b/src/container/utils.rs new file mode 100644 index 0000000..cb9c37f --- /dev/null +++ b/src/container/utils.rs @@ -0,0 +1,38 @@ +use crate::value::LoroValue as PyLoroValue; +use loro::LoroValue as CoreLoroValue; +use pyo3::{ + exceptions::PyTypeError, + types::{PyAnyMethods, PySequence, PySequenceMethods, PySlice, PySliceIndices}, + Bound, FromPyObject, PyAny, PyResult, +}; + +#[derive(FromPyObject)] +pub enum SliceOrInt<'py> { + Slice(Bound<'py, PySlice>), + Int(usize), +} + +pub fn py_any_to_loro_values(obj: &Bound<'_, PyAny>) -> PyResult> { + let sequence = obj + .downcast::() + .map_err(|_| PyTypeError::new_err("can only assign an iterable to a slice"))?; + + let length = sequence.len()?; + let mut values = Vec::with_capacity(length); + for idx in 0..length { + let element = sequence.get_item(idx)?; + let extracted: PyLoroValue = element.extract()?; + values.push(extracted.0); + } + Ok(values) +} + +pub fn slice_indices_positions(indices: &PySliceIndices) -> Vec { + let mut positions = Vec::with_capacity(indices.slicelength); + let mut current = indices.start; + for _ in 0..indices.slicelength { + positions.push(current as usize); + current += indices.step; + } + positions +} diff --git a/tests/test_magic_methods.py b/tests/test_magic_methods.py new file mode 100644 index 0000000..b859ce8 --- /dev/null +++ b/tests/test_magic_methods.py @@ -0,0 +1,182 @@ +import pytest + +from loro import LoroDoc, LoroText + + +def make_text(content: str = "hello"): + doc = LoroDoc() + text = doc.get_text("text") + text.insert(0, content) + return doc, text + + +def test_lorotext_magic_methods_basic(): + _, text = make_text("hello") + + assert str(text) == "hello" + assert repr(text) == 'LoroText("hello")' + assert len(text) == 5 + assert text[0] == "h" + assert text[-1] == "o" + assert text[1:4] == "ell" + + +def test_lorotext_magic_methods_contains_and_add(): + doc, text = make_text("hello world") + + other = doc.get_text("other") + other.insert(0, "world") + + assert "hello" in text + assert other in text + assert "absent" not in text + + assert text + "!" == "hello world!" + assert text + other == "hello worldworld" + assert "say " + text == "say hello world" + assert other + text == "worldhello world" + + +def test_lorotext_getitem_type_errors(): + _, text = make_text("hello") + + with pytest.raises(TypeError): + _ = text["invalid"] + + +def test_lorotree_contains(): + doc = LoroDoc() + tree = doc.get_tree("tree") + root = tree.create() + child = tree.create(root) + + assert root in tree + assert child in tree + + other_doc = LoroDoc() + other_tree = other_doc.get_tree("other") + other_root = other_tree.create() + + assert other_root not in tree + + +def test_lorocounter_numeric_magic_methods(): + doc = LoroDoc() + counter = doc.get_counter("counter") + counter.increment(10) + doc.commit() + + assert float(counter) == pytest.approx(10.0) + assert int(counter) == 10 + assert counter + 2 == pytest.approx(12.0) + assert 2 + counter == pytest.approx(12.0) + assert counter - 3 == pytest.approx(7.0) + assert 3 - counter == pytest.approx(-7.0) + + counter.increment(5) + doc.commit() + assert counter.value == pytest.approx(15.0) + + counter.decrement(2) + doc.commit() + assert counter.value == pytest.approx(13.0) + + other = doc.get_counter("other") + other.increment(4) + doc.commit() + + assert counter + other == pytest.approx(17.0) + assert -counter == pytest.approx(-13.0) + assert abs(counter) == pytest.approx(13.0) + + +def test_lorolist_magic_methods(): + doc = LoroDoc() + lst = doc.get_list("items") + for i, value in enumerate([1, 2, "three"]): + lst.insert(i, value) + doc.commit() + + assert len(lst) == 3 + assert lst[0].value == 1 + assert [x.value for x in lst[1:3]] == [2, "three"] + + lst[1] = 42 + assert lst.to_vec() == [1, 42, "three"] + + del lst[0] + assert lst.to_vec() == [42, "three"] + + lst[0:1] = ["alpha", "beta"] + assert lst.to_vec() == ["alpha", "beta", "three"] + + lst[::2] = ["first", "second"] + assert lst.to_vec() == ["first", "beta", "second"] + + with pytest.raises(ValueError): + lst[::2] = ["only one"] + + del lst[1:3] + assert lst.to_vec() == ["first"] + + lst[:] = ["reset", "list"] + assert lst.to_vec() == ["reset", "list"] + + del lst[::-1] + assert lst.to_vec() == [] + + +def test_loromovable_list_magic_methods(): + doc = LoroDoc() + mlist = doc.get_movable_list("items") + for i, value in enumerate(["a", "b", "c"]): + mlist.insert(i, value) + + assert len(mlist) == 3 + assert mlist[2].value == "c" + assert [x.value for x in mlist[0:2]] == ["a", "b"] + + mlist[1] = "beta" + assert mlist.to_vec() == ["a", "beta", "c"] + + del mlist[2] + assert mlist.to_vec() == ["a", "beta"] + + mlist[:] = ["x", "y", "z"] + assert mlist.to_vec() == ["x", "y", "z"] + + mlist[1:3] = ["Y", "Z", "extra"] + assert mlist.to_vec() == ["x", "Y", "Z", "extra"] + + mlist[::2] = ["even", "odd"] + assert mlist.to_vec() == ["even", "Y", "odd", "extra"] + + with pytest.raises(ValueError): + mlist[::2] = ["mismatch"] + + del mlist[1::2] + assert mlist.to_vec() == ["even", "odd"] + + del mlist[::-1] + assert mlist.to_vec() == [] + + +def test_loromap_magic_methods(): + doc = LoroDoc() + map_obj = doc.get_map("map") + + map_obj["x"] = 10 + map_obj["y"] = "value" + + assert len(map_obj) == 2 + assert "x" in map_obj + assert map_obj["x"].value == 10 + + del map_obj["x"] + assert "x" not in map_obj + + with pytest.raises(KeyError): + _ = map_obj["missing"] + + with pytest.raises(KeyError): + del map_obj["missing"]