From 6992298a01dca4f766c5cb532fbe4f4e45c00b9b Mon Sep 17 00:00:00 2001 From: Daniel Levenson Date: Mon, 1 Nov 2021 14:00:35 -0400 Subject: [PATCH 1/2] docs: tensor ops return functions, not tensors The docs as currently written are a bit misleading as they suggest that these higher-order functions returns `Tensor`s directly, when in fact they return functions that return tensors. --- minitorch/tensor_ops.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/minitorch/tensor_ops.py b/minitorch/tensor_ops.py index 2eb9a5d0..18ce67c1 100644 --- a/minitorch/tensor_ops.py +++ b/minitorch/tensor_ops.py @@ -71,7 +71,8 @@ def map(fn): should broadcast with `a` Returns: - :class:`TensorData` : new tensor data + function: A function that takes a tensor, applies `fn` to each cell, + and returns the resulting tensor. """ f = tensor_map(fn) @@ -160,7 +161,8 @@ def zip(fn): b (:class:`TensorData`): tensor to zip over Returns: - :class:`TensorData` : new tensor data + function: A function that takes two tensors, applies `fn` to each pair + of cells, and returns the resulting tensor. """ f = tensor_zip(fn) @@ -225,7 +227,9 @@ def reduce(fn, start=0.0): dim (int): int of dim to reduce Returns: - :class:`TensorData` : new tensor + function: A function that takes a tensor, applies `fn` to all cells + along a particular axis (or all cells if `dim` is not specified), and + returns the resulting tensor. """ f = tensor_reduce(fn) From bb88ddc09464d6eb52b2720bbfa9d4448de95981 Mon Sep 17 00:00:00 2001 From: Daniel Levenson Date: Mon, 1 Nov 2021 14:22:11 -0400 Subject: [PATCH 2/2] types: add type annotations for high-order tensor functions --- minitorch/tensor_ops.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/minitorch/tensor_ops.py b/minitorch/tensor_ops.py index 18ce67c1..9cc82e25 100644 --- a/minitorch/tensor_ops.py +++ b/minitorch/tensor_ops.py @@ -6,6 +6,8 @@ shape_broadcast, MAX_DIMS, ) +from minitorch import Tensor +from typing import Callable def tensor_map(fn): @@ -44,7 +46,7 @@ def _map(out, out_shape, out_strides, in_storage, in_shape, in_strides): return _map -def map(fn): +def map(fn: Callable[[float], float]) -> Callable[[Tensor], Tensor]: """ Higher-order tensor map function :: @@ -135,7 +137,7 @@ def _zip( return _zip -def zip(fn): +def zip(fn: Callable[[float, float], float]) -> Callable[[Tensor, Tensor], Tensor]: """ Higher-order tensor zip function :: @@ -206,7 +208,7 @@ def _reduce(out, out_shape, out_strides, a_storage, a_shape, a_strides, reduce_d return _reduce -def reduce(fn, start=0.0): +def reduce(fn: Callable[[float, float], float], start: float = 0.0) -> Callable[[Tensor, float], Tensor]: """ Higher-order tensor reduce function. ::