diff --git a/minitorch/tensor_ops.py b/minitorch/tensor_ops.py index 2eb9a5d0..9cc82e25 100644 --- a/minitorch/tensor_ops.py +++ b/minitorch/tensor_ops.py @@ -6,6 +6,8 @@ shape_broadcast, MAX_DIMS, ) +from minitorch import Tensor +from typing import Callable def tensor_map(fn): @@ -44,7 +46,7 @@ def _map(out, out_shape, out_strides, in_storage, in_shape, in_strides): return _map -def map(fn): +def map(fn: Callable[[float], float]) -> Callable[[Tensor], Tensor]: """ Higher-order tensor map function :: @@ -71,7 +73,8 @@ def map(fn): should broadcast with `a` Returns: - :class:`TensorData` : new tensor data + function: A function that takes a tensor, applies `fn` to each cell, + and returns the resulting tensor. """ f = tensor_map(fn) @@ -134,7 +137,7 @@ def _zip( return _zip -def zip(fn): +def zip(fn: Callable[[float, float], float]) -> Callable[[Tensor, Tensor], Tensor]: """ Higher-order tensor zip function :: @@ -160,7 +163,8 @@ def zip(fn): b (:class:`TensorData`): tensor to zip over Returns: - :class:`TensorData` : new tensor data + function: A function that takes two tensors, applies `fn` to each pair + of cells, and returns the resulting tensor. """ f = tensor_zip(fn) @@ -204,7 +208,7 @@ def _reduce(out, out_shape, out_strides, a_storage, a_shape, a_strides, reduce_d return _reduce -def reduce(fn, start=0.0): +def reduce(fn: Callable[[float, float], float], start: float = 0.0) -> Callable[[Tensor, float], Tensor]: """ Higher-order tensor reduce function. :: @@ -225,7 +229,9 @@ def reduce(fn, start=0.0): dim (int): int of dim to reduce Returns: - :class:`TensorData` : new tensor + function: A function that takes a tensor, applies `fn` to all cells + along a particular axis (or all cells if `dim` is not specified), and + returns the resulting tensor. """ f = tensor_reduce(fn)