Skip to content

Give linalg elementwise functions a proper definition and signature #101

@Balint-R

Description

@Balint-R

Right now, our code is structured like this:

def _gen_elementwise_unary_macro(op: DefinedOpCallable) -> CallMacro:
    @CallMacro.generate()
    def op_macro(visitor: ToMLIRBase, x: Compiled) -> Tensor | MemRef:
        ...
    
    return op_macro

exp = _gen_elementwise_unary_macro(mlir_linalg.exp)
log = _gen_elementwise_unary_macro(mlir_linalg.log)
abs = _gen_elementwise_unary_macro(mlir_linalg.abs)
...

However, we could instead restructure it like this:

def _elementwise_unary(op: DefinedOpCallable, visitor: ToMLIRBase, x: Tensor | MemRef) -> Tensor | MemRef:
    ...

@CallMacro.generate()
def exp(visitor: ToMLIRBase, x: Compiled) -> Tensor | MemRef:
    """
    exp documentation
    """
    return _elementwise_unary(mlir_linalg.exp, visitor, x)


@CallMacro.generate()
def log(visitor: ToMLIRBase, x: Compiled) -> Tensor | MemRef:
    """
    log documentation
    """
    return _elementwise_unary(mlir_linalg.log, visitor, x)

where we no longer have a function that returns a CallMacro, but just call the inner function directly (_elementwise_unary should have the same behaviour as op_macro).

This results in a bit more code, but in the long term, it might be better, because we can nicely add a different docstring to each function and it's also easier to see their signature.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions