Skip to content

Commit c859e8e

Browse files
committed
Build Fixed, Nice way to avoid hidden files
Style Fixed Clean Up Why is config there even though I did a rebase Add data split (internal), need implementation only for field type Style Fixed need to seperate the declaration Fixed style fixes, commits also need clean up Remove template from Augmentation Fixed the duplicacy error Add mat type support (there is invalid read) Mat type gives invalid read Hmm, style fixes, commits also need clean up Typo -> invalid read -> Fixed Remove extra lines Rename cell to image Style Fixes and use mlpack URL Make augmentation case insensitive Allow multiple Resize Transform Add unknown augmentation warning Style Fixes and boundary checks Increase dataset size, adjust comments, change URL Style Fixes and boundary checks
1 parent f860d62 commit c859e8e

File tree

12 files changed

+411
-138
lines changed

12 files changed

+411
-138
lines changed

augmentation/augmentation.hpp

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,53 +11,74 @@
1111
*/
1212

1313
#include <mlpack/methods/ann/layer/bilinear_interpolation.hpp>
14+
#include <mlpack/core/util/to_lower.hpp>
1415
#include <boost/regex.hpp>
1516

1617
#ifndef MODELS_AUGMENTATION_HPP
1718
#define MODELS_AUGMENTATION_HPP
1819

1920
/**
20-
* Augmentation class used to perform augmentations / transform the data.
21+
* Augmentation class used to perform augmentations by transforming the data.
2122
* For the list of supported augmentation, take a look at our wiki page.
2223
*
2324
* @code
24-
* Augmentation<> augmentation({"horizontal-flip", "resize = (224, 224)"}, 0.2);
25+
* Augmentation augmentation({"horizontal-flip", "resize = (224, 224)"}, 0.2);
2526
* augmentation.Transform(dataloader.TrainFeatures);
2627
* @endcode
27-
*
28-
* @tparam DatasetType Datatype on which augmentation will be done.
2928
*/
30-
template<typename DatasetType = arma::mat>
3129
class Augmentation
3230
{
3331
public:
3432
//! Create the augmentation class object.
35-
Augmentation();
33+
Augmentation() :
34+
augmentations(std::vector<std::string>()),
35+
augmentationProbability(0.2)
36+
{
37+
// Nothing to do here.
38+
}
3639

3740
/**
3841
* Constructor for augmentation class.
3942
*
4043
* @param augmentations List of strings containing one of the supported
41-
* augmentations.
44+
* augmentations.
4245
* @param augmentationProbability Probability of applying augmentation on
4346
* the dataset.
4447
* NOTE : This doesn't apply to augmentations
4548
* such as resize.
4649
*/
4750
Augmentation(const std::vector<std::string>& augmentations,
48-
const double augmentationProbability);
51+
const double augmentationProbability) :
52+
augmentations(augmentations),
53+
augmentationProbability(augmentationProbability)
54+
{
55+
// Convert strings to lower case.
56+
for (size_t i = 0; i < augmentations.size(); i++)
57+
mlpack::util::ToLower(augmentations[i], this->augmentations[i]);
58+
59+
// Sort the vector to place resize parameter to the front of the string.
60+
// This prevents constant lookups for resize.
61+
sort(this->augmentations.begin(), this->augmentations.end(), [](
62+
std::string& str1, std::string& str2)
63+
{
64+
return str1.find("resize") != std::string::npos;
65+
});
66+
}
4967

5068
/**
5169
* Applies augmentation to the passed dataset.
5270
*
71+
* @tparam DatasetType Datatype on which augmentation will be done.
72+
*
5373
* @param dataset Dataset on which augmentation will be applied.
5474
* @param datapointWidth Width of a single data point i.e.
55-
* Since each column represents a seperate data
56-
* point.
75+
* Since each column represents a seperate data
76+
* point.
5777
* @param datapointHeight Height of a single data point.
58-
* @param datapointDepth Depth of a single data point. For 2-dimensional
59-
* data point, set it to 1. Defaults to 1.
78+
* @param datapointDepth Depth of a single data point. For one 2-dimensional
79+
* data point, set it to 1. Defaults to 1.
6080
*/
81+
template<typename DatasetType>
6182
void Transform(DatasetType& dataset,
6283
const size_t datapointWidth,
6384
const size_t datapointHeight,
@@ -66,29 +87,28 @@ class Augmentation
6687
/**
6788
* Applies resize transform to the entire dataset.
6889
*
90+
* @tparam DatasetType Datatype on which augmentation will be done.
91+
*
6992
* @param dataset Dataset on which augmentation will be applied.
7093
* @param datapointWidth Width of a single data point i.e.
71-
* Since each column represents a seperate data
72-
* point.
94+
* Since each column represents a seperate data
95+
* point.
7396
* @param datapointHeight Height of a single data point.
74-
* @param datapointDepth Depth of a single data point. For 2-dimensional
75-
* data point, set it to 1. Defaults to 1.
97+
* @param datapointDepth Depth of a single data point. For one 2-dimensional
98+
* data point, set it to 1. Defaults to 1.
7699
* @param augmentation String containing the transform.
77100
*/
101+
template<typename DatasetType>
78102
void ResizeTransform(DatasetType& dataset,
79103
const size_t datapointWidth,
80104
const size_t datapointHeight,
81105
const size_t datapointDepth,
82106
const std::string& augmentation);
83107

84108
private:
85-
/**
86-
* Initializes augmentation map for the class.
87-
*/
88-
void InitializeAugmentationMap();
89-
90109
/**
91110
* Function to determine if augmentation has Resize function.
111+
*
92112
* @param augmentation Optional argument to check if a string has
93113
* resize substring.
94114
*/
@@ -118,11 +138,10 @@ class Augmentation
118138
if (!HasResizeParam())
119139
return;
120140

121-
122141
outWidth = 0;
123142
outHeight = 0;
124143

125-
// Use regex to find one / two numbers. If only one provided
144+
// Use regex to find one or two numbers. If only one provided
126145
// set output width equal to output height.
127146
boost::regex regex{"[0-9]+"};
128147

@@ -151,16 +170,12 @@ class Augmentation
151170
}
152171
}
153172

154-
//! Locally held augmentations / transforms that need to be applied.
173+
//! Locally held augmentations and transforms that need to be applied.
155174
std::vector<std::string> augmentations;
156175

157176
//! Locally held value of augmentation probability.
158177
double augmentationProbability;
159178

160-
//! Locally help map for mapping functions and strings.
161-
std::unordered_map<std::string, void(*)(DatasetType&,
162-
size_t, size_t, size_t, std::string&)> augmentationMap;
163-
164179
// The dataloader class should have access to internal functions of
165180
// the dataloader.
166181
template<typename DatasetX, typename DatasetY, class ScalerType>

augmentation/augmentation_impl.hpp

Lines changed: 20 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -9,74 +9,51 @@
99
* 3-clause BSD license along with mlpack. If not, see
1010
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
1111
*/
12+
1213
// Incase it has not been included already.
1314
#include "augmentation.hpp"
1415

1516
#ifndef MODELS_AUGMENTATION_IMPL_HPP
1617
#define MODELS_AUGMENTATION_IMPL_HPP
1718

1819
template<typename DatasetType>
19-
Augmentation<DatasetType>::Augmentation() :
20-
augmentations(std::vector<std::string>()),
21-
augmentationProbability(0.2)
22-
{
23-
// Nothing to do here.
24-
}
25-
26-
template<typename DatasetType>
27-
Augmentation<DatasetType>::Augmentation(
28-
const std::vector<std::string>& augmentations,
29-
const double augmentationProbability) :
30-
augmentations(augmentations),
31-
augmentationProbability(augmentationProbability)
20+
void Augmentation::Transform(DatasetType& dataset,
21+
const size_t datapointWidth,
22+
const size_t datapointHeight,
23+
const size_t datapointDepth)
3224
{
33-
// Sort the vector to place resize parameter to the front of the string.
34-
// This prevents constant lookups for resize.
35-
sort(this->augmentations.begin(), this->augmentations.end(), [](
36-
std::string& str1, std::string& str2)
37-
{
38-
return str1.find("resize") != std::string::npos;
39-
});
25+
// Initialize the augmentation map.
26+
std::unordered_map<std::string, void(*)(DatasetType&,
27+
size_t, size_t, size_t, std::string&)> augmentationMap;
4028

41-
// Fill augmentation map with supported augmentations other than resize.
42-
InitializeAugmentationMap();
43-
}
44-
45-
template<typename DatasetType>
46-
void Augmentation<DatasetType>::Transform(DatasetType& dataset,
47-
const size_t datapointWidth,
48-
const size_t datapointHeight,
49-
const size_t datapointDepth)
50-
{
51-
size_t i = 0;
52-
if (this->HasResizeParam())
53-
{
54-
this->ResizeTransform(dataset, datapointWidth, datapointHeight,
55-
datapointDepth, augmentations[0]);
56-
i++;
57-
}
58-
59-
for (; i < augmentations.size(); i++)
29+
for (size_t i = 0; i < augmentations.size(); i++)
6030
{
6131
if (augmentationMap.count(augmentations[i]))
6232
{
6333
augmentationMap[augmentations[i]](dataset, datapointWidth,
6434
datapointHeight, datapointDepth, augmentations[i]);
6535
}
36+
else if (this->HasResizeParam(augmentations[i]))
37+
{
38+
this->ResizeTransform(dataset, datapointWidth, datapointHeight,
39+
datapointDepth, augmentations[i]);
40+
}
41+
else
42+
{
43+
mlpack::Log::Warn << "Unknown augmentation : \'" <<
44+
augmentations[i] << "\' not found!" << std::endl;
45+
}
6646
}
6747
}
6848

6949
template<typename DatasetType>
70-
void Augmentation<DatasetType>::ResizeTransform(
50+
void Augmentation::ResizeTransform(
7151
DatasetType& dataset,
7252
const size_t datapointWidth,
7353
const size_t datapointHeight,
7454
const size_t datapointDepth,
7555
const std::string& augmentation)
7656
{
77-
if (!this->HasResizeParam(augmentation))
78-
return;
79-
8057
size_t outputWidth = 0, outputHeight = 0;
8158

8259
// Get output width and output height.
@@ -88,16 +65,9 @@ void Augmentation<DatasetType>::ResizeTransform(
8865
datapointWidth, datapointHeight, outputWidth, outputHeight,
8966
datapointDepth);
9067

91-
// Not sure how to avoid a copy here.
9268
DatasetType output;
9369
resizeLayer.Forward(dataset, output);
9470
dataset = std::move(output);
9571
}
9672

97-
template<typename DatasetType>
98-
void Augmentation<DatasetType>::InitializeAugmentationMap()
99-
{
100-
// Fill the map here.
101-
}
102-
10373
#endif

data/PASCAL-VOC-Test.tar.gz

10.9 MB
Binary file not shown.

data/cifar-test.tar.gz

2.11 MB
Binary file not shown.

0 commit comments

Comments
 (0)