Spaces:
Running
on
Zero
Running
on
Zero
"""Implement lattice interface.""" | |
from overrides import overrides | |
import numpy as np | |
from abc import ABCMeta, abstractmethod | |
class Lattice(metaclass=ABCMeta): | |
"""Abstract base class representing a complemented lattice.""" | |
def join(cls, probs1: np.ndarray, probs2: np.ndarray) -> np.ndarray: | |
return NotImplemented | |
def meet(cls, probs1: np.ndarray, probs2: np.ndarray) -> np.ndarray: | |
return NotImplemented | |
def join_reduce(cls, probs: np.ndarray) -> np.ndarray: | |
return NotImplemented | |
def meet_reduce(cls, probs: np.ndarray) -> np.ndarray: | |
return NotImplemented | |
class Product(Lattice): | |
"""Lattice where meet=prod and sum is defined accordingly. | |
Equivalent to assuming independence, more or less. | |
""" | |
eps = 1e-9 | |
def join(cls, probs1: np.ndarray, probs2: np.ndarray) -> np.ndarray: | |
return probs1 + probs2 - cls.meet(probs1, probs2) | |
def meet(cls, probs1: np.ndarray, probs2: np.ndarray) -> np.ndarray: | |
return probs1 * probs2 | |
def join_reduce(cls, probs: np.ndarray) -> np.ndarray: | |
"""Assumes disjoint events.""" | |
# return cls.comp(cls.meet_reduce(cls.comp(probs))) | |
return np.sum(probs, axis=-1) | |
def meet_reduce(cls, probs: np.ndarray) -> np.ndarray: | |
return np.prod(probs, axis=-1) | |
def comp(cls, probs): | |
return 1 - probs | |
def normalize(cls, probs): | |
"""Normalize a distribution by dividing by the total mass.""" | |
return probs / np.sum(probs + cls.eps, axis=-1) | |