from torch.utils.data import Dataset
import torch
[docs]
class TasteNetDataset(Dataset):
def __init__(self, x, y, alt_spec_features, socio_demo_features):
"""
Parameters:
----------
x : pandas DataFrame
DataFrame containing the alternative-specific features.
y : pandas Series
Series containing the choice outcomes.
alt_spec_features : list
List of alternative-specific feature names.
socio_demo_features : list
List of socio-demographic feature names.
"""
self.x = torch.from_numpy(x.loc[:, alt_spec_features].values).to(dtype=torch.float32) # N,A alternative-specific variables
self.x_names = alt_spec_features
self.N = len(self.x)
self.y = torch.from_numpy(y.values)
self.z = torch.from_numpy(x.loc[:, socio_demo_features].values).to(dtype=torch.float32) # N,D socio-demo variables
_, self.D = self.z.size() # z size = (N,D)
def __len__(self):
return self.N
def __getitem__(self, idx):
'''
Get the sample given its idx in the list
'''
return self.x[idx], self.y[idx], self.z[idx]