Skip to content

API Reference: BEMB

model special

bayesian_coefficient

Bayesian Coefficient is the building block for the BEMB model.

Author: Tianyu Du Update: Apr. 28, 2022

BayesianCoefficient (Module)

Source code in bemb/model/bayesian_coefficient.py
class BayesianCoefficient(nn.Module):
    def __init__(self,
                 variation: str,
                 num_classes: int,
                 obs2prior: bool,
                 num_obs: Optional[int] = None,
                 dim: int = 1,
                 prior_mean: float = 0.0,
                 prior_variance: float = 1.0
                 ) -> None:
        """The Bayesian coefficient object represents a learnable tensor mu_i in R^k, where i is from a family (e.g., user, item)
            so there are num_classes * num_obs learnable weights in total.
            The prior distribution of mu_i is N(0, I) or N(H*X_obs(H shape=num_obs, X_obs shape=dim), Ix1).
            The posterior(i.e., variational) distribution of mu_i is a Gaussian distribution with learnable mean mu_i and unit covariance.
            The mean of the variational distribution consists of two parts:
                1. The fixed part, which is not learnable. This part is particularly useful when the researcher want to impose
                    some structure on the variational distribution. For example, the research might have some variational mean
                    learned from another model and wish to use BEMB to polish the learned mean.
                2. The flexible part, which is the main learnable part of the variational mean.

        Args:
            variation (str): the variation # TODO: this will be removed in the next version, after we have a complete
                test pipline.
            num_classes (int): number of classes in the coefficient. For example, if we have user-specific coefficients,
                `theta_user`, the `num_classes` should be the number of users. If we have item-specific coefficients,
                the the `num_classes` should be the number of items.
            obs2prior (bool): whether the mean of coefficient prior depends on the observable or not.
            num_obs (int, optional): the number of observables associated with each class. For example, if the coefficient
                if item-specific, and we have `obs2prior` set to True, the `num_obs` should be the number of observables
                for each item.
                Defaults to None.
            dim (int, optional): the dimension of the coefficient.
                Defaults to 1.
            prior_mean (float): the mean of the prior distribution of coefficient.
                Defaults to 0.0.
            prior_variance (float): the variance of the prior distribution of coefficient.
                Defaults to 1.0.
        """
        super(BayesianCoefficient, self).__init__()
        # do we use this at all? TODO: drop self.variation.
        assert variation in ['item', 'user', 'constant', 'category']

        self.variation = variation
        self.obs2prior = obs2prior
        if variation == 'constant' or variation == 'category':
            if obs2prior:
                raise NotImplementedError('obs2prior is not supported for constant and category variation at present.')

        self.num_classes = num_classes
        self.num_obs = num_obs
        self.dim = dim  # the dimension of greek letter parameter.
        self.prior_mean = prior_mean
        self.prior_variance = prior_variance

        # assert self.prior_variance > 0

        # create prior distribution.
        if self.obs2prior:
            # the mean of prior distribution depends on observables.
            # initiate a Bayesian Coefficient with shape (dim, num_obs) standard Gaussian.
            self.prior_H = BayesianCoefficient(variation='constant', num_classes=dim, obs2prior=False,
                                               dim=num_obs, prior_variance=1.0)
        else:
            self.register_buffer(
                'prior_zero_mean', torch.zeros(num_classes, dim) + (self.prior_mean))

        # self.prior_cov_factor = nn.Parameter(torch.zeros(num_classes, dim, 1), requires_grad=False)
        # self.prior_cov_diag = nn.Parameter(torch.ones(num_classes, dim), requires_grad=False)
        self.register_buffer('prior_cov_factor',
                             torch.zeros(num_classes, dim, 1))
        self.register_buffer('prior_cov_diag', torch.ones(
            num_classes, dim) * self.prior_variance)

        # create variational distribution.
        self.variational_mean_flexible = nn.Parameter(
            torch.randn(num_classes, dim), requires_grad=True)
        self.variational_logstd = nn.Parameter(
            torch.randn(num_classes, dim), requires_grad=True)

        self.register_buffer('variational_cov_factor',
                             torch.zeros(num_classes, dim, 1))

        self.variational_mean_fixed = None

    def __repr__(self) -> str:
        """Constructs a string representation of the Bayesian coefficient object.

        Returns:
            str: the string representation of the Bayesian coefficient object.
        """
        if self.obs2prior:
            prior_str = f'prior=N(H*X_obs(H shape={self.prior_H.prior_zero_mean.shape}, X_obs shape={self.prior_H.dim}), Ix{self.prior_variance})'
        else:
            prior_str = f'prior=N(0, I)'
        return f'BayesianCoefficient(num_classes={self.num_classes}, dimension={self.dim}, {prior_str})'

    def update_variational_mean_fixed(self, new_value: torch.Tensor) -> None:
        """Updates the fixed part of the mean of the variational distribution.

        Args:
            new_value (torch.Tensor): the new value of the fixed part of the mean of the variational distribution.
        """
        assert new_value.shape == self.variational_mean_flexible.shape
        del self.variational_mean_fixed
        self.register_buffer('variational_mean_fixed', new_value)

    @property
    def variational_mean(self) -> torch.Tensor:
        """Returns the mean of the variational distribution.

        Returns:
            torch.Tensor: the current mean of the variational distribution with shape (num_classes, dim).
        """
        if self.variational_mean_fixed is None:
            return self.variational_mean_flexible
        else:
            return self.variational_mean_fixed + self.variational_mean_flexible

    def log_prior(self,
                  sample: torch.Tensor,
                  H_sample: Optional[torch.Tensor] = None,
                  x_obs: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Computes the logP_{Prior}(Coefficient Sample) for provided samples of the coefficient. The prior will either be a
        zero-mean Gaussian (if `obs2prior` is False) or a Gaussian with a learnable mean (if `obs2prior` is True).

        Args:
            sample (torch.Tensor): Monte Carlo samples of the variable with shape (num_seeds, num_classes, dim), where
                sample[i, :, :] corresponds to one sample of the coefficient.

            # arguments required only if `obs2prior == True`:
            H_sample (Optional[torch.Tensor], optional): Monte Carlo samples of the weight in obs2prior term, with shape
                (num_seeds, dim, self.num_obs), this is required if and only if obs2prior == True.
                Defaults to None.
            x_obs (Optional[torch.Tensor], optional): observables for obs2prior with shape (num_classes, num_obs),
                only required if and only if obs2prior == True.
                Defaults to None.

        Returns:
            torch.Tensor: the log prior of the variable with shape (num_seeds, num_classes).
        """
        # p(sample)
        num_seeds, num_classes, dim = sample.shape
        # shape (num_seeds, num_classes)
        if self.obs2prior:
            assert H_sample.shape == (num_seeds, dim, self.num_obs)
            assert x_obs.shape == (num_classes, self.num_obs)
            x_obs = x_obs.view(1, num_classes, self.num_obs).expand(
                num_seeds, -1, -1)
            H_sample = torch.transpose(H_sample, 1, 2)
            assert H_sample.shape == (num_seeds, self.num_obs, dim)
            mu = torch.bmm(x_obs, H_sample)
            assert mu.shape == (num_seeds, num_classes, dim)

        else:
            mu = self.prior_zero_mean
        out = LowRankMultivariateNormal(loc=mu,
                                        cov_factor=self.prior_cov_factor,
                                        cov_diag=self.prior_cov_diag).log_prob(sample)
        assert out.shape == (num_seeds, num_classes)
        return out

    def log_variational(self, sample: torch.Tensor) -> torch.Tensor:
        """Given a set of sampled values of coefficients, with shape (num_seeds, num_classes, dim), computes the
            the log probability of these sampled values of coefficients under the current variational distribution.

        Args:
            sample (torch.Tensor): a tensor of shape (num_seeds, num_classes, dim) containing sampled values of coefficients,
                where sample[i, :, :] corresponds to one sample of the coefficient.

        Returns:
            torch.Tensor: a tensor of shape (num_seeds, num_classes) containing the log probability of provided samples
                under the variational distribution. The output is splitted by random seeds and classes, you can sum
                along the second axis (i.e., the num_classes axis) to get the total log probability.
        """
        num_seeds, num_classes, dim = sample.shape
        out = self.variational_distribution.log_prob(sample)
        assert out.shape == (num_seeds, num_classes)
        return out

    def rsample(self, num_seeds: int = 1) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
        """Samples values of the coefficient from the variational distribution using re-parameterization trick.

        Args:
            num_seeds (int, optional): number of values to be sampled. Defaults to 1.

        Returns:
            Union[torch.Tensor, Tuple[torch.Tensor]]: if `obs2prior` is disabled, returns a tensor of shape (num_seeds, num_classes, dim)
                where each output[i, :, :] corresponds to one sample of the coefficient.
                If `obs2prior` is enabled, returns a tuple of samples: (1) a tensor of shape (num_seeds, num_classes, dim) containing
                sampled values of coefficient, and (2) a tensor o shape (num_seeds, dim, num_obs) containing samples of the H weight
                in the prior distribution.
        """
        value_sample = self.variational_distribution.rsample(
            torch.Size([num_seeds]))
        if self.obs2prior:
            # sample obs2prior H as well.
            H_sample = self.prior_H.rsample(num_seeds=num_seeds)
            return (value_sample, H_sample)
        else:
            return value_sample

    @property
    def variational_distribution(self) -> LowRankMultivariateNormal:
        """Constructs the current variational distribution of the coefficient from current variational mean and covariance.
        """
        return LowRankMultivariateNormal(loc=self.variational_mean,
                                         cov_factor=self.variational_cov_factor,
                                         cov_diag=torch.exp(self.variational_logstd))

    @property
    def device(self) -> torch.device:
        """Returns the device of tensors contained in this module."""
        return self.variational_mean.device
device: device property readonly

Returns the device of tensors contained in this module.

variational_distribution: LowRankMultivariateNormal property readonly

Constructs the current variational distribution of the coefficient from current variational mean and covariance.

variational_mean: Tensor property readonly

Returns the mean of the variational distribution.

Returns:

Type Description
torch.Tensor

the current mean of the variational distribution with shape (num_classes, dim).

__init__(self, variation, num_classes, obs2prior, num_obs=None, dim=1, prior_mean=0.0, prior_variance=1.0) special

The Bayesian coefficient object represents a learnable tensor mu_i in R^k, where i is from a family (e.g., user, item) so there are num_classes * num_obs learnable weights in total. The prior distribution of mu_i is N(0, I) or N(H*X_obs(H shape=num_obs, X_obs shape=dim), Ix1). The posterior(i.e., variational) distribution of mu_i is a Gaussian distribution with learnable mean mu_i and unit covariance. The mean of the variational distribution consists of two parts: 1. The fixed part, which is not learnable. This part is particularly useful when the researcher want to impose some structure on the variational distribution. For example, the research might have some variational mean learned from another model and wish to use BEMB to polish the learned mean. 2. The flexible part, which is the main learnable part of the variational mean.

Parameters:

Name Type Description Default
variation str

the variation # TODO: this will be removed in the next version, after we have a complete test pipline.

required
num_classes int

number of classes in the coefficient. For example, if we have user-specific coefficients, theta_user, the num_classes should be the number of users. If we have item-specific coefficients, the the num_classes should be the number of items.

required
obs2prior bool

whether the mean of coefficient prior depends on the observable or not.

required
num_obs int

the number of observables associated with each class. For example, if the coefficient if item-specific, and we have obs2prior set to True, the num_obs should be the number of observables for each item. Defaults to None.

None
dim int

the dimension of the coefficient. Defaults to 1.

1
prior_mean float

the mean of the prior distribution of coefficient. Defaults to 0.0.

0.0
prior_variance float

the variance of the prior distribution of coefficient. Defaults to 1.0.

1.0
Source code in bemb/model/bayesian_coefficient.py
def __init__(self,
             variation: str,
             num_classes: int,
             obs2prior: bool,
             num_obs: Optional[int] = None,
             dim: int = 1,
             prior_mean: float = 0.0,
             prior_variance: float = 1.0
             ) -> None:
    """The Bayesian coefficient object represents a learnable tensor mu_i in R^k, where i is from a family (e.g., user, item)
        so there are num_classes * num_obs learnable weights in total.
        The prior distribution of mu_i is N(0, I) or N(H*X_obs(H shape=num_obs, X_obs shape=dim), Ix1).
        The posterior(i.e., variational) distribution of mu_i is a Gaussian distribution with learnable mean mu_i and unit covariance.
        The mean of the variational distribution consists of two parts:
            1. The fixed part, which is not learnable. This part is particularly useful when the researcher want to impose
                some structure on the variational distribution. For example, the research might have some variational mean
                learned from another model and wish to use BEMB to polish the learned mean.
            2. The flexible part, which is the main learnable part of the variational mean.

    Args:
        variation (str): the variation # TODO: this will be removed in the next version, after we have a complete
            test pipline.
        num_classes (int): number of classes in the coefficient. For example, if we have user-specific coefficients,
            `theta_user`, the `num_classes` should be the number of users. If we have item-specific coefficients,
            the the `num_classes` should be the number of items.
        obs2prior (bool): whether the mean of coefficient prior depends on the observable or not.
        num_obs (int, optional): the number of observables associated with each class. For example, if the coefficient
            if item-specific, and we have `obs2prior` set to True, the `num_obs` should be the number of observables
            for each item.
            Defaults to None.
        dim (int, optional): the dimension of the coefficient.
            Defaults to 1.
        prior_mean (float): the mean of the prior distribution of coefficient.
            Defaults to 0.0.
        prior_variance (float): the variance of the prior distribution of coefficient.
            Defaults to 1.0.
    """
    super(BayesianCoefficient, self).__init__()
    # do we use this at all? TODO: drop self.variation.
    assert variation in ['item', 'user', 'constant', 'category']

    self.variation = variation
    self.obs2prior = obs2prior
    if variation == 'constant' or variation == 'category':
        if obs2prior:
            raise NotImplementedError('obs2prior is not supported for constant and category variation at present.')

    self.num_classes = num_classes
    self.num_obs = num_obs
    self.dim = dim  # the dimension of greek letter parameter.
    self.prior_mean = prior_mean
    self.prior_variance = prior_variance

    # assert self.prior_variance > 0

    # create prior distribution.
    if self.obs2prior:
        # the mean of prior distribution depends on observables.
        # initiate a Bayesian Coefficient with shape (dim, num_obs) standard Gaussian.
        self.prior_H = BayesianCoefficient(variation='constant', num_classes=dim, obs2prior=False,
                                           dim=num_obs, prior_variance=1.0)
    else:
        self.register_buffer(
            'prior_zero_mean', torch.zeros(num_classes, dim) + (self.prior_mean))

    # self.prior_cov_factor = nn.Parameter(torch.zeros(num_classes, dim, 1), requires_grad=False)
    # self.prior_cov_diag = nn.Parameter(torch.ones(num_classes, dim), requires_grad=False)
    self.register_buffer('prior_cov_factor',
                         torch.zeros(num_classes, dim, 1))
    self.register_buffer('prior_cov_diag', torch.ones(
        num_classes, dim) * self.prior_variance)

    # create variational distribution.
    self.variational_mean_flexible = nn.Parameter(
        torch.randn(num_classes, dim), requires_grad=True)
    self.variational_logstd = nn.Parameter(
        torch.randn(num_classes, dim), requires_grad=True)

    self.register_buffer('variational_cov_factor',
                         torch.zeros(num_classes, dim, 1))

    self.variational_mean_fixed = None
__repr__(self) special

Constructs a string representation of the Bayesian coefficient object.

Returns:

Type Description
str

the string representation of the Bayesian coefficient object.

Source code in bemb/model/bayesian_coefficient.py
def __repr__(self) -> str:
    """Constructs a string representation of the Bayesian coefficient object.

    Returns:
        str: the string representation of the Bayesian coefficient object.
    """
    if self.obs2prior:
        prior_str = f'prior=N(H*X_obs(H shape={self.prior_H.prior_zero_mean.shape}, X_obs shape={self.prior_H.dim}), Ix{self.prior_variance})'
    else:
        prior_str = f'prior=N(0, I)'
    return f'BayesianCoefficient(num_classes={self.num_classes}, dimension={self.dim}, {prior_str})'
log_prior(self, sample, H_sample=None, x_obs=None)

Computes the logP_{Prior}(Coefficient Sample) for provided samples of the coefficient. The prior will either be a zero-mean Gaussian (if obs2prior is False) or a Gaussian with a learnable mean (if obs2prior is True).

Parameters:

Name Type Description Default
sample torch.Tensor

Monte Carlo samples of the variable with shape (num_seeds, num_classes, dim), where sample[i, :, :] corresponds to one sample of the coefficient.

required
# arguments required only if `obs2prior == True` required
H_sample Optional[torch.Tensor]

Monte Carlo samples of the weight in obs2prior term, with shape (num_seeds, dim, self.num_obs), this is required if and only if obs2prior == True. Defaults to None.

None
x_obs Optional[torch.Tensor]

observables for obs2prior with shape (num_classes, num_obs), only required if and only if obs2prior == True. Defaults to None.

None

Returns:

Type Description
torch.Tensor

the log prior of the variable with shape (num_seeds, num_classes).

Source code in bemb/model/bayesian_coefficient.py
def log_prior(self,
              sample: torch.Tensor,
              H_sample: Optional[torch.Tensor] = None,
              x_obs: Optional[torch.Tensor] = None) -> torch.Tensor:
    """
    Computes the logP_{Prior}(Coefficient Sample) for provided samples of the coefficient. The prior will either be a
    zero-mean Gaussian (if `obs2prior` is False) or a Gaussian with a learnable mean (if `obs2prior` is True).

    Args:
        sample (torch.Tensor): Monte Carlo samples of the variable with shape (num_seeds, num_classes, dim), where
            sample[i, :, :] corresponds to one sample of the coefficient.

        # arguments required only if `obs2prior == True`:
        H_sample (Optional[torch.Tensor], optional): Monte Carlo samples of the weight in obs2prior term, with shape
            (num_seeds, dim, self.num_obs), this is required if and only if obs2prior == True.
            Defaults to None.
        x_obs (Optional[torch.Tensor], optional): observables for obs2prior with shape (num_classes, num_obs),
            only required if and only if obs2prior == True.
            Defaults to None.

    Returns:
        torch.Tensor: the log prior of the variable with shape (num_seeds, num_classes).
    """
    # p(sample)
    num_seeds, num_classes, dim = sample.shape
    # shape (num_seeds, num_classes)
    if self.obs2prior:
        assert H_sample.shape == (num_seeds, dim, self.num_obs)
        assert x_obs.shape == (num_classes, self.num_obs)
        x_obs = x_obs.view(1, num_classes, self.num_obs).expand(
            num_seeds, -1, -1)
        H_sample = torch.transpose(H_sample, 1, 2)
        assert H_sample.shape == (num_seeds, self.num_obs, dim)
        mu = torch.bmm(x_obs, H_sample)
        assert mu.shape == (num_seeds, num_classes, dim)

    else:
        mu = self.prior_zero_mean
    out = LowRankMultivariateNormal(loc=mu,
                                    cov_factor=self.prior_cov_factor,
                                    cov_diag=self.prior_cov_diag).log_prob(sample)
    assert out.shape == (num_seeds, num_classes)
    return out
log_variational(self, sample)

Given a set of sampled values of coefficients, with shape (num_seeds, num_classes, dim), computes the the log probability of these sampled values of coefficients under the current variational distribution.

Parameters:

Name Type Description Default
sample torch.Tensor

a tensor of shape (num_seeds, num_classes, dim) containing sampled values of coefficients, where sample[i, :, :] corresponds to one sample of the coefficient.

required

Returns:

Type Description
torch.Tensor

a tensor of shape (num_seeds, num_classes) containing the log probability of provided samples under the variational distribution. The output is splitted by random seeds and classes, you can sum along the second axis (i.e., the num_classes axis) to get the total log probability.

Source code in bemb/model/bayesian_coefficient.py
def log_variational(self, sample: torch.Tensor) -> torch.Tensor:
    """Given a set of sampled values of coefficients, with shape (num_seeds, num_classes, dim), computes the
        the log probability of these sampled values of coefficients under the current variational distribution.

    Args:
        sample (torch.Tensor): a tensor of shape (num_seeds, num_classes, dim) containing sampled values of coefficients,
            where sample[i, :, :] corresponds to one sample of the coefficient.

    Returns:
        torch.Tensor: a tensor of shape (num_seeds, num_classes) containing the log probability of provided samples
            under the variational distribution. The output is splitted by random seeds and classes, you can sum
            along the second axis (i.e., the num_classes axis) to get the total log probability.
    """
    num_seeds, num_classes, dim = sample.shape
    out = self.variational_distribution.log_prob(sample)
    assert out.shape == (num_seeds, num_classes)
    return out
rsample(self, num_seeds=1)

Samples values of the coefficient from the variational distribution using re-parameterization trick.

Parameters:

Name Type Description Default
num_seeds int

number of values to be sampled. Defaults to 1.

1

Returns:

Type Description
Union[torch.Tensor, Tuple[torch.Tensor]]

if obs2prior is disabled, returns a tensor of shape (num_seeds, num_classes, dim) where each output[i, :, :] corresponds to one sample of the coefficient. If obs2prior is enabled, returns a tuple of samples: (1) a tensor of shape (num_seeds, num_classes, dim) containing sampled values of coefficient, and (2) a tensor o shape (num_seeds, dim, num_obs) containing samples of the H weight in the prior distribution.

Source code in bemb/model/bayesian_coefficient.py
def rsample(self, num_seeds: int = 1) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
    """Samples values of the coefficient from the variational distribution using re-parameterization trick.

    Args:
        num_seeds (int, optional): number of values to be sampled. Defaults to 1.

    Returns:
        Union[torch.Tensor, Tuple[torch.Tensor]]: if `obs2prior` is disabled, returns a tensor of shape (num_seeds, num_classes, dim)
            where each output[i, :, :] corresponds to one sample of the coefficient.
            If `obs2prior` is enabled, returns a tuple of samples: (1) a tensor of shape (num_seeds, num_classes, dim) containing
            sampled values of coefficient, and (2) a tensor o shape (num_seeds, dim, num_obs) containing samples of the H weight
            in the prior distribution.
    """
    value_sample = self.variational_distribution.rsample(
        torch.Size([num_seeds]))
    if self.obs2prior:
        # sample obs2prior H as well.
        H_sample = self.prior_H.rsample(num_seeds=num_seeds)
        return (value_sample, H_sample)
    else:
        return value_sample
update_variational_mean_fixed(self, new_value)

Updates the fixed part of the mean of the variational distribution.

Parameters:

Name Type Description Default
new_value torch.Tensor

the new value of the fixed part of the mean of the variational distribution.

required
Source code in bemb/model/bayesian_coefficient.py
def update_variational_mean_fixed(self, new_value: torch.Tensor) -> None:
    """Updates the fixed part of the mean of the variational distribution.

    Args:
        new_value (torch.Tensor): the new value of the fixed part of the mean of the variational distribution.
    """
    assert new_value.shape == self.variational_mean_flexible.shape
    del self.variational_mean_fixed
    self.register_buffer('variational_mean_fixed', new_value)

bayesian_linear

Bayesian tensor object.

BayesianLinear (Module)

Source code in bemb/model/bayesian_linear.py
class BayesianLinear(nn.Module):
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 bias: bool=True,
                 W_variational_mean_fixed: Optional[torch.Tensor]=None,
                 device=None,
                 dtype=None,
                 W_prior_variance: float=1.0,
                 b_prior_variance: float=1.0
                 ):
        """Linear layer where weight and bias are modelled as distributions.
        """
        super().__init__()
        if dtype is not None:
            raise NotImplementedError('dtype is not Supported yet.')

        self.in_features = in_features  # the same as number of classes before.
        self.out_features = out_features  # the same as latent dimension before.
        self.bias = bias

        # ==============================================================================================================
        # prior distributions for mean and bias.
        # ==============================================================================================================
        # the prior of weights are gausssian distributions independent across in_feature dimensions.
        self.register_buffer('W_prior_mean', torch.zeros(in_features, out_features))
        self.register_buffer('W_prior_logstd', torch.ones(in_features, out_features) * np.log(W_prior_variance))

        if self.bias:
            self.register_buffer('b_prior_mean', torch.zeros(in_features, out_features))
            self.register_buffer('b_prior_logstd', torch.ones(in_features, out_features) * np.log(b_prior_variance))

        # ==============================================================================================================
        # variational distributions for weight and bias.
        # ==============================================================================================================
        if W_variational_mean_fixed is None:
            self.W_variational_mean_fixed = None
        else:
            assert W_variational_mean_fixed.shape == (in_features, out_features), \
                f'W_variational_mean_fixed tensor should have shape (in_features, out_features), got {W_variational_mean_fixed.shape}'
            self.register_buffer('W_variational_mean_fixed', W_variational_mean_fixed)

        # TODO: optionally add customizable initialization here.
        self.W_variational_mean_flexible = nn.Parameter(torch.randn(in_features, out_features), requires_grad=True)
        self.W_variational_logstd = nn.Parameter(torch.randn(in_features, out_features), requires_grad=True)

        if self.bias:
            self.b_variational_mean = nn.Parameter(torch.randn(out_features), requires_grad=True)
            self.b_variational_logstd = nn.Parameter(torch.randn(out_features), requires_grad=True)

        if device is not None:
            self.to(device)

        self.W_sample = None
        self.b_sample = None
        self.num_seeds = None

    @property
    def W_variational_mean(self):
        if self.W_variational_mean_fixed is None:
            return self.W_variational_mean_flexible
        else:
            return self.W_variational_mean_fixed + self.W_variational_mean_flexible

    def rsample(self, num_seeds: int=1) -> Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]:
        """sample all parameters using re-parameterization trick.
        """
        self.num_seeds = num_seeds
        self.W_sample = self.W_variational_distribution.rsample(torch.Size([num_seeds]))

        if self.bias:
            self.b_sample = self.b_variational_distribution.rsample(torch.Size([num_seeds]))

        return self.W_sample, self.b_sample

    def dsample(self):
        """Deterministic sample method, set (W, b) sample to the mean of variational distribution."""
        self.num_seeds = 1
        self.W_sample = self.W_variational_mean.unsqueeze(dim=0)
        if self.bias:
            self.b_sample = self.b_variational_mean.unsqueeze(dim=0)
        return self.W_sample, self.b_sample

    def forward(self, x, mode: str='multiply'):
        """
        Forward with weight sampling. Forward does out = XW + b, for forward() method behaves like the embedding layer
        in PyTorch, use the lookup() method.
        To have determinstic results, call self.dsample() before executing.
        To have stochastic results, call self.rsample() before executing.
        mode in ['multiply', 'lookup']

        output shape: (num_seeds, batch_size, out_features).
        """
        assert self.num_seeds is not None, 'run BayesianLinear.rsample() or dsample() first to sample weight and bias.'

        # if determinstic, num_seeds is set to 1.
        # w: (num_seeds, in_features=num_classes, out_features)
        # b: (num_seeds, out_features)
        # x: (N, in_features) if multiply and (N,) if lookup.
        # output: (num_seeds, N, out_features)

        if mode == 'multiply':
            x = x.view(1, -1, self.in_features).expand(self.num_seeds, -1, -1)  # (num_seeds, N, in_features)
            out = x.bmm(self.W_sample)  # (num_seeds, N, out_features)
        elif mode == 'lookup':
            out = self.W_sample[:, x, :]  # (num_seeds, N, out_features)
        else:
            raise ValueError(f'mode={mode} is not allowed.')

        if self.bias:
            out += self.b_sample.view(self.num_seeds, 1, self.out_features)

        # (num_seeds, N, out_features)
        return out

    @property
    def W_variational_distribution(self):
        """the weight variational distribution."""
        return Normal(loc=self.W_variational_mean, scale=torch.exp(self.W_variational_logstd))

    @property
    def b_variational_distribution(self):
        return Normal(loc=self.b_variational_mean, scale=torch.exp(self.b_variational_logstd))

    @property
    def device(self) -> torch.device:
        return self.W_variational_mean.device

    def log_prior(self):
        """Evaluate the likelihood of the provided samples of parameter under the current prior distribution."""
        assert self.num_seeds is not None, 'run BayesianLinear.rsample() or dsample() first to sample weight and bias.'
        num_seeds = self.W_sample.shape[0]
        total_log_prob = torch.zeros(num_seeds, device=self.device)
        # log P(W_sample). shape = (num_seeds,)
        W_prior = Normal(loc=self.W_prior_mean, scale=torch.exp(self.W_prior_logstd))
        total_log_prob += W_prior.log_prob(self.W_sample).sum(dim=[1, 2])

        # log P(b_sample) if applicable.
        if self.bias:
            b_prior = Normal(loc=self.b_prior_mean, scale=torch.exp(self.b_prior_logstd))
            total_log_prob += b_prior.log_prob(self.b_sample).sum(dim=1)

        assert total_log_prob.shape == (num_seeds,)
        return total_log_prob

    def log_variational(self):
        """Evaluate the likelihood of the provided samples of parameter under the current variational distribution."""
        assert self.num_seeds is not None, 'run BayesianLinear.rsample() or dsample() first to sample weight and bias.'
        num_seeds = self.W_sample.shape[0]

        total_log_prob = torch.zeros(num_seeds, device=self.device)
        total_log_prob += self.W_variational_distribution.log_prob(self.W_sample).sum(dim=[1, 2])
        if self.bias:
            total_log_prob += self.b_variational_distribution.log_prob(self.b_sample).sum(dim=1)
        assert total_log_prob.shape == (num_seeds,)
        return total_log_prob

    def __repr__(self):
        prior_info = f'W_prior ~ N(mu={self.W_prior_mean}, logstd={self.W_prior_logstd})'
        if self.bias:
            prior_info += f'b_prior ~ N(mu={self.b_prior_mean}, logstd={self.b_prior_logstd})'
        return f"BayesianLinear(in_features={self.in_features}, out_features={self.out_features}, bias={self.bias}, {prior_info})"
W_variational_distribution property readonly

the weight variational distribution.

__init__(self, in_features, out_features, bias=True, W_variational_mean_fixed=None, device=None, dtype=None, W_prior_variance=1.0, b_prior_variance=1.0) special

Linear layer where weight and bias are modelled as distributions.

Source code in bemb/model/bayesian_linear.py
def __init__(self,
             in_features: int,
             out_features: int,
             bias: bool=True,
             W_variational_mean_fixed: Optional[torch.Tensor]=None,
             device=None,
             dtype=None,
             W_prior_variance: float=1.0,
             b_prior_variance: float=1.0
             ):
    """Linear layer where weight and bias are modelled as distributions.
    """
    super().__init__()
    if dtype is not None:
        raise NotImplementedError('dtype is not Supported yet.')

    self.in_features = in_features  # the same as number of classes before.
    self.out_features = out_features  # the same as latent dimension before.
    self.bias = bias

    # ==============================================================================================================
    # prior distributions for mean and bias.
    # ==============================================================================================================
    # the prior of weights are gausssian distributions independent across in_feature dimensions.
    self.register_buffer('W_prior_mean', torch.zeros(in_features, out_features))
    self.register_buffer('W_prior_logstd', torch.ones(in_features, out_features) * np.log(W_prior_variance))

    if self.bias:
        self.register_buffer('b_prior_mean', torch.zeros(in_features, out_features))
        self.register_buffer('b_prior_logstd', torch.ones(in_features, out_features) * np.log(b_prior_variance))

    # ==============================================================================================================
    # variational distributions for weight and bias.
    # ==============================================================================================================
    if W_variational_mean_fixed is None:
        self.W_variational_mean_fixed = None
    else:
        assert W_variational_mean_fixed.shape == (in_features, out_features), \
            f'W_variational_mean_fixed tensor should have shape (in_features, out_features), got {W_variational_mean_fixed.shape}'
        self.register_buffer('W_variational_mean_fixed', W_variational_mean_fixed)

    # TODO: optionally add customizable initialization here.
    self.W_variational_mean_flexible = nn.Parameter(torch.randn(in_features, out_features), requires_grad=True)
    self.W_variational_logstd = nn.Parameter(torch.randn(in_features, out_features), requires_grad=True)

    if self.bias:
        self.b_variational_mean = nn.Parameter(torch.randn(out_features), requires_grad=True)
        self.b_variational_logstd = nn.Parameter(torch.randn(out_features), requires_grad=True)

    if device is not None:
        self.to(device)

    self.W_sample = None
    self.b_sample = None
    self.num_seeds = None
dsample(self)

Deterministic sample method, set (W, b) sample to the mean of variational distribution.

Source code in bemb/model/bayesian_linear.py
def dsample(self):
    """Deterministic sample method, set (W, b) sample to the mean of variational distribution."""
    self.num_seeds = 1
    self.W_sample = self.W_variational_mean.unsqueeze(dim=0)
    if self.bias:
        self.b_sample = self.b_variational_mean.unsqueeze(dim=0)
    return self.W_sample, self.b_sample
forward(self, x, mode='multiply')

Forward with weight sampling. Forward does out = XW + b, for forward() method behaves like the embedding layer in PyTorch, use the lookup() method. To have determinstic results, call self.dsample() before executing. To have stochastic results, call self.rsample() before executing. mode in ['multiply', 'lookup']

output shape: (num_seeds, batch_size, out_features).

Source code in bemb/model/bayesian_linear.py
def forward(self, x, mode: str='multiply'):
    """
    Forward with weight sampling. Forward does out = XW + b, for forward() method behaves like the embedding layer
    in PyTorch, use the lookup() method.
    To have determinstic results, call self.dsample() before executing.
    To have stochastic results, call self.rsample() before executing.
    mode in ['multiply', 'lookup']

    output shape: (num_seeds, batch_size, out_features).
    """
    assert self.num_seeds is not None, 'run BayesianLinear.rsample() or dsample() first to sample weight and bias.'

    # if determinstic, num_seeds is set to 1.
    # w: (num_seeds, in_features=num_classes, out_features)
    # b: (num_seeds, out_features)
    # x: (N, in_features) if multiply and (N,) if lookup.
    # output: (num_seeds, N, out_features)

    if mode == 'multiply':
        x = x.view(1, -1, self.in_features).expand(self.num_seeds, -1, -1)  # (num_seeds, N, in_features)
        out = x.bmm(self.W_sample)  # (num_seeds, N, out_features)
    elif mode == 'lookup':
        out = self.W_sample[:, x, :]  # (num_seeds, N, out_features)
    else:
        raise ValueError(f'mode={mode} is not allowed.')

    if self.bias:
        out += self.b_sample.view(self.num_seeds, 1, self.out_features)

    # (num_seeds, N, out_features)
    return out
log_prior(self)

Evaluate the likelihood of the provided samples of parameter under the current prior distribution.

Source code in bemb/model/bayesian_linear.py
def log_prior(self):
    """Evaluate the likelihood of the provided samples of parameter under the current prior distribution."""
    assert self.num_seeds is not None, 'run BayesianLinear.rsample() or dsample() first to sample weight and bias.'
    num_seeds = self.W_sample.shape[0]
    total_log_prob = torch.zeros(num_seeds, device=self.device)
    # log P(W_sample). shape = (num_seeds,)
    W_prior = Normal(loc=self.W_prior_mean, scale=torch.exp(self.W_prior_logstd))
    total_log_prob += W_prior.log_prob(self.W_sample).sum(dim=[1, 2])

    # log P(b_sample) if applicable.
    if self.bias:
        b_prior = Normal(loc=self.b_prior_mean, scale=torch.exp(self.b_prior_logstd))
        total_log_prob += b_prior.log_prob(self.b_sample).sum(dim=1)

    assert total_log_prob.shape == (num_seeds,)
    return total_log_prob
log_variational(self)

Evaluate the likelihood of the provided samples of parameter under the current variational distribution.

Source code in bemb/model/bayesian_linear.py
def log_variational(self):
    """Evaluate the likelihood of the provided samples of parameter under the current variational distribution."""
    assert self.num_seeds is not None, 'run BayesianLinear.rsample() or dsample() first to sample weight and bias.'
    num_seeds = self.W_sample.shape[0]

    total_log_prob = torch.zeros(num_seeds, device=self.device)
    total_log_prob += self.W_variational_distribution.log_prob(self.W_sample).sum(dim=[1, 2])
    if self.bias:
        total_log_prob += self.b_variational_distribution.log_prob(self.b_sample).sum(dim=1)
    assert total_log_prob.shape == (num_seeds,)
    return total_log_prob
rsample(self, num_seeds=1)

sample all parameters using re-parameterization trick.

Source code in bemb/model/bayesian_linear.py
def rsample(self, num_seeds: int=1) -> Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]:
    """sample all parameters using re-parameterization trick.
    """
    self.num_seeds = num_seeds
    self.W_sample = self.W_variational_distribution.rsample(torch.Size([num_seeds]))

    if self.bias:
        self.b_sample = self.b_variational_distribution.rsample(torch.Size([num_seeds]))

    return self.W_sample, self.b_sample

bemb

The core class of the Bayesian EMBedding (BEMB) model.

Author: Tianyu Du Update: Apr. 28, 2022

BEMBFlex (Module)

Source code in bemb/model/bemb.py
class BEMBFlex(nn.Module):
    # ==================================================================================================================
    # core function as a PyTorch module.
    # ==================================================================================================================
    def __init__(self,
                 utility_formula: str,
                 obs2prior_dict: Dict[str, bool],
                 coef_dim_dict: Dict[str, int],
                 num_items: int,
                 pred_item: bool,
                 prior_mean: Union[float, Dict[str, float]] = 0.0,
                 default_prior_mean: float = 0.0,
                 prior_variance: Union[float, Dict[str, float]] = 1.0,
                 num_users: Optional[int] = None,
                 num_sessions: Optional[int] = None,
                 trace_log_q: bool = False,
                 category_to_item: Dict[int, List[int]] = None,
                 # number of observables.
                 num_user_obs: Optional[int] = None,
                 num_item_obs: Optional[int] = None,
                 num_session_obs: Optional[int] = None,
                 num_price_obs: Optional[int] = None,
                 num_taste_obs: Optional[int] = None,
                 # additional modules.
                 additional_modules: Optional[List[nn.Module]] = None
                 ) -> None:
        """
        Args:
            utility_formula (str): a string representing the utility function U[user, item, session].
                See documentation for more details in the documentation for the format of formula.
                Examples:
                    lambda_item
                    lambda_item + theta_user * alpha_item + zeta_user * item_obs
                    lambda_item + theta_user * alpha_item + gamma_user * beta_item * price_obs
                See the doc-string of parse_utility for an example.

            obs2prior_dict (Dict[str, bool]): a dictionary maps coefficient name (e.g., 'lambda_item')
                to a boolean indicating if observable (e.g., item_obs) enters the prior of the coefficient.

            coef_dim_dict (Dict[str, int]): a dictionary maps coefficient name (e.g., 'lambda_item')
                to an integer indicating the dimension of coefficient.
                For standalone coefficients like U = lambda_item, the dim should be 1.
                For factorized coefficients like U = theta_user * alpha_item, the dim should be the
                    latent dimension of theta and alpha.
                For coefficients multiplied with observables like U = zeta_user * item_obs, the dim
                    should be the number of observables in item_obs.
                For factorized coefficient multiplied with observables like U = gamma_user * beta_item * price_obs,
                    the dim should be the latent dim multiplied by number of observables in price_obs.

            num_items (int): number of items.

            pred_item (bool): there are two use cases of this model, suppose we have `user_index[i]` and `item_index[i]`
                for the i-th observation in the dataset.
                Case 1: which item among all items user `user_index[i]` is going to purchase, the prediction label
                    is therefore `item_index[i]`. Equivalently, we can ask what's the likelihood for user `user_index[i]`
                    to purchase `item_index[i]`.
                Case 2: what rating would user `user_index[i]` assign to item `item_index[i]`? In this case, the dataset
                    object needs to contain a separate label.
                    NOTE: for now, we only support binary labels.

            default_prior_mean (float): the default prior mean for coefficients,
            if it is not specified in the prior_mean; defaults to 0.0.

            prior_mean (Union[float, Dict[str, float]]): the mean of prior
                distribution for coefficients. If a float is provided, all prior
                mean will be diagonal matrix with the provided value.  If a
                dictionary is provided, keys of prior_mean should be coefficient
                names, and the mean of prior of coef_name would the provided
                value Defaults to 0.0, which means all prior means are
                initalized to 0.0

            prior_variance (Union[float, Dict[str, float]]): the variance of prior distribution for
                coefficients. If a float is provided, all priors will be diagonal matrix with
                prior_variance along the diagonal. If a dictionary is provided, keys of prior_variance
                should be coefficient names, and the variance of prior of coef_name would be a diagonal
                matrix with prior_variance[coef_name] along the diagonal.
                Defaults to 1.0, which means all prior have identity matrix as the covariance matrix.

            num_users (int, optional): number of users, required only if coefficient or observable
                depending on user is in utility. Defaults to None.
            num_sessions (int, optional): number of sessions, required only if coefficient or
                observable depending on session is in utility. Defaults to None.

            trace_log_q (bool, optional): whether to trace the derivative of variational likelihood logQ
                with respect to variational parameters in the ELBO while conducting gradient update.
                Defaults to False.

            category_to_item (Dict[str, List[int]], optional): a dictionary with category id or name
                as keys, and category_to_item[C] contains the list of item ids belonging to category C.
                If None is provided, all items are assumed to be in the same category.
                Defaults to None.

            num_{user, item, session, price, taste}_obs (int, optional): number of observables of
                each type of features, only required if observable enters prior.
                NOTE: currently we only allow coefficient to depend on either user or item, thus only
                user and item observables can enter the prior of coefficient. Hence session, price,
                and taste observables are never required, we include it here for completeness.
        """
        super(BEMBFlex, self).__init__()
        self.utility_formula = utility_formula
        self.obs2prior_dict = obs2prior_dict
        self.coef_dim_dict = coef_dim_dict
        self.prior_variance = prior_variance
        self.default_prior_mean = default_prior_mean
        self.prior_mean = prior_mean

        self.pred_item = pred_item

        self.num_items = num_items
        self.num_users = num_users
        self.num_sessions = num_sessions

        self.trace_log_q = trace_log_q
        self.category_to_item = category_to_item

        # ==============================================================================================================
        # Category ID to Item ID mapping.
        # Category ID to Category Size mapping.
        # Item ID to Category ID mapping.
        # ==============================================================================================================
        if self.category_to_item is None:
            if self.pred_item:
                # assign all items to the same category if predicting items.
                self.category_to_item = {0: list(np.arange(self.num_items))}
            else:
                # otherwise, for the j-th observation in the dataset, the label[j]
                # only depends on user_index[j] and item_index[j], so we put each
                # item to its own category.
                self.category_to_item = {i: [i] for i in range(self.num_items)}

        self.num_categories = len(self.category_to_item)

        max_category_size = max(len(x) for x in self.category_to_item.values())
        category_to_item_tensor = torch.full(
            (self.num_categories, max_category_size), -1)
        category_to_size_tensor = torch.empty(self.num_categories)

        for c, item_in_c in self.category_to_item.items():
            category_to_item_tensor[c, :len(
                item_in_c)] = torch.LongTensor(item_in_c)
            category_to_size_tensor[c] = torch.scalar_tensor(len(item_in_c))

        self.register_buffer('category_to_item_tensor',
                             category_to_item_tensor.long())
        self.register_buffer('category_to_size_tensor',
                             category_to_size_tensor.long())

        item_to_category_tensor = torch.zeros(self.num_items)
        for c, items_in_c in self.category_to_item.items():
            item_to_category_tensor[items_in_c] = c
        self.register_buffer('item_to_category_tensor',
                             item_to_category_tensor.long())

        # ==============================================================================================================
        # Create Bayesian Coefficient Objects
        # ==============================================================================================================
        # model configuration.
        self.formula = parse_utility(utility_formula)
        print('BEMB: utility formula parsed:')
        pprint(self.formula)
        self.raw_formula = utility_formula
        self.obs2prior_dict = obs2prior_dict

        # dimension of each observable, this one is used only for obs2prior.
        self.num_obs_dict = {
            'user': num_user_obs,
            'item': num_item_obs,
            'category' : 0,
            'session': num_session_obs,
            'price': num_price_obs,
            'taste': num_taste_obs,
            'constant': 1  # not really used, for dummy variables.
        }

        # how many classes for the variational distribution.
        # for example, beta_item would be `num_items` 10-dimensional gaussian if latent dim = 10.
        variation_to_num_classes = {
            'user': self.num_users,
            'item': self.num_items,
            'constant': 1,
            'category' : self.num_categories,
        }

        coef_dict = dict()
        for additive_term in self.formula:
            for coef_name in additive_term['coefficient']:
                variation = coef_name.split('_')[-1]
                mean = self.prior_mean[coef_name] if isinstance(
                    self.prior_mean, dict) else self.default_prior_mean
                s2 = self.prior_variance[coef_name] if isinstance(
                    self.prior_variance, dict) else self.prior_variance
                coef_dict[coef_name] = BayesianCoefficient(variation=variation,
                                                           num_classes=variation_to_num_classes[variation],
                                                           obs2prior=self.obs2prior_dict[coef_name],
                                                           num_obs=self.num_obs_dict[variation],
                                                           dim=self.coef_dim_dict[coef_name],
                                                           prior_mean=mean,
                                                           prior_variance=s2)
        self.coef_dict = nn.ModuleDict(coef_dict)

        # ==============================================================================================================
        # Optional: register additional modules.
        # ==============================================================================================================
        if additional_modules is None:
            self.additional_modules = []
        else:
            raise NotImplementedError(
                'Additional modules are temporarily disabled for further development.')
            self.additional_modules = nn.ModuleList(additional_modules)

    def __str__(self):
        return f'Bayesian EMBedding Model with U[user, item, session] = {self.raw_formula}\n' \
               + f'Total number of parameters: {self.num_params}.\n' \
               + 'With the following coefficients:\n' \
               + str(self.coef_dict) + '\n' \
               + str(self.additional_modules)

    def posterior_mean(self, coef_name: str) -> torch.Tensor:
        """Returns the mean of estimated posterior distribution of coefficient `coef_name`.

        Args:
            coef_name (str): name of the coefficient to query.

        Returns:
            torch.Tensor: mean of the estimated posterior distribution of `coef_name`.
        """
        if coef_name in self.coef_dict.keys():
            return self.coef_dict[coef_name].variational_mean
        else:
            raise KeyError(f'{coef_name} is not a valid coefficient name in {self.utility_formula}.')

    def ivs(self, batch) -> torch.Tensor:
        """The combined method of computing utilities and log probability.

            Args:
                batch (dict): a batch of data.

            Returns:
                torch.Tensor: the combined utility and log probability.
            """
        # Use the means of variational distributions as the sole MC sample.
        sample_dict = dict()
        for coef_name, coef in self.coef_dict.items():
            sample_dict[coef_name] = coef.variational_distribution.mean.unsqueeze(dim=0)  # (1, num_*, dim)

        # there is 1 random seed in this case.
        # (num_seeds=1, len(batch), num_items)
        out = self.log_likelihood_all_items(batch, return_logit=True, sample_dict=sample_dict)
        out = out.squeeze(0)
        # import pdb; pdb.set_trace()
        ivs = scatter_logsumexp(out, self.item_to_category_tensor, dim=-1)
        return ivs # (len(batch), num_categories)

    def sample_choices(self, batch:ChoiceDataset, debug: bool = False, num_seeds: int = 1, **kwargs) -> Tuple[torch.Tensor]:
        """Samples choices given model paramaters and trips

        Args:
        batch(ChoiceDataset): batch data containing trip information; item choice information is discarded
        debug(bool): whether to print debug information

        Returns:
        Tuple[torch.Tensor]: sampled choices; shape: (batch_size, num_categories)
        """
        # Use the means of variational distributions as the sole MC sample.
        sample_dict = dict()
        for coef_name, coef in self.coef_dict.items():
            sample_dict[coef_name] = coef.variational_distribution.mean.unsqueeze(dim=0)  # (1, num_*, dim)
        # sample_dict = self.sample_coefficient_dictionary(num_seeds)
        maxes, out = self.sample_log_likelihoods(batch, sample_dict)
        return maxes.squeeze(), out.squeeze()

    def sample_log_likelihoods(self, batch:ChoiceDataset, sample_dict: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
        """Samples log likelihoods given model paramaters and trips

        Args:
        batch(ChoiceDataset): batch data containing trip information; item choice information is discarded
        sample_dict(Dict[str, torch.Tensor]): sampled coefficient values

        Returns:
        Tuple[torch.Tensor]: sampled log likelihoods; shape: (batch_size, num_categories)
        """
        # get the log likelihoods for all items for all categories
        utility = self.log_likelihood_all_items(batch, return_logit=True, sample_dict=sample_dict)
        mu_gumbel = 0.0
        beta_gumbel = 1.0
        EUL_MAS_CONST = 0.5772156649
        mean_gumbel = torch.tensor([mu_gumbel + beta_gumbel * EUL_MAS_CONST], device=self.device)
        m = torch.distributions.gumbel.Gumbel(torch.tensor([0.0], device=self.device), torch.tensor([1.0], device=self.device))
        # m = torch.distributions.gumbel.Gumbel(0.0, 1.0)
        gumbel_samples = m.sample(utility.shape).squeeze(-1)
        gumbel_samples -= mean_gumbel
        utility += gumbel_samples
        max_by_category, argmax_by_category = scatter_max(utility, self.item_to_category_tensor, dim=-1)
        return max_by_category, argmax_by_category
        log_likelihoods = self.sample_log_likelihoods_per_category(batch, sample_dict)

        # sum over all categories.
        log_likelihoods = log_likelihoods.sum(dim=1)

        return log_likelihoods, log_likelihoods

    def forward(self, batch: ChoiceDataset,
                return_type: str,
                return_scope: str,
                deterministic: bool = True,
                sample_dict: Optional[Dict[str, torch.Tensor]] = None,
                num_seeds: Optional[int] = None
                ) -> torch.Tensor:
        """A combined method for inference with the model.

        Args:
            batch (ChoiceDataset): batch data containing choice information.
            return_type (str): either 'log_prob' or 'utility'.
                'log_prob': return the log-probability (by within-category log-softmax) for items
                'utility': return the utility value of items.
            return_scope (str): either 'item_index' or 'all_items'.
                'item_index': for each observation i, return log-prob/utility for the chosen item batch.item_index[i] only.
                'all_items': for each observation i, return log-prob/utility for all items.
            deterministic (bool, optional):
                True: expectations of parameter variational distributions are used for inference.
                False: the user needs to supply a dictionary of sampled parameters for inference.
                Defaults to True.
            sample_dict (Optional[Dict[str, torch.Tensor]], optional): sampled parameters for inference task.
                This is not needed when `deterministic` is True.
                When `deterministic` is False, the user can supply a `sample_dict`. If `sample_dict` is not provided,
                this method will create `num_seeds` samples.
                Defaults to None.
            num_seeds (Optional[int]): the number of random samples of parameters to construct. This is only required
                if `deterministic` is False (i.e., stochastic mode) and `sample_dict` is not provided.
                Defaults to None.
        Returns:
            torch.Tensor: a tensor of log-probabilities or utilities, depending on `return_type`.
                The shape of the returned tensor depends on `return_scope` and `deterministic`.
                -------------------------------------------------------------------------
                | `return_scope` | `deterministic` |         Output shape               |
                -------------------------------------------------------------------------
                |   'item_index` |      True       | (len(batch),)                      |
                -------------------------------------------------------------------------
                |   'all_items'  |      True       | (len(batch), num_items)            |
                -------------------------------------------------------------------------
                |   'item_index' |      False      | (num_seeds, len(batch))            |
                -------------------------------------------------------------------------
                |   'all_items'  |      False      | (num_seeds, len(batch), num_items) |
                -------------------------------------------------------------------------
        """
        # ==============================================================================================================
        # check arguments.
        # ==============================================================================================================
        assert return_type in [
            'log_prob', 'utility'], "return_type must be either 'log_prob' or 'utility'."
        assert return_scope in [
            'item_index', 'all_items'], "return_scope must be either 'item_index' or 'all_items'."
        assert deterministic in [True, False]
        if (not deterministic) and (sample_dict is None):
            assert num_seeds >= 1, "A positive interger `num_seeds` is required if `deterministic` is False and no `sample_dict` is provided."

        # when pred_item is true, the model is predicting which item is bought (specified by item_index).
        if self.pred_item:
            batch.label = batch.item_index

        # ==============================================================================================================
        # get sample_dict ready.
        # ==============================================================================================================
        if deterministic:
            num_seeds = 1
            # Use the means of variational distributions as the sole deterministic MC sample.
            # NOTE: here we don't need to sample the obs2prior weight H since we only compute the log-likelihood.
            # TODO: is this correct?
            sample_dict = dict()
            for coef_name, coef in self.coef_dict.items():
                sample_dict[coef_name] = coef.variational_distribution.mean.unsqueeze(
                    dim=0)  # (1, num_*, dim)
        else:
            if sample_dict is None:
                # sample stochastic parameters.
                sample_dict = self.sample_coefficient_dictionary(num_seeds)
            else:
                # use the provided sample_dict.
                num_seeds = list(sample_dict.values())[0].shape[0]

        # ==============================================================================================================
        # call the sampling method of additional modules.
        # ==============================================================================================================
        for module in self.additional_modules:
            # deterministic sample.
            if deterministic:
                module.dsample()
            else:
                module.rsample(num_seeds=num_seeds)

        # if utility is requested, don't run log-softmax, simply return logit.
        return_logit = (return_type == 'utility')
        if return_scope == 'all_items':
            # (num_seeds, len(batch), num_items)
            out = self.log_likelihood_all_items(
                batch=batch, sample_dict=sample_dict, return_logit=return_logit)
        elif return_scope == 'item_index':
            # (num_seeds, len(batch))
            out = self.log_likelihood_item_index(
                batch=batch, sample_dict=sample_dict, return_logit=return_logit)

        if deterministic:
            # drop the first dimension, which has size of `num_seeds` (equals 1 in the deterministic case).
            # (len(batch), num_items) or (len(batch),)
            return out.squeeze(dim=0)

        return out

    @property
    def num_params(self) -> int:
        return sum([p.numel() for p in self.parameters()])

    @property
    def device(self) -> torch.device:
        for coef in self.coef_dict.values():
            return coef.device

    # ==================================================================================================================
    # helper functions.
    # ==================================================================================================================
    def sample_coefficient_dictionary(self, num_seeds: int) -> Dict[str, torch.Tensor]:
        """A helper function to sample parameters from coefficients.

        Args:
            num_seeds (int): number of random samples.

        Returns:
            Dict[str, torch.Tensor]: a dictionary maps coefficient names to tensor of sampled coefficient parameters,
                where the first dimension of the sampled tensor has size `num_seeds`.
                Each sample tensor has shape (num_seeds, num_classes, dim).
        """
        sample_dict = dict()
        for coef_name, coef in self.coef_dict.items():
            s = coef.rsample(num_seeds)
            if coef.obs2prior:
                # sample both obs2prior weight and realization of variable.
                assert isinstance(s, tuple) and len(s) == 2
                sample_dict[coef_name] = s[0]
                sample_dict[coef_name + '.H'] = s[1]
            else:
                # only sample the realization of variable.
                assert torch.is_tensor(s)
                sample_dict[coef_name] = s
        return sample_dict

    @torch.no_grad()
    def get_within_category_accuracy(self, log_p_all_items: torch.Tensor, label: torch.LongTensor) -> Dict[str, float]:
        """A helper function for computing prediction accuracy (i.e., all non-differential metrics)
        within category.
        In particular, this method calculates the accuracy, precision, recall and F1 score.


        This method has the same functionality as the following peusodcode:
        for C in categories:
            # get sessions in which item in category C was purchased.
            T <- (t for t in {0,1,..., len(label)-1} if label[t] is in C)
            Y <- label[T]

            predictions = list()
            for t in T:
                # get the prediction within category for this session.
                y_pred = argmax_{items in C} log prob computed before.
                predictions.append(y_pred)

            accuracy = mean(Y == predictions)

        Similarly, this function computes precision, recall and f1score as well.

        Args:
            log_p_all_items (torch.Tensor): shape (num_sessions, num_items) the log probability of
                choosing each item in each session.
            label (torch.LongTensor): shape (num_sessions,), the IDs of items purchased in each session.

        Returns:
            [Dict[str, float]]: A dictionary containing performance metrics.
        """
        # argmax: (num_sessions, num_categories), within category argmax.
        # item IDs are consecutive, thus argmax is the same as IDs of the item with highest P.
        _, argmax_by_category = scatter_max(
            log_p_all_items, self.item_to_category_tensor, dim=-1)

        # category_purchased[t] = the category of item label[t].
        # (num_sessions,)
        category_purchased = self.item_to_category_tensor[label]

        # pred[t] = the item with highest utility from the category item label[t] belongs to.
        # (num_sessions,)
        pred_from_category = argmax_by_category[torch.arange(
            len(label)), category_purchased]

        within_category_accuracy = (
            pred_from_category == label).float().mean().item()

        # precision
        precision = list()

        recall = list()
        for i in range(self.num_items):
            correct_i = torch.sum(
                (torch.logical_and(pred_from_category == i, label == i)).float())
            precision_i = correct_i / \
                torch.sum((pred_from_category == i).float())
            recall_i = correct_i / torch.sum((label == i).float())

            # do not add if divided by zero.
            if torch.any(pred_from_category == i):
                precision.append(precision_i.cpu().item())
            if torch.any(label == i):
                recall.append(recall_i.cpu().item())

        precision = float(np.mean(precision))
        recall = float(np.mean(recall))

        if precision == recall == 0:
            f1 = 0
        else:
            f1 = 2 * precision * recall / (precision + recall)

        return {'accuracy': within_category_accuracy,
                'precision': precision,
                'recall': recall,
                'f1score': f1}

    # ==================================================================================================================
    # Methods for terms in the ELBO: prior, likelihood, and variational.
    # ==================================================================================================================
    def log_likelihood_all_items(self, batch: ChoiceDataset, return_logit: bool, sample_dict: Dict[str, torch.Tensor]) -> torch.Tensor:
        """
        NOTE to developers:
        NOTE (akanodia to tianyudu): Is this really slow; even with log_likelihood you need log_prob which depends on logits of all items?
        This method computes utilities for all items available, which is a relatively slow operation. For
        training the model, you only need the utility/log-prob for the chosen/relevant item (i.e., item_index[i] for each i-th observation).
        Use this method for inference only.
        Use self.log_likelihood_item_index() for training instead.

        Computes the log probability of choosing `each` item in each session based on current model parameters.
        NOTE (akanodiadu to tianyudu): What does the next line mean? I think it just says its allowing for samples instead of posterior mean.
        This method allows for specifying {user, item}_latent_value for Monte Carlo estimation in ELBO.
        For actual prediction tasks, use the forward() function, which will use means of variational
        distributions for user and item latents.

        Args:
            batch (ChoiceDataset): a ChoiceDataset object containing relevant information.
            return_logit(bool): if set to True, return the log-probability, otherwise return the logit/utility.
            sample_dict(Dict[str, torch.Tensor]): Monte Carlo samples for model coefficients
                (i.e., those Greek letters).
                sample_dict.keys() should be the same as keys of self.obs2prior_dict, i.e., those
                greek letters actually enter the functional form of utility.
                The value of sample_dict should be tensors of shape (num_seeds, num_classes, dim)
                where num_classes in {num_users, num_items, 1}
                and dim in {latent_dim(K), num_item_obs, num_user_obs, 1}.

        Returns:
            torch.Tensor: a tensor of shape (num_seeds, len(batch), self.num_items), where
                out[x, y, z] is the probability of choosing item z in session y conditioned on
                latents to be the x-th Monte Carlo sample.
        """
        num_seeds = next(iter(sample_dict.values())).shape[0]

        # avoid repeated work when user purchased several items in the same session.
        user_session_index = torch.stack(
            [batch.user_index, batch.session_index])
        assert user_session_index.shape == (2, len(batch))
        unique_user_sess, inverse_indices = torch.unique(
            user_session_index, dim=1, return_inverse=True)

        user_index = unique_user_sess[0, :]
        session_index = unique_user_sess[1, :]
        assert len(user_index) == len(session_index)

        # short-hands for easier shape check.
        R = num_seeds
        # P = len(batch)  # num_purchases.
        P = unique_user_sess.shape[1]
        S = self.num_sessions
        U = self.num_users
        I = self.num_items
        NC = self.num_categories

        # ==============================================================================================================
        # Helper Functions for Reshaping.
        # ==============================================================================================================
        def reshape_user_coef_sample(C):
            # input shape (R, U, *)
            C = C.view(R, U, 1, -1).expand(-1, -1, I, -1)  # (R, U, I, *)
            C = C[:, user_index, :, :]
            assert C.shape == (R, P, I, positive_integer)
            return C

        def reshape_item_coef_sample(C):
            # input shape (R, I, *)
            C = C.view(R, 1, I, -1).expand(-1, P, -1, -1)
            assert C.shape == (R, P, I, positive_integer)
            return C

        def reshape_category_coef_sample(C):
            # input shape (R, NC, *)
            C = torch.repeat_interleave(C, self.category_to_size_tensor, dim=1)
            # input shape (R, I, *)
            C = C.view(R, 1, I, -1).expand(-1, P, -1, -1)
            assert C.shape == (R, P, I, positive_integer)
            return C

        def reshape_constant_coef_sample(C):
            # input shape (R, *)
            C = C.view(R, 1, 1, -1).expand(-1, P, I, -1)
            assert C.shape == (R, P, I, positive_integer)
            return C

        def reshape_coef_sample(sample, name):
            # reshape the monte carlo sample of coefficients to (R, P, I, *).
            if name.endswith('_user'):
                # (R, U, *) --> (R, P, I, *)
                return reshape_user_coef_sample(sample)
            elif name.endswith('_item'):
                # (R, I, *) --> (R, P, I, *)
                return reshape_item_coef_sample(sample)
            elif name.endswith('_category'):
                # (R, NC, *) --> (R, P, NC, *)
                return reshape_category_coef_sample(sample)
            elif name.endswith('_constant'):
                # (R, *) --> (R, P, I, *)
                return reshape_constant_coef_sample(sample)
            else:
                raise ValueError

        def reshape_observable(obs, name):
            # reshape observable to (R, P, I, *) so that it can be multiplied with monte carlo
            # samples of coefficients.
            O = obs.shape[-1]  # number of observables.
            assert O == positive_integer
            if name.startswith('item_'):
                assert obs.shape == (I, O)
                obs = obs.view(1, 1, I, O).expand(R, P, -1, -1)
            elif name.startswith('user_'):
                assert obs.shape == (U, O)
                obs = obs[user_index, :]  # (P, O)
                obs = obs.view(1, P, 1, O).expand(R, -1, I, -1)
            elif name.startswith('session_'):
                assert obs.shape == (S, O)
                obs = obs[session_index, :]  # (P, O)
                return obs.view(1, P, 1, O).expand(R, -1, I, -1)
            elif name.startswith('price_'):
                assert obs.shape == (S, I, O)
                obs = obs[session_index, :, :]  # (P, I, O)
                return obs.view(1, P, I, O).expand(R, -1, -1, -1)
            elif name.startswith('taste_'):
                assert obs.shape == (U, I, O)
                obs = obs[user_index, :, :]  # (P, I, O)
                return obs.view(1, P, I, O).expand(R, -1, -1, -1)
            else:
                raise ValueError
            assert obs.shape == (R, P, I, O)
            return obs

        # ==============================================================================================================
        # Copmute the Utility Term by Term.
        # ==============================================================================================================
        # P is the number of unique (user, session) pairs.
        # (random_seeds, P, num_items).
        utility = torch.zeros(R, P, I, device=self.device)

        # loop over additive term to utility
        for term in self.formula:
            # Type I: single coefficient, e.g., lambda_item or lambda_user.
            if len(term['coefficient']) == 1 and term['observable'] is None:
                # E.g., lambda_item or lambda_user
                coef_name = term['coefficient'][0]
                coef_sample = reshape_coef_sample(
                    sample_dict[coef_name], coef_name)
                assert coef_sample.shape == (R, P, I, 1)
                additive_term = coef_sample.view(R, P, I)

            # Type II: factorized coefficient, e.g., <theta_user, lambda_item>.
            elif len(term['coefficient']) == 2 and term['observable'] is None:
                coef_name_0 = term['coefficient'][0]
                coef_name_1 = term['coefficient'][1]

                coef_sample_0 = reshape_coef_sample(
                    sample_dict[coef_name_0], coef_name_0)
                coef_sample_1 = reshape_coef_sample(
                    sample_dict[coef_name_1], coef_name_1)

                assert coef_sample_0.shape == coef_sample_1.shape == (
                    R, P, I, positive_integer)

                additive_term = (coef_sample_0 * coef_sample_1).sum(dim=-1)

            # Type III: single coefficient multiplied by observable, e.g., theta_user * x_obs_item.
            elif len(term['coefficient']) == 1 and term['observable'] is not None:
                coef_name = term['coefficient'][0]
                coef_sample = reshape_coef_sample(
                    sample_dict[coef_name], coef_name)
                assert coef_sample.shape == (R, P, I, positive_integer)

                obs_name = term['observable']
                obs = reshape_observable(getattr(batch, obs_name), obs_name)
                assert obs.shape == (R, P, I, positive_integer)

                additive_term = (coef_sample * obs).sum(dim=-1)

            # Type IV: factorized coefficient multiplied by observable.
            # e.g., gamma_user * beta_item * price_obs.
            elif len(term['coefficient']) == 2 and term['observable'] is not None:
                coef_name_0, coef_name_1 = term['coefficient'][0], term['coefficient'][1]

                coef_sample_0 = reshape_coef_sample(
                    sample_dict[coef_name_0], coef_name_0)
                coef_sample_1 = reshape_coef_sample(
                    sample_dict[coef_name_1], coef_name_1)
                assert coef_sample_0.shape == coef_sample_1.shape == (
                    R, P, I, positive_integer)
                num_obs_times_latent_dim = coef_sample_0.shape[-1]

                obs_name = term['observable']
                obs = reshape_observable(getattr(batch, obs_name), obs_name)
                assert obs.shape == (R, P, I, positive_integer)
                num_obs = obs.shape[-1]  # number of observables.

                assert (num_obs_times_latent_dim % num_obs) == 0
                latent_dim = num_obs_times_latent_dim // num_obs

                coef_sample_0 = coef_sample_0.view(
                    R, P, I, num_obs, latent_dim)
                coef_sample_1 = coef_sample_1.view(
                    R, P, I, num_obs, latent_dim)
                # compute the factorized coefficient with shape (R, P, I, O).
                coef = (coef_sample_0 * coef_sample_1).sum(dim=-1)

                additive_term = (coef * obs).sum(dim=-1)

            else:
                raise ValueError(f'Undefined term type: {term}')

            assert additive_term.shape == (R, P, I)
            utility += additive_term

        # ==============================================================================================================
        # Mask Out Unavailable Items in Each Session.
        # ==============================================================================================================

        if batch.item_availability is not None:
            # expand to the Monte Carlo sample dimension.
            # (S, I) -> (P, I) -> (1, P, I) -> (R, P, I)
            A = batch.item_availability[session_index, :].unsqueeze(
                dim=0).expand(R, -1, -1)
            utility[~A] = - (torch.finfo(utility.dtype).max / 2)

        utility = utility[:, inverse_indices, :]
        assert utility.shape == (R, len(batch), I)

        for module in self.additional_modules:
            additive_term = module(batch)
            assert additive_term.shape == (R, len(batch), 1)
            utility += additive_term.expand(-1, -1, I)

        if return_logit:
            # output shape: (num_seeds, len(batch), num_items)
            return utility
        else:
            # compute log likelihood log p(choosing item i | user, item latents)
            # compute log softmax separately within each category.
            if self.pred_item:
                # output shape: (num_seeds, len(batch), num_items)
                log_p = scatter_log_softmax(
                    utility, self.item_to_category_tensor, dim=-1)
            else:
                log_p = torch.nn.functional.logsigmoid(utility)
            return log_p

    def log_likelihood_item_index(self, batch: ChoiceDataset, return_logit: bool, sample_dict: Dict[str, torch.Tensor]) -> torch.Tensor:
        """
        NOTE for developers:
        This method is more efficient and only computes log-likelihood/logit(utility) for item in item_index[i] for each
        i-th observation.
        Developers should use use `log_likelihood_all_items` for inference purpose and to computes log-likelihoods/utilities
        for ALL items for the i-th observation.

        Computes the log probability of choosing item_index[i] in each session based on current model parameters.
        This method allows for specifying {user, item}_latent_value for Monte Carlo estimation in ELBO.
        For actual prediction tasks, use the forward() function, which will use means of variational
        distributions for user and item latents.

        Args:
            batch (ChoiceDataset): a ChoiceDataset object containing relevant information.
            return_logit(bool): if set to True, return the log-probability, otherwise return the logit/utility.
            sample_dict(Dict[str, torch.Tensor]): Monte Carlo samples for model coefficients
                (i.e., those Greek letters).
                sample_dict.keys() should be the same as keys of self.obs2prior_dict, i.e., those
                greek letters actually enter the functional form of utility.
                The value of sample_dict should be tensors of shape (num_seeds, num_classes, dim)
                where num_classes in {num_users, num_items, 1}
                and dim in {latent_dim(K), num_item_obs, num_user_obs, 1}.

        Returns:
            torch.Tensor: a tensor of shape (num_seeds, len(batch)), where
                out[x, y] is the probabilities of choosing item batch.item[y] in session y
                conditioned on latents to be the x-th Monte Carlo sample.
        """
        num_seeds = next(iter(sample_dict.values())).shape[0]

        # get category id of the item bought in each row of batch.
        cate_index = self.item_to_category_tensor[batch.item_index]

        # get item ids of all items from the same category of each item bought.
        relevant_item_index = self.category_to_item_tensor[cate_index, :]
        relevant_item_index = relevant_item_index.view(-1,)
        # index were padded with -1's, drop those dummy entries.
        relevant_item_index = relevant_item_index[relevant_item_index != -1]

        # the first repeats[0] entries in relevant_item_index are for the category of item_index[0]
        repeats = self.category_to_size_tensor[cate_index]
        # argwhere(reverse_indices == k) are positions in relevant_item_index for the category of item_index[k].
        reverse_indices = torch.repeat_interleave(
            torch.arange(len(batch), device=self.device), repeats)
        # expand the user_index and session_index.
        user_index = torch.repeat_interleave(batch.user_index, repeats)
        repeat_category_index = torch.repeat_interleave(cate_index, repeats)
        session_index = torch.repeat_interleave(batch.session_index, repeats)
        # duplicate the item focused to match.
        item_index_expanded = torch.repeat_interleave(
            batch.item_index, repeats)

        # short-hands for easier shape check.
        R = num_seeds
        # total number of relevant items.
        total_computation = len(session_index)
        S = self.num_sessions
        U = self.num_users
        I = self.num_items
        NC = self.num_categories
        # ==========================================================================================
        # Helper Functions for Reshaping.
        # ==========================================================================================

        def reshape_coef_sample(sample, name):
            # reshape the monte carlo sample of coefficients to (R, P, I, *).
            if name.endswith('_user'):
                # (R, U, *) --> (R, total_computation, *)
                return sample[:, user_index, :]
            elif name.endswith('_item'):
                # (R, I, *) --> (R, total_computation, *)
                return sample[:, relevant_item_index, :]
            elif name.endswith('_category'):
                # (R, NC, *) --> (R, total_computation, *)
                return sample[:, repeat_category_index, :]
            elif name.endswith('_constant'):
                # (R, *) --> (R, total_computation, *)
                return sample.view(R, 1, -1).expand(-1, total_computation, -1)
            else:
                raise ValueError

        def reshape_observable(obs, name):
            # reshape observable to (R, P, I, *) so that it can be multiplied with monte carlo
            # samples of coefficients.
            O = obs.shape[-1]  # number of observables.
            assert O == positive_integer
            if name.startswith('item_'):
                assert obs.shape == (I, O)
                obs = obs[relevant_item_index, :]
            elif name.startswith('user_'):
                assert obs.shape == (U, O)
                obs = obs[user_index, :]
            elif name.startswith('session_'):
                assert obs.shape == (S, O)
                obs = obs[session_index, :]
            elif name.startswith('price_'):
                assert obs.shape == (S, I, O)
                obs = obs[session_index, relevant_item_index, :]
            elif name.startswith('taste_'):
                assert obs.shape == (U, I, O)
                obs = obs[user_index, relevant_item_index, :]
            else:
                raise ValueError
            assert obs.shape == (total_computation, O)
            return obs.unsqueeze(dim=0).expand(R, -1, -1)

        # ==========================================================================================
        # Compute Components related to users and items only.
        # ==========================================================================================
        utility = torch.zeros(R, total_computation, device=self.device)

        # loop over additive term to utility
        for term in self.formula:
            # Type I: single coefficient, e.g., lambda_item or lambda_user.
            if len(term['coefficient']) == 1 and term['observable'] is None:
                # E.g., lambda_item or lambda_user
                coef_name = term['coefficient'][0]
                coef_sample = reshape_coef_sample(
                    sample_dict[coef_name], coef_name)
                assert coef_sample.shape == (R, total_computation, 1)
                additive_term = coef_sample.view(R, total_computation)

            # Type II: factorized coefficient, e.g., <theta_user, lambda_item>.
            elif len(term['coefficient']) == 2 and term['observable'] is None:
                coef_name_0 = term['coefficient'][0]
                coef_name_1 = term['coefficient'][1]

                coef_sample_0 = reshape_coef_sample(
                    sample_dict[coef_name_0], coef_name_0)
                coef_sample_1 = reshape_coef_sample(
                    sample_dict[coef_name_1], coef_name_1)

                assert coef_sample_0.shape == coef_sample_1.shape == (
                    R, total_computation, positive_integer)

                additive_term = (coef_sample_0 * coef_sample_1).sum(dim=-1)

            # Type III: single coefficient multiplied by observable, e.g., theta_user * x_obs_item.
            elif len(term['coefficient']) == 1 and term['observable'] is not None:
                coef_name = term['coefficient'][0]
                coef_sample = reshape_coef_sample(
                    sample_dict[coef_name], coef_name)
                assert coef_sample.shape == (
                    R, total_computation, positive_integer)

                obs_name = term['observable']
                obs = reshape_observable(getattr(batch, obs_name), obs_name)
                assert obs.shape == (R, total_computation, positive_integer)

                additive_term = (coef_sample * obs).sum(dim=-1)

            # Type IV: factorized coefficient multiplied by observable.
            # e.g., gamma_user * beta_item * price_obs.
            elif len(term['coefficient']) == 2 and term['observable'] is not None:
                coef_name_0, coef_name_1 = term['coefficient'][0], term['coefficient'][1]
                coef_sample_0 = reshape_coef_sample(
                    sample_dict[coef_name_0], coef_name_0)
                coef_sample_1 = reshape_coef_sample(
                    sample_dict[coef_name_1], coef_name_1)
                assert coef_sample_0.shape == coef_sample_1.shape == (
                    R, total_computation, positive_integer)
                num_obs_times_latent_dim = coef_sample_0.shape[-1]

                obs_name = term['observable']
                obs = reshape_observable(getattr(batch, obs_name), obs_name)
                assert obs.shape == (R, total_computation, positive_integer)
                num_obs = obs.shape[-1]  # number of observables.

                assert (num_obs_times_latent_dim % num_obs) == 0
                latent_dim = num_obs_times_latent_dim // num_obs

                coef_sample_0 = coef_sample_0.view(
                    R, total_computation, num_obs, latent_dim)
                coef_sample_1 = coef_sample_1.view(
                    R, total_computation, num_obs, latent_dim)
                # compute the factorized coefficient with shape (R, P, I, O).
                coef = (coef_sample_0 * coef_sample_1).sum(dim=-1)

                additive_term = (coef * obs).sum(dim=-1)

            else:
                raise ValueError(f'Undefined term type: {term}')

            assert additive_term.shape == (R, total_computation)
            utility += additive_term

        # ==========================================================================================
        # Mask Out Unavailable Items in Each Session.
        # ==========================================================================================

        if batch.item_availability is not None:
            # expand to the Monte Carlo sample dimension.
            A = batch.item_availability[session_index, relevant_item_index].unsqueeze(
                dim=0).expand(R, -1)
            utility[~A] = - (torch.finfo(utility.dtype).max / 2)

        for module in self.additional_modules:
            # current utility shape: (R, total_computation)
            additive_term = module(batch)
            assert additive_term.shape == (
                R, len(batch)) or additive_term.shape == (R, len(batch), 1)
            if additive_term.shape == (R, len(batch), 1):
                # TODO: need to make this consistent with log_likelihood_all.
                # be tolerant for some customized module with BayesianLinear that returns (R, len(batch), 1).
                additive_term = additive_term.view(R, len(batch))
            # expand to total number of computation, query by reverse_indices.
            # reverse_indices has length total_computation, and reverse_indices[i] correspond to the row-id that this
            # computation is responsible for.
            additive_term = additive_term[:, reverse_indices]
            assert additive_term.shape == (R, total_computation)

        # compute log likelihood log p(choosing item i | user, item latents)
        if return_logit:
            log_p = utility
        else:
            if self.pred_item:
                # compute the log probability from logits/utilities.
                # output shape: (num_seeds, len(batch), num_items)
                log_p = scatter_log_softmax(utility, reverse_indices, dim=-1)
                # select the log-P of the item actually bought.
                log_p = log_p[:, item_index_expanded == relevant_item_index]
            else:
                # This is the binomial choice situation in which case we just report sigmoid log likelihood
                bce = nn.BCELoss(reduction='none')
                log_p = - bce(torch.sigmoid(utility.view(-1)), batch.label.to(torch.float32))
        return log_p

    def log_prior(self, batch: ChoiceDataset, sample_dict: Dict[str, torch.Tensor]) -> torch.Tensor:
        """Calculates the log-likelihood of Monte Carlo samples of Bayesian coefficients under their
        prior distribution. This method assume coefficients are statistically independent.

        Args:
            batch (ChoiceDataset): a dataset object contains observables for computing the prior distribution
                if obs2prior is True.
            sample_dict (Dict[str, torch.Tensor]): a dictionary coefficient names to Monte Carlo samples.

        Raises:
            ValueError: [description]

        Returns:
            torch.scalar_tensor: a tensor with shape (num_seeds,) of [ log P_{prior_distribution}(param[i]) ],
                where param[i] is the i-th Monte Carlo sample.
        """
        # assert sample_dict.keys() == self.coef_dict.keys()
        num_seeds = next(iter(sample_dict.values())).shape[0]

        total = torch.zeros(num_seeds, device=self.device)

        for coef_name, coef in self.coef_dict.items():
            if self.obs2prior_dict[coef_name]:
                if coef_name.endswith('_item'):
                    x_obs = batch.item_obs
                elif coef_name.endswith('_user'):
                    x_obs = batch.user_obs
                else:
                    raise ValueError(
                        f'No observable found to support obs2prior for {coef_name}.')

                total += coef.log_prior(sample=sample_dict[coef_name],
                                        H_sample=sample_dict[coef_name + '.H'],
                                        x_obs=x_obs).sum(dim=-1)
            else:
                # log_prob outputs (num_seeds, num_{items, users}), sum to (num_seeds).
                total += coef.log_prior(
                    sample=sample_dict[coef_name], H_sample=None, x_obs=None).sum(dim=-1)

        for module in self.additional_modules:
            raise NotImplementedError()
            total += module.log_prior()

        return total

    def log_variational(self, sample_dict: Dict[str, torch.Tensor]) -> torch.Tensor:
        """Calculate the log-likelihood of samples in sample_dict under the current variational
        distribution.

        Args:
            sample_dict (Dict[str, torch.Tensor]):  a dictionary coefficient names to Monte Carlo
                samples.

        Returns:
            torch.Tensor: a tensor of shape (num_seeds) of [ log P_{variational_distribution}(param[i]) ],
                where param[i] is the i-th Monte Carlo sample.
        """
        num_seeds = list(sample_dict.values())[0].shape[0]
        total = torch.zeros(num_seeds, device=self.device)

        for coef_name, coef in self.coef_dict.items():
            # log_prob outputs (num_seeds, num_{items, users}), sum to (num_seeds).
            total += coef.log_variational(sample_dict[coef_name]).sum(dim=-1)

        for module in self.additional_modules:
            raise NotImplementedError()
            # with shape (num_seeds,)
            total += module.log_variational().sum()

        return total

    def elbo(self, batch: ChoiceDataset, num_seeds: int = 1) -> torch.Tensor:
        """A combined method to computes the current ELBO given a batch, this method is used for training the model.

        Args:
            batch (ChoiceDataset): a ChoiceDataset containing necessary information.
            num_seeds (int, optional): the number of Monte Carlo samples from variational distributions
                to evaluate the expectation in ELBO.
                Defaults to 1.

        Returns:
            torch.Tensor: a scalar tensor of the ELBO estimated from num_seeds Monte Carlo samples.
        """
        # ==============================================================================================================
        # 1. sample latent variables from their variational distributions.
        # ==============================================================================================================
        sample_dict = self.sample_coefficient_dictionary(num_seeds)

        # ==============================================================================================================
        # 2. compute log p(latent) prior.
        # (num_seeds,) --mean--> scalar.
        elbo = self.log_prior(batch, sample_dict).mean(dim=0)
        # ==============================================================================================================

        # ==============================================================================================================
        # 3. compute the log likelihood log p(obs|latent).
        # sum over independent purchase decision for individual observations, mean over MC seeds.
        # the forward() function calls module.rsample(num_seeds) for module in self.additional_modules.
        # ==============================================================================================================
        if self.pred_item:
            # the prediction target is item_index.
            elbo += self.forward(batch,
                                 return_type='log_prob',
                                 return_scope='item_index',
                                 deterministic=False,
                                 sample_dict=sample_dict).sum(dim=1).mean(dim=0)  # (num_seeds, len(batch)) --> scalar.
        else:
            # the prediction target is binary.
            # TODO: update the prediction function.
            utility = self.forward(batch,
                                   return_type='utility',
                                   return_scope='item_index',
                                   deterministic=False,
                                   sample_dict=sample_dict)  # (num_seeds, len(batch))

            # compute the log-likelihood for binary label.
            # (num_seeds, len(batch))
            y_stacked = torch.stack([batch.label] * num_seeds).float()
            assert y_stacked.shape == utility.shape
            bce = nn.BCELoss(reduction='none')
            # scalar.
            ll = - bce(torch.sigmoid(utility),
                       y_stacked).sum(dim=1).mean(dim=0)
            elbo += ll

        # ==============================================================================================================
        # 4. optionally add log likelihood under variational distributions q(latent).
        # ==============================================================================================================
        if self.trace_log_q:
            elbo -= self.log_variational(sample_dict).mean(dim=0)

        return elbo
__init__(self, utility_formula, obs2prior_dict, coef_dim_dict, num_items, pred_item, prior_mean=0.0, default_prior_mean=0.0, prior_variance=1.0, num_users=None, num_sessions=None, trace_log_q=False, category_to_item=None, num_user_obs=None, num_item_obs=None, num_session_obs=None, num_price_obs=None, num_taste_obs=None, additional_modules=None) special

Parameters:

Name Type Description Default
utility_formula str

a string representing the utility function U[user, item, session]. See documentation for more details in the documentation for the format of formula. Examples: lambda_item lambda_item + theta_user * alpha_item + zeta_user * item_obs lambda_item + theta_user * alpha_item + gamma_user * beta_item * price_obs See the doc-string of parse_utility for an example.

required
obs2prior_dict Dict[str, bool]

a dictionary maps coefficient name (e.g., 'lambda_item') to a boolean indicating if observable (e.g., item_obs) enters the prior of the coefficient.

required
coef_dim_dict Dict[str, int]

a dictionary maps coefficient name (e.g., 'lambda_item') to an integer indicating the dimension of coefficient. For standalone coefficients like U = lambda_item, the dim should be 1. For factorized coefficients like U = theta_user * alpha_item, the dim should be the latent dimension of theta and alpha. For coefficients multiplied with observables like U = zeta_user * item_obs, the dim should be the number of observables in item_obs. For factorized coefficient multiplied with observables like U = gamma_user * beta_item * price_obs, the dim should be the latent dim multiplied by number of observables in price_obs.

required
num_items int

number of items.

required
pred_item bool

there are two use cases of this model, suppose we have user_index[i] and item_index[i] for the i-th observation in the dataset. Case 1: which item among all items user user_index[i] is going to purchase, the prediction label is therefore item_index[i]. Equivalently, we can ask what's the likelihood for user user_index[i] to purchase item_index[i]. Case 2: what rating would user user_index[i] assign to item item_index[i]? In this case, the dataset object needs to contain a separate label. NOTE: for now, we only support binary labels.

required
default_prior_mean float

the default prior mean for coefficients,

0.0
prior_mean Union[float, Dict[str, float]]

the mean of prior distribution for coefficients. If a float is provided, all prior mean will be diagonal matrix with the provided value. If a dictionary is provided, keys of prior_mean should be coefficient names, and the mean of prior of coef_name would the provided value Defaults to 0.0, which means all prior means are initalized to 0.0

0.0
prior_variance Union[float, Dict[str, float]]

the variance of prior distribution for coefficients. If a float is provided, all priors will be diagonal matrix with prior_variance along the diagonal. If a dictionary is provided, keys of prior_variance should be coefficient names, and the variance of prior of coef_name would be a diagonal matrix with prior_variance[coef_name] along the diagonal. Defaults to 1.0, which means all prior have identity matrix as the covariance matrix.

1.0
num_users int

number of users, required only if coefficient or observable depending on user is in utility. Defaults to None.

None
num_sessions int

number of sessions, required only if coefficient or observable depending on session is in utility. Defaults to None.

None
trace_log_q bool

whether to trace the derivative of variational likelihood logQ with respect to variational parameters in the ELBO while conducting gradient update. Defaults to False.

False
category_to_item Dict[str, List[int]]

a dictionary with category id or name as keys, and category_to_item[C] contains the list of item ids belonging to category C. If None is provided, all items are assumed to be in the same category. Defaults to None.

None
num_{user, item, session, price, taste}_obs (int

number of observables of each type of features, only required if observable enters prior. NOTE: currently we only allow coefficient to depend on either user or item, thus only user and item observables can enter the prior of coefficient. Hence session, price, and taste observables are never required, we include it here for completeness.

required
Source code in bemb/model/bemb.py
def __init__(self,
             utility_formula: str,
             obs2prior_dict: Dict[str, bool],
             coef_dim_dict: Dict[str, int],
             num_items: int,
             pred_item: bool,
             prior_mean: Union[float, Dict[str, float]] = 0.0,
             default_prior_mean: float = 0.0,
             prior_variance: Union[float, Dict[str, float]] = 1.0,
             num_users: Optional[int] = None,
             num_sessions: Optional[int] = None,
             trace_log_q: bool = False,
             category_to_item: Dict[int, List[int]] = None,
             # number of observables.
             num_user_obs: Optional[int] = None,
             num_item_obs: Optional[int] = None,
             num_session_obs: Optional[int] = None,
             num_price_obs: Optional[int] = None,
             num_taste_obs: Optional[int] = None,
             # additional modules.
             additional_modules: Optional[List[nn.Module]] = None
             ) -> None:
    """
    Args:
        utility_formula (str): a string representing the utility function U[user, item, session].
            See documentation for more details in the documentation for the format of formula.
            Examples:
                lambda_item
                lambda_item + theta_user * alpha_item + zeta_user * item_obs
                lambda_item + theta_user * alpha_item + gamma_user * beta_item * price_obs
            See the doc-string of parse_utility for an example.

        obs2prior_dict (Dict[str, bool]): a dictionary maps coefficient name (e.g., 'lambda_item')
            to a boolean indicating if observable (e.g., item_obs) enters the prior of the coefficient.

        coef_dim_dict (Dict[str, int]): a dictionary maps coefficient name (e.g., 'lambda_item')
            to an integer indicating the dimension of coefficient.
            For standalone coefficients like U = lambda_item, the dim should be 1.
            For factorized coefficients like U = theta_user * alpha_item, the dim should be the
                latent dimension of theta and alpha.
            For coefficients multiplied with observables like U = zeta_user * item_obs, the dim
                should be the number of observables in item_obs.
            For factorized coefficient multiplied with observables like U = gamma_user * beta_item * price_obs,
                the dim should be the latent dim multiplied by number of observables in price_obs.

        num_items (int): number of items.

        pred_item (bool): there are two use cases of this model, suppose we have `user_index[i]` and `item_index[i]`
            for the i-th observation in the dataset.
            Case 1: which item among all items user `user_index[i]` is going to purchase, the prediction label
                is therefore `item_index[i]`. Equivalently, we can ask what's the likelihood for user `user_index[i]`
                to purchase `item_index[i]`.
            Case 2: what rating would user `user_index[i]` assign to item `item_index[i]`? In this case, the dataset
                object needs to contain a separate label.
                NOTE: for now, we only support binary labels.

        default_prior_mean (float): the default prior mean for coefficients,
        if it is not specified in the prior_mean; defaults to 0.0.

        prior_mean (Union[float, Dict[str, float]]): the mean of prior
            distribution for coefficients. If a float is provided, all prior
            mean will be diagonal matrix with the provided value.  If a
            dictionary is provided, keys of prior_mean should be coefficient
            names, and the mean of prior of coef_name would the provided
            value Defaults to 0.0, which means all prior means are
            initalized to 0.0

        prior_variance (Union[float, Dict[str, float]]): the variance of prior distribution for
            coefficients. If a float is provided, all priors will be diagonal matrix with
            prior_variance along the diagonal. If a dictionary is provided, keys of prior_variance
            should be coefficient names, and the variance of prior of coef_name would be a diagonal
            matrix with prior_variance[coef_name] along the diagonal.
            Defaults to 1.0, which means all prior have identity matrix as the covariance matrix.

        num_users (int, optional): number of users, required only if coefficient or observable
            depending on user is in utility. Defaults to None.
        num_sessions (int, optional): number of sessions, required only if coefficient or
            observable depending on session is in utility. Defaults to None.

        trace_log_q (bool, optional): whether to trace the derivative of variational likelihood logQ
            with respect to variational parameters in the ELBO while conducting gradient update.
            Defaults to False.

        category_to_item (Dict[str, List[int]], optional): a dictionary with category id or name
            as keys, and category_to_item[C] contains the list of item ids belonging to category C.
            If None is provided, all items are assumed to be in the same category.
            Defaults to None.

        num_{user, item, session, price, taste}_obs (int, optional): number of observables of
            each type of features, only required if observable enters prior.
            NOTE: currently we only allow coefficient to depend on either user or item, thus only
            user and item observables can enter the prior of coefficient. Hence session, price,
            and taste observables are never required, we include it here for completeness.
    """
    super(BEMBFlex, self).__init__()
    self.utility_formula = utility_formula
    self.obs2prior_dict = obs2prior_dict
    self.coef_dim_dict = coef_dim_dict
    self.prior_variance = prior_variance
    self.default_prior_mean = default_prior_mean
    self.prior_mean = prior_mean

    self.pred_item = pred_item

    self.num_items = num_items
    self.num_users = num_users
    self.num_sessions = num_sessions

    self.trace_log_q = trace_log_q
    self.category_to_item = category_to_item

    # ==============================================================================================================
    # Category ID to Item ID mapping.
    # Category ID to Category Size mapping.
    # Item ID to Category ID mapping.
    # ==============================================================================================================
    if self.category_to_item is None:
        if self.pred_item:
            # assign all items to the same category if predicting items.
            self.category_to_item = {0: list(np.arange(self.num_items))}
        else:
            # otherwise, for the j-th observation in the dataset, the label[j]
            # only depends on user_index[j] and item_index[j], so we put each
            # item to its own category.
            self.category_to_item = {i: [i] for i in range(self.num_items)}

    self.num_categories = len(self.category_to_item)

    max_category_size = max(len(x) for x in self.category_to_item.values())
    category_to_item_tensor = torch.full(
        (self.num_categories, max_category_size), -1)
    category_to_size_tensor = torch.empty(self.num_categories)

    for c, item_in_c in self.category_to_item.items():
        category_to_item_tensor[c, :len(
            item_in_c)] = torch.LongTensor(item_in_c)
        category_to_size_tensor[c] = torch.scalar_tensor(len(item_in_c))

    self.register_buffer('category_to_item_tensor',
                         category_to_item_tensor.long())
    self.register_buffer('category_to_size_tensor',
                         category_to_size_tensor.long())

    item_to_category_tensor = torch.zeros(self.num_items)
    for c, items_in_c in self.category_to_item.items():
        item_to_category_tensor[items_in_c] = c
    self.register_buffer('item_to_category_tensor',
                         item_to_category_tensor.long())

    # ==============================================================================================================
    # Create Bayesian Coefficient Objects
    # ==============================================================================================================
    # model configuration.
    self.formula = parse_utility(utility_formula)
    print('BEMB: utility formula parsed:')
    pprint(self.formula)
    self.raw_formula = utility_formula
    self.obs2prior_dict = obs2prior_dict

    # dimension of each observable, this one is used only for obs2prior.
    self.num_obs_dict = {
        'user': num_user_obs,
        'item': num_item_obs,
        'category' : 0,
        'session': num_session_obs,
        'price': num_price_obs,
        'taste': num_taste_obs,
        'constant': 1  # not really used, for dummy variables.
    }

    # how many classes for the variational distribution.
    # for example, beta_item would be `num_items` 10-dimensional gaussian if latent dim = 10.
    variation_to_num_classes = {
        'user': self.num_users,
        'item': self.num_items,
        'constant': 1,
        'category' : self.num_categories,
    }

    coef_dict = dict()
    for additive_term in self.formula:
        for coef_name in additive_term['coefficient']:
            variation = coef_name.split('_')[-1]
            mean = self.prior_mean[coef_name] if isinstance(
                self.prior_mean, dict) else self.default_prior_mean
            s2 = self.prior_variance[coef_name] if isinstance(
                self.prior_variance, dict) else self.prior_variance
            coef_dict[coef_name] = BayesianCoefficient(variation=variation,
                                                       num_classes=variation_to_num_classes[variation],
                                                       obs2prior=self.obs2prior_dict[coef_name],
                                                       num_obs=self.num_obs_dict[variation],
                                                       dim=self.coef_dim_dict[coef_name],
                                                       prior_mean=mean,
                                                       prior_variance=s2)
    self.coef_dict = nn.ModuleDict(coef_dict)

    # ==============================================================================================================
    # Optional: register additional modules.
    # ==============================================================================================================
    if additional_modules is None:
        self.additional_modules = []
    else:
        raise NotImplementedError(
            'Additional modules are temporarily disabled for further development.')
        self.additional_modules = nn.ModuleList(additional_modules)
elbo(self, batch, num_seeds=1)

A combined method to computes the current ELBO given a batch, this method is used for training the model.

Parameters:

Name Type Description Default
batch ChoiceDataset

a ChoiceDataset containing necessary information.

required
num_seeds int

the number of Monte Carlo samples from variational distributions to evaluate the expectation in ELBO. Defaults to 1.

1

Returns:

Type Description
torch.Tensor

a scalar tensor of the ELBO estimated from num_seeds Monte Carlo samples.

Source code in bemb/model/bemb.py
def elbo(self, batch: ChoiceDataset, num_seeds: int = 1) -> torch.Tensor:
    """A combined method to computes the current ELBO given a batch, this method is used for training the model.

    Args:
        batch (ChoiceDataset): a ChoiceDataset containing necessary information.
        num_seeds (int, optional): the number of Monte Carlo samples from variational distributions
            to evaluate the expectation in ELBO.
            Defaults to 1.

    Returns:
        torch.Tensor: a scalar tensor of the ELBO estimated from num_seeds Monte Carlo samples.
    """
    # ==============================================================================================================
    # 1. sample latent variables from their variational distributions.
    # ==============================================================================================================
    sample_dict = self.sample_coefficient_dictionary(num_seeds)

    # ==============================================================================================================
    # 2. compute log p(latent) prior.
    # (num_seeds,) --mean--> scalar.
    elbo = self.log_prior(batch, sample_dict).mean(dim=0)
    # ==============================================================================================================

    # ==============================================================================================================
    # 3. compute the log likelihood log p(obs|latent).
    # sum over independent purchase decision for individual observations, mean over MC seeds.
    # the forward() function calls module.rsample(num_seeds) for module in self.additional_modules.
    # ==============================================================================================================
    if self.pred_item:
        # the prediction target is item_index.
        elbo += self.forward(batch,
                             return_type='log_prob',
                             return_scope='item_index',
                             deterministic=False,
                             sample_dict=sample_dict).sum(dim=1).mean(dim=0)  # (num_seeds, len(batch)) --> scalar.
    else:
        # the prediction target is binary.
        # TODO: update the prediction function.
        utility = self.forward(batch,
                               return_type='utility',
                               return_scope='item_index',
                               deterministic=False,
                               sample_dict=sample_dict)  # (num_seeds, len(batch))

        # compute the log-likelihood for binary label.
        # (num_seeds, len(batch))
        y_stacked = torch.stack([batch.label] * num_seeds).float()
        assert y_stacked.shape == utility.shape
        bce = nn.BCELoss(reduction='none')
        # scalar.
        ll = - bce(torch.sigmoid(utility),
                   y_stacked).sum(dim=1).mean(dim=0)
        elbo += ll

    # ==============================================================================================================
    # 4. optionally add log likelihood under variational distributions q(latent).
    # ==============================================================================================================
    if self.trace_log_q:
        elbo -= self.log_variational(sample_dict).mean(dim=0)

    return elbo
forward(self, batch, return_type, return_scope, deterministic=True, sample_dict=None, num_seeds=None)

A combined method for inference with the model.

Parameters:

Name Type Description Default
batch ChoiceDataset

batch data containing choice information.

required
return_type str

either 'log_prob' or 'utility'. 'log_prob': return the log-probability (by within-category log-softmax) for items 'utility': return the utility value of items.

required
return_scope str

either 'item_index' or 'all_items'. 'item_index': for each observation i, return log-prob/utility for the chosen item batch.item_index[i] only. 'all_items': for each observation i, return log-prob/utility for all items.

required
deterministic bool

True: expectations of parameter variational distributions are used for inference. False: the user needs to supply a dictionary of sampled parameters for inference. Defaults to True.

True
sample_dict Optional[Dict[str, torch.Tensor]]

sampled parameters for inference task. This is not needed when deterministic is True. When deterministic is False, the user can supply a sample_dict. If sample_dict is not provided, this method will create num_seeds samples. Defaults to None.

None
num_seeds Optional[int]

the number of random samples of parameters to construct. This is only required if deterministic is False (i.e., stochastic mode) and sample_dict is not provided. Defaults to None.

None

Returns:

Type Description
torch.Tensor

a tensor of log-probabilities or utilities, depending on return_type. The shape of the returned tensor depends on return_scope and deterministic. ------------------------------------------------------------------------- | return_scope | deterministic | Output shape | ------------------------------------------------------------------------- | 'item_index` | True | (len(batch),) | ------------------------------------------------------------------------- | 'all_items' | True | (len(batch), num_items) | ------------------------------------------------------------------------- | 'item_index' | False | (num_seeds, len(batch)) | ------------------------------------------------------------------------- | 'all_items' | False | (num_seeds, len(batch), num_items) | -------------------------------------------------------------------------

Source code in bemb/model/bemb.py
def forward(self, batch: ChoiceDataset,
            return_type: str,
            return_scope: str,
            deterministic: bool = True,
            sample_dict: Optional[Dict[str, torch.Tensor]] = None,
            num_seeds: Optional[int] = None
            ) -> torch.Tensor:
    """A combined method for inference with the model.

    Args:
        batch (ChoiceDataset): batch data containing choice information.
        return_type (str): either 'log_prob' or 'utility'.
            'log_prob': return the log-probability (by within-category log-softmax) for items
            'utility': return the utility value of items.
        return_scope (str): either 'item_index' or 'all_items'.
            'item_index': for each observation i, return log-prob/utility for the chosen item batch.item_index[i] only.
            'all_items': for each observation i, return log-prob/utility for all items.
        deterministic (bool, optional):
            True: expectations of parameter variational distributions are used for inference.
            False: the user needs to supply a dictionary of sampled parameters for inference.
            Defaults to True.
        sample_dict (Optional[Dict[str, torch.Tensor]], optional): sampled parameters for inference task.
            This is not needed when `deterministic` is True.
            When `deterministic` is False, the user can supply a `sample_dict`. If `sample_dict` is not provided,
            this method will create `num_seeds` samples.
            Defaults to None.
        num_seeds (Optional[int]): the number of random samples of parameters to construct. This is only required
            if `deterministic` is False (i.e., stochastic mode) and `sample_dict` is not provided.
            Defaults to None.
    Returns:
        torch.Tensor: a tensor of log-probabilities or utilities, depending on `return_type`.
            The shape of the returned tensor depends on `return_scope` and `deterministic`.
            -------------------------------------------------------------------------
            | `return_scope` | `deterministic` |         Output shape               |
            -------------------------------------------------------------------------
            |   'item_index` |      True       | (len(batch),)                      |
            -------------------------------------------------------------------------
            |   'all_items'  |      True       | (len(batch), num_items)            |
            -------------------------------------------------------------------------
            |   'item_index' |      False      | (num_seeds, len(batch))            |
            -------------------------------------------------------------------------
            |   'all_items'  |      False      | (num_seeds, len(batch), num_items) |
            -------------------------------------------------------------------------
    """
    # ==============================================================================================================
    # check arguments.
    # ==============================================================================================================
    assert return_type in [
        'log_prob', 'utility'], "return_type must be either 'log_prob' or 'utility'."
    assert return_scope in [
        'item_index', 'all_items'], "return_scope must be either 'item_index' or 'all_items'."
    assert deterministic in [True, False]
    if (not deterministic) and (sample_dict is None):
        assert num_seeds >= 1, "A positive interger `num_seeds` is required if `deterministic` is False and no `sample_dict` is provided."

    # when pred_item is true, the model is predicting which item is bought (specified by item_index).
    if self.pred_item:
        batch.label = batch.item_index

    # ==============================================================================================================
    # get sample_dict ready.
    # ==============================================================================================================
    if deterministic:
        num_seeds = 1
        # Use the means of variational distributions as the sole deterministic MC sample.
        # NOTE: here we don't need to sample the obs2prior weight H since we only compute the log-likelihood.
        # TODO: is this correct?
        sample_dict = dict()
        for coef_name, coef in self.coef_dict.items():
            sample_dict[coef_name] = coef.variational_distribution.mean.unsqueeze(
                dim=0)  # (1, num_*, dim)
    else:
        if sample_dict is None:
            # sample stochastic parameters.
            sample_dict = self.sample_coefficient_dictionary(num_seeds)
        else:
            # use the provided sample_dict.
            num_seeds = list(sample_dict.values())[0].shape[0]

    # ==============================================================================================================
    # call the sampling method of additional modules.
    # ==============================================================================================================
    for module in self.additional_modules:
        # deterministic sample.
        if deterministic:
            module.dsample()
        else:
            module.rsample(num_seeds=num_seeds)

    # if utility is requested, don't run log-softmax, simply return logit.
    return_logit = (return_type == 'utility')
    if return_scope == 'all_items':
        # (num_seeds, len(batch), num_items)
        out = self.log_likelihood_all_items(
            batch=batch, sample_dict=sample_dict, return_logit=return_logit)
    elif return_scope == 'item_index':
        # (num_seeds, len(batch))
        out = self.log_likelihood_item_index(
            batch=batch, sample_dict=sample_dict, return_logit=return_logit)

    if deterministic:
        # drop the first dimension, which has size of `num_seeds` (equals 1 in the deterministic case).
        # (len(batch), num_items) or (len(batch),)
        return out.squeeze(dim=0)

    return out
get_within_category_accuracy(self, log_p_all_items, label)

A helper function for computing prediction accuracy (i.e., all non-differential metrics) within category. In particular, this method calculates the accuracy, precision, recall and F1 score.

This method has the same functionality as the following peusodcode: for C in categories: # get sessions in which item in category C was purchased. T <- (t for t in {0,1,..., len(label)-1} if label[t] is in C) Y <- label[T]

predictions = list()
for t in T:
    # get the prediction within category for this session.
    y_pred = argmax_{items in C} log prob computed before.
    predictions.append(y_pred)

accuracy = mean(Y == predictions)

Similarly, this function computes precision, recall and f1score as well.

Parameters:

Name Type Description Default
log_p_all_items torch.Tensor

shape (num_sessions, num_items) the log probability of choosing each item in each session.

required
label torch.LongTensor

shape (num_sessions,), the IDs of items purchased in each session.

required

Returns:

Type Description
[Dict[str, float]]

A dictionary containing performance metrics.

Source code in bemb/model/bemb.py
@torch.no_grad()
def get_within_category_accuracy(self, log_p_all_items: torch.Tensor, label: torch.LongTensor) -> Dict[str, float]:
    """A helper function for computing prediction accuracy (i.e., all non-differential metrics)
    within category.
    In particular, this method calculates the accuracy, precision, recall and F1 score.


    This method has the same functionality as the following peusodcode:
    for C in categories:
        # get sessions in which item in category C was purchased.
        T <- (t for t in {0,1,..., len(label)-1} if label[t] is in C)
        Y <- label[T]

        predictions = list()
        for t in T:
            # get the prediction within category for this session.
            y_pred = argmax_{items in C} log prob computed before.
            predictions.append(y_pred)

        accuracy = mean(Y == predictions)

    Similarly, this function computes precision, recall and f1score as well.

    Args:
        log_p_all_items (torch.Tensor): shape (num_sessions, num_items) the log probability of
            choosing each item in each session.
        label (torch.LongTensor): shape (num_sessions,), the IDs of items purchased in each session.

    Returns:
        [Dict[str, float]]: A dictionary containing performance metrics.
    """
    # argmax: (num_sessions, num_categories), within category argmax.
    # item IDs are consecutive, thus argmax is the same as IDs of the item with highest P.
    _, argmax_by_category = scatter_max(
        log_p_all_items, self.item_to_category_tensor, dim=-1)

    # category_purchased[t] = the category of item label[t].
    # (num_sessions,)
    category_purchased = self.item_to_category_tensor[label]

    # pred[t] = the item with highest utility from the category item label[t] belongs to.
    # (num_sessions,)
    pred_from_category = argmax_by_category[torch.arange(
        len(label)), category_purchased]

    within_category_accuracy = (
        pred_from_category == label).float().mean().item()

    # precision
    precision = list()

    recall = list()
    for i in range(self.num_items):
        correct_i = torch.sum(
            (torch.logical_and(pred_from_category == i, label == i)).float())
        precision_i = correct_i / \
            torch.sum((pred_from_category == i).float())
        recall_i = correct_i / torch.sum((label == i).float())

        # do not add if divided by zero.
        if torch.any(pred_from_category == i):
            precision.append(precision_i.cpu().item())
        if torch.any(label == i):
            recall.append(recall_i.cpu().item())

    precision = float(np.mean(precision))
    recall = float(np.mean(recall))

    if precision == recall == 0:
        f1 = 0
    else:
        f1 = 2 * precision * recall / (precision + recall)

    return {'accuracy': within_category_accuracy,
            'precision': precision,
            'recall': recall,
            'f1score': f1}
ivs(self, batch)

The combined method of computing utilities and log probability.

Parameters:

Name Type Description Default
batch dict

a batch of data.

required

Returns:

Type Description
torch.Tensor

the combined utility and log probability.

Source code in bemb/model/bemb.py
def ivs(self, batch) -> torch.Tensor:
    """The combined method of computing utilities and log probability.

        Args:
            batch (dict): a batch of data.

        Returns:
            torch.Tensor: the combined utility and log probability.
        """
    # Use the means of variational distributions as the sole MC sample.
    sample_dict = dict()
    for coef_name, coef in self.coef_dict.items():
        sample_dict[coef_name] = coef.variational_distribution.mean.unsqueeze(dim=0)  # (1, num_*, dim)

    # there is 1 random seed in this case.
    # (num_seeds=1, len(batch), num_items)
    out = self.log_likelihood_all_items(batch, return_logit=True, sample_dict=sample_dict)
    out = out.squeeze(0)
    # import pdb; pdb.set_trace()
    ivs = scatter_logsumexp(out, self.item_to_category_tensor, dim=-1)
    return ivs # (len(batch), num_categories)
log_likelihood_all_items(self, batch, return_logit, sample_dict)

NOTE to developers: NOTE (akanodia to tianyudu): Is this really slow; even with log_likelihood you need log_prob which depends on logits of all items? This method computes utilities for all items available, which is a relatively slow operation. For training the model, you only need the utility/log-prob for the chosen/relevant item (i.e., item_index[i] for each i-th observation). Use this method for inference only. Use self.log_likelihood_item_index() for training instead.

Computes the log probability of choosing each item in each session based on current model parameters. NOTE (akanodiadu to tianyudu): What does the next line mean? I think it just says its allowing for samples instead of posterior mean. This method allows for specifying {user, item}_latent_value for Monte Carlo estimation in ELBO. For actual prediction tasks, use the forward() function, which will use means of variational distributions for user and item latents.

Parameters:

Name Type Description Default
batch ChoiceDataset

a ChoiceDataset object containing relevant information.

required
return_logit(bool)

if set to True, return the log-probability, otherwise return the logit/utility.

required
sample_dict(Dict[str, torch.Tensor]

Monte Carlo samples for model coefficients (i.e., those Greek letters). sample_dict.keys() should be the same as keys of self.obs2prior_dict, i.e., those greek letters actually enter the functional form of utility. The value of sample_dict should be tensors of shape (num_seeds, num_classes, dim) where num_classes in {num_users, num_items, 1} and dim in {latent_dim(K), num_item_obs, num_user_obs, 1}.

required

Returns:

Type Description
torch.Tensor

a tensor of shape (num_seeds, len(batch), self.num_items), where out[x, y, z] is the probability of choosing item z in session y conditioned on latents to be the x-th Monte Carlo sample.

Source code in bemb/model/bemb.py
def log_likelihood_all_items(self, batch: ChoiceDataset, return_logit: bool, sample_dict: Dict[str, torch.Tensor]) -> torch.Tensor:
    """
    NOTE to developers:
    NOTE (akanodia to tianyudu): Is this really slow; even with log_likelihood you need log_prob which depends on logits of all items?
    This method computes utilities for all items available, which is a relatively slow operation. For
    training the model, you only need the utility/log-prob for the chosen/relevant item (i.e., item_index[i] for each i-th observation).
    Use this method for inference only.
    Use self.log_likelihood_item_index() for training instead.

    Computes the log probability of choosing `each` item in each session based on current model parameters.
    NOTE (akanodiadu to tianyudu): What does the next line mean? I think it just says its allowing for samples instead of posterior mean.
    This method allows for specifying {user, item}_latent_value for Monte Carlo estimation in ELBO.
    For actual prediction tasks, use the forward() function, which will use means of variational
    distributions for user and item latents.

    Args:
        batch (ChoiceDataset): a ChoiceDataset object containing relevant information.
        return_logit(bool): if set to True, return the log-probability, otherwise return the logit/utility.
        sample_dict(Dict[str, torch.Tensor]): Monte Carlo samples for model coefficients
            (i.e., those Greek letters).
            sample_dict.keys() should be the same as keys of self.obs2prior_dict, i.e., those
            greek letters actually enter the functional form of utility.
            The value of sample_dict should be tensors of shape (num_seeds, num_classes, dim)
            where num_classes in {num_users, num_items, 1}
            and dim in {latent_dim(K), num_item_obs, num_user_obs, 1}.

    Returns:
        torch.Tensor: a tensor of shape (num_seeds, len(batch), self.num_items), where
            out[x, y, z] is the probability of choosing item z in session y conditioned on
            latents to be the x-th Monte Carlo sample.
    """
    num_seeds = next(iter(sample_dict.values())).shape[0]

    # avoid repeated work when user purchased several items in the same session.
    user_session_index = torch.stack(
        [batch.user_index, batch.session_index])
    assert user_session_index.shape == (2, len(batch))
    unique_user_sess, inverse_indices = torch.unique(
        user_session_index, dim=1, return_inverse=True)

    user_index = unique_user_sess[0, :]
    session_index = unique_user_sess[1, :]
    assert len(user_index) == len(session_index)

    # short-hands for easier shape check.
    R = num_seeds
    # P = len(batch)  # num_purchases.
    P = unique_user_sess.shape[1]
    S = self.num_sessions
    U = self.num_users
    I = self.num_items
    NC = self.num_categories

    # ==============================================================================================================
    # Helper Functions for Reshaping.
    # ==============================================================================================================
    def reshape_user_coef_sample(C):
        # input shape (R, U, *)
        C = C.view(R, U, 1, -1).expand(-1, -1, I, -1)  # (R, U, I, *)
        C = C[:, user_index, :, :]
        assert C.shape == (R, P, I, positive_integer)
        return C

    def reshape_item_coef_sample(C):
        # input shape (R, I, *)
        C = C.view(R, 1, I, -1).expand(-1, P, -1, -1)
        assert C.shape == (R, P, I, positive_integer)
        return C

    def reshape_category_coef_sample(C):
        # input shape (R, NC, *)
        C = torch.repeat_interleave(C, self.category_to_size_tensor, dim=1)
        # input shape (R, I, *)
        C = C.view(R, 1, I, -1).expand(-1, P, -1, -1)
        assert C.shape == (R, P, I, positive_integer)
        return C

    def reshape_constant_coef_sample(C):
        # input shape (R, *)
        C = C.view(R, 1, 1, -1).expand(-1, P, I, -1)
        assert C.shape == (R, P, I, positive_integer)
        return C

    def reshape_coef_sample(sample, name):
        # reshape the monte carlo sample of coefficients to (R, P, I, *).
        if name.endswith('_user'):
            # (R, U, *) --> (R, P, I, *)
            return reshape_user_coef_sample(sample)
        elif name.endswith('_item'):
            # (R, I, *) --> (R, P, I, *)
            return reshape_item_coef_sample(sample)
        elif name.endswith('_category'):
            # (R, NC, *) --> (R, P, NC, *)
            return reshape_category_coef_sample(sample)
        elif name.endswith('_constant'):
            # (R, *) --> (R, P, I, *)
            return reshape_constant_coef_sample(sample)
        else:
            raise ValueError

    def reshape_observable(obs, name):
        # reshape observable to (R, P, I, *) so that it can be multiplied with monte carlo
        # samples of coefficients.
        O = obs.shape[-1]  # number of observables.
        assert O == positive_integer
        if name.startswith('item_'):
            assert obs.shape == (I, O)
            obs = obs.view(1, 1, I, O).expand(R, P, -1, -1)
        elif name.startswith('user_'):
            assert obs.shape == (U, O)
            obs = obs[user_index, :]  # (P, O)
            obs = obs.view(1, P, 1, O).expand(R, -1, I, -1)
        elif name.startswith('session_'):
            assert obs.shape == (S, O)
            obs = obs[session_index, :]  # (P, O)
            return obs.view(1, P, 1, O).expand(R, -1, I, -1)
        elif name.startswith('price_'):
            assert obs.shape == (S, I, O)
            obs = obs[session_index, :, :]  # (P, I, O)
            return obs.view(1, P, I, O).expand(R, -1, -1, -1)
        elif name.startswith('taste_'):
            assert obs.shape == (U, I, O)
            obs = obs[user_index, :, :]  # (P, I, O)
            return obs.view(1, P, I, O).expand(R, -1, -1, -1)
        else:
            raise ValueError
        assert obs.shape == (R, P, I, O)
        return obs

    # ==============================================================================================================
    # Copmute the Utility Term by Term.
    # ==============================================================================================================
    # P is the number of unique (user, session) pairs.
    # (random_seeds, P, num_items).
    utility = torch.zeros(R, P, I, device=self.device)

    # loop over additive term to utility
    for term in self.formula:
        # Type I: single coefficient, e.g., lambda_item or lambda_user.
        if len(term['coefficient']) == 1 and term['observable'] is None:
            # E.g., lambda_item or lambda_user
            coef_name = term['coefficient'][0]
            coef_sample = reshape_coef_sample(
                sample_dict[coef_name], coef_name)
            assert coef_sample.shape == (R, P, I, 1)
            additive_term = coef_sample.view(R, P, I)

        # Type II: factorized coefficient, e.g., <theta_user, lambda_item>.
        elif len(term['coefficient']) == 2 and term['observable'] is None:
            coef_name_0 = term['coefficient'][0]
            coef_name_1 = term['coefficient'][1]

            coef_sample_0 = reshape_coef_sample(
                sample_dict[coef_name_0], coef_name_0)
            coef_sample_1 = reshape_coef_sample(
                sample_dict[coef_name_1], coef_name_1)

            assert coef_sample_0.shape == coef_sample_1.shape == (
                R, P, I, positive_integer)

            additive_term = (coef_sample_0 * coef_sample_1).sum(dim=-1)

        # Type III: single coefficient multiplied by observable, e.g., theta_user * x_obs_item.
        elif len(term['coefficient']) == 1 and term['observable'] is not None:
            coef_name = term['coefficient'][0]
            coef_sample = reshape_coef_sample(
                sample_dict[coef_name], coef_name)
            assert coef_sample.shape == (R, P, I, positive_integer)

            obs_name = term['observable']
            obs = reshape_observable(getattr(batch, obs_name), obs_name)
            assert obs.shape == (R, P, I, positive_integer)

            additive_term = (coef_sample * obs).sum(dim=-1)

        # Type IV: factorized coefficient multiplied by observable.
        # e.g., gamma_user * beta_item * price_obs.
        elif len(term['coefficient']) == 2 and term['observable'] is not None:
            coef_name_0, coef_name_1 = term['coefficient'][0], term['coefficient'][1]

            coef_sample_0 = reshape_coef_sample(
                sample_dict[coef_name_0], coef_name_0)
            coef_sample_1 = reshape_coef_sample(
                sample_dict[coef_name_1], coef_name_1)
            assert coef_sample_0.shape == coef_sample_1.shape == (
                R, P, I, positive_integer)
            num_obs_times_latent_dim = coef_sample_0.shape[-1]

            obs_name = term['observable']
            obs = reshape_observable(getattr(batch, obs_name), obs_name)
            assert obs.shape == (R, P, I, positive_integer)
            num_obs = obs.shape[-1]  # number of observables.

            assert (num_obs_times_latent_dim % num_obs) == 0
            latent_dim = num_obs_times_latent_dim // num_obs

            coef_sample_0 = coef_sample_0.view(
                R, P, I, num_obs, latent_dim)
            coef_sample_1 = coef_sample_1.view(
                R, P, I, num_obs, latent_dim)
            # compute the factorized coefficient with shape (R, P, I, O).
            coef = (coef_sample_0 * coef_sample_1).sum(dim=-1)

            additive_term = (coef * obs).sum(dim=-1)

        else:
            raise ValueError(f'Undefined term type: {term}')

        assert additive_term.shape == (R, P, I)
        utility += additive_term

    # ==============================================================================================================
    # Mask Out Unavailable Items in Each Session.
    # ==============================================================================================================

    if batch.item_availability is not None:
        # expand to the Monte Carlo sample dimension.
        # (S, I) -> (P, I) -> (1, P, I) -> (R, P, I)
        A = batch.item_availability[session_index, :].unsqueeze(
            dim=0).expand(R, -1, -1)
        utility[~A] = - (torch.finfo(utility.dtype).max / 2)

    utility = utility[:, inverse_indices, :]
    assert utility.shape == (R, len(batch), I)

    for module in self.additional_modules:
        additive_term = module(batch)
        assert additive_term.shape == (R, len(batch), 1)
        utility += additive_term.expand(-1, -1, I)

    if return_logit:
        # output shape: (num_seeds, len(batch), num_items)
        return utility
    else:
        # compute log likelihood log p(choosing item i | user, item latents)
        # compute log softmax separately within each category.
        if self.pred_item:
            # output shape: (num_seeds, len(batch), num_items)
            log_p = scatter_log_softmax(
                utility, self.item_to_category_tensor, dim=-1)
        else:
            log_p = torch.nn.functional.logsigmoid(utility)
        return log_p
log_likelihood_item_index(self, batch, return_logit, sample_dict)

NOTE for developers: This method is more efficient and only computes log-likelihood/logit(utility) for item in item_index[i] for each i-th observation. Developers should use use log_likelihood_all_items for inference purpose and to computes log-likelihoods/utilities for ALL items for the i-th observation.

Computes the log probability of choosing item_index[i] in each session based on current model parameters. This method allows for specifying {user, item}_latent_value for Monte Carlo estimation in ELBO. For actual prediction tasks, use the forward() function, which will use means of variational distributions for user and item latents.

Parameters:

Name Type Description Default
batch ChoiceDataset

a ChoiceDataset object containing relevant information.

required
return_logit(bool)

if set to True, return the log-probability, otherwise return the logit/utility.

required
sample_dict(Dict[str, torch.Tensor]

Monte Carlo samples for model coefficients (i.e., those Greek letters). sample_dict.keys() should be the same as keys of self.obs2prior_dict, i.e., those greek letters actually enter the functional form of utility. The value of sample_dict should be tensors of shape (num_seeds, num_classes, dim) where num_classes in {num_users, num_items, 1} and dim in {latent_dim(K), num_item_obs, num_user_obs, 1}.

required

Returns:

Type Description
torch.Tensor

a tensor of shape (num_seeds, len(batch)), where out[x, y] is the probabilities of choosing item batch.item[y] in session y conditioned on latents to be the x-th Monte Carlo sample.

Source code in bemb/model/bemb.py
def log_likelihood_item_index(self, batch: ChoiceDataset, return_logit: bool, sample_dict: Dict[str, torch.Tensor]) -> torch.Tensor:
    """
    NOTE for developers:
    This method is more efficient and only computes log-likelihood/logit(utility) for item in item_index[i] for each
    i-th observation.
    Developers should use use `log_likelihood_all_items` for inference purpose and to computes log-likelihoods/utilities
    for ALL items for the i-th observation.

    Computes the log probability of choosing item_index[i] in each session based on current model parameters.
    This method allows for specifying {user, item}_latent_value for Monte Carlo estimation in ELBO.
    For actual prediction tasks, use the forward() function, which will use means of variational
    distributions for user and item latents.

    Args:
        batch (ChoiceDataset): a ChoiceDataset object containing relevant information.
        return_logit(bool): if set to True, return the log-probability, otherwise return the logit/utility.
        sample_dict(Dict[str, torch.Tensor]): Monte Carlo samples for model coefficients
            (i.e., those Greek letters).
            sample_dict.keys() should be the same as keys of self.obs2prior_dict, i.e., those
            greek letters actually enter the functional form of utility.
            The value of sample_dict should be tensors of shape (num_seeds, num_classes, dim)
            where num_classes in {num_users, num_items, 1}
            and dim in {latent_dim(K), num_item_obs, num_user_obs, 1}.

    Returns:
        torch.Tensor: a tensor of shape (num_seeds, len(batch)), where
            out[x, y] is the probabilities of choosing item batch.item[y] in session y
            conditioned on latents to be the x-th Monte Carlo sample.
    """
    num_seeds = next(iter(sample_dict.values())).shape[0]

    # get category id of the item bought in each row of batch.
    cate_index = self.item_to_category_tensor[batch.item_index]

    # get item ids of all items from the same category of each item bought.
    relevant_item_index = self.category_to_item_tensor[cate_index, :]
    relevant_item_index = relevant_item_index.view(-1,)
    # index were padded with -1's, drop those dummy entries.
    relevant_item_index = relevant_item_index[relevant_item_index != -1]

    # the first repeats[0] entries in relevant_item_index are for the category of item_index[0]
    repeats = self.category_to_size_tensor[cate_index]
    # argwhere(reverse_indices == k) are positions in relevant_item_index for the category of item_index[k].
    reverse_indices = torch.repeat_interleave(
        torch.arange(len(batch), device=self.device), repeats)
    # expand the user_index and session_index.
    user_index = torch.repeat_interleave(batch.user_index, repeats)
    repeat_category_index = torch.repeat_interleave(cate_index, repeats)
    session_index = torch.repeat_interleave(batch.session_index, repeats)
    # duplicate the item focused to match.
    item_index_expanded = torch.repeat_interleave(
        batch.item_index, repeats)

    # short-hands for easier shape check.
    R = num_seeds
    # total number of relevant items.
    total_computation = len(session_index)
    S = self.num_sessions
    U = self.num_users
    I = self.num_items
    NC = self.num_categories
    # ==========================================================================================
    # Helper Functions for Reshaping.
    # ==========================================================================================

    def reshape_coef_sample(sample, name):
        # reshape the monte carlo sample of coefficients to (R, P, I, *).
        if name.endswith('_user'):
            # (R, U, *) --> (R, total_computation, *)
            return sample[:, user_index, :]
        elif name.endswith('_item'):
            # (R, I, *) --> (R, total_computation, *)
            return sample[:, relevant_item_index, :]
        elif name.endswith('_category'):
            # (R, NC, *) --> (R, total_computation, *)
            return sample[:, repeat_category_index, :]
        elif name.endswith('_constant'):
            # (R, *) --> (R, total_computation, *)
            return sample.view(R, 1, -1).expand(-1, total_computation, -1)
        else:
            raise ValueError

    def reshape_observable(obs, name):
        # reshape observable to (R, P, I, *) so that it can be multiplied with monte carlo
        # samples of coefficients.
        O = obs.shape[-1]  # number of observables.
        assert O == positive_integer
        if name.startswith('item_'):
            assert obs.shape == (I, O)
            obs = obs[relevant_item_index, :]
        elif name.startswith('user_'):
            assert obs.shape == (U, O)
            obs = obs[user_index, :]
        elif name.startswith('session_'):
            assert obs.shape == (S, O)
            obs = obs[session_index, :]
        elif name.startswith('price_'):
            assert obs.shape == (S, I, O)
            obs = obs[session_index, relevant_item_index, :]
        elif name.startswith('taste_'):
            assert obs.shape == (U, I, O)
            obs = obs[user_index, relevant_item_index, :]
        else:
            raise ValueError
        assert obs.shape == (total_computation, O)
        return obs.unsqueeze(dim=0).expand(R, -1, -1)

    # ==========================================================================================
    # Compute Components related to users and items only.
    # ==========================================================================================
    utility = torch.zeros(R, total_computation, device=self.device)

    # loop over additive term to utility
    for term in self.formula:
        # Type I: single coefficient, e.g., lambda_item or lambda_user.
        if len(term['coefficient']) == 1 and term['observable'] is None:
            # E.g., lambda_item or lambda_user
            coef_name = term['coefficient'][0]
            coef_sample = reshape_coef_sample(
                sample_dict[coef_name], coef_name)
            assert coef_sample.shape == (R, total_computation, 1)
            additive_term = coef_sample.view(R, total_computation)

        # Type II: factorized coefficient, e.g., <theta_user, lambda_item>.
        elif len(term['coefficient']) == 2 and term['observable'] is None:
            coef_name_0 = term['coefficient'][0]
            coef_name_1 = term['coefficient'][1]

            coef_sample_0 = reshape_coef_sample(
                sample_dict[coef_name_0], coef_name_0)
            coef_sample_1 = reshape_coef_sample(
                sample_dict[coef_name_1], coef_name_1)

            assert coef_sample_0.shape == coef_sample_1.shape == (
                R, total_computation, positive_integer)

            additive_term = (coef_sample_0 * coef_sample_1).sum(dim=-1)

        # Type III: single coefficient multiplied by observable, e.g., theta_user * x_obs_item.
        elif len(term['coefficient']) == 1 and term['observable'] is not None:
            coef_name = term['coefficient'][0]
            coef_sample = reshape_coef_sample(
                sample_dict[coef_name], coef_name)
            assert coef_sample.shape == (
                R, total_computation, positive_integer)

            obs_name = term['observable']
            obs = reshape_observable(getattr(batch, obs_name), obs_name)
            assert obs.shape == (R, total_computation, positive_integer)

            additive_term = (coef_sample * obs).sum(dim=-1)

        # Type IV: factorized coefficient multiplied by observable.
        # e.g., gamma_user * beta_item * price_obs.
        elif len(term['coefficient']) == 2 and term['observable'] is not None:
            coef_name_0, coef_name_1 = term['coefficient'][0], term['coefficient'][1]
            coef_sample_0 = reshape_coef_sample(
                sample_dict[coef_name_0], coef_name_0)
            coef_sample_1 = reshape_coef_sample(
                sample_dict[coef_name_1], coef_name_1)
            assert coef_sample_0.shape == coef_sample_1.shape == (
                R, total_computation, positive_integer)
            num_obs_times_latent_dim = coef_sample_0.shape[-1]

            obs_name = term['observable']
            obs = reshape_observable(getattr(batch, obs_name), obs_name)
            assert obs.shape == (R, total_computation, positive_integer)
            num_obs = obs.shape[-1]  # number of observables.

            assert (num_obs_times_latent_dim % num_obs) == 0
            latent_dim = num_obs_times_latent_dim // num_obs

            coef_sample_0 = coef_sample_0.view(
                R, total_computation, num_obs, latent_dim)
            coef_sample_1 = coef_sample_1.view(
                R, total_computation, num_obs, latent_dim)
            # compute the factorized coefficient with shape (R, P, I, O).
            coef = (coef_sample_0 * coef_sample_1).sum(dim=-1)

            additive_term = (coef * obs).sum(dim=-1)

        else:
            raise ValueError(f'Undefined term type: {term}')

        assert additive_term.shape == (R, total_computation)
        utility += additive_term

    # ==========================================================================================
    # Mask Out Unavailable Items in Each Session.
    # ==========================================================================================

    if batch.item_availability is not None:
        # expand to the Monte Carlo sample dimension.
        A = batch.item_availability[session_index, relevant_item_index].unsqueeze(
            dim=0).expand(R, -1)
        utility[~A] = - (torch.finfo(utility.dtype).max / 2)

    for module in self.additional_modules:
        # current utility shape: (R, total_computation)
        additive_term = module(batch)
        assert additive_term.shape == (
            R, len(batch)) or additive_term.shape == (R, len(batch), 1)
        if additive_term.shape == (R, len(batch), 1):
            # TODO: need to make this consistent with log_likelihood_all.
            # be tolerant for some customized module with BayesianLinear that returns (R, len(batch), 1).
            additive_term = additive_term.view(R, len(batch))
        # expand to total number of computation, query by reverse_indices.
        # reverse_indices has length total_computation, and reverse_indices[i] correspond to the row-id that this
        # computation is responsible for.
        additive_term = additive_term[:, reverse_indices]
        assert additive_term.shape == (R, total_computation)

    # compute log likelihood log p(choosing item i | user, item latents)
    if return_logit:
        log_p = utility
    else:
        if self.pred_item:
            # compute the log probability from logits/utilities.
            # output shape: (num_seeds, len(batch), num_items)
            log_p = scatter_log_softmax(utility, reverse_indices, dim=-1)
            # select the log-P of the item actually bought.
            log_p = log_p[:, item_index_expanded == relevant_item_index]
        else:
            # This is the binomial choice situation in which case we just report sigmoid log likelihood
            bce = nn.BCELoss(reduction='none')
            log_p = - bce(torch.sigmoid(utility.view(-1)), batch.label.to(torch.float32))
    return log_p
log_prior(self, batch, sample_dict)

Calculates the log-likelihood of Monte Carlo samples of Bayesian coefficients under their prior distribution. This method assume coefficients are statistically independent.

Parameters:

Name Type Description Default
batch ChoiceDataset

a dataset object contains observables for computing the prior distribution if obs2prior is True.

required
sample_dict Dict[str, torch.Tensor]

a dictionary coefficient names to Monte Carlo samples.

required

Exceptions:

Type Description
ValueError

[description]

Returns:

Type Description
torch.scalar_tensor

a tensor with shape (num_seeds,) of [ log P_{prior_distribution}(param[i]) ], where param[i] is the i-th Monte Carlo sample.

Source code in bemb/model/bemb.py
def log_prior(self, batch: ChoiceDataset, sample_dict: Dict[str, torch.Tensor]) -> torch.Tensor:
    """Calculates the log-likelihood of Monte Carlo samples of Bayesian coefficients under their
    prior distribution. This method assume coefficients are statistically independent.

    Args:
        batch (ChoiceDataset): a dataset object contains observables for computing the prior distribution
            if obs2prior is True.
        sample_dict (Dict[str, torch.Tensor]): a dictionary coefficient names to Monte Carlo samples.

    Raises:
        ValueError: [description]

    Returns:
        torch.scalar_tensor: a tensor with shape (num_seeds,) of [ log P_{prior_distribution}(param[i]) ],
            where param[i] is the i-th Monte Carlo sample.
    """
    # assert sample_dict.keys() == self.coef_dict.keys()
    num_seeds = next(iter(sample_dict.values())).shape[0]

    total = torch.zeros(num_seeds, device=self.device)

    for coef_name, coef in self.coef_dict.items():
        if self.obs2prior_dict[coef_name]:
            if coef_name.endswith('_item'):
                x_obs = batch.item_obs
            elif coef_name.endswith('_user'):
                x_obs = batch.user_obs
            else:
                raise ValueError(
                    f'No observable found to support obs2prior for {coef_name}.')

            total += coef.log_prior(sample=sample_dict[coef_name],
                                    H_sample=sample_dict[coef_name + '.H'],
                                    x_obs=x_obs).sum(dim=-1)
        else:
            # log_prob outputs (num_seeds, num_{items, users}), sum to (num_seeds).
            total += coef.log_prior(
                sample=sample_dict[coef_name], H_sample=None, x_obs=None).sum(dim=-1)

    for module in self.additional_modules:
        raise NotImplementedError()
        total += module.log_prior()

    return total
log_variational(self, sample_dict)

Calculate the log-likelihood of samples in sample_dict under the current variational distribution.

Parameters:

Name Type Description Default
sample_dict Dict[str, torch.Tensor]

a dictionary coefficient names to Monte Carlo samples.

required

Returns:

Type Description
torch.Tensor

a tensor of shape (num_seeds) of [ log P_{variational_distribution}(param[i]) ], where param[i] is the i-th Monte Carlo sample.

Source code in bemb/model/bemb.py
def log_variational(self, sample_dict: Dict[str, torch.Tensor]) -> torch.Tensor:
    """Calculate the log-likelihood of samples in sample_dict under the current variational
    distribution.

    Args:
        sample_dict (Dict[str, torch.Tensor]):  a dictionary coefficient names to Monte Carlo
            samples.

    Returns:
        torch.Tensor: a tensor of shape (num_seeds) of [ log P_{variational_distribution}(param[i]) ],
            where param[i] is the i-th Monte Carlo sample.
    """
    num_seeds = list(sample_dict.values())[0].shape[0]
    total = torch.zeros(num_seeds, device=self.device)

    for coef_name, coef in self.coef_dict.items():
        # log_prob outputs (num_seeds, num_{items, users}), sum to (num_seeds).
        total += coef.log_variational(sample_dict[coef_name]).sum(dim=-1)

    for module in self.additional_modules:
        raise NotImplementedError()
        # with shape (num_seeds,)
        total += module.log_variational().sum()

    return total
posterior_mean(self, coef_name)

Returns the mean of estimated posterior distribution of coefficient coef_name.

Parameters:

Name Type Description Default
coef_name str

name of the coefficient to query.

required

Returns:

Type Description
torch.Tensor

mean of the estimated posterior distribution of coef_name.

Source code in bemb/model/bemb.py
def posterior_mean(self, coef_name: str) -> torch.Tensor:
    """Returns the mean of estimated posterior distribution of coefficient `coef_name`.

    Args:
        coef_name (str): name of the coefficient to query.

    Returns:
        torch.Tensor: mean of the estimated posterior distribution of `coef_name`.
    """
    if coef_name in self.coef_dict.keys():
        return self.coef_dict[coef_name].variational_mean
    else:
        raise KeyError(f'{coef_name} is not a valid coefficient name in {self.utility_formula}.')
sample_choices(self, batch, debug=False, num_seeds=1, **kwargs)

Samples choices given model paramaters and trips

batch(ChoiceDataset): batch data containing trip information; item choice information is discarded debug(bool): whether to print debug information

Tuple[torch.Tensor]: sampled choices; shape: (batch_size, num_categories)

Source code in bemb/model/bemb.py
def sample_choices(self, batch:ChoiceDataset, debug: bool = False, num_seeds: int = 1, **kwargs) -> Tuple[torch.Tensor]:
    """Samples choices given model paramaters and trips

    Args:
    batch(ChoiceDataset): batch data containing trip information; item choice information is discarded
    debug(bool): whether to print debug information

    Returns:
    Tuple[torch.Tensor]: sampled choices; shape: (batch_size, num_categories)
    """
    # Use the means of variational distributions as the sole MC sample.
    sample_dict = dict()
    for coef_name, coef in self.coef_dict.items():
        sample_dict[coef_name] = coef.variational_distribution.mean.unsqueeze(dim=0)  # (1, num_*, dim)
    # sample_dict = self.sample_coefficient_dictionary(num_seeds)
    maxes, out = self.sample_log_likelihoods(batch, sample_dict)
    return maxes.squeeze(), out.squeeze()
sample_coefficient_dictionary(self, num_seeds)

A helper function to sample parameters from coefficients.

Parameters:

Name Type Description Default
num_seeds int

number of random samples.

required

Returns:

Type Description
Dict[str, torch.Tensor]

a dictionary maps coefficient names to tensor of sampled coefficient parameters, where the first dimension of the sampled tensor has size num_seeds. Each sample tensor has shape (num_seeds, num_classes, dim).

Source code in bemb/model/bemb.py
def sample_coefficient_dictionary(self, num_seeds: int) -> Dict[str, torch.Tensor]:
    """A helper function to sample parameters from coefficients.

    Args:
        num_seeds (int): number of random samples.

    Returns:
        Dict[str, torch.Tensor]: a dictionary maps coefficient names to tensor of sampled coefficient parameters,
            where the first dimension of the sampled tensor has size `num_seeds`.
            Each sample tensor has shape (num_seeds, num_classes, dim).
    """
    sample_dict = dict()
    for coef_name, coef in self.coef_dict.items():
        s = coef.rsample(num_seeds)
        if coef.obs2prior:
            # sample both obs2prior weight and realization of variable.
            assert isinstance(s, tuple) and len(s) == 2
            sample_dict[coef_name] = s[0]
            sample_dict[coef_name + '.H'] = s[1]
        else:
            # only sample the realization of variable.
            assert torch.is_tensor(s)
            sample_dict[coef_name] = s
    return sample_dict
sample_log_likelihoods(self, batch, sample_dict)

Samples log likelihoods given model paramaters and trips

batch(ChoiceDataset): batch data containing trip information; item choice information is discarded sample_dict(Dict[str, torch.Tensor]): sampled coefficient values

Tuple[torch.Tensor]: sampled log likelihoods; shape: (batch_size, num_categories)

Source code in bemb/model/bemb.py
def sample_log_likelihoods(self, batch:ChoiceDataset, sample_dict: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
    """Samples log likelihoods given model paramaters and trips

    Args:
    batch(ChoiceDataset): batch data containing trip information; item choice information is discarded
    sample_dict(Dict[str, torch.Tensor]): sampled coefficient values

    Returns:
    Tuple[torch.Tensor]: sampled log likelihoods; shape: (batch_size, num_categories)
    """
    # get the log likelihoods for all items for all categories
    utility = self.log_likelihood_all_items(batch, return_logit=True, sample_dict=sample_dict)
    mu_gumbel = 0.0
    beta_gumbel = 1.0
    EUL_MAS_CONST = 0.5772156649
    mean_gumbel = torch.tensor([mu_gumbel + beta_gumbel * EUL_MAS_CONST], device=self.device)
    m = torch.distributions.gumbel.Gumbel(torch.tensor([0.0], device=self.device), torch.tensor([1.0], device=self.device))
    # m = torch.distributions.gumbel.Gumbel(0.0, 1.0)
    gumbel_samples = m.sample(utility.shape).squeeze(-1)
    gumbel_samples -= mean_gumbel
    utility += gumbel_samples
    max_by_category, argmax_by_category = scatter_max(utility, self.item_to_category_tensor, dim=-1)
    return max_by_category, argmax_by_category
    log_likelihoods = self.sample_log_likelihoods_per_category(batch, sample_dict)

    # sum over all categories.
    log_likelihoods = log_likelihoods.sum(dim=1)

    return log_likelihoods, log_likelihoods

parse_utility(utility_string)

A helper function parse utility string into a list of additive terms.

Examples:

utility_string = 'lambda_item + theta_user * alpha_item + gamma_user * beta_item * price_obs' output = [ { 'coefficient': ['lambda_item'], 'observable': None }, { 'coefficient': ['theta_user', 'alpha_item'], 'observable': None }, { 'coefficient': ['gamma_user', 'beta_item'], 'observable': 'price_obs' } ]

Source code in bemb/model/bemb.py
def parse_utility(utility_string: str) -> List[Dict[str, Union[List[str], None]]]:
    """
    A helper function parse utility string into a list of additive terms.

    Example:
        utility_string = 'lambda_item + theta_user * alpha_item + gamma_user * beta_item * price_obs'
        output = [
            {
                'coefficient': ['lambda_item'],
                'observable': None
            },
            {
                'coefficient': ['theta_user', 'alpha_item'],
                'observable': None
            },
            {
                'coefficient': ['gamma_user', 'beta_item'],
                'observable': 'price_obs'
            }
            ]
    """
    # split additive terms
    coefficient_suffix = ('_item', '_user', '_constant', '_category')
    observable_prefix = ('item_', 'user_', 'session_', 'price_', 'taste_')

    def is_coefficient(name: str) -> bool:
        return any(name.endswith(suffix) for suffix in coefficient_suffix)

    def is_observable(name: str) -> bool:
        return any(name.startswith(prefix) for prefix in observable_prefix)

    additive_terms = utility_string.split(' + ')
    additive_decomposition = list()
    for term in additive_terms:
        atom = {'coefficient': [], 'observable': None}
        # split multiplicative terms.
        for x in term.split(' * '):
            if is_coefficient(x):
                atom['coefficient'].append(x)
            elif is_observable(x):
                atom['observable'] = x
            else:
                raise ValueError(f'{x} term cannot be classified.')
        additive_decomposition.append(atom)
    return additive_decomposition

bemb_flex_lightning

PyTorch lightning wrapper for the BEMB Flex model, allows for more smooth model training and inference. You can still use this package without using LitBEMBFlex.

Author: Tianyu Du Update: Apr. 29, 2022

LitBEMBFlex (LightningModule)

Source code in bemb/model/bemb_flex_lightning.py
class LitBEMBFlex(pl.LightningModule):

    def __init__(self, learning_rate: float = 0.3, num_seeds: int = 1, **kwargs):
        """The initialization method of the wrapper model.

        Args:
            learning_rate (float, optional): the learning rate of optimization. Defaults to 0.3.
            num_seeds (int, optional): number of random seeds for the Monte Carlo estimation in the variational inference.
                Defaults to 1.
            **kwargs: all keyword arguments used for constructing the wrapped BEMB model.
        """
        # use kwargs to pass parameter to BEMB Torch.
        super().__init__()
        self.model = BEMBFlex(**kwargs)
        self.num_needs = num_seeds
        self.learning_rate = learning_rate

    def __str__(self) -> str:
        return str(self.model)

    def forward(self, args, kwargs):
        """Calls the forward method of the wrapped BEMB model, please refer to the documentaton of the BEMB class
            for detailed definitions of the arguments.

        Args:
            args (_type_): arguments passed to the forward method of the wrapped BEMB model.
            kwargs (_type_): keyword arguments passed to the forward method of the wrapped BEMB model.

        Returns:
            _type_: returns whatever the wrapped BEMB model returns.
        """
        return self.model(*args, **kwargs)

    def training_step(self, batch, batch_idx):
        elbo = self.model.elbo(batch, num_seeds=self.num_needs)
        self.log('train_elbo', elbo)
        loss = - elbo
        return loss

    def _get_performance_dict(self, batch):
        if self.model.pred_item:
            log_p = self.model(batch, return_type='log_prob',
                               return_scope='all_items', deterministic=True).cpu().numpy()
            num_classes = log_p.shape[1]
            y_pred = np.argmax(log_p, axis=1)
            y_true = batch.item_index.cpu().numpy()
            performance = {'acc': metrics.accuracy_score(y_true=y_true, y_pred=y_pred),
                           'll': - metrics.log_loss(y_true=y_true, y_pred=np.exp(log_p), labels=np.arange(num_classes))}
        else:
            # making binary station.
            pred = self.model(batch, return_type='utility',
                              return_scope='item_index', deterministic=True)
            y_pred = torch.sigmoid(pred).cpu().numpy()
            y_true = batch.label.cpu().numpy()
            performance = {'acc': metrics.accuracy_score(y_true=y_true, y_pred=(y_pred >= 0.5).astype(int)),
                           'll': - metrics.log_loss(y_true=y_true, y_pred=y_pred, eps=1E-5, labels=[0, 1]),
                           #    'auc': metrics.roc_auc_score(y_true=y_true, y_score=y_pred),
                           #    'f1': metrics.f1_score(y_true=y_true, y_pred=(y_pred >= 0.5).astype(int))
                           }
        return performance

    def validation_step(self, batch, batch_idx):
        # LL = self.model.forward(batch, return_type='log_prob', return_scope='item_index', deterministic=True).mean()
        # self.log('val_log_likelihood', LL, prog_bar=True)
        # pred = self.model(batch)
        # performance = self.model.get_within_category_accuracy(pred, batch.label)

        # utility.

        for key, val in self._get_performance_dict(batch).items():
            self.log('val_' + key, val, prog_bar=True, batch_size=len(batch))

    def test_step(self, batch, batch_idx):
        # LL = self.model.forward(batch, return_logit=False, all_items=False).mean()
        # self.log('test_log_likelihood', LL)

        # pred = self.model(batch, return_type='utility', return_scope='item_index', deterministic=True)
        # y_pred = torch.sigmoid(pred).cpu().numpy()
        # y_true = batch.label.cpu().numpy()
        # performance = {'acc': metrics.accuracy_score(y_true=y_true, y_pred=(y_pred >= 0.5).astype(int)),
        #                'll': - metrics.log_loss(y_true=y_true, y_pred=y_pred, eps=1E-5, labels=[0, 1]),
        #             #    'auc': metrics.roc_auc_score(y_true=y_true, y_score=y_pred),
        #             #    'f1': metrics.f1_score(y_true=y_true, y_pred=(y_pred >= 0.5).astype(int))
        #                }

        # pred = self.model(batch)
        # performance = self.model.get_within_category_accuracy(pred, batch.label)
        for key, val in self._get_performance_dict(batch).items():
            self.log('test_' + key, val, prog_bar=True, batch_size=len(batch))

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    def fit_model(self, dataset_list: List[ChoiceDataset], batch_size: int=-1, num_epochs: int=10, num_workers: int=8, **kwargs) -> "LitBEMBFlex":
        """A standard pipeline of model training and evaluation.

        Args:
            dataset_list (List[ChoiceDataset]): train_dataset, validation_test, and test_dataset in a list of length 3.
            batch_size (int, optional): batch_size for training and evaluation. Defaults to -1, which indicates full-batch training.
            num_epochs (int, optional): number of epochs for training. Defaults to 10.
            **kwargs: additional keyword argument for the pytorch-lightning Trainer.

        Returns:
            LitBEMBFlex: the trained bemb model.
        """

        def section_print(input_text):
            """Helper function for printing"""
            print('=' * 20, input_text, '=' * 20)
        # present a summary of the model received.
        section_print('model received')
        print(self)

        # present a summary of datasets received.
        section_print('data set received')
        print('[Training dataset]', dataset_list[0])
        print('[Validation dataset]', dataset_list[1])
        print('[Testing dataset]', dataset_list[2])

        # create pytorch dataloader objects.
        train = create_data_loader(dataset_list[0], batch_size=batch_size, shuffle=True, num_workers=num_workers)
        validation = create_data_loader(dataset_list[1], batch_size=batch_size, shuffle=False, num_workers=num_workers)
        # WARNING: the test step takes extensive memory cost since it computes likelihood for all items.
        # we run the test step with a much smaller batch_size.
        test = create_data_loader(dataset_list[2], batch_size=batch_size // 10, shuffle=False, num_workers=num_workers)

        section_print('train the model')
        trainer = pl.Trainer(gpus=1 if ('cuda' in str(self)) else 0,  # use GPU if the model is currently on the GPU.
                            max_epochs=num_epochs,
                            check_val_every_n_epoch=1,
                            log_every_n_steps=1,
                            **kwargs)
        start_time = time.time()
        trainer.fit(self, train_dataloaders=train, val_dataloaders=validation)
        print(f'time taken: {time.time() - start_time}')

        section_print('test performance')
        trainer.test(self, dataloaders=test)
        return self
__init__(self, learning_rate=0.3, num_seeds=1, **kwargs) special

The initialization method of the wrapper model.

Parameters:

Name Type Description Default
learning_rate float

the learning rate of optimization. Defaults to 0.3.

0.3
num_seeds int

number of random seeds for the Monte Carlo estimation in the variational inference. Defaults to 1.

1
**kwargs

all keyword arguments used for constructing the wrapped BEMB model.

{}
Source code in bemb/model/bemb_flex_lightning.py
def __init__(self, learning_rate: float = 0.3, num_seeds: int = 1, **kwargs):
    """The initialization method of the wrapper model.

    Args:
        learning_rate (float, optional): the learning rate of optimization. Defaults to 0.3.
        num_seeds (int, optional): number of random seeds for the Monte Carlo estimation in the variational inference.
            Defaults to 1.
        **kwargs: all keyword arguments used for constructing the wrapped BEMB model.
    """
    # use kwargs to pass parameter to BEMB Torch.
    super().__init__()
    self.model = BEMBFlex(**kwargs)
    self.num_needs = num_seeds
    self.learning_rate = learning_rate
configure_optimizers(self)

Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple.

Returns:

Type Description

Any of these 6 options.

  • Single optimizer.
  • List or Tuple of optimizers.
  • Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).
  • Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.
  • Tuple of dictionaries as described above, with an optional "frequency" key.
  • None - Fit will run without any optimizer.

The lr_scheduler_config is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

.. code-block:: python

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,
}

When there are schedulers in which the .step() method is conditioned on a value, such as the :class:torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Lightning requires that the lr_scheduler_config contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on.

.. testcode::

# The ReduceLROnPlateau scheduler requires a monitor
def configure_optimizers(self):
    optimizer = Adam(...)
    return {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": ReduceLROnPlateau(optimizer, ...),
            "monitor": "metric_to_track",
            "frequency": "indicates how often the metric is updated"
            # If "monitor" references validation metrics, then "frequency" should be set to a
            # multiple of "trainer.check_val_every_n_epoch".
        },
    }


# In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
def configure_optimizers(self):
    optimizer1 = Adam(...)
    optimizer2 = SGD(...)
    scheduler1 = ReduceLROnPlateau(optimizer1, ...)
    scheduler2 = LambdaLR(optimizer2, ...)
    return (
        {
            "optimizer": optimizer1,
            "lr_scheduler": {
                "scheduler": scheduler1,
                "monitor": "metric_to_track",
            },
        },
        {"optimizer": optimizer2, "lr_scheduler": scheduler2},
    )

Metrics can be made available to monitor by simply logging it using self.log('metric_to_track', metric_val) in your :class:~pytorch_lightning.core.lightning.LightningModule.

!!! note The frequency value specified in a dict along with the optimizer key is an int corresponding to the number of sequential batches optimized with the specific optimizer. It should be given to none or to all of the optimizers. There is a difference between passing multiple optimizers in a list, and passing multiple optimizers in dictionaries with a frequency of 1:

    - In the former case, all optimizers will operate on the given batch in each optimization step.
    - In the latter, only one optimizer will operate on the given batch at every step.

This is different from the ``frequency`` value specified in the ``lr_scheduler_config`` mentioned above.

.. code-block:: python

    def configure_optimizers(self):
        optimizer_one = torch.optim.SGD(self.model.parameters(), lr=0.01)
        optimizer_two = torch.optim.SGD(self.model.parameters(), lr=0.01)
        return [
            {"optimizer": optimizer_one, "frequency": 5},
            {"optimizer": optimizer_two, "frequency": 10},
        ]

In this example, the first optimizer will be used for the first 5 steps,
the second optimizer for the next 10 steps and that cycle will continue.
If an LR scheduler is specified for an optimizer using the ``lr_scheduler`` key in the above dict,
the scheduler will only be updated when its optimizer is being used.

Examples::

# most cases. no learning rate scheduler
def configure_optimizers(self):
    return Adam(self.parameters(), lr=1e-3)

# multiple optimizer case (e.g.: GAN)
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    return gen_opt, dis_opt

# example with learning rate schedulers
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    dis_sch = CosineAnnealing(dis_opt, T_max=10)
    return [gen_opt, dis_opt], [dis_sch]

# example with step-based learning rate schedulers
# each optimizer has its own scheduler
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    gen_sch = {
        'scheduler': ExponentialLR(gen_opt, 0.99),
        'interval': 'step'  # called after each training step
    }
    dis_sch = CosineAnnealing(dis_opt, T_max=10) # called every epoch
    return [gen_opt, dis_opt], [gen_sch, dis_sch]

# example with optimizer frequencies
# see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1
# https://arxiv.org/abs/1704.00028
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    n_critic = 5
    return (
        {'optimizer': dis_opt, 'frequency': n_critic},
        {'optimizer': gen_opt, 'frequency': 1}
    )

!!! note Some things to know:

- Lightning calls ``.backward()`` and ``.step()`` on each optimizer and learning rate scheduler as needed.
- If you use 16-bit precision (``precision=16``), Lightning will automatically handle the optimizers.
- If you use multiple optimizers, :meth:`training_step` will have an additional ``optimizer_idx`` parameter.
- If you use :class:`torch.optim.LBFGS`, Lightning handles the closure function automatically for you.
- If you use multiple optimizers, gradients will be calculated only for the parameters of current optimizer
  at each training step.
- If you need to control how often those optimizers step or override the default ``.step()`` schedule,
  override the :meth:`optimizer_step` hook.
Source code in bemb/model/bemb_flex_lightning.py
def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
    return optimizer
fit_model(self, dataset_list, batch_size=-1, num_epochs=10, num_workers=8, **kwargs)

A standard pipeline of model training and evaluation.

Parameters:

Name Type Description Default
dataset_list List[ChoiceDataset]

train_dataset, validation_test, and test_dataset in a list of length 3.

required
batch_size int

batch_size for training and evaluation. Defaults to -1, which indicates full-batch training.

-1
num_epochs int

number of epochs for training. Defaults to 10.

10
**kwargs

additional keyword argument for the pytorch-lightning Trainer.

{}

Returns:

Type Description
LitBEMBFlex

the trained bemb model.

Source code in bemb/model/bemb_flex_lightning.py
def fit_model(self, dataset_list: List[ChoiceDataset], batch_size: int=-1, num_epochs: int=10, num_workers: int=8, **kwargs) -> "LitBEMBFlex":
    """A standard pipeline of model training and evaluation.

    Args:
        dataset_list (List[ChoiceDataset]): train_dataset, validation_test, and test_dataset in a list of length 3.
        batch_size (int, optional): batch_size for training and evaluation. Defaults to -1, which indicates full-batch training.
        num_epochs (int, optional): number of epochs for training. Defaults to 10.
        **kwargs: additional keyword argument for the pytorch-lightning Trainer.

    Returns:
        LitBEMBFlex: the trained bemb model.
    """

    def section_print(input_text):
        """Helper function for printing"""
        print('=' * 20, input_text, '=' * 20)
    # present a summary of the model received.
    section_print('model received')
    print(self)

    # present a summary of datasets received.
    section_print('data set received')
    print('[Training dataset]', dataset_list[0])
    print('[Validation dataset]', dataset_list[1])
    print('[Testing dataset]', dataset_list[2])

    # create pytorch dataloader objects.
    train = create_data_loader(dataset_list[0], batch_size=batch_size, shuffle=True, num_workers=num_workers)
    validation = create_data_loader(dataset_list[1], batch_size=batch_size, shuffle=False, num_workers=num_workers)
    # WARNING: the test step takes extensive memory cost since it computes likelihood for all items.
    # we run the test step with a much smaller batch_size.
    test = create_data_loader(dataset_list[2], batch_size=batch_size // 10, shuffle=False, num_workers=num_workers)

    section_print('train the model')
    trainer = pl.Trainer(gpus=1 if ('cuda' in str(self)) else 0,  # use GPU if the model is currently on the GPU.
                        max_epochs=num_epochs,
                        check_val_every_n_epoch=1,
                        log_every_n_steps=1,
                        **kwargs)
    start_time = time.time()
    trainer.fit(self, train_dataloaders=train, val_dataloaders=validation)
    print(f'time taken: {time.time() - start_time}')

    section_print('test performance')
    trainer.test(self, dataloaders=test)
    return self
forward(self, args, kwargs)

Calls the forward method of the wrapped BEMB model, please refer to the documentaton of the BEMB class for detailed definitions of the arguments.

Parameters:

Name Type Description Default
args _type_

arguments passed to the forward method of the wrapped BEMB model.

required
kwargs _type_

keyword arguments passed to the forward method of the wrapped BEMB model.

required

Returns:

Type Description
_type_

returns whatever the wrapped BEMB model returns.

Source code in bemb/model/bemb_flex_lightning.py
def forward(self, args, kwargs):
    """Calls the forward method of the wrapped BEMB model, please refer to the documentaton of the BEMB class
        for detailed definitions of the arguments.

    Args:
        args (_type_): arguments passed to the forward method of the wrapped BEMB model.
        kwargs (_type_): keyword arguments passed to the forward method of the wrapped BEMB model.

    Returns:
        _type_: returns whatever the wrapped BEMB model returns.
    """
    return self.model(*args, **kwargs)
test_step(self, batch, batch_idx)

Operates on a single batch of data from the test set. In this step you'd normally generate examples or calculate anything of interest such as accuracy.

.. code-block:: python

# the pseudocode for these calls
test_outs = []
for test_batch in test_data:
    out = test_step(test_batch)
    test_outs.append(out)
test_epoch_end(test_outs)

Parameters:

Name Type Description Default
batch

The output of your :class:~torch.utils.data.DataLoader.

required
batch_idx

The index of this batch.

required
dataloader_id

The index of the dataloader that produced this batch. (only if multiple test dataloaders used).

required

Returns:

Type Description

Any of.

  • Any object or value
  • None - Testing will skip to the next batch

.. code-block:: python

# if you have one test dataloader:
def test_step(self, batch, batch_idx):
    ...


# if you have multiple test dataloaders:
def test_step(self, batch, batch_idx, dataloader_idx=0):
    ...

Examples::

# CASE 1: A single test dataset
def test_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'test_loss': loss, 'test_acc': test_acc})

If you pass in multiple test dataloaders, :meth:test_step will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

.. code-block:: python

# CASE 2: multiple test dataloaders
def test_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    ...

!!! note If you don't need to test you don't need to implement this method.

!!! note When the :meth:test_step is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of the test epoch, the model goes back to training mode and gradients are enabled.

Source code in bemb/model/bemb_flex_lightning.py
def test_step(self, batch, batch_idx):
    # LL = self.model.forward(batch, return_logit=False, all_items=False).mean()
    # self.log('test_log_likelihood', LL)

    # pred = self.model(batch, return_type='utility', return_scope='item_index', deterministic=True)
    # y_pred = torch.sigmoid(pred).cpu().numpy()
    # y_true = batch.label.cpu().numpy()
    # performance = {'acc': metrics.accuracy_score(y_true=y_true, y_pred=(y_pred >= 0.5).astype(int)),
    #                'll': - metrics.log_loss(y_true=y_true, y_pred=y_pred, eps=1E-5, labels=[0, 1]),
    #             #    'auc': metrics.roc_auc_score(y_true=y_true, y_score=y_pred),
    #             #    'f1': metrics.f1_score(y_true=y_true, y_pred=(y_pred >= 0.5).astype(int))
    #                }

    # pred = self.model(batch)
    # performance = self.model.get_within_category_accuracy(pred, batch.label)
    for key, val in self._get_performance_dict(batch).items():
        self.log('test_' + key, val, prog_bar=True, batch_size=len(batch))
training_step(self, batch, batch_idx)

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

Parameters:

Name Type Description Default
batch

class:~torch.Tensor | (:class:~torch.Tensor, ...) | [:class:~torch.Tensor, ...]): The output of your :class:~torch.utils.data.DataLoader. A tensor, tuple or list.

required
batch_idx ``int``

Integer displaying index of this batch

required
optimizer_idx ``int``

When using multiple optimizers, this argument will also be present.

required
hiddens ``Any``

Passed in if :paramref:~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps > 0.

required

Returns:

Type Description
Any of. -

class:~torch.Tensor - The loss tensor - dict - A dictionary. Can include any keys, but must include the key 'loss' - None - Training will skip to the next batch. This is only for automatic optimization. This is not supported for multi-GPU, TPU, IPU, or DeepSpeed.

In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.

Example::

def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

If you define multiple optimizers, this step will be called with an additional optimizer_idx parameter.

.. code-block:: python

# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx, optimizer_idx):
    if optimizer_idx == 0:
        # do training_step with encoder
        ...
    if optimizer_idx == 1:
        # do training_step with decoder
        ...

If you add truncated back propagation through time you will also get an additional argument with the hidden states of the previous step.

.. code-block:: python

# Truncated back-propagation through time
def training_step(self, batch, batch_idx, hiddens):
    # hiddens are the hidden states from the previous truncated backprop step
    out, hiddens = self.lstm(data, hiddens)
    loss = ...
    return {"loss": loss, "hiddens": hiddens}

!!! note The loss value shown in the progress bar is smoothed (averaged) over the last values, so it differs from the actual loss returned in train/validation step.

Source code in bemb/model/bemb_flex_lightning.py
def training_step(self, batch, batch_idx):
    elbo = self.model.elbo(batch, num_seeds=self.num_needs)
    self.log('train_elbo', elbo)
    loss = - elbo
    return loss
validation_step(self, batch, batch_idx)

Operates on a single batch of data from the validation set. In this step you'd might generate examples or calculate anything of interest like accuracy.

.. code-block:: python

# the pseudocode for these calls
val_outs = []
for val_batch in val_data:
    out = validation_step(val_batch)
    val_outs.append(out)
validation_epoch_end(val_outs)

Parameters:

Name Type Description Default
batch

The output of your :class:~torch.utils.data.DataLoader.

required
batch_idx

The index of this batch.

required
dataloader_idx

The index of the dataloader that produced this batch. (only if multiple val dataloaders used)

required

Returns:

Type Description
  • Any object or value
  • None - Validation will skip to the next batch

.. code-block:: python

# pseudocode of order
val_outs = []
for val_batch in val_data:
    out = validation_step(val_batch)
    if defined("validation_step_end"):
        out = validation_step_end(out)
    val_outs.append(out)
val_outs = validation_epoch_end(val_outs)

.. code-block:: python

# if you have one val dataloader:
def validation_step(self, batch, batch_idx):
    ...


# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0):
    ...

Examples::

# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'val_loss': loss, 'val_acc': val_acc})

If you pass in multiple val dataloaders, :meth:validation_step will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

.. code-block:: python

# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    ...

!!! note If you don't need to validate you don't need to implement this method.

!!! note When the :meth:validation_step is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.

Source code in bemb/model/bemb_flex_lightning.py
def validation_step(self, batch, batch_idx):
    # LL = self.model.forward(batch, return_type='log_prob', return_scope='item_index', deterministic=True).mean()
    # self.log('val_log_likelihood', LL, prog_bar=True)
    # pred = self.model(batch)
    # performance = self.model.get_within_category_accuracy(pred, batch.label)

    # utility.

    for key, val in self._get_performance_dict(batch).items():
        self.log('val_' + key, val, prog_bar=True, batch_size=len(batch))