-
Notifications
You must be signed in to change notification settings - Fork 3
Description
Following code is mentioned in src\experiments\tabular_data\adult_census_data\train.py to perform training on adult_census_data
SLASHobj.learn(dataList=dataList, queryList=queryList, epoch=1, batchSize=exp_dict['bs'], p_num=exp_dict['p_num'], method='exact') # 'network_prediction'
But the learn function definition does not accept dataList and queryList as parameters in slash.py
def learn(self, dataset_loader, epoch, method='exact', lr=0.01, opt=False, batchSize=1, use_em=False, train_slot=False, slot_net=None, p_num=1, marginalisation_masks=None):
There is no arguments to pass dataList - which is a tensor array and queryList. This learn function is defined to be suitable for mnist_addition. How should we use SLASH to train on Tabular data?