1111 */
1212
1313#include < mlpack/methods/ann/layer/bilinear_interpolation.hpp>
14+ #include < mlpack/core/util/to_lower.hpp>
1415#include < boost/regex.hpp>
1516
1617#ifndef MODELS_AUGMENTATION_HPP
1718#define MODELS_AUGMENTATION_HPP
1819
1920/* *
20- * Augmentation class used to perform augmentations / transform the data.
21+ * Augmentation class used to perform augmentations by transforming the data.
2122 * For the list of supported augmentation, take a look at our wiki page.
2223 *
2324 * @code
24- * Augmentation<> augmentation({"horizontal-flip", "resize = (224, 224)"}, 0.2);
25+ * Augmentation augmentation({"horizontal-flip", "resize = (224, 224)"}, 0.2);
2526 * augmentation.Transform(dataloader.TrainFeatures);
2627 * @endcode
27- *
28- * @tparam DatasetType Datatype on which augmentation will be done.
2928 */
30- template <typename DatasetType = arma::mat>
3129class Augmentation
3230{
3331 public:
3432 // ! Create the augmentation class object.
35- Augmentation ();
33+ Augmentation () :
34+ augmentations (std::vector<std::string>()),
35+ augmentationProbability (0.2 )
36+ {
37+ // Nothing to do here.
38+ }
3639
3740 /* *
3841 * Constructor for augmentation class.
3942 *
4043 * @param augmentations List of strings containing one of the supported
41- * augmentations.
44+ * augmentations.
4245 * @param augmentationProbability Probability of applying augmentation on
4346 * the dataset.
4447 * NOTE : This doesn't apply to augmentations
4548 * such as resize.
4649 */
4750 Augmentation (const std::vector<std::string>& augmentations,
48- const double augmentationProbability);
51+ const double augmentationProbability) :
52+ augmentations (augmentations),
53+ augmentationProbability (augmentationProbability)
54+ {
55+ // Convert strings to lower case.
56+ for (size_t i = 0 ; i < augmentations.size (); i++)
57+ mlpack::util::ToLower (augmentations[i], this ->augmentations [i]);
58+
59+ // Sort the vector to place resize parameter to the front of the string.
60+ // This prevents constant lookups for resize.
61+ sort (this ->augmentations .begin (), this ->augmentations .end (), [](
62+ std::string& str1, std::string& str2)
63+ {
64+ return str1.find (" resize" ) != std::string::npos;
65+ });
66+ }
4967
5068 /* *
5169 * Applies augmentation to the passed dataset.
5270 *
71+ * @tparam DatasetType Datatype on which augmentation will be done.
72+ *
5373 * @param dataset Dataset on which augmentation will be applied.
5474 * @param datapointWidth Width of a single data point i.e.
55- * Since each column represents a seperate data
56- * point.
75+ * Since each column represents a seperate data
76+ * point.
5777 * @param datapointHeight Height of a single data point.
58- * @param datapointDepth Depth of a single data point. For 2-dimensional
59- * data point, set it to 1. Defaults to 1.
78+ * @param datapointDepth Depth of a single data point. For one 2-dimensional
79+ * data point, set it to 1. Defaults to 1.
6080 */
81+ template <typename DatasetType>
6182 void Transform (DatasetType& dataset,
6283 const size_t datapointWidth,
6384 const size_t datapointHeight,
@@ -66,29 +87,28 @@ class Augmentation
6687 /* *
6788 * Applies resize transform to the entire dataset.
6889 *
90+ * @tparam DatasetType Datatype on which augmentation will be done.
91+ *
6992 * @param dataset Dataset on which augmentation will be applied.
7093 * @param datapointWidth Width of a single data point i.e.
71- * Since each column represents a seperate data
72- * point.
94+ * Since each column represents a seperate data
95+ * point.
7396 * @param datapointHeight Height of a single data point.
74- * @param datapointDepth Depth of a single data point. For 2-dimensional
75- * data point, set it to 1. Defaults to 1.
97+ * @param datapointDepth Depth of a single data point. For one 2-dimensional
98+ * data point, set it to 1. Defaults to 1.
7699 * @param augmentation String containing the transform.
77100 */
101+ template <typename DatasetType>
78102 void ResizeTransform (DatasetType& dataset,
79103 const size_t datapointWidth,
80104 const size_t datapointHeight,
81105 const size_t datapointDepth,
82106 const std::string& augmentation);
83107
84108 private:
85- /* *
86- * Initializes augmentation map for the class.
87- */
88- void InitializeAugmentationMap ();
89-
90109 /* *
91110 * Function to determine if augmentation has Resize function.
111+ *
92112 * @param augmentation Optional argument to check if a string has
93113 * resize substring.
94114 */
@@ -118,11 +138,10 @@ class Augmentation
118138 if (!HasResizeParam ())
119139 return ;
120140
121-
122141 outWidth = 0 ;
123142 outHeight = 0 ;
124143
125- // Use regex to find one / two numbers. If only one provided
144+ // Use regex to find one or two numbers. If only one provided
126145 // set output width equal to output height.
127146 boost::regex regex{" [0-9]+" };
128147
@@ -151,16 +170,12 @@ class Augmentation
151170 }
152171 }
153172
154- // ! Locally held augmentations / transforms that need to be applied.
173+ // ! Locally held augmentations and transforms that need to be applied.
155174 std::vector<std::string> augmentations;
156175
157176 // ! Locally held value of augmentation probability.
158177 double augmentationProbability;
159178
160- // ! Locally help map for mapping functions and strings.
161- std::unordered_map<std::string, void (*)(DatasetType&,
162- size_t , size_t , size_t , std::string&)> augmentationMap;
163-
164179 // The dataloader class should have access to internal functions of
165180 // the dataloader.
166181 template <typename DatasetX, typename DatasetY, class ScalerType >
0 commit comments