diff --git a/emannotationschemas/__init__.py b/emannotationschemas/__init__.py index 56c8343..de8e3be 100644 --- a/emannotationschemas/__init__.py +++ b/emannotationschemas/__init__.py @@ -52,7 +52,7 @@ DigitalTwinPropertiesBCM, ) from emannotationschemas.schemas.glia_contact import GliaContact -from emannotationschemas.schemas.groups import SimpleGroup +from emannotationschemas.schemas.groups import SimpleGroup, SimpleGroupIndexed from emannotationschemas.schemas.neuropil import FlyNeuropil from emannotationschemas.schemas.nucleus_detection import NucleusDetection from emannotationschemas.schemas.postsynaptic_compartment import PostsynapticCompartment @@ -124,6 +124,7 @@ "representative_point": RepresentativePoint, "reference_synapse_valid": ValidSynapse, "reference_simple_group": SimpleGroup, + "reference_simple_group_indexed": SimpleGroupIndexed, "fly_cell_type": FlyCellType, "fly_cell_type_ext": FlyCellTypeExt, "braincircuits_annotation_user": BrainCircuitsBoundTagAnnotationUser, diff --git a/emannotationschemas/schemas/groups.py b/emannotationschemas/schemas/groups.py index 442a4ab..b7bdf11 100644 --- a/emannotationschemas/schemas/groups.py +++ b/emannotationschemas/schemas/groups.py @@ -7,3 +7,10 @@ class SimpleGroup(ReferenceAnnotation): required=True, description="group id", ) + +class SimpleGroupIndexed(ReferenceAnnotation): + group_id = mm.fields.Int( + required=True, + description="group id", + index=True, + ) diff --git a/tests/test_simple_group.py b/tests/test_simple_group.py new file mode 100644 index 0000000..51240ab --- /dev/null +++ b/tests/test_simple_group.py @@ -0,0 +1,27 @@ +from emannotationschemas.schemas.groups import SimpleGroup, SimpleGroupIndexed + +good_simple_group = { + "pt": {"position": [31, 32, 33]}, + "group_id": 42, + "target_id": 456, +} + +good_simple_group_indexed = { + "pt": {"position": [10, 20, 30]}, + "group_id": 123, + "target_id": 456, +} + + +def test_simple_group_schema(): + schema = SimpleGroup() + result = schema.load(good_simple_group) + assert result["group_id"] == 42 + assert result["pt"]["position"] == [31, 32, 33] + + +def test_simple_group_indexed_schema(): + schema = SimpleGroupIndexed() + result = schema.load(good_simple_group_indexed) + assert result["group_id"] == 123 + assert result["pt"]["position"] == [10, 20, 30]