Skip to content

Error in Slash learning function to train on Tabular data #1

@Saraavana

Description

@Saraavana

Following code is mentioned in src\experiments\tabular_data\adult_census_data\train.py to perform training on adult_census_data

SLASHobj.learn(dataList=dataList, queryList=queryList, epoch=1, batchSize=exp_dict['bs'], p_num=exp_dict['p_num'], method='exact') # 'network_prediction'
But the learn function definition does not accept dataList and queryList as parameters in slash.py

def learn(self, dataset_loader, epoch, method='exact', lr=0.01, opt=False, batchSize=1, use_em=False, train_slot=False, slot_net=None, p_num=1, marginalisation_masks=None):

There is no arguments to pass dataList - which is a tensor array and queryList. This learn function is defined to be suitable for mnist_addition. How should we use SLASH to train on Tabular data?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions