这是论文Unsupervised Domain Adaptation by Backpropagation的复现代码,完成了MNIST与MNIST-M数据集之间的迁移训练
- tensorflow=2.4.0
- opencv
- numpy
- pickle
- skimage
checkpoints存放训练过程中模型权重;logs存放模型训练过程中相关日志文件;config存放参数配置类脚本及训练过程中参数配置文件;model存放网络模型定义脚本;model_data存放包括但不限于数据集、预训练模型等文件;utils存放包括但不限于数据集和模型训练相关工具类和工具脚本;image存放tensorboard可视化截图;create_mnistm.py是根据MNIST数据集生成MNIST-M数据集的脚本;train_MNIST2MNIST_M.py是利用MNIST和MNIST-M数据集进行DANN自适应模型训练的脚本;
首先下载BSDS500数据集 ,放在model_data/dataset路径下。其下载路径如下:
- 官网:BSDS500数据集
- github:BSDS500数据集
然后执行python create_mnistm.py生成MNIST-M数据集,根据自己需要修改create_mnistm.py中BST_PATH、mnist_dir和mnistm _dir ,默认路径如下:
BST_PATH = os.path.abspath('./model_data/dataset/BSR_bsds500.tgz')
mnist_dir = os.path.abspath("model_data/dataset/MNIST")
mnistm_dir = os.path.abspath("model_data/dataset/MNIST_M")最后运行如下命令进行MNIST和MNIST-M数据集之间的自适应模型训练,根据自己的需要进行修改相关超参数,例如init_learning_rate、momentum_rate、batch_size、epoch、pre_model_path、source_dataset_path和target_dataset_path。
python train_MNIST2MNIST_M.py下面主要包括了MNIST和MNIST-M数据集在自适应训练过程中学习率、梯度反转层参数
首先是超参数学习率和梯度反转层参数
接着是训练数据集和验证数据集的图像分类精度和域分类精度在训练过程中的数据可视化,其中蓝色代表训练集,红色代表验证集。训练精度是在源域数据集即MNIST数据集上的统计结果,验证精度是在目标域数据集即MNIST-M数据集上的统计结果。
最后是训练数据集和验证数据集的图像分类损失和域分类损失在训练过程中的数据可视化,其中蓝色代表训练集,红色代表验证集。
CSDN博客链接:
知乎专栏链接:


