Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion augmentation/augmentation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include <mlpack/methods/ann/layer/bilinear_interpolation.hpp>
#include <mlpack/core/util/to_lower.hpp>
#include <mlpack/core/data/split_data.hpp>
#include <boost/regex.hpp>

#ifndef MODELS_AUGMENTATION_HPP
Expand Down Expand Up @@ -105,7 +106,49 @@ class Augmentation
const size_t datapointDepth,
const std::string& augmentation);

private:
/**
* Applies horizontal flip transform to the splited dataset.
*
* @tparam DatasetType Datatype on which augmentation will be done.
*
* @param dataset Dataset on which augmentation will be applied.
* @param datapointWidth Width of a single data point i.e.
* Since each column represents a seperate data
* point.
* @param datapointHeight Height of a single data point.
* @param datapointDepth Depth of a single data point. For one 2-dimensional
* data point, set it to 1. Defaults to 1.
* @param augmentation String containing the transform.
*/
template<typename DatasetType>
void HorizontalFlipTransform(DatasetType& dataset,
const size_t datapointWidth,
const size_t datapointHeight,
const size_t datapointDepth,
const std::string& augmentation);

/**
* Applies verticle flip transform to the splited dataset.
*
* @tparam DatasetType Datatype on which augmentation will be done.
*
* @param dataset Dataset on which augmentation will be applied.
* @param datapointWidth Width of a single data point i.e.
* Since each column represents a seperate data
* point.
* @param datapointHeight Height of a single data point.
* @param datapointDepth Depth of a single data point. For one 2-dimensional
* data point, set it to 1. Defaults to 1.
* @param augmentation String containing the transform.
*/
template<typename DatasetType>
void VerticalFlipTransform(DatasetType& dataset,
const size_t datapointWidth,
const size_t datapointHeight,
const size_t datapointDepth,
const std::string& augmentation);

private:
/**
* Function to determine if augmentation has Resize function.
*
Expand Down
36 changes: 35 additions & 1 deletion augmentation/augmentation_impl.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/**
* @file augmentation_impl.hpp
* @author Kartik Dutt
* @author Kartik Dutt, Ritu Raj Singh
*
* Implementation of Augmentation class for augmenting data.
*
Expand Down Expand Up @@ -70,4 +70,38 @@ void Augmentation::ResizeTransform(
dataset = std::move(output);
}

template<typename DatasetType>
void Augmentation::HorizontalFlipTransform(
DatasetType& dataset,
const size_t datapointWidth,
const size_t datapointHeight,
const size_t datapointDepth,
const std::string& augmentation)
{
// We will use mlpack's split to split the dataset.
auto splitResult = mlpack::data::Split(dataset, augmentationProbability);
// We will use arma's fliplr to flip the columns.
std::get<1>(splitResult) = (arma::fliplr(std::get<1>(splitResult)));
dataset = arma::join_rows( std::get<0>(splitResult), std::get<1>(splitResult) );
dataset = std::move(dataset);

}

template<typename DatasetType>
void Augmentation::VerticalFlipTransform(
DatasetType& dataset,
const size_t datapointWidth,
const size_t datapointHeight,
const size_t datapointDepth,
const std::string& augmentation)
{
// We will use mlpack's split to split the dataset.
auto splitResult = mlpack::data::Split(dataset, augmentationProbability);
// We will use arma's flipud to flip the rows.
std::get<1>(splitResult) = (arma::flipud(std::get<1>(splitResult)));
dataset = arma::join_rows( std::get<0>(splitResult), std::get<1>(splitResult) );
dataset = std::move(dataset);

}

#endif