Source code for src.tastenet.models

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs] def get_act(nl_func): if nl_func == "tanh": return nn.Tanh() elif nl_func == "relu": return nn.ReLU() elif nl_func == "sigmoid": return nn.Sigmoid() else: return None
[docs] class TasteNet(nn.Module): """TasteNet-MNL model for Swissmetro""" def __init__( self, args, num_alt_features, num_sd_chars, num_classes, num_latent_vals=None, utility_structure=None, ): """ Initialize the TasteNet class. Args: args (argparse.Namespace): command line arguments. num_alt_features (int): number of alternative features. num_sd_chars (int): number of socio-demographic characteristics. num_classes (int): number of classes. num_latent_vals (int, optional): number of latent values. Defaults to None. Useful only for ordinal regression problems, since the number of latent vars is 1. utility_structure (str, optional): structure of the utility function. Defaults to None. """ super(TasteNet, self).__init__() self.func_intercept = args.functional_intercept self.func_params = args.functional_params if not num_latent_vals: num_latent_vals = num_classes if self.func_intercept or self.func_params: self.params_module = TasteParams( args.layer_sizes, args, num_alt_features, num_latent_vals, num_sd_chars, self.func_intercept, self.func_params, ) self.util_module = Utility( args, num_alt_features, num_latent_vals, self.func_intercept, self.func_params, utility_structure=utility_structure, ) self.ordinal_module = Coral_layer(num_classes) if num_latent_vals == 1 else None self.args = args self.num_classes = num_classes
[docs] def forward(self, x, z=None): if self.func_intercept or self.func_params: b = self.params_module(z) # taste parameters, (N,1) if self.num_classes == 3 and self.func_params: b = self.monotonic_constraints(b) elif self.num_classes == 4 and self.func_params and b.shape[1] > 10: b = self.lpmc_monotonic_constraints(b) elif self.num_classes == 4 and self.func_params: b = self.synthetic_monotonic_constraints(b) else: b = None v = self.util_module(x, b) # no softmax here if self.ordinal_module is None: logits = v else: logits = self.ordinal_module(v) # (N, J-1) return logits
[docs] def monotonic_constraints(self, b): """ Put transformation for the sake of constraints on the value of times This is only for the SwissMetro dataset and needs to be adapted for other datasets. b: taste parameters (N, 1): Individual taste parameters. """ if self.func_intercept: return torch.cat( [ -F.relu(-b[:, :6]), b[:, 6].view(-1, 1), -F.relu(-b[:, 7:9]), b[:, -self.num_classes :].view(-1, self.num_classes), ], dim=1, ) else: return torch.cat( [-F.relu(-b[:, :6]), b[:, 6].view(-1, 1), -F.relu(-b[:, 7:9])], dim=1 )
[docs] def lpmc_monotonic_constraints(self, b): """ Put transformation for the sake of constraints on the value of times This is only for the LPMC dataset and needs to be adapted for other datasets. b: taste parameters (N, 1): Individual taste parameters. """ if self.func_intercept: return torch.cat( [ -F.relu(-b[:, :2]), b[:, 2:4].view(-1, 2), -F.relu(-b[:, 4:6]), b[:, 6:8].view(-1, 2), -F.relu(-b[:, 8:16]), b[:, 16:18].view(-1, 2), -F.relu(-b[:, 18:23]), b[:, 23:25].view(-1, 2), b[:, -self.num_classes :].view(-1, self.num_classes), ], dim=1, ) else: return torch.cat( [ -F.relu(-b[:, :2]), b[:, 2:4].view(-1, 2), -F.relu(-b[:, 4:6]), b[:, 6:8].view(-1, 2), -F.relu(-b[:, 8:16]), b[:, 16:18].view(-1, 2), -F.relu(-b[:, 18:23]), b[:, 23:25].view(-1, 2), ], dim=1, )
[docs] def synthetic_monotonic_constraints(self, b): """Put transformation for the sake of constraints on the value of times This is only for the synthetic dataset and needs to be adapted for other datasets. b: taste parameters (N, 1): Individual taste parameters. """ if self.func_intercept: return torch.cat( [ -F.relu(-b[:, :4]), b[:, -self.num_classes :].view(-1, self.num_classes), ], dim=1, ) else: return torch.cat( [ -F.relu(-b[:, :4]) ], dim=1, )
[docs] def l2_norm(self): """ L2 norm, not including bias """ norm = torch.zeros(1).to(device=torch.device(self.args.device)) for i, params in enumerate(self.params_module.parameters()): if i % 2 == 1: # skip bias continue norm += (params**2).sum() return norm
[docs] def l1_norm(self): """ L1 norm, not including bias """ norm = torch.zeros(1).to(device=torch.device(self.args.device)) for i, params in enumerate(self.params_module.parameters()): if i % 2 == 1: # skip bias continue norm += torch.abs(params).sum() return norm
[docs] class Utility(nn.Module): def __init__( self, args, num_alt_features, num_classes, func_intercept=True, func_params=True, utility_structure=None, ): super(Utility, self).__init__() self.args = args self.func_intercept = func_intercept self.func_params = func_params self.num_classes = num_classes self.num_alt_features = num_alt_features self.utility_structure = utility_structure self.mnl = MNL_layer(utility_structure, args) if not self.func_intercept: self.intercept = nn.Parameter(torch.zeros(num_classes)) # (1, J)
[docs] def forward(self, x, b=None): """ x: attributes of each alternative, including the intercept (N,K) J alternatives, each have K attributes. b: taste parameters (N, 1): Individual taste parameters. """ if not self.func_params and not self.func_intercept: v = self.mnl(x) + self.intercept.view(1, self.num_classes) elif self.func_params and not self.func_intercept: v = self.mnl(x, b) + self.intercept.view(1, self.num_classes) elif not self.func_params and self.func_intercept: v = self.mnl(x) + b.view(-1, self.num_classes) else: v = self.mnl(x, b[:, : -self.num_classes]) + b[:, -self.num_classes :] return v
[docs] class TasteParams(nn.Module): """ Network for tastes """ def __init__( self, layer_sizes, args, num_alt_features, num_classes, num_sd_chars, func_intercept=True, func_params=True, ): """Initialize the TasteParams class. Args: layer_sizes (list[tuple]): list of layer sizes in a tuple. args (argparse.Namespace): command line arguments. func_intercept (bool): whether to include functional intercepts. func_params (bool): whether to include functional taste parameters. num_alt_features (int): number of alternative features. num_classes (int): number of classes. num_sd_chars (int): number of socio-demographic characteristics. """ if not func_intercept and not func_params: raise ValueError( "At least one of func_intercept or func_params must be True." ) all_layers = [l for l in layer_sizes] all_layers.insert(0, num_sd_chars) all_layers.append(num_alt_features * func_params + num_classes * func_intercept) super(TasteParams, self).__init__() self.seq = nn.Sequential() for i, (in_size, out_size) in enumerate(zip(all_layers[:-1], all_layers[1:])): self.seq.add_module( name=f"L{i+1}", module=nn.Linear(in_size, out_size, bias=True) ) if i < len(all_layers) - 2: self.seq.add_module(name=f"A{i+1}", module=get_act(args.act_func)) if args.dropout > 0: self.seq.add_module(name=f"D{i+1}", module=nn.Dropout(args.dropout)) if args.batch_norm: self.seq.add_module( name=f"BN{i+1}", module=nn.BatchNorm1d(out_size) ) self.args = args
[docs] def forward(self, z): """ Parameters: z: (N,D) # batch size, input dimension Returns: V: (N,1) # taste parameters """ return self.seq(z) # (N,K)
[docs] class Coral_layer(nn.Module): def __init__(self, n_choices): """Initialize the Ordinal_layer class (Coral layer). Args: n_choices (int): number of choice alternatives. """ super(Coral_layer, self).__init__() self.coral_bias = nn.Parameter(torch.ones((n_choices - 1,)))
[docs] def forward(self, x): """return the output of Coral layer. Args: input (TensorVariable): output of last residual layer. """ return x + self.coral_bias # (N, J-1)
[docs] class MNL_layer(nn.Module): def __init__(self, utility_structure, args): """Initialize the MNL_complex_layer class. Args: n_choices (int): number of choice alternatives. args (argparse.Namespace): command line arguments. """ super(MNL_layer, self).__init__() self.args = args self.mnl = nn.ModuleList() self.utility_structure = utility_structure for _, v in utility_structure.items(): self.mnl.append(nn.Linear(v[1] - v[0], 1, bias=False))
[docs] def forward(self, x, b=None): """return the output of MNL complex layer. Args: x (TensorVariable): output of last residual layer. b (TensorVariable): taste parameters. """ logits = torch.zeros(x.shape[0], len(self.utility_structure)).to( device=torch.device(self.args.device) ) for k, v in self.utility_structure.items(): if b is not None: logits[:, k] = (x[:, v[0] : v[1]] * b[:, v[0] : v[1]]).sum(dim=1) else: logits[:, k] = self.mnl[k](x[:, v[0] : v[1]]).view(-1) return logits