This is an implementation of
-
computing the symmetry group of a given rank decomposition of the matrix multiplication tensor, and
-
conversion of a symmetry group and orbit structure guess into the corresponding jax search.
The matrix multiplication tensor
Applying a symmetry of symmetry_group determines the symmetry group under some mild
assumptions on the
Given a finite subgroup of linearize_representation finds a linear
representation projecting to a projective representation under some mild
assumptions.
Given a projective or linear representation, we may describe the orbit
structure of the rank 1's under the action. The triples in the decomposition
are collected into orbits; we pick one element orbit_structure computes this information given a decomposition and
its symmetry group.
Finally, given a group with either projective or linear representation, and
given an orbit structure, get_map_to_rank1s computes a function applicable to
numpy or jax vectors which parameterizes decompositions with the corresponding
symmetry and orbit structure. More precisely, the set of all the matrix entries
get_map_to_rank1s picks a computable
parameterization of this subspace.
Caveat: If we search for decompositions in the same way as before, fixing in advance an allowable sublattice the parameters may take, then our choice of parameterization affects the decompositions we consider. In other words there is no longer any privileged basis for the symmetric rank 1 matrix subspace with respect to which to base our lattice. To take this approach of searching in a lattice rather than a vector space in a principled way, I expect we need to allow the lattice to be determined by the agent.
For simple examples of corresponding jax search, see search_2x2.py,
search_4x4.py, search_2x2_notrans.py, and search_4x4_notrans.py, where we
take the symmetry group of Strassen and the AlphaEvolve found 48 and look for
decompositions with only this symmetry. The notrans versions pass to only the
The code makes heavy use of gap for the group theory computations through the
SageMath interface. Nix is used to ensure reproducibility of the environment.
Of course one can set up an environment manually following shell.nix if
desired. Due to the way SageMath is currently packaged, one hase to use the
sage included python.
Steps to run:
-
Install nix (https://nixos.org/download/)
-
Enter environment
$ nix-shell
or, if you prefer, use direnv
# to compute symmetry and orbit structure for AlphaEvolve found decompositions
$ sage -python ae_decomps.py
# to run a simple jax search for a 4x4 rank 48 decomposition with the same
# symmetry as the AlphaEvolve found one
$ sage -python search_4x4.py
# similarly
$ sage -python search_4x4_notrans.py
$ sage -python search_2x2.py
$ sage -python search_2x2_notrans.py