Building autograd engine tinytorch 03

Community blog post
Published January 30, 2024

image/png


Clean up & refactor

(dont skip we also add new things here)

Let's get professional. It's a good time now to add type hints and do some cleanup. Let's start with the tensor class.

from __future__ import annotations
import numpy as np


class Tensor:
    def __init__(self, data):
        self.data:np.ndarray = Tensor._data_to_numpy(data)
        self.grad:Tensor = None
        self._ctx:Function = None
    
    @staticmethod
    def _data_to_numpy(data):
        if isinstance(data,(int,float)):
            return np.array([data])
        if isinstance(data,(list,tuple)):
            return np.array(data)
        if isinstance(data,np.ndarray):
            return data
        if isinstance(data,Tensor):
            return data.data.copy()  
        raise ValueError("Invalid value passed to tensor")
    
    @staticmethod
    def _ensure_tensor(data):
        if isinstance(data,Tensor):
            return data
        return Tensor(data)
        

Import annotations so type definitions won't be a problem with old Python versions. And add a new _data_to_numpy method so people can pass any input like Tensor(2), Tensor([1,2]), or Tensor(np.arange(10)) as we will get consistent data underneath. also new ensure_tensor a simple methoed that well ensure is data tensor.

    def __add__(self, other):
        fn = Function(Add, self, other)
        result = Add.forward(self, other)
        result._ctx = fn
        return result

this is pretty ugly lets make it clean by modifying funtion class.

class Function:
    def __init__(self, op, *args):
        self.op:Function = op
        self.args:list[Tensor] = args
        
    @classmethod
    def apply(cls,*args):
        ctx = Function(cls,*args)
        result = cls.forward(*args)
        result._ctx = ctx
        return result
    
    @staticmethod
    def forward(self,*args):
        raise NotImplementedError
    
    @staticmethod
    def backward(self,*args):
        raise NotImplementedError

Apply is a class method that essentially does what we do in __add__ and __mul__. The idea is that all functions will inherit from the Function class with static methods. Then they will just do Add.apply(). Also, the forward and backward methods indicate that you must implement them. plus some type defs.

Lets modify add,mul to get nice properties of function


class Add(Function):
    ...
class Mul(Function):
    ...

Time to chnage magic methods

    def __add__(self, other)->Tensor:
        return Add.apply(self,Tensor.ensure_tensor(other))

    def __mul__(self, other)->Tensor:
        return Mul.apply(self,Tensor.ensure_tensor(other))

    def __radd__(self, other)->Tensor:
        return Add.apply(self,Tensor.ensure_tensor(other))

    def __rmul__(self, other)->Tensor:
        return Mul.apply(self,Tensor.ensure_tensor(other))

look nit,

new radd method do add be in reverse if other value is not tensor, Tensor(1)+1 will use add & 1+Tensor(1) will use radd

now we also do += for that we nee iadd, lets add those.


    def __iadd__(self, other)->Tensor:
        return Add.apply(self,Tensor.ensure_tensor(other))

    def __imul__(self, other)->Tensor:
        return Mul.apply(self,Tensor.ensure_tensor(other))

Since we are here lets also add detach() that will remove node from graph. and how will be do that. by setting ctx to None.


class Tensor:
    ...

    def detach(self)->Tensor:
        self._ctx = None
        return self

Since we added detach, let's clean up any unnecessary graphs that might have been generated in the backward pass.


    def backward(self,grad=None):
        if self._ctx is None:
            return 
        
        if grad is None:
            grad = Tensor([1.])
            self.grad = grad
        
        op = self._ctx.op
        child_nodes = self._ctx.args
        
        grads = op.backward(self._ctx,grad)
        if len(self._ctx.args) == 1:
            grads = [grads]
        
        for tensor,grad in zip(child_nodes,grads):
            if grad is None:
                continue
            if tensor.grad is None:
                tensor.grad = Tensor(np.zeros_like(tensor.data)) 
            tensor.grad += grad.detach()
            tensor.backward(grad)

ok we are doing couple of thing here. most imp we detach grads. we are skipping node that dont have grads. In future we will save single node transform so if args to function is 1. we are wrapping grads in list so loop works correctly.

For Bonus we will add add clone

    def clone(self)->Tensor:
        return Tensor(self.data.clone())

lets add also tiny utils functions

Shape property on tensor class

class Tensor:
    ...
    @property
    def shape(self) -> tuple:
        return self.data.shape

and two methods to create ones and zeros

def ones(shape) -> Tensor:
    return Tensor(np.ones(shape))

def zeros(shape) -> Tensor:
    return Tensor(np.zeros(shape))

commit 96bd8a262398ad56699175cf64592b88ebd4be11

at the end file should look like this

from __future__ import annotations
import numpy as np


class Tensor:
    def __init__(self, data):
        self.data: np.ndarray = Tensor._data_to_numpy(data)
        self.grad: Tensor = None
        self._ctx: Function = None

    @staticmethod
    def _data_to_numpy(data):
        if isinstance(data, (int, float)):
            return np.array([data])
        if isinstance(data, (list, tuple)):
            return np.array(data)
        if isinstance(data, np.ndarray):
            return data
        if isinstance(data, Tensor):
            return data.data.copy()
        raise ValueError("Invalid value passed to tensor")

    @staticmethod
    def ensure_tensor(data):
        if isinstance(data, Tensor):
            return data
        return Tensor(data)

    def __add__(self, other) -> Tensor:
        return Add.apply(self, Tensor.ensure_tensor(other))

    def __mul__(self, other) -> Tensor:
        return Mul.apply(self, Tensor.ensure_tensor(other))

    def __radd__(self, other) -> Tensor:
        return Add.apply(self, Tensor.ensure_tensor(other))

    def __rmul__(self, other) -> Tensor:
        return Mul.apply(self, Tensor.ensure_tensor(other))

    def __iadd__(self, other) -> Tensor:
        return Add.apply(self, Tensor.ensure_tensor(other))

    def __imul__(self, other) -> Tensor:
        return Mul.apply(self, Tensor.ensure_tensor(other))

    def __repr__(self):
        return f"tensor({self.data})"

    @property
    def shape(self) -> tuple:
        return self.data.shape

    def detach(self) -> Tensor:
        self._ctx = None
        return self

    def clone(self) -> Tensor:
        return Tensor(self.data.clone())

    def backward(self, grad=None):
        if self._ctx is None:
            return

        if grad is None:
            grad = Tensor([1.0])
            self.grad = grad

        op = self._ctx.op
        child_nodes = self._ctx.args

        grads = op.backward(self._ctx, grad)
        if len(self._ctx.args) == 1:
            grads = [grads]

        for tensor, grad in zip(child_nodes, grads):
            if grad is None:
                continue
            if tensor.grad is None:
                tensor.grad = Tensor(np.zeros_like(tensor.data))
            tensor.grad += grad.detach()
            tensor.backward(grad)


class Function:
    def __init__(self, op, *args):
        self.op: Function = op
        self.args: list[Tensor] = args

    @classmethod
    def apply(cls, *args):
        ctx = Function(cls, *args)
        result = cls.forward(*args)
        result._ctx = ctx
        return result

    @staticmethod
    def forward(self, *args):
        raise NotImplementedError

    @staticmethod
    def backward(self, *args):
        raise NotImplementedError


class Add(Function):
    @staticmethod
    def forward(x, y):
        return Tensor(x.data + y.data)

    @staticmethod
    def backward(ctx, grad):
        x, y = ctx.args
        return Tensor([1]) * grad, Tensor([1]) * grad


class Mul(Function):
    @staticmethod
    def forward(x, y):
        return Tensor(x.data * y.data)  # z = x*y

    @staticmethod
    def backward(ctx, grad):
        x, y = ctx.args
        return Tensor(y.data) * grad, Tensor(x.data) * grad  #  dz/dx, dz/dy


def ones(shape) -> Tensor:
    return Tensor(np.ones(shape))


def zeros(shape) -> Tensor:
    return Tensor(np.zeros(shape))


if __name__ == "__main__":

    def f(x):
        return x * x * x + x

    x = Tensor([1.2])

    z = f(x)
    z.backward()
    print(f"X: {x} grad: {x.grad}")