diff --git a/LICENSE b/LICENSE deleted file mode 100644 index e04b480..0000000 --- a/LICENSE +++ /dev/null @@ -1,427 +0,0 @@ -Attribution-ShareAlike 4.0 International - -======================================================================= - -Creative Commons Corporation ("Creative Commons") is not a law firm and -does not provide legal services or legal advice. Distribution of -Creative Commons public licenses does not create a lawyer-client or -other relationship. Creative Commons makes its licenses and related -information available on an "as-is" basis. Creative Commons gives no -warranties regarding its licenses, any material licensed under their -terms and conditions, or any related information. Creative Commons -disclaims all liability for damages resulting from their use to the -fullest extent possible. - -Using Creative Commons Public Licenses - -Creative Commons public licenses provide a standard set of terms and -conditions that creators and other rights holders may use to share -original works of authorship and other material subject to copyright -and certain other rights specified in the public license below. The -following considerations are for informational purposes only, are not -exhaustive, and do not form part of our licenses. - - Considerations for licensors: Our public licenses are - intended for use by those authorized to give the public - permission to use material in ways otherwise restricted by - copyright and certain other rights. Our licenses are - irrevocable. Licensors should read and understand the terms - and conditions of the license they choose before applying it. - Licensors should also secure all rights necessary before - applying our licenses so that the public can reuse the - material as expected. Licensors should clearly mark any - material not subject to the license. This includes other CC- - licensed material, or material used under an exception or - limitation to copyright. More considerations for licensors: - wiki.creativecommons.org/Considerations_for_licensors - - Considerations for the public: By using one of our public - licenses, a licensor grants the public permission to use the - licensed material under specified terms and conditions. If - the licensor's permission is not necessary for any reason--for - example, because of any applicable exception or limitation to - copyright--then that use is not regulated by the license. Our - licenses grant only permissions under copyright and certain - other rights that a licensor has authority to grant. Use of - the licensed material may still be restricted for other - reasons, including because others have copyright or other - rights in the material. A licensor may make special requests, - such as asking that all changes be marked or described. - Although not required by our licenses, you are encouraged to - respect those requests where reasonable. More considerations - for the public: - wiki.creativecommons.org/Considerations_for_licensees - -======================================================================= - -Creative Commons Attribution-ShareAlike 4.0 International Public -License - -By exercising the Licensed Rights (defined below), You accept and agree -to be bound by the terms and conditions of this Creative Commons -Attribution-ShareAlike 4.0 International Public License ("Public -License"). To the extent this Public License may be interpreted as a -contract, You are granted the Licensed Rights in consideration of Your -acceptance of these terms and conditions, and the Licensor grants You -such rights in consideration of benefits the Licensor receives from -making the Licensed Material available under these terms and -conditions. - - -Section 1 -- Definitions. - - a. Adapted Material means material subject to Copyright and Similar - Rights that is derived from or based upon the Licensed Material - and in which the Licensed Material is translated, altered, - arranged, transformed, or otherwise modified in a manner requiring - permission under the Copyright and Similar Rights held by the - Licensor. For purposes of this Public License, where the Licensed - Material is a musical work, performance, or sound recording, - Adapted Material is always produced where the Licensed Material is - synched in timed relation with a moving image. - - b. Adapter's License means the license You apply to Your Copyright - and Similar Rights in Your contributions to Adapted Material in - accordance with the terms and conditions of this Public License. - - c. BY-SA Compatible License means a license listed at - creativecommons.org/compatiblelicenses, approved by Creative - Commons as essentially the equivalent of this Public License. - - d. Copyright and Similar Rights means copyright and/or similar rights - closely related to copyright including, without limitation, - performance, broadcast, sound recording, and Sui Generis Database - Rights, without regard to how the rights are labeled or - categorized. For purposes of this Public License, the rights - specified in Section 2(b)(1)-(2) are not Copyright and Similar - Rights. - - e. Effective Technological Measures means those measures that, in the - absence of proper authority, may not be circumvented under laws - fulfilling obligations under Article 11 of the WIPO Copyright - Treaty adopted on December 20, 1996, and/or similar international - agreements. - - f. Exceptions and Limitations means fair use, fair dealing, and/or - any other exception or limitation to Copyright and Similar Rights - that applies to Your use of the Licensed Material. - - g. License Elements means the license attributes listed in the name - of a Creative Commons Public License. The License Elements of this - Public License are Attribution and ShareAlike. - - h. Licensed Material means the artistic or literary work, database, - or other material to which the Licensor applied this Public - License. - - i. Licensed Rights means the rights granted to You subject to the - terms and conditions of this Public License, which are limited to - all Copyright and Similar Rights that apply to Your use of the - Licensed Material and that the Licensor has authority to license. - - j. Licensor means the individual(s) or entity(ies) granting rights - under this Public License. - - k. Share means to provide material to the public by any means or - process that requires permission under the Licensed Rights, such - as reproduction, public display, public performance, distribution, - dissemination, communication, or importation, and to make material - available to the public including in ways that members of the - public may access the material from a place and at a time - individually chosen by them. - - l. Sui Generis Database Rights means rights other than copyright - resulting from Directive 96/9/EC of the European Parliament and of - the Council of 11 March 1996 on the legal protection of databases, - as amended and/or succeeded, as well as other essentially - equivalent rights anywhere in the world. - - m. You means the individual or entity exercising the Licensed Rights - under this Public License. Your has a corresponding meaning. - - -Section 2 -- Scope. - - a. License grant. - - 1. Subject to the terms and conditions of this Public License, - the Licensor hereby grants You a worldwide, royalty-free, - non-sublicensable, non-exclusive, irrevocable license to - exercise the Licensed Rights in the Licensed Material to: - - a. reproduce and Share the Licensed Material, in whole or - in part; and - - b. produce, reproduce, and Share Adapted Material. - - 2. Exceptions and Limitations. For the avoidance of doubt, where - Exceptions and Limitations apply to Your use, this Public - License does not apply, and You do not need to comply with - its terms and conditions. - - 3. Term. The term of this Public License is specified in Section - 6(a). - - 4. Media and formats; technical modifications allowed. The - Licensor authorizes You to exercise the Licensed Rights in - all media and formats whether now known or hereafter created, - and to make technical modifications necessary to do so. The - Licensor waives and/or agrees not to assert any right or - authority to forbid You from making technical modifications - necessary to exercise the Licensed Rights, including - technical modifications necessary to circumvent Effective - Technological Measures. For purposes of this Public License, - simply making modifications authorized by this Section 2(a) - (4) never produces Adapted Material. - - 5. Downstream recipients. - - a. Offer from the Licensor -- Licensed Material. Every - recipient of the Licensed Material automatically - receives an offer from the Licensor to exercise the - Licensed Rights under the terms and conditions of this - Public License. - - b. Additional offer from the Licensor -- Adapted Material. - Every recipient of Adapted Material from You - automatically receives an offer from the Licensor to - exercise the Licensed Rights in the Adapted Material - under the conditions of the Adapter's License You apply. - - c. No downstream restrictions. You may not offer or impose - any additional or different terms or conditions on, or - apply any Effective Technological Measures to, the - Licensed Material if doing so restricts exercise of the - Licensed Rights by any recipient of the Licensed - Material. - - 6. No endorsement. Nothing in this Public License constitutes or - may be construed as permission to assert or imply that You - are, or that Your use of the Licensed Material is, connected - with, or sponsored, endorsed, or granted official status by, - the Licensor or others designated to receive attribution as - provided in Section 3(a)(1)(A)(i). - - b. Other rights. - - 1. Moral rights, such as the right of integrity, are not - licensed under this Public License, nor are publicity, - privacy, and/or other similar personality rights; however, to - the extent possible, the Licensor waives and/or agrees not to - assert any such rights held by the Licensor to the limited - extent necessary to allow You to exercise the Licensed - Rights, but not otherwise. - - 2. Patent and trademark rights are not licensed under this - Public License. - - 3. To the extent possible, the Licensor waives any right to - collect royalties from You for the exercise of the Licensed - Rights, whether directly or through a collecting society - under any voluntary or waivable statutory or compulsory - licensing scheme. In all other cases the Licensor expressly - reserves any right to collect such royalties. - - -Section 3 -- License Conditions. - -Your exercise of the Licensed Rights is expressly made subject to the -following conditions. - - a. Attribution. - - 1. If You Share the Licensed Material (including in modified - form), You must: - - a. retain the following if it is supplied by the Licensor - with the Licensed Material: - - i. identification of the creator(s) of the Licensed - Material and any others designated to receive - attribution, in any reasonable manner requested by - the Licensor (including by pseudonym if - designated); - - ii. a copyright notice; - - iii. a notice that refers to this Public License; - - iv. a notice that refers to the disclaimer of - warranties; - - v. a URI or hyperlink to the Licensed Material to the - extent reasonably practicable; - - b. indicate if You modified the Licensed Material and - retain an indication of any previous modifications; and - - c. indicate the Licensed Material is licensed under this - Public License, and include the text of, or the URI or - hyperlink to, this Public License. - - 2. You may satisfy the conditions in Section 3(a)(1) in any - reasonable manner based on the medium, means, and context in - which You Share the Licensed Material. For example, it may be - reasonable to satisfy the conditions by providing a URI or - hyperlink to a resource that includes the required - information. - - 3. If requested by the Licensor, You must remove any of the - information required by Section 3(a)(1)(A) to the extent - reasonably practicable. - - b. ShareAlike. - - In addition to the conditions in Section 3(a), if You Share - Adapted Material You produce, the following conditions also apply. - - 1. The Adapter's License You apply must be a Creative Commons - license with the same License Elements, this version or - later, or a BY-SA Compatible License. - - 2. You must include the text of, or the URI or hyperlink to, the - Adapter's License You apply. You may satisfy this condition - in any reasonable manner based on the medium, means, and - context in which You Share Adapted Material. - - 3. You may not offer or impose any additional or different terms - or conditions on, or apply any Effective Technological - Measures to, Adapted Material that restrict exercise of the - rights granted under the Adapter's License You apply. - - -Section 4 -- Sui Generis Database Rights. - -Where the Licensed Rights include Sui Generis Database Rights that -apply to Your use of the Licensed Material: - - a. for the avoidance of doubt, Section 2(a)(1) grants You the right - to extract, reuse, reproduce, and Share all or a substantial - portion of the contents of the database; - - b. if You include all or a substantial portion of the database - contents in a database in which You have Sui Generis Database - Rights, then the database in which You have Sui Generis Database - Rights (but not its individual contents) is Adapted Material, - - including for purposes of Section 3(b); and - c. You must comply with the conditions in Section 3(a) if You Share - all or a substantial portion of the contents of the database. - -For the avoidance of doubt, this Section 4 supplements and does not -replace Your obligations under this Public License where the Licensed -Rights include other Copyright and Similar Rights. - - -Section 5 -- Disclaimer of Warranties and Limitation of Liability. - - a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE - EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS - AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF - ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, - IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, - WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR - PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, - ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT - KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT - ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. - - b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE - TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, - NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, - INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, - COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR - USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN - ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR - DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR - IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. - - c. The disclaimer of warranties and limitation of liability provided - above shall be interpreted in a manner that, to the extent - possible, most closely approximates an absolute disclaimer and - waiver of all liability. - - -Section 6 -- Term and Termination. - - a. This Public License applies for the term of the Copyright and - Similar Rights licensed here. However, if You fail to comply with - this Public License, then Your rights under this Public License - terminate automatically. - - b. Where Your right to use the Licensed Material has terminated under - Section 6(a), it reinstates: - - 1. automatically as of the date the violation is cured, provided - it is cured within 30 days of Your discovery of the - violation; or - - 2. upon express reinstatement by the Licensor. - - For the avoidance of doubt, this Section 6(b) does not affect any - right the Licensor may have to seek remedies for Your violations - of this Public License. - - c. For the avoidance of doubt, the Licensor may also offer the - Licensed Material under separate terms or conditions or stop - distributing the Licensed Material at any time; however, doing so - will not terminate this Public License. - - d. Sections 1, 5, 6, 7, and 8 survive termination of this Public - License. - - -Section 7 -- Other Terms and Conditions. - - a. The Licensor shall not be bound by any additional or different - terms or conditions communicated by You unless expressly agreed. - - b. Any arrangements, understandings, or agreements regarding the - Licensed Material not stated herein are separate from and - independent of the terms and conditions of this Public License. - - -Section 8 -- Interpretation. - - a. For the avoidance of doubt, this Public License does not, and - shall not be interpreted to, reduce, limit, restrict, or impose - conditions on any use of the Licensed Material that could lawfully - be made without permission under this Public License. - - b. To the extent possible, if any provision of this Public License is - deemed unenforceable, it shall be automatically reformed to the - minimum extent necessary to make it enforceable. If the provision - cannot be reformed, it shall be severed from this Public License - without affecting the enforceability of the remaining terms and - conditions. - - c. No term or condition of this Public License will be waived and no - failure to comply consented to unless expressly agreed to by the - Licensor. - - d. Nothing in this Public License constitutes or may be interpreted - as a limitation upon, or waiver of, any privileges and immunities - that apply to the Licensor or You, including from the legal - processes of any jurisdiction or authority. - - -======================================================================= - -Creative Commons is not a party to its public -licenses. Notwithstanding, Creative Commons may elect to apply one of -its public licenses to material it publishes and in those instances -will be considered the β€œLicensor.” The text of the Creative Commons -public licenses is dedicated to the public domain under the CC0 Public -Domain Dedication. Except for the limited purpose of indicating that -material is shared under a Creative Commons public license or as -otherwise permitted by the Creative Commons policies published at -creativecommons.org/policies, Creative Commons does not authorize the -use of the trademark "Creative Commons" or any other trademark or logo -of Creative Commons without its prior written consent including, -without limitation, in connection with any unauthorized modifications -to any of its public licenses or any other arrangements, -understandings, or agreements concerning use of licensed material. For -the avoidance of doubt, this paragraph does not form part of the -public licenses. - -Creative Commons may be contacted at creativecommons.org. diff --git a/README.md b/README.md index f000432..8c16f70 100644 --- a/README.md +++ b/README.md @@ -1,133 +1,98 @@ -# Kuzushiji-MNIST +# Kuzushiji-MNIST Classification -[![License: CC BY-SA 4.0](https://img.shields.io/badge/License-CC%20BY--SA%204.0-blue.svg)](https://creativecommons.org/licenses/by-sa/4.0/) -πŸ“š [Read the paper](https://arxiv.org/abs/1812.01718) to learn more about Kuzushiji, the datasets and our motivations for making them! +This repository impletments the classification for [Kuzushiji-MNIST](https://github.com/rois-codh/kmnist) in Pytorch with model ResNet & ResMLP. -## News and Updates -**IMPORTANT:** If you downloaded the KMNIST or K49 dataset before **5 February 2019**, please re-download the dataset and run your code again. We fixed minor image processing bugs and released an updated version, we find that the updated version gives slightly better performance. Thanks to [#1](https://github.com/rois-codh/kmnist/issues/1) and [#5](https://github.com/rois-codh/kmnist/issues/5) for bringing this to our attention. -## The Dataset -**Kuzushiji-MNIST** is a drop-in replacement for the MNIST dataset (28x28 grayscale, 70,000 images), provided in the original MNIST format as well as a NumPy format. Since MNIST restricts us to 10 classes, we chose one character to represent each of the 10 rows of Hiragana when creating Kuzushiji-MNIST. +## Download the dataset -**Kuzushiji-49**, as the name suggests, has 49 classes (28x28 grayscale, 270,912 images), is a much larger, but imbalanced dataset containing 48 Hiragana characters and one Hiragana iteration mark. +(1) Get in the project folder in the terminal. -**Kuzushiji-Kanji** is an imbalanced dataset of total 3832 Kanji characters (64x64 grayscale, 140,426 images), ranging from 1,766 examples to only a single example per class. +(2) Run -

- - The 10 classes of Kuzushiji-MNIST, with the first column showing each character's modern hiragana counterpart. -

+```shell +wget http://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist-train-imgs.npz +wget http://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist-train-labels.npz +wget http://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist-test-imgs.npz +wget http://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist-test-labels.npz +``` -## Get the data πŸ’Ύ -🌟 You can run [`python download_data.py`](download_data.py) to interactively select and download any of these datasets! -### Kuzushiji-MNIST +## Visualization and Unsupervised Models -Kuzushiji-MNIST contains 70,000 28x28 grayscale images spanning 10 classes (one from each column of [hiragana](https://upload.wikimedia.org/wikipedia/commons/thumb/2/28/Table_hiragana.svg/768px-Table_hiragana.svg.png)), and is perfectly balanced like the original MNIST dataset (6k/1k train/test for each class). -| File | Examples | Download (MNIST format) | Download (NumPy format) | -|-----------------|--------------------|----------------------------|------------------------------| -| Training images | 60,000 | [train-images-idx3-ubyte.gz](http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-images-idx3-ubyte.gz) (18MB) | [kmnist-train-imgs.npz](http://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist-train-imgs.npz) (18MB) | -| Training labels | 60,000 | [train-labels-idx1-ubyte.gz](http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-labels-idx1-ubyte.gz) (30KB) | [kmnist-train-labels.npz](http://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist-train-labels.npz) (30KB) | -| Testing images | 10,000 | [t10k-images-idx3-ubyte.gz](http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-images-idx3-ubyte.gz) (3MB) | [kmnist-test-imgs.npz](http://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist-test-imgs.npz) (3MB) | -| Testing labels | 10,000 | [t10k-labels-idx1-ubyte.gz](http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-labels-idx1-ubyte.gz) (5KB) | [kmnist-test-labels.npz](http://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist-test-labels.npz) (5KB) | +**You will need install these packages to run the specific functions.** -Mapping from class indices to characters: [kmnist_classmap.csv](http://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist_classmap.csv) (1KB) +For basic functions, such as loading data and displaying images: -We recommend using standard top-1 accuracy on the test set for evaluating on Kuzushiji-MNIST. +``` +pip install numpy +pip install matplotlib +pip install seaborn +pip install pandas +``` -##### Which format do I download? -If you're looking for a drop-in replacement for the MNIST or Fashion-MNIST dataset (for tools that currently work with these datasets), download the data in MNIST format. +For Unsupervised Model, PCA and Evaluation: -Otherwise, it's recommended to download in NumPy format, which can be loaded into an array as easy as: -`arr = np.load(filename)['arr_0']`. +``` +pip install -U scikit-learn +``` -### Kuzushiji-49 +**Run ./unsupervised.ipynb** -Kuzushiji-49 contains 270,912 images spanning 49 classes, and is an extension of the Kuzushiji-MNIST dataset. -| File | Examples | Download (NumPy format) | -|-----------------|--------------------|----------------------------| -| Training images | 232,365 | [k49-train-imgs.npz](http://codh.rois.ac.jp/kmnist/dataset/k49/k49-train-imgs.npz) (63MB) | -| Training labels | 232,365 | [k49-train-labels.npz](http://codh.rois.ac.jp/kmnist/dataset/k49/k49-train-labels.npz) (200KB) | -| Testing images | 38,547 | [k49-test-imgs.npz](http://codh.rois.ac.jp/kmnist/dataset/k49/k49-test-imgs.npz) (11MB) | -| Testing labels | 38,547 | [k49-test-labels.npz](http://codh.rois.ac.jp/kmnist/dataset/k49/k49-test-labels.npz) (50KB) | +## MLP/CNN Model -Mapping from class indices to characters: [k49_classmap.csv](http://codh.rois.ac.jp/kmnist/dataset/k49/k49_classmap.csv) (1KB) +**You will need GPU and PyTorch packages to run the following code.** -We recommend using **balanced accuracy** on the test set for evaluating on Kuzushiji-49. -We use the following implementation of balanced accuracy: -```python -p_test = # Model predictions of class index -y_test = # Ground truth class indices +**Results will be stored in `./log`** + +**Model parameters will be stored in `./models`** -accs = [] -for cls in range(49): - mask = (y_test == cls) - cls_acc = (p_test == cls)[mask].mean() # Accuracy for rows of class cls - accs.append(cls_acc) - -accs = np.mean(accs) # Final balanced accuracy -``` -### Kuzushiji-Kanji -Kuzushiji-Kanji is a large and highly imbalanced 64x64 dataset of 3832 Kanji characters, containing 140,426 images of both common and rare characters. +Available model names: -The full dataset is available for download [here](http://codh.rois.ac.jp/kmnist/dataset/kkanji/kkanji.tar) (310MB). -We plan to release a train/test split version as a low-shot learning dataset very soon. +``` +ResMLP-12, ResMLP-24, ResNet-18, ResNet-34 +``` -![Examples of Kuzushiji-Kanji classes](images/kkanji_examples.png) +To train and test the model: -## Benchmarks & Results πŸ“ˆ +```shell +python classification.py --model [model name] --gpu [GPU No.] --train 1 --test 1 --train_batch [train batch size] --test_batch [test batch size] --epoch [number of train epoch] +``` -Have more results to add to the table? Feel free to submit an [issue](https://github.com/rois-codh/kmnist/issues/new) or [pull request](https://github.com/rois-codh/kmnist/compare)! +Only to test the model -|Model | MNIST | Kuzushiji-MNIST | Kuzushiji-49 | Credit -|---------------------------------|-------|--------|-----|---| -|[4-Nearest Neighbour Baseline](benchmarks/kuzushiji_mnist_knn.py) |97.14% | 92.10% | 83.65% | -|[PCA + 4-kNN](https://github.com/rois-codh/kmnist/issues/10) | 97.76% | 93.98% | 86.80% | [dzisandy](https://github.com/dzisandy) -|[Tuned SVM (RBF kernel)](https://github.com/rois-codh/kmnist/issues/3) | 98.57% | 92.82%\* | 85.61%\* | [TomZephire](https://github.com/TomZephire) -|[Keras Simple CNN Benchmark](benchmarks/kuzushiji_mnist_cnn.py) |99.06% | 94.63% | 89.36% | -|PreActResNet-18 |99.56% | 97.82%\* |96.64%\*| -|PreActResNet-18 + Input Mixup |99.54% | 98.41%\* |97.04%\*| -|PreActResNet-18 + Manifold Mixup |99.54% | 98.83%\* | 97.33%\* | -|[ResNet18 + VGG Ensemble](https://github.com/ranihorev/Kuzushiji_MNIST) | 99.60% | 98.90%\* | | [Rani Horev](https://twitter.com/HorevRani) -|[DenseNet-100 (k=12)](https://github.com/kurapan/pytorch_image_classification) | | | 97.32% | [Jan Zdenek](https://github.com/kurapan) -|[Shake-Shake-26 2x96d (cutout 14)](https://github.com/kurapan/pytorch_image_classification) | | | **98.29%** | [Jan Zdenek](https://github.com/kurapan) -|[shake-shake-26 2x96d (S-S-I), Cutout 14](https://github.com/hysts/pytorch_image_classification#results-on-kuzushiji-mnist) | **99.76%** | **99.34%\*** | | [hysts](https://github.com/hysts) +```shell +python classification.py --model [model name] --gpu [GPU No.] --test_batch [test batch size] +``` -_\* These results were obtained using an old version of the dataset, which gave slightly lower performance numbers_ -For MNIST and Kuzushiji-MNIST we use a standard accuracy metric, while Kuzushiji-49 is evaluated using balanced accuracy (so that all classes have equal weight). -## Citing Kuzushiji-MNIST +For ResMLP-12 -If you use any of the Kuzushiji datasets in your work, we would appreciate a reference to our paper: +```shell +python classification.py --model ResMLP-12 --gpu 0 --train 1 --test 1 --train_batch 64 --test_batch 500 --epoch 30 +``` -**Deep Learning for Classical Japanese Literature. Tarin Clanuwat et al. [arXiv:1812.01718](https://arxiv.org/abs/1812.01718)** +For ResMLP-24 -```latex -@online{clanuwat2018deep, - author = {Tarin Clanuwat and Mikel Bober-Irizar and Asanobu Kitamoto and Alex Lamb and Kazuaki Yamamoto and David Ha}, - title = {Deep Learning for Classical Japanese Literature}, - date = {2018-12-03}, - year = {2018}, - eprintclass = {cs.CV}, - eprinttype = {arXiv}, - eprint = {cs.CV/1812.01718}, -} +```shell +python classification.py --model ResMLP-24 --gpu 0 --train 1 --test 1 --train_batch 64 --test_batch 500 --epoch 30 ``` -## License +For ResNet-18 -Both the dataset itself and the contents of this repo are licensed under a permissive [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/) license, except where specified within some benchmark scripts. CC BY-SA 4.0 license requires attribution, and we would suggest to use the following attribution to the KMNIST dataset. +```shell +python classification.py --model ResNet-18 --gpu 0 --train 1 --test 1 --train_batch 64 --test_batch 500 --epoch 30 +``` -"KMNIST Dataset" (created by CODH), adapted from "Kuzushiji Dataset" -(created by NIJL and others), doi:10.20676/00000341 +For ResNet-34 -## Related datasets +```shell +python classification.py --model ResNet-34 --gpu 0 --train 1 --test 1 --train_batch 64 --test_batch 500 --epoch 30 +``` -Kuzushiji Dataset http://codh.rois.ac.jp/char-shape/ offers 4,328 character types and 1,086,326 character images (November 2019) with CSV files containing the bounding box of characters on the original page images. At this moment, the description of the dataset is available only in Japanese, but the English version will be available soon. diff --git a/benchmarks/kuzushiji_mnist_cnn.py b/benchmarks/kuzushiji_mnist_cnn.py deleted file mode 100644 index 485d8bb..0000000 --- a/benchmarks/kuzushiji_mnist_cnn.py +++ /dev/null @@ -1,73 +0,0 @@ -# Based on MNIST CNN from Keras' examples: https://github.com/keras-team/keras/blob/master/examples/mnist_cnn.py (MIT License) - -from __future__ import print_function -import keras -from keras.models import Sequential -from keras.layers import Dense, Dropout, Flatten -from keras.layers import Conv2D, MaxPooling2D -from keras import backend as K -import numpy as np - -batch_size = 128 -num_classes = 10 -epochs = 12 - -# input image dimensions -img_rows, img_cols = 28, 28 - -def load(f): - return np.load(f)['arr_0'] - -# Load the data -x_train = load('kmnist-train-imgs.npz') -x_test = load('kmnist-test-imgs.npz') -y_train = load('kmnist-train-labels.npz') -y_test = load('kmnist-test-labels.npz') - -if K.image_data_format() == 'channels_first': - x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) - x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) - input_shape = (1, img_rows, img_cols) -else: - x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) - x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) - input_shape = (img_rows, img_cols, 1) - -x_train = x_train.astype('float32') -x_test = x_test.astype('float32') -x_train /= 255 -x_test /= 255 -print('{} train samples, {} test samples'.format(len(x_train), len(x_test))) - -# Convert class vectors to binary class matrices -y_train = keras.utils.to_categorical(y_train, num_classes) -y_test = keras.utils.to_categorical(y_test, num_classes) - -model = Sequential() -model.add(Conv2D(32, kernel_size=(3, 3), - activation='relu', - input_shape=input_shape)) -model.add(Conv2D(64, (3, 3), activation='relu')) -model.add(MaxPooling2D(pool_size=(2, 2))) -model.add(Dropout(0.25)) -model.add(Flatten()) -model.add(Dense(128, activation='relu')) -model.add(Dropout(0.5)) -model.add(Dense(num_classes, activation='softmax')) - -model.compile(loss=keras.losses.categorical_crossentropy, - optimizer=keras.optimizers.Adadelta(), - metrics=['accuracy']) - -model.fit(x_train, y_train, - batch_size=batch_size, - epochs=epochs, - verbose=1, - validation_data=(x_test, y_test)) - -train_score = model.evaluate(x_train, y_train, verbose=0) -test_score = model.evaluate(x_test, y_test, verbose=0) -print('Train loss:', train_score[0]) -print('Train accuracy:', train_score[1]) -print('Test loss:', test_score[0]) -print('Test accuracy:', test_score[1]) diff --git a/benchmarks/kuzushiji_mnist_knn.py b/benchmarks/kuzushiji_mnist_knn.py deleted file mode 100644 index 2b04a58..0000000 --- a/benchmarks/kuzushiji_mnist_knn.py +++ /dev/null @@ -1,26 +0,0 @@ -# kNN with neighbors=4 benchmark for Kuzushiji-MNIST -# Acheives 92.10% test accuracy - -from sklearn.neighbors import KNeighborsClassifier -import numpy as np - -def load(f): - return np.load(f)['arr_0'] - -# Load the data -x_train = load('kmnist-train-imgs.npz') -x_test = load('kmnist-test-imgs.npz') -y_train = load('kmnist-train-labels.npz') -y_test = load('kmnist-test-labels.npz') - -# Flatten images -x_train = x_train.reshape(-1, 784) -x_test = x_test.reshape(-1, 784) - -clf = KNeighborsClassifier(n_neighbors=4, weights='distance', n_jobs=-1) -print('Fitting', clf) -clf.fit(x_train, y_train) -print('Evaluating', clf) - -test_score = clf.score(x_test, y_test) -print('Test accuracy:', test_score) diff --git a/classification.py b/classification.py new file mode 100644 index 0000000..f7740f1 --- /dev/null +++ b/classification.py @@ -0,0 +1,198 @@ +# encoding: utf-8 +import argparse +import numpy as np +import torch +import torch.nn as nn +from torch.optim import lr_scheduler +import torch.optim as optim +import torch.backends.cudnn as cudnn +from dataset.KMNIST import get_train_dataloader, get_validation_dataloader, get_test_dataloader +import os +# self-defined +from utils.init import * +from utils.logger import get_logger +from utils.evaluation import avg_accuracy, class_metric, visualize_val_accuracy, visualize_train_loss, \ + visualize_confusion_matrix +from config import * + + +class Classification: + def __init__(self, model='ResNet-18', train_batch=64, test_batch=1000, epoch=30, ckpt_path='./models/', + class_num=10, log_img_path=''): + # get dataloader + self.dataloader_train = get_train_dataloader(batch_size=train_batch, shuffle=True, num_workers=4) + self.dataloader_val = get_validation_dataloader(batch_size=train_batch, shuffle=False, num_workers=4) + self.dataloader_test = get_test_dataloader(batch_size=test_batch, shuffle=False, num_workers=4) + self.model = nn.DataParallel(get_model(model)).cuda() + torch.backends.cudnn.benchmark = True + # set loss criterion + self.criterion = nn.CrossEntropyLoss().cuda() + # set optimizer + self.optimizer = optim.Adam(self.model.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5) + # set scheduler + self.lr_scheduler_model = lr_scheduler.StepLR(self.optimizer, step_size=10, gamma=1) + # set No. of epoch + self.max_epoch = epoch + self.loss_log = [] + self.accuracy_log = [] + # get model parameter path + self.ckpt_path = ckpt_path + model + '.pkl' + self.class_num = class_num + # get path to log result + self.log_img_path = log_img_path + + def train_epoch(self): + # set model to training mode + self.model.train() + train_loss = [] + with torch.autograd.enable_grad(): + for batch_idx, (img, lbl) in enumerate(self.dataloader_train): + # put them to gpu + image = torch.autograd.Variable(img).cuda() + label = torch.autograd.Variable(lbl).cuda() + # set gradient to 0 + self.optimizer.zero_grad() + # get prediction + output = self.model(image) + if type(output) is tuple: + output = output[0] + # compute loss + loss_tensor = self.criterion.forward(output, label) + # backward + loss_tensor.backward() + # update parameters + self.optimizer.step() + train_loss.append(loss_tensor.item()) + + return train_loss + + def test_epoch(self, dataloader=None): + # set model to evaluation mode + self.model.eval() + # set dataloader + if dataloader is None: + dataloader = self.dataloader_test + # initialize + gt = torch.FloatTensor().cuda() + pred = torch.FloatTensor().cuda() + loss_test = [] + + with torch.autograd.no_grad(): + for batch_idx, (img, lbl) in enumerate(dataloader): + # put them to gpu + image = torch.autograd.Variable(img).cuda() + label = torch.autograd.Variable(lbl).cuda() + # get prediction + output = self.model(image) + if type(output) is tuple: + output = output[0] + # compute loss + loss_tensor = self.criterion.forward(output, label) + loss_test.append(loss_tensor.item()) + _, pred_label = torch.max(output.data, 1) + + gt = torch.cat((gt, label.data), 0) + pred = torch.cat((pred, pred_label.data), 0) + + return np.mean(loss_test), gt.cpu().numpy(), pred.cpu().numpy() + + def val_epoch(self): + loss, gt, pred = self.test_epoch(self.dataloader_val) + acc = avg_accuracy(gt, pred) + return loss, acc + + def train_model(self): + logger.info('********************begin training!********************') + accuracy_max = 0.0 + for epoch in range(self.max_epoch): + # train + # get train loss for one epoch + train_loss = self.train_epoch() + train_loss = np.mean(train_loss) + self.loss_log.append(train_loss) + + # log train loss + logger.info("Eopch: %5d train loss = %.6f" % (epoch + 1, train_loss)) + self.lr_scheduler_model.step() + + # validation + # get validation loss and accuracy + val_loss, val_accuracy = self.val_epoch() + + # log validation loss and accuracy + logger.info("Eopch: %5d valuation loss = %.6f, ACC = %.6f" % (epoch + 1, val_loss, val_accuracy)) + self.accuracy_log.append(val_accuracy) + + # save best checkpoint + if accuracy_max < val_accuracy: + accuracy_max = val_accuracy + torch.save(self.model.state_dict(), self.ckpt_path) # Saving torch.nn.DataParallel Models + logger.info(' Epoch: {} model has been already save!'.format(epoch + 1)) + + logger.info( + 'Training epoch: {} completed.'.format(epoch + 1)) + + # visualize train loss and validation accuracy + visualize_train_loss(self.loss_log, logger, self.log_img_path) + visualize_val_accuracy(self.accuracy_log, logger, self.log_img_path) + logger.info('Train Loss:') + logger.info(','.join([str(x) for x in self.loss_log])) + logger.info('Validation Accuracy:') + logger.info(','.join([str(x) for x in self.accuracy_log])) + + def test_model(self): + # test + # load model parameters + if os.path.isfile(self.ckpt_path): + checkpoint = torch.load(self.ckpt_path) + self.model.load_state_dict(checkpoint) + logger.info("=> loaded model checkpoint: " + self.ckpt_path) + + logger.info('******* begin testing!*********') + + # get test loss, groud truth, prediction + loss, gt, pred = self.test_epoch() + # log test loss + logger.info("Test Averaged Loss = %.6f" % (loss)) + # compute average accuracy + test_acc = avg_accuracy(gt, pred) + # log test accuracy + logger.info("Test Averaged Accuracy = %.6f" % (test_acc)) + # plot confusion matrix + cm = visualize_confusion_matrix(gt, pred, logger, self.log_img_path) + + # compute accuracy, precision, recall, f1 score for each class + for i in range(self.class_num): + acc, precision, recall, f_score = class_metric(cm, i) + logger.info("Class: %5d Accuracy = %.6f Precision = %.6f Recall = %.6f f-score = %.6f" % ( + i, acc, precision, recall, f_score)) + + +if __name__ == '__main__': + # command parameters + parser = argparse.ArgumentParser() + parser.add_argument('--model', type=str, default='ResMLP-12') + parser.add_argument('--gpu', type=str, default=config['CUDA_VISIBLE_DEVICES']) + parser.add_argument('--train_batch', type=int, default=64) + parser.add_argument('--test_batch', type=int, default=500) + parser.add_argument('--epoch', type=int, default=50) + parser.add_argument('--train', type=int, default=0) + parser.add_argument('--test', type=int, default=1) + parser.add_argument('--class_num', type=int, default=None) + parser.add_argument('--ckpt_path', type=str, default='./models/') + args = parser.parse_args() + # set log + logger = get_logger(config['LOG_PATH'] + args.model) + os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu + + if args.class_num is None: + class_num = config['N_CLASSES'] + else: + class_num = args.class_num + classification = Classification(model=args.model, train_batch=args.train_batch, test_batch=args.test_batch, + epoch=args.epoch, ckpt_path=args.ckpt_path, class_num=class_num, + log_img_path=config['LOG_PATH'] + args.model + '/') + if args.train == 1: + classification.train_model() + if args.test == 1: + classification.test_model() diff --git a/config.py b/config.py new file mode 100644 index 0000000..83cd86b --- /dev/null +++ b/config.py @@ -0,0 +1,15 @@ +PROJECT_PATH = './' + +config = { + 'LOG_PATH': PROJECT_PATH + 'log/', + 'TRAIN_FILE': PROJECT_PATH + 'kmnist-train-imgs.npz', + 'TEST_FILE': PROJECT_PATH + 'kmnist-test-imgs.npz', + 'TRAIN_LABEL': PROJECT_PATH + 'kmnist-train-labels.npz', + 'TEST_LABEL': PROJECT_PATH + 'kmnist-test-labels.npz', + 'TRAIN_NUM': 54000, + 'CUDA_VISIBLE_DEVICES': "0", + 'TRAN_SIZE': 224, + 'TRAN_CROP': 224, + 'N_CLASSES': 10, + 'MODEL': 'ResMLP-12' +} diff --git a/dataset/KMNIST.py b/dataset/KMNIST.py new file mode 100644 index 0000000..058cc0b --- /dev/null +++ b/dataset/KMNIST.py @@ -0,0 +1,71 @@ +import torch +from torch.utils.data import Dataset +from torch.utils.data import DataLoader +import torchvision.transforms as transforms +import numpy as np + +from config import * + +# set dataset path +PATH_TO_TRAIN_FILE = config['TRAIN_FILE'] +PATH_TO_TEST_FILE = config['TEST_FILE'] +PATH_TO_TRAIN_LABEL_FILE = config['TRAIN_LABEL'] +PATH_TO_TEST_LABEL_FILE = config['TEST_LABEL'] +TRAIN_IMAGE_NUM = config['TRAIN_NUM'] + + +# load data +class MyDataset(Dataset): + def __init__(self, data, target, transform=None): + self.data = torch.from_numpy(data).float() + self.data = self.data.unsqueeze(1) + self.data = torch.cat((self.data, self.data, self.data), 1) + self.target = torch.from_numpy(target).long() + self.transform = transform + + def __getitem__(self, index): + x = self.data[index] + y = self.target[index] + + if self.transform: + x = self.transform(x) + + return x, y + + def __len__(self): + return len(self.data) + + +# preprocess images +transform = transforms.Compose([transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + transforms.Resize((config['TRAN_SIZE'], config['TRAN_SIZE']))]) + + +# configure dataloader for train data +def get_train_dataloader(batch_size, shuffle, num_workers): + train_images = np.load(PATH_TO_TRAIN_FILE)['arr_0'][:TRAIN_IMAGE_NUM] + train_labels = np.load(PATH_TO_TRAIN_LABEL_FILE)['arr_0'][:TRAIN_IMAGE_NUM] + train_dataset = MyDataset(train_images, train_labels, transform) + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) + + return train_loader + + +# configure dataloader for validation data +def get_validation_dataloader(batch_size, shuffle, num_workers): + val_images = np.load(PATH_TO_TRAIN_FILE)['arr_0'][TRAIN_IMAGE_NUM:] + val_labels = np.load(PATH_TO_TRAIN_LABEL_FILE)['arr_0'][TRAIN_IMAGE_NUM:] + val_dataset = MyDataset(val_images, val_labels, transform) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) + + return val_loader + + +# configure dataloader for test data +def get_test_dataloader(batch_size, shuffle, num_workers): + test_images = np.load(PATH_TO_TEST_FILE)['arr_0'] + test_labels = np.load(PATH_TO_TEST_LABEL_FILE)['arr_0'] + test_dataset = MyDataset(test_images, test_labels, transform) + test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) + + return test_loader diff --git a/dataset/__pycache__/KMNIST.cpython-38.pyc b/dataset/__pycache__/KMNIST.cpython-38.pyc new file mode 100644 index 0000000..6647318 Binary files /dev/null and b/dataset/__pycache__/KMNIST.cpython-38.pyc differ diff --git a/download_data.py b/download_data.py deleted file mode 100644 index 1b831de..0000000 --- a/download_data.py +++ /dev/null @@ -1,80 +0,0 @@ -import requests - -try: - from tqdm import tqdm -except ImportError: - tqdm = lambda x, total, unit: x # If tqdm doesn't exist, replace it with a function that does nothing - print('**** Could not import tqdm. Please install tqdm for download progressbars! (pip install tqdm) ****') - -# Python2 compatibility -try: - input = raw_input -except NameError: - pass - -download_dict = { - '1) Kuzushiji-MNIST (10 classes, 28x28, 70k examples)': { - '1) MNIST data format (ubyte.gz)': - ['http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-images-idx3-ubyte.gz', - 'http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-labels-idx1-ubyte.gz', - 'http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-images-idx3-ubyte.gz', - 'http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-labels-idx1-ubyte.gz'], - '2) NumPy data format (.npz)': - ['http://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist-train-imgs.npz', - 'http://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist-train-labels.npz', - 'http://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist-test-imgs.npz', - 'http://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist-test-labels.npz'], - }, - '2) Kuzushiji-49 (49 classes, 28x28, 270k examples)': { - '1) NumPy data format (.npz)': - ['http://codh.rois.ac.jp/kmnist/dataset/k49/k49-train-imgs.npz', - 'http://codh.rois.ac.jp/kmnist/dataset/k49/k49-train-labels.npz', - 'http://codh.rois.ac.jp/kmnist/dataset/k49/k49-test-imgs.npz', - 'http://codh.rois.ac.jp/kmnist/dataset/k49/k49-test-labels.npz'], - }, - '3) Kuzushiji-Kanji (3832 classes, 64x64, 140k examples)': { - '1) Folders of images (.tar)': - ['http://codh.rois.ac.jp/kmnist/dataset/kkanji/kkanji.tar'], - } - -} - -# Download a list of files -def download_list(url_list): - for url in url_list: - path = url.split('/')[-1] - r = requests.get(url, stream=True) - with open(path, 'wb') as f: - total_length = int(r.headers.get('content-length')) - print('Downloading {} - {:.1f} MB'.format(path, (total_length / 1024000))) - - for chunk in tqdm(r.iter_content(chunk_size=1024), total=int(total_length / 1024) + 1, unit="KB"): - if chunk: - f.write(chunk) - print('All dataset files downloaded!') - -# Ask the user about which path to take down the dict -def traverse_dict(d): - print('Please select a download option:') - keys = sorted(d.keys()) # Print download options - for key in keys: - print(key) - - userinput = input('> ').strip() - - try: - selection = int(userinput) - 1 - except ValueError: - print('Your selection was not valid') - traverse_dict(d) # Try again if input was not valid - return - - selected = keys[selection] - - next_level = d[selected] - if isinstance(next_level, list): # If we've hit a list of downloads, download that list - download_list(next_level) - else: - traverse_dict(next_level) # Otherwise, repeat with the next level - -traverse_dict(download_dict) diff --git a/images/kkanji_examples.png b/images/kkanji_examples.png deleted file mode 100644 index b093571..0000000 Binary files a/images/kkanji_examples.png and /dev/null differ diff --git a/images/kmnist_examples.png b/images/kmnist_examples.png deleted file mode 100644 index 7df90cb..0000000 Binary files a/images/kmnist_examples.png and /dev/null differ diff --git a/kmnist_classmap.csv b/kmnist_classmap.csv new file mode 100644 index 0000000..ba15c5e --- /dev/null +++ b/kmnist_classmap.csv @@ -0,0 +1,11 @@ +index,codepoint,char +0,U+304A,お +1,U+304D,き +2,U+3059,す +3,U+3064,぀ +4,U+306A,γͺ +5,U+306F,は +6,U+307E,ま +7,U+3084,γ‚„ +8,U+308C,γ‚Œ +9,U+3092,γ‚’ diff --git a/log/ResMLP-12/confusion_matrix.png b/log/ResMLP-12/confusion_matrix.png new file mode 100644 index 0000000..c35cfeb Binary files /dev/null and b/log/ResMLP-12/confusion_matrix.png differ diff --git a/log/ResMLP-12/log.txt b/log/ResMLP-12/log.txt new file mode 100644 index 0000000..cc434d6 --- /dev/null +++ b/log/ResMLP-12/log.txt @@ -0,0 +1,118 @@ +********************begin training!******************** +Eopch: 1 train loss = 1.474743 +Eopch: 1 valuation loss = 0.159035, ACC = 0.950333 + Epoch: 1 model has been already save! +Training epoch: 1 completed. +Eopch: 2 train loss = 0.108675 +Eopch: 2 valuation loss = 0.088274, ACC = 0.973833 + Epoch: 2 model has been already save! +Training epoch: 2 completed. +Eopch: 3 train loss = 0.066299 +Eopch: 3 valuation loss = 0.085802, ACC = 0.978167 + Epoch: 3 model has been already save! +Training epoch: 3 completed. +Eopch: 4 train loss = 0.053803 +Eopch: 4 valuation loss = 0.069033, ACC = 0.982500 + Epoch: 4 model has been already save! +Training epoch: 4 completed. +Eopch: 5 train loss = 0.053161 +Eopch: 5 valuation loss = 0.064765, ACC = 0.983000 + Epoch: 5 model has been already save! +Training epoch: 5 completed. +Eopch: 6 train loss = 0.039282 +Eopch: 6 valuation loss = 0.076107, ACC = 0.984500 + Epoch: 6 model has been already save! +Training epoch: 6 completed. +Eopch: 7 train loss = 0.048531 +Eopch: 7 valuation loss = 0.070968, ACC = 0.980333 +Training epoch: 7 completed. +Eopch: 8 train loss = 0.046734 +Eopch: 8 valuation loss = 0.101906, ACC = 0.974667 +Training epoch: 8 completed. +Eopch: 9 train loss = 0.057909 +Eopch: 9 valuation loss = 0.091453, ACC = 0.982500 +Training epoch: 9 completed. +Eopch: 10 train loss = 0.068523 +Eopch: 10 valuation loss = 0.100253, ACC = 0.980667 +Training epoch: 10 completed. +Eopch: 11 train loss = 0.093497 +Eopch: 11 valuation loss = 0.107574, ACC = 0.977167 +Training epoch: 11 completed. +Eopch: 12 train loss = 0.044268 +Eopch: 12 valuation loss = 0.078949, ACC = 0.983667 +Training epoch: 12 completed. +Eopch: 13 train loss = 0.623085 +Eopch: 13 valuation loss = 0.157252, ACC = 0.978500 +Training epoch: 13 completed. +Eopch: 14 train loss = 0.037385 +Eopch: 14 valuation loss = 0.110624, ACC = 0.982500 +Training epoch: 14 completed. +Eopch: 15 train loss = 0.020321 +Eopch: 15 valuation loss = 0.297462, ACC = 0.973000 +Training epoch: 15 completed. +Eopch: 16 train loss = 0.052521 +Eopch: 16 valuation loss = 0.225200, ACC = 0.980667 +Training epoch: 16 completed. +Eopch: 17 train loss = 0.126149 +Eopch: 17 valuation loss = 0.223529, ACC = 0.979167 +Training epoch: 17 completed. +Eopch: 18 train loss = 0.064713 +Eopch: 18 valuation loss = 0.177399, ACC = 0.983833 +Training epoch: 18 completed. +Eopch: 19 train loss = 0.114511 +Eopch: 19 valuation loss = 0.253915, ACC = 0.980500 +Training epoch: 19 completed. +Eopch: 20 train loss = 0.074080 +Eopch: 20 valuation loss = 0.226443, ACC = 0.979333 +Training epoch: 20 completed. +Eopch: 21 train loss = 0.080519 +Eopch: 21 valuation loss = 0.377549, ACC = 0.975333 +Training epoch: 21 completed. +Eopch: 22 train loss = 0.105814 +Eopch: 22 valuation loss = 0.296062, ACC = 0.980833 +Training epoch: 22 completed. +Eopch: 23 train loss = 0.117329 +Eopch: 23 valuation loss = 0.336344, ACC = 0.980333 +Training epoch: 23 completed. +Eopch: 24 train loss = 0.048457 +Eopch: 24 valuation loss = 0.182534, ACC = 0.983333 +Training epoch: 24 completed. +Eopch: 25 train loss = 0.110891 +Eopch: 25 valuation loss = 0.194350, ACC = 0.984167 +Training epoch: 25 completed. +Eopch: 26 train loss = 0.093364 +Eopch: 26 valuation loss = 0.332007, ACC = 0.977000 +Training epoch: 26 completed. +Eopch: 27 train loss = 0.079099 +Eopch: 27 valuation loss = 0.309876, ACC = 0.980833 +Training epoch: 27 completed. +Eopch: 28 train loss = 0.097644 +Eopch: 28 valuation loss = 0.295460, ACC = 0.981500 +Training epoch: 28 completed. +Eopch: 29 train loss = 0.102000 +Eopch: 29 valuation loss = 0.293703, ACC = 0.982833 +Training epoch: 29 completed. +Eopch: 30 train loss = 0.080658 +Eopch: 30 valuation loss = 0.261712, ACC = 0.983833 +Training epoch: 30 completed. +Train Loss Visualization is saved to /userhome/cs2/mingzeng/codes/kmnist/log/ResMLP-12/train_loss.png +Validation Accuracy Visualization is saved to /userhome/cs2/mingzeng/codes/kmnist/log/ResMLP-12/val_acc.png +Train Loss: +1.474743462770171,0.1086748668005893,0.06629924054679094,0.05380273102917496,0.05316117941462244,0.0392818591715359,0.04853067634755173,0.04673448809202136,0.057909078096927434,0.06852341221147061,0.09349699254253382,0.044267768561718206,0.6230845790207216,0.03738521428460376,0.020321206379019432,0.052520529435868546,0.12614923562628944,0.0647127392301931,0.11451104171100272,0.074079771555273,0.0805192425736352,0.1058142589081407,0.11732911902895017,0.04845723011321656,0.11089138890054262,0.09336423637522687,0.0790989557700349,0.0976435722336647,0.10200038970579342,0.08065755875858951 +Validation Accuracy: +0.9503333333333334,0.9738333333333333,0.9781666666666666,0.9825,0.983,0.9845,0.9803333333333333,0.9746666666666667,0.9825,0.9806666666666667,0.9771666666666666,0.9836666666666667,0.9785,0.9825,0.973,0.9806666666666667,0.9791666666666666,0.9838333333333333,0.9805,0.9793333333333333,0.9753333333333334,0.9808333333333333,0.9803333333333333,0.9833333333333333,0.9841666666666666,0.977,0.9808333333333333,0.9815,0.9828333333333333,0.9838333333333333 +=> loaded model checkpoint: /userhome/cs2/mingzeng/codes/kmnist/models/ResMLP-12.pkl +******* begin testing!********* +Test Averaged Loss = 0.263454 +Test Averaged Accuracy = 0.948300 +Confusion Matrix Visualization is saved to /userhome/cs2/mingzeng/codes/kmnist/log/ResMLP-12/confusion_matrix.png +Class: 0 Accuracy = 0.991500 Precision = 0.951629 Recall = 0.964000 f-score = 0.957774 +Class: 1 Accuracy = 0.990000 Precision = 0.961066 Recall = 0.938000 f-score = 0.949393 +Class: 2 Accuracy = 0.978500 Precision = 0.874166 Recall = 0.917000 f-score = 0.895071 +Class: 3 Accuracy = 0.991500 Precision = 0.941176 Recall = 0.976000 f-score = 0.958272 +Class: 4 Accuracy = 0.988800 Precision = 0.957732 Recall = 0.929000 f-score = 0.943147 +Class: 5 Accuracy = 0.987100 Precision = 0.963791 Recall = 0.905000 f-score = 0.933471 +Class: 6 Accuracy = 0.989700 Precision = 0.925926 Recall = 0.975000 f-score = 0.949830 +Class: 7 Accuracy = 0.994500 Precision = 0.983623 Recall = 0.961000 f-score = 0.972180 +Class: 8 Accuracy = 0.991100 Precision = 0.945259 Recall = 0.967000 f-score = 0.956006 +Class: 9 Accuracy = 0.993900 Precision = 0.987539 Recall = 0.951000 f-score = 0.968925 \ No newline at end of file diff --git a/log/ResMLP-12/train_loss.png b/log/ResMLP-12/train_loss.png new file mode 100644 index 0000000..425945a Binary files /dev/null and b/log/ResMLP-12/train_loss.png differ diff --git a/log/ResMLP-12/val_acc.png b/log/ResMLP-12/val_acc.png new file mode 100644 index 0000000..dc48edd Binary files /dev/null and b/log/ResMLP-12/val_acc.png differ diff --git a/log/ResMLP-24/confusion_matrix.png b/log/ResMLP-24/confusion_matrix.png new file mode 100644 index 0000000..8ead357 Binary files /dev/null and b/log/ResMLP-24/confusion_matrix.png differ diff --git a/log/ResMLP-24/log.txt b/log/ResMLP-24/log.txt new file mode 100644 index 0000000..fd80bc8 --- /dev/null +++ b/log/ResMLP-24/log.txt @@ -0,0 +1,117 @@ +********************begin training!******************** +Eopch: 1 train loss = 0.258324 +Eopch: 1 valuation loss = 0.103075, ACC = 0.970500 + Epoch: 1 model has been already save! +Training epoch: 1 completed. +Eopch: 2 train loss = 0.060491 +Eopch: 2 valuation loss = 0.055647, ACC = 0.983667 + Epoch: 2 model has been already save! +Training epoch: 2 completed. +Eopch: 3 train loss = 0.039388 +Eopch: 3 valuation loss = 0.061271, ACC = 0.982000 +Training epoch: 3 completed. +Eopch: 4 train loss = 0.034410 +Eopch: 4 valuation loss = 0.066298, ACC = 0.983500 +Training epoch: 4 completed. +Eopch: 5 train loss = 0.024529 +Eopch: 5 valuation loss = 0.102639, ACC = 0.977667 +Training epoch: 5 completed. +Eopch: 6 train loss = 0.024616 +Eopch: 6 valuation loss = 0.090551, ACC = 0.980000 +Training epoch: 6 completed. +Eopch: 7 train loss = 0.063829 +Eopch: 7 valuation loss = 0.049286, ACC = 0.988833 + Epoch: 7 model has been already save! +Training epoch: 7 completed. +Eopch: 8 train loss = 0.010175 +Eopch: 8 valuation loss = 0.084901, ACC = 0.983167 +Training epoch: 8 completed. +Eopch: 9 train loss = 0.008173 +Eopch: 9 valuation loss = 0.066344, ACC = 0.988500 +Training epoch: 9 completed. +Eopch: 10 train loss = 0.010041 +Eopch: 10 valuation loss = 0.081577, ACC = 0.986000 +Training epoch: 10 completed. +Eopch: 11 train loss = 0.017877 +Eopch: 11 valuation loss = 0.096666, ACC = 0.984000 +Training epoch: 11 completed. +Eopch: 12 train loss = 0.016719 +Eopch: 12 valuation loss = 0.072074, ACC = 0.985833 +Training epoch: 12 completed. +Eopch: 13 train loss = 0.016574 +Eopch: 13 valuation loss = 0.046974, ACC = 0.988167 +Training epoch: 13 completed. +Eopch: 14 train loss = 0.012572 +Eopch: 14 valuation loss = 0.078121, ACC = 0.982167 +Training epoch: 14 completed. +Eopch: 15 train loss = 0.013178 +Eopch: 15 valuation loss = 0.101448, ACC = 0.981667 +Training epoch: 15 completed. +Eopch: 16 train loss = 0.010542 +Eopch: 16 valuation loss = 0.068331, ACC = 0.991000 + Epoch: 16 model has been already save! +Training epoch: 16 completed. +Eopch: 17 train loss = 0.008493 +Eopch: 17 valuation loss = 0.096219, ACC = 0.983833 +Training epoch: 17 completed. +Eopch: 18 train loss = 0.013357 +Eopch: 18 valuation loss = 0.061578, ACC = 0.987333 +Training epoch: 18 completed. +Eopch: 19 train loss = 0.009315 +Eopch: 19 valuation loss = 0.078054, ACC = 0.986833 +Training epoch: 19 completed. +Eopch: 20 train loss = 0.006614 +Eopch: 20 valuation loss = 0.065001, ACC = 0.989833 +Training epoch: 20 completed. +Eopch: 21 train loss = 0.009743 +Eopch: 21 valuation loss = 0.058913, ACC = 0.986833 +Training epoch: 21 completed. +Eopch: 22 train loss = 0.004101 +Eopch: 22 valuation loss = 0.071325, ACC = 0.989000 +Training epoch: 22 completed. +Eopch: 23 train loss = 0.011880 +Eopch: 23 valuation loss = 0.061699, ACC = 0.988000 +Training epoch: 23 completed. +Eopch: 24 train loss = 0.005898 +Eopch: 24 valuation loss = 0.052160, ACC = 0.989667 +Training epoch: 24 completed. +Eopch: 25 train loss = 0.007055 +Eopch: 25 valuation loss = 0.060927, ACC = 0.986500 +Training epoch: 25 completed. +Eopch: 26 train loss = 0.006872 +Eopch: 26 valuation loss = 0.047085, ACC = 0.991167 + Epoch: 26 model has been already save! +Training epoch: 26 completed. +Eopch: 27 train loss = 0.006201 +Eopch: 27 valuation loss = 0.062409, ACC = 0.989333 +Training epoch: 27 completed. +Eopch: 28 train loss = 0.004974 +Eopch: 28 valuation loss = 0.079944, ACC = 0.989667 +Training epoch: 28 completed. +Eopch: 29 train loss = 0.009411 +Eopch: 29 valuation loss = 0.058826, ACC = 0.988333 +Training epoch: 29 completed. +Eopch: 30 train loss = 0.002585 +Eopch: 30 valuation loss = 0.071575, ACC = 0.986667 +Training epoch: 30 completed. +Train Loss Visualization is saved to /userhome/cs2/mingzeng/codes/kmnist/log/ResMLP-24/train_loss.png +Validation Accuracy Visualization is saved to /userhome/cs2/mingzeng/codes/kmnist/log/ResMLP-24/val_acc.png +Train Loss: +0.25832390663422583,0.06049111991284456,0.03938772366459338,0.03440955887405244,0.0245294870920013,0.024616345635608633,0.06382881039148937,0.010174585332210257,0.00817297156926667,0.010040507038089059,0.01787675555819019,0.01671852239838346,0.016574336527537304,0.012571598617393875,0.01317848334338329,0.010541742407758438,0.008492613288001888,0.013356601136628456,0.009315182970803984,0.006614003157931455,0.009743359711699696,0.004101245476077177,0.011880251484419804,0.005898466958283593,0.007055483976123433,0.0068718988866095405,0.006201463303413665,0.004973818432163777,0.009411264056348005,0.0025846648623085136 +Validation Accuracy: +0.9705,0.9836666666666667,0.982,0.9835,0.9776666666666667,0.98,0.9888333333333333,0.9831666666666666,0.9885,0.986,0.984,0.9858333333333333,0.9881666666666666,0.9821666666666666,0.9816666666666667,0.991,0.9838333333333333,0.9873333333333333,0.9868333333333333,0.9898333333333333,0.9868333333333333,0.989,0.988,0.9896666666666667,0.9865,0.9911666666666666,0.9893333333333333,0.9896666666666667,0.9883333333333333,0.9866666666666667 +=> loaded model checkpoint: /userhome/cs2/mingzeng/codes/kmnist/models/ResMLP-24.pkl +******* begin testing!********* +Test Averaged Loss = 0.205754 +Test Averaged Accuracy = 0.971000 +Confusion Matrix Visualization is saved to /userhome/cs2/mingzeng/codes/kmnist/log/ResMLP-24/confusion_matrix.png +Class: 0 Accuracy = 0.995300 Precision = 0.965787 Recall = 0.988000 f-score = 0.976767 +Class: 1 Accuracy = 0.994200 Precision = 0.986570 Recall = 0.955000 f-score = 0.970528 +Class: 2 Accuracy = 0.991800 Precision = 0.980126 Recall = 0.937000 f-score = 0.958078 +Class: 3 Accuracy = 0.996100 Precision = 0.980020 Recall = 0.981000 f-score = 0.980510 +Class: 4 Accuracy = 0.991200 Precision = 0.957831 Recall = 0.954000 f-score = 0.955912 +Class: 5 Accuracy = 0.995000 Precision = 0.972167 Recall = 0.978000 f-score = 0.975075 +Class: 6 Accuracy = 0.989400 Precision = 0.913889 Recall = 0.987000 f-score = 0.949038 +Class: 7 Accuracy = 0.995600 Precision = 0.988753 Recall = 0.967000 f-score = 0.977755 +Class: 8 Accuracy = 0.996000 Precision = 0.983871 Recall = 0.976000 f-score = 0.979920 +Class: 9 Accuracy = 0.997400 Precision = 0.987000 Recall = 0.987000 f-score = 0.987000 \ No newline at end of file diff --git a/log/ResMLP-24/train_loss.png b/log/ResMLP-24/train_loss.png new file mode 100644 index 0000000..22f0460 Binary files /dev/null and b/log/ResMLP-24/train_loss.png differ diff --git a/log/ResMLP-24/val_acc.png b/log/ResMLP-24/val_acc.png new file mode 100644 index 0000000..478eba5 Binary files /dev/null and b/log/ResMLP-24/val_acc.png differ diff --git a/log/ResNet-18/confusion_matrix.png b/log/ResNet-18/confusion_matrix.png new file mode 100644 index 0000000..ba4e838 Binary files /dev/null and b/log/ResNet-18/confusion_matrix.png differ diff --git a/log/ResNet-18/log.txt b/log/ResNet-18/log.txt new file mode 100644 index 0000000..410c501 --- /dev/null +++ b/log/ResNet-18/log.txt @@ -0,0 +1,536 @@ +********************begin training!******************** +Eopch: 1 train loss = 0.112472 +Eopch: 1 valuation loss = 0.032543, ACC = 0.991333 + Epoch: 1 model has been already save! +Training epoch: 1 completed. +Eopch: 2 train loss = 0.016170 +Eopch: 2 valuation loss = 0.033873, ACC = 0.991000 +Training epoch: 2 completed. +Eopch: 3 train loss = 0.010401 +Eopch: 3 valuation loss = 0.033791, ACC = 0.991500 + Epoch: 3 model has been already save! +Training epoch: 3 completed. +Eopch: 4 train loss = 0.008944 +Eopch: 4 valuation loss = 0.021961, ACC = 0.994333 + Epoch: 4 model has been already save! +Training epoch: 4 completed. +Eopch: 5 train loss = 0.006552 +Eopch: 5 valuation loss = 0.020410, ACC = 0.995500 + Epoch: 5 model has been already save! +Training epoch: 5 completed. +Eopch: 6 train loss = 0.004441 +Eopch: 6 valuation loss = 0.028272, ACC = 0.992833 +Training epoch: 6 completed. +Eopch: 7 train loss = 0.007019 +Eopch: 7 valuation loss = 0.029333, ACC = 0.992500 +Training epoch: 7 completed. +Eopch: 8 train loss = 0.003278 +Eopch: 8 valuation loss = 0.021766, ACC = 0.993500 +Training epoch: 8 completed. +Eopch: 9 train loss = 0.006029 +Eopch: 9 valuation loss = 0.023517, ACC = 0.993667 +Training epoch: 9 completed. +Eopch: 10 train loss = 0.005264 +Eopch: 10 valuation loss = 0.021146, ACC = 0.993833 +Training epoch: 10 completed. +Eopch: 11 train loss = 0.002638 +Eopch: 11 valuation loss = 0.017968, ACC = 0.995500 +Training epoch: 11 completed. +Eopch: 12 train loss = 0.002231 +Eopch: 12 valuation loss = 0.028438, ACC = 0.993167 +Training epoch: 12 completed. +Eopch: 13 train loss = 0.004573 +Eopch: 13 valuation loss = 0.024170, ACC = 0.994167 +Training epoch: 13 completed. +Eopch: 14 train loss = 0.002521 +Eopch: 14 valuation loss = 0.022760, ACC = 0.995500 +Training epoch: 14 completed. +Eopch: 15 train loss = 0.001604 +Eopch: 15 valuation loss = 0.029690, ACC = 0.992000 +Training epoch: 15 completed. +Eopch: 16 train loss = 0.006733 +Eopch: 16 valuation loss = 0.024623, ACC = 0.994000 +Training epoch: 16 completed. +Eopch: 17 train loss = 0.002092 +Eopch: 17 valuation loss = 0.024397, ACC = 0.993833 +Training epoch: 17 completed. +Eopch: 18 train loss = 0.001126 +Eopch: 18 valuation loss = 0.018558, ACC = 0.994500 +Training epoch: 18 completed. +Eopch: 19 train loss = 0.003236 +Eopch: 19 valuation loss = 0.019188, ACC = 0.995333 +Training epoch: 19 completed. +Eopch: 20 train loss = 0.003076 +Eopch: 20 valuation loss = 0.021177, ACC = 0.994833 +Training epoch: 20 completed. +Eopch: 21 train loss = 0.001987 +Eopch: 21 valuation loss = 0.027134, ACC = 0.992667 +Training epoch: 21 completed. +Eopch: 22 train loss = 0.002780 +Eopch: 22 valuation loss = 0.036788, ACC = 0.992167 +Training epoch: 22 completed. +Eopch: 23 train loss = 0.002600 +Eopch: 23 valuation loss = 0.026954, ACC = 0.994000 +Training epoch: 23 completed. +Eopch: 24 train loss = 0.002509 +Eopch: 24 valuation loss = 0.024505, ACC = 0.994333 +Training epoch: 24 completed. +Eopch: 25 train loss = 0.002124 +Eopch: 25 valuation loss = 0.027661, ACC = 0.992333 +Training epoch: 25 completed. +Eopch: 26 train loss = 0.001353 +Eopch: 26 valuation loss = 0.023039, ACC = 0.994833 +Training epoch: 26 completed. +Eopch: 27 train loss = 0.001375 +Eopch: 27 valuation loss = 0.019753, ACC = 0.995667 + Epoch: 27 model has been already save! +Training epoch: 27 completed. +Eopch: 28 train loss = 0.002047 +Eopch: 28 valuation loss = 0.022252, ACC = 0.994500 +Training epoch: 28 completed. +Eopch: 29 train loss = 0.003442 +Eopch: 29 valuation loss = 0.026994, ACC = 0.993333 +Training epoch: 29 completed. +Eopch: 30 train loss = 0.000458 +Eopch: 30 valuation loss = 0.019145, ACC = 0.995167 +Training epoch: 30 completed. +Eopch: 31 train loss = 0.000170 +Eopch: 31 valuation loss = 0.019967, ACC = 0.995833 + Epoch: 31 model has been already save! +Training epoch: 31 completed. +Eopch: 32 train loss = 0.004198 +Eopch: 32 valuation loss = 0.022859, ACC = 0.994000 +Training epoch: 32 completed. +Eopch: 33 train loss = 0.001504 +Eopch: 33 valuation loss = 0.032546, ACC = 0.992500 +Training epoch: 33 completed. +Eopch: 34 train loss = 0.000715 +Eopch: 34 valuation loss = 0.021275, ACC = 0.995333 +Training epoch: 34 completed. +Eopch: 35 train loss = 0.001949 +Eopch: 35 valuation loss = 0.024419, ACC = 0.992667 +Training epoch: 35 completed. +Eopch: 36 train loss = 0.002625 +Eopch: 36 valuation loss = 0.022674, ACC = 0.993500 +Training epoch: 36 completed. +Eopch: 37 train loss = 0.001432 +Eopch: 37 valuation loss = 0.020469, ACC = 0.994500 +Training epoch: 37 completed. +Eopch: 38 train loss = 0.000919 +Eopch: 38 valuation loss = 0.019236, ACC = 0.995333 +Training epoch: 38 completed. +Eopch: 39 train loss = 0.002806 +Eopch: 39 valuation loss = 0.036104, ACC = 0.992500 +Training epoch: 39 completed. +Eopch: 40 train loss = 0.000593 +Eopch: 40 valuation loss = 0.017297, ACC = 0.996000 + Epoch: 40 model has been already save! +Training epoch: 40 completed. +Eopch: 41 train loss = 0.000074 +Eopch: 41 valuation loss = 0.018258, ACC = 0.996000 +Training epoch: 41 completed. +Eopch: 42 train loss = 0.000043 +Eopch: 42 valuation loss = 0.016699, ACC = 0.996333 + Epoch: 42 model has been already save! +Training epoch: 42 completed. +Eopch: 43 train loss = 0.004258 +Eopch: 43 valuation loss = 0.022825, ACC = 0.993833 +Training epoch: 43 completed. +Eopch: 44 train loss = 0.001685 +Eopch: 44 valuation loss = 0.019110, ACC = 0.995500 +Training epoch: 44 completed. +Eopch: 45 train loss = 0.000996 +Eopch: 45 valuation loss = 0.018361, ACC = 0.994833 +Training epoch: 45 completed. +Eopch: 46 train loss = 0.000111 +Eopch: 46 valuation loss = 0.017752, ACC = 0.996000 +Training epoch: 46 completed. +Eopch: 47 train loss = 0.002375 +Eopch: 47 valuation loss = 0.021209, ACC = 0.994667 +Training epoch: 47 completed. +Eopch: 48 train loss = 0.002433 +Eopch: 48 valuation loss = 0.022865, ACC = 0.995333 +Training epoch: 48 completed. +Eopch: 49 train loss = 0.000814 +Eopch: 49 valuation loss = 0.019654, ACC = 0.995333 +Training epoch: 49 completed. +Eopch: 50 train loss = 0.002580 +Eopch: 50 valuation loss = 0.026501, ACC = 0.993667 +Training epoch: 50 completed. +Train Loss Visualization is saved to /userhome/cs2/mingzeng/codes/kmnist/log/ResNet-18/train_loss.png +Validation Accuracy Visualization is saved to /userhome/cs2/mingzeng/codes/kmnist/log/ResNet-18/val_acc.png +Train Loss: +0.1124717586737797,0.016170259801807738,0.010400650284713205,0.008944160013339069,0.0065518872841114845,0.004440692789629065,0.007019181146336049,0.0032781219610485625,0.006029157559703597,0.0052644157093864686,0.0026384786645780125,0.002230725335097287,0.0045732613302939965,0.0025210315351931367,0.0016043741853499992,0.006733059158194363,0.002091662949198499,0.0011263878866893246,0.0032356474211569525,0.003076039213504813,0.0019870159624117358,0.0027797034472573034,0.0026003861377827384,0.0025087218852057145,0.0021240568647468206,0.0013532344944920081,0.0013749040911510686,0.0020474040254814243,0.0034417064467761865,0.0004576196333709312,0.00016953872533785995,0.0041983963904424,0.0015040430597704152,0.0007154906393583548,0.0019489095852944982,0.0026249186220027878,0.0014320352600488563,0.0009191627323662895,0.002806055992076538,0.000592689474532082,7.439634051884694e-05,4.266309739758123e-05,0.0042578651909786535,0.001685497372659699,0.0009962033815523893,0.00011079841955785807,0.0023747650443307396,0.0024328762135416336,0.0008143995424636108,0.00257967325731004 +Validation Accuracy: +0.9913333333333333,0.991,0.9915,0.9943333333333333,0.9955,0.9928333333333333,0.9925,0.9935,0.9936666666666667,0.9938333333333333,0.9955,0.9931666666666666,0.9941666666666666,0.9955,0.992,0.994,0.9938333333333333,0.9945,0.9953333333333333,0.9948333333333333,0.9926666666666667,0.9921666666666666,0.994,0.9943333333333333,0.9923333333333333,0.9948333333333333,0.9956666666666667,0.9945,0.9933333333333333,0.9951666666666666,0.9958333333333333,0.994,0.9925,0.9953333333333333,0.9926666666666667,0.9935,0.9945,0.9953333333333333,0.9925,0.996,0.996,0.9963333333333333,0.9938333333333333,0.9955,0.9948333333333333,0.996,0.9946666666666667,0.9953333333333333,0.9953333333333333,0.9936666666666667 +=> loaded model checkpoint: /userhome/cs2/mingzeng/codes/kmnist/models/ResNet-18.pkl +******* begin testing!********* +Test Averaged Loss = 0.057160 +Test Averaged Accuracy = 0.988300 +Confusion Matrix Visualization is saved to /userhome/cs2/mingzeng/codes/kmnist/log/ResNet-18/confusion_matrix.png +Class: 0 Accuracy = 0.996600 Precision = 0.977273 Recall = 0.989000 f-score = 0.983101 +Class: 1 Accuracy = 0.998100 Precision = 0.994955 Recall = 0.986000 f-score = 0.990457 +Class: 2 Accuracy = 0.996100 Precision = 0.986829 Recall = 0.974000 f-score = 0.980372 +Class: 3 Accuracy = 0.998300 Precision = 0.989055 Recall = 0.994000 f-score = 0.991521 +Class: 4 Accuracy = 0.995300 Precision = 0.983756 Recall = 0.969000 f-score = 0.976322 +Class: 5 Accuracy = 0.998000 Precision = 0.990982 Recall = 0.989000 f-score = 0.989990 +Class: 6 Accuracy = 0.997900 Precision = 0.985134 Recall = 0.994000 f-score = 0.989547 +Class: 7 Accuracy = 0.998600 Precision = 0.990060 Recall = 0.996000 f-score = 0.993021 +Class: 8 Accuracy = 0.998600 Precision = 0.989087 Recall = 0.997000 f-score = 0.993028 +Class: 9 Accuracy = 0.999100 Precision = 0.995996 Recall = 0.995000 f-score = 0.995498 +=> loaded model checkpoint: /userhome/cs2/mingzeng/codes/kmnist/models/ResNet-18.pkl +********************begin training!******************** +Eopch: 1 train loss = 0.002380 +Eopch: 1 valuation loss = 0.020418, ACC = 0.994667 + Epoch: 1 model has been already save! +Training epoch: 1 completed. +Eopch: 2 train loss = 0.001551 +Eopch: 2 valuation loss = 0.034197, ACC = 0.992000 +Training epoch: 2 completed. +Eopch: 3 train loss = 0.001335 +Eopch: 3 valuation loss = 0.028909, ACC = 0.992833 +Training epoch: 3 completed. +Eopch: 4 train loss = 0.002054 +Eopch: 4 valuation loss = 0.023759, ACC = 0.995000 + Epoch: 4 model has been already save! +Training epoch: 4 completed. +Eopch: 5 train loss = 0.001485 +Eopch: 5 valuation loss = 0.022776, ACC = 0.994500 +Training epoch: 5 completed. +Eopch: 6 train loss = 0.001384 +Eopch: 6 valuation loss = 0.024340, ACC = 0.993667 +Training epoch: 6 completed. +Eopch: 7 train loss = 0.002727 +Eopch: 7 valuation loss = 0.021776, ACC = 0.994333 +Training epoch: 7 completed. +Eopch: 8 train loss = 0.001291 +Eopch: 8 valuation loss = 0.018708, ACC = 0.996333 + Epoch: 8 model has been already save! +Training epoch: 8 completed. +Eopch: 9 train loss = 0.001168 +Eopch: 9 valuation loss = 0.041369, ACC = 0.990000 +Training epoch: 9 completed. +Eopch: 10 train loss = 0.003025 +Eopch: 10 valuation loss = 0.022515, ACC = 0.994833 +Training epoch: 10 completed. +Eopch: 11 train loss = 0.001245 +Eopch: 11 valuation loss = 0.021370, ACC = 0.995333 +Training epoch: 11 completed. +Eopch: 12 train loss = 0.000126 +Eopch: 12 valuation loss = 0.018098, ACC = 0.996167 +Training epoch: 12 completed. +Eopch: 13 train loss = 0.000051 +Eopch: 13 valuation loss = 0.017615, ACC = 0.996333 +Training epoch: 13 completed. +Eopch: 14 train loss = 0.000036 +Eopch: 14 valuation loss = 0.017552, ACC = 0.996667 + Epoch: 14 model has been already save! +Training epoch: 14 completed. +Eopch: 15 train loss = 0.004902 +Eopch: 15 valuation loss = 0.023053, ACC = 0.995333 +Training epoch: 15 completed. +Eopch: 16 train loss = 0.001666 +Eopch: 16 valuation loss = 0.016388, ACC = 0.996000 +Training epoch: 16 completed. +Eopch: 17 train loss = 0.000240 +Eopch: 17 valuation loss = 0.015207, ACC = 0.996500 +Training epoch: 17 completed. +Eopch: 18 train loss = 0.000062 +Eopch: 18 valuation loss = 0.014461, ACC = 0.996667 +Training epoch: 18 completed. +Eopch: 19 train loss = 0.000047 +Eopch: 19 valuation loss = 0.015068, ACC = 0.996333 +Training epoch: 19 completed. +Eopch: 20 train loss = 0.003623 +Eopch: 20 valuation loss = 0.032521, ACC = 0.992167 +Training epoch: 20 completed. +Eopch: 21 train loss = 0.003394 +Eopch: 21 valuation loss = 0.017815, ACC = 0.995167 +Training epoch: 21 completed. +Eopch: 22 train loss = 0.000386 +Eopch: 22 valuation loss = 0.016767, ACC = 0.995667 +Training epoch: 22 completed. +Eopch: 23 train loss = 0.002935 +Eopch: 23 valuation loss = 0.022401, ACC = 0.995000 +Training epoch: 23 completed. +Eopch: 24 train loss = 0.001660 +Eopch: 24 valuation loss = 0.023005, ACC = 0.995000 +Training epoch: 24 completed. +Eopch: 25 train loss = 0.000494 +Eopch: 25 valuation loss = 0.026597, ACC = 0.994500 +Training epoch: 25 completed. +Eopch: 26 train loss = 0.002265 +Eopch: 26 valuation loss = 0.024182, ACC = 0.994333 +Training epoch: 26 completed. +Eopch: 27 train loss = 0.000340 +Eopch: 27 valuation loss = 0.017949, ACC = 0.996167 +Training epoch: 27 completed. +Eopch: 28 train loss = 0.000819 +Eopch: 28 valuation loss = 0.027741, ACC = 0.993167 +Training epoch: 28 completed. +Eopch: 29 train loss = 0.001746 +Eopch: 29 valuation loss = 0.022587, ACC = 0.995167 +Training epoch: 29 completed. +Eopch: 30 train loss = 0.000783 +Eopch: 30 valuation loss = 0.023898, ACC = 0.995667 +Training epoch: 30 completed. +Eopch: 31 train loss = 0.001118 +Eopch: 31 valuation loss = 0.021154, ACC = 0.995167 +Training epoch: 31 completed. +Eopch: 32 train loss = 0.003023 +Eopch: 32 valuation loss = 0.026357, ACC = 0.993500 +Training epoch: 32 completed. +Eopch: 33 train loss = 0.000544 +Eopch: 33 valuation loss = 0.023024, ACC = 0.994833 +Training epoch: 33 completed. +Eopch: 34 train loss = 0.000282 +Eopch: 34 valuation loss = 0.020969, ACC = 0.996167 +Training epoch: 34 completed. +Eopch: 35 train loss = 0.001814 +Eopch: 35 valuation loss = 0.024345, ACC = 0.994667 +Training epoch: 35 completed. +Eopch: 36 train loss = 0.002169 +Eopch: 36 valuation loss = 0.026725, ACC = 0.994333 +Training epoch: 36 completed. +Eopch: 37 train loss = 0.000258 +Eopch: 37 valuation loss = 0.021777, ACC = 0.995500 +Training epoch: 37 completed. +Eopch: 38 train loss = 0.002167 +Eopch: 38 valuation loss = 0.017834, ACC = 0.995500 +Training epoch: 38 completed. +Eopch: 39 train loss = 0.000532 +Eopch: 39 valuation loss = 0.020227, ACC = 0.995500 +Training epoch: 39 completed. +Eopch: 40 train loss = 0.000691 +Eopch: 40 valuation loss = 0.017154, ACC = 0.996333 +Training epoch: 40 completed. +Eopch: 41 train loss = 0.000897 +Eopch: 41 valuation loss = 0.025298, ACC = 0.993333 +Training epoch: 41 completed. +Eopch: 42 train loss = 0.001715 +Eopch: 42 valuation loss = 0.025247, ACC = 0.994833 +Training epoch: 42 completed. +Eopch: 43 train loss = 0.000493 +Eopch: 43 valuation loss = 0.023587, ACC = 0.995667 +Training epoch: 43 completed. +Eopch: 44 train loss = 0.000088 +Eopch: 44 valuation loss = 0.020639, ACC = 0.995500 +Training epoch: 44 completed. +Eopch: 45 train loss = 0.000049 +Eopch: 45 valuation loss = 0.020038, ACC = 0.996000 +Training epoch: 45 completed. +Eopch: 46 train loss = 0.000034 +Eopch: 46 valuation loss = 0.019470, ACC = 0.995667 +Training epoch: 46 completed. +Eopch: 47 train loss = 0.002782 +Eopch: 47 valuation loss = 0.041904, ACC = 0.988167 +Training epoch: 47 completed. +Eopch: 48 train loss = 0.002982 +Eopch: 48 valuation loss = 0.020116, ACC = 0.995500 +Training epoch: 48 completed. +Eopch: 49 train loss = 0.000191 +Eopch: 49 valuation loss = 0.018578, ACC = 0.996000 +Training epoch: 49 completed. +Eopch: 50 train loss = 0.000624 +Eopch: 50 valuation loss = 0.023107, ACC = 0.994667 +Training epoch: 50 completed. +Train Loss Visualization is saved to /userhome/cs2/mingzeng/codes/kmnist/log/ResNet-18/train_loss.png +Validation Accuracy Visualization is saved to /userhome/cs2/mingzeng/codes/kmnist/log/ResNet-18/val_acc.png +Train Loss: +0.0023798978278849223,0.001551491871023827,0.0013354753494879885,0.002054083236969754,0.0014850462116790097,0.0013839491533943866,0.0027274314285686938,0.0012908909813588265,0.0011675883227027773,0.0030252575809148878,0.0012450252825608912,0.00012600182976248581,5.131643628639778e-05,3.556391739841912e-05,0.004902001174867005,0.0016656818849003822,0.0002404311794494762,6.223151201156164e-05,4.731706620064671e-05,0.00362265661686718,0.0033936699260611156,0.0003855040930247019,0.0029351176442366277,0.0016602201061659066,0.0004939092476814802,0.002265196821997021,0.00034032644487177533,0.0008193693273553988,0.0017461086846763667,0.000783329249710952,0.001117981775017776,0.0030230789899898867,0.0005442745805398923,0.000281953724842536,0.0018135306640964424,0.002168529378855163,0.00025785759732562403,0.0021670936117468833,0.0005320045849255673,0.0006905365555163453,0.0008968770830824293,0.0017150963560410139,0.0004934786079238362,8.781648941732989e-05,4.8519646186183976e-05,3.3700352101059855e-05,0.0027815273267490404,0.002981613644168263,0.0001908360637885186,0.0006241627901026818 +Validation Accuracy: +0.9946666666666667,0.992,0.9928333333333333,0.995,0.9945,0.9936666666666667,0.9943333333333333,0.9963333333333333,0.99,0.9948333333333333,0.9953333333333333,0.9961666666666666,0.9963333333333333,0.9966666666666667,0.9953333333333333,0.996,0.9965,0.9966666666666667,0.9963333333333333,0.9921666666666666,0.9951666666666666,0.9956666666666667,0.995,0.995,0.9945,0.9943333333333333,0.9961666666666666,0.9931666666666666,0.9951666666666666,0.9956666666666667,0.9951666666666666,0.9935,0.9948333333333333,0.9961666666666666,0.9946666666666667,0.9943333333333333,0.9955,0.9955,0.9955,0.9963333333333333,0.9933333333333333,0.9948333333333333,0.9956666666666667,0.9955,0.996,0.9956666666666667,0.9881666666666666,0.9955,0.996,0.9946666666666667 +=> loaded model checkpoint: /userhome/cs2/mingzeng/codes/kmnist/models/ResNet-18.pkl +******* begin testing!********* +Test Averaged Loss = 0.059358 +Test Averaged Accuracy = 0.988100 +Confusion Matrix Visualization is saved to /userhome/cs2/mingzeng/codes/kmnist/log/ResNet-18/confusion_matrix.png +Class: 0 Accuracy = 0.996300 Precision = 0.975321 Recall = 0.988000 f-score = 0.981619 +Class: 1 Accuracy = 0.998100 Precision = 0.994955 Recall = 0.986000 f-score = 0.990457 +Class: 2 Accuracy = 0.996300 Precision = 0.991828 Recall = 0.971000 f-score = 0.981304 +Class: 3 Accuracy = 0.998600 Precision = 0.991036 Recall = 0.995000 f-score = 0.993014 +Class: 4 Accuracy = 0.995300 Precision = 0.985729 Recall = 0.967000 f-score = 0.976275 +Class: 5 Accuracy = 0.997900 Precision = 0.989990 Recall = 0.989000 f-score = 0.989495 +Class: 6 Accuracy = 0.997600 Precision = 0.980315 Recall = 0.996000 f-score = 0.988095 +Class: 7 Accuracy = 0.999000 Precision = 0.993028 Recall = 0.997000 f-score = 0.995010 +Class: 8 Accuracy = 0.998200 Precision = 0.984221 Recall = 0.998000 f-score = 0.991063 +Class: 9 Accuracy = 0.998900 Precision = 0.994995 Recall = 0.994000 f-score = 0.994497 +********************begin training!******************** +Eopch: 1 train loss = 0.110339 +Eopch: 1 valuation loss = 0.030959, ACC = 0.991500 + Epoch: 1 model has been already save! +Training epoch: 1 completed. +Eopch: 2 train loss = 0.015871 +Eopch: 2 valuation loss = 0.026959, ACC = 0.992667 + Epoch: 2 model has been already save! +Training epoch: 2 completed. +Eopch: 3 train loss = 0.010047 +Eopch: 3 valuation loss = 0.030732, ACC = 0.991167 +Training epoch: 3 completed. +Eopch: 4 train loss = 0.009027 +Eopch: 4 valuation loss = 0.027057, ACC = 0.992167 +Training epoch: 4 completed. +Eopch: 5 train loss = 0.007910 +Eopch: 5 valuation loss = 0.028243, ACC = 0.992333 +Training epoch: 5 completed. +Eopch: 6 train loss = 0.007268 +Eopch: 6 valuation loss = 0.039620, ACC = 0.988167 +Training epoch: 6 completed. +Eopch: 7 train loss = 0.005927 +Eopch: 7 valuation loss = 0.027992, ACC = 0.992833 + Epoch: 7 model has been already save! +Training epoch: 7 completed. +Eopch: 8 train loss = 0.004416 +Eopch: 8 valuation loss = 0.023292, ACC = 0.994500 + Epoch: 8 model has been already save! +Training epoch: 8 completed. +Eopch: 9 train loss = 0.004000 +Eopch: 9 valuation loss = 0.031309, ACC = 0.992500 +Training epoch: 9 completed. +Eopch: 10 train loss = 0.004315 +Eopch: 10 valuation loss = 0.038368, ACC = 0.990333 +Training epoch: 10 completed. +Eopch: 11 train loss = 0.005083 +Eopch: 11 valuation loss = 0.026671, ACC = 0.992500 +Training epoch: 11 completed. +Eopch: 12 train loss = 0.001932 +Eopch: 12 valuation loss = 0.020445, ACC = 0.994833 + Epoch: 12 model has been already save! +Training epoch: 12 completed. +Eopch: 13 train loss = 0.003449 +Eopch: 13 valuation loss = 0.034891, ACC = 0.991333 +Training epoch: 13 completed. +Eopch: 14 train loss = 0.003702 +Eopch: 14 valuation loss = 0.020935, ACC = 0.994833 +Training epoch: 14 completed. +Eopch: 15 train loss = 0.001845 +Eopch: 15 valuation loss = 0.030075, ACC = 0.993833 +Training epoch: 15 completed. +Eopch: 16 train loss = 0.004842 +Eopch: 16 valuation loss = 0.025842, ACC = 0.993833 +Training epoch: 16 completed. +Eopch: 17 train loss = 0.003709 +Eopch: 17 valuation loss = 0.025759, ACC = 0.994667 +Training epoch: 17 completed. +Eopch: 18 train loss = 0.002032 +Eopch: 18 valuation loss = 0.021540, ACC = 0.994500 +Training epoch: 18 completed. +Eopch: 19 train loss = 0.002395 +Eopch: 19 valuation loss = 0.033867, ACC = 0.993500 +Training epoch: 19 completed. +Eopch: 20 train loss = 0.002707 +Eopch: 20 valuation loss = 0.019205, ACC = 0.994667 +Training epoch: 20 completed. +Eopch: 21 train loss = 0.002436 +Eopch: 21 valuation loss = 0.028644, ACC = 0.992500 +Training epoch: 21 completed. +Eopch: 22 train loss = 0.002499 +Eopch: 22 valuation loss = 0.017397, ACC = 0.996500 + Epoch: 22 model has been already save! +Training epoch: 22 completed. +Eopch: 23 train loss = 0.001254 +Eopch: 23 valuation loss = 0.030186, ACC = 0.992833 +Training epoch: 23 completed. +Eopch: 24 train loss = 0.003710 +Eopch: 24 valuation loss = 0.024254, ACC = 0.993000 +Training epoch: 24 completed. +Eopch: 25 train loss = 0.001489 +Eopch: 25 valuation loss = 0.019749, ACC = 0.994667 +Training epoch: 25 completed. +Eopch: 26 train loss = 0.004095 +Eopch: 26 valuation loss = 0.020992, ACC = 0.995167 +Training epoch: 26 completed. +Eopch: 27 train loss = 0.000797 +Eopch: 27 valuation loss = 0.016143, ACC = 0.996333 +Training epoch: 27 completed. +Eopch: 28 train loss = 0.000089 +Eopch: 28 valuation loss = 0.018128, ACC = 0.995500 +Training epoch: 28 completed. +Eopch: 29 train loss = 0.002590 +Eopch: 29 valuation loss = 0.019603, ACC = 0.995500 +Training epoch: 29 completed. +Eopch: 30 train loss = 0.003278 +Eopch: 30 valuation loss = 0.025397, ACC = 0.994667 +Training epoch: 30 completed. +Eopch: 31 train loss = 0.001702 +Eopch: 31 valuation loss = 0.018519, ACC = 0.996167 +Training epoch: 31 completed. +Eopch: 32 train loss = 0.000383 +Eopch: 32 valuation loss = 0.022197, ACC = 0.994667 +Training epoch: 32 completed. +Eopch: 33 train loss = 0.003521 +Eopch: 33 valuation loss = 0.018160, ACC = 0.995500 +Training epoch: 33 completed. +Eopch: 34 train loss = 0.000882 +Eopch: 34 valuation loss = 0.044170, ACC = 0.990333 +Training epoch: 34 completed. +Eopch: 35 train loss = 0.000976 +Eopch: 35 valuation loss = 0.017261, ACC = 0.996167 +Training epoch: 35 completed. +Eopch: 36 train loss = 0.003609 +Eopch: 36 valuation loss = 0.021408, ACC = 0.994500 +Training epoch: 36 completed. +Eopch: 37 train loss = 0.002237 +Eopch: 37 valuation loss = 0.015329, ACC = 0.996333 +Training epoch: 37 completed. +Eopch: 38 train loss = 0.000244 +Eopch: 38 valuation loss = 0.016126, ACC = 0.996833 + Epoch: 38 model has been already save! +Training epoch: 38 completed. +Eopch: 39 train loss = 0.000064 +Eopch: 39 valuation loss = 0.015117, ACC = 0.996667 +Training epoch: 39 completed. +Eopch: 40 train loss = 0.000035 +Eopch: 40 valuation loss = 0.014981, ACC = 0.996667 +Training epoch: 40 completed. +Eopch: 41 train loss = 0.000154 +Eopch: 41 valuation loss = 0.036214, ACC = 0.993000 +Training epoch: 41 completed. +Eopch: 42 train loss = 0.007359 +Eopch: 42 valuation loss = 0.025046, ACC = 0.994333 +Training epoch: 42 completed. +Eopch: 43 train loss = 0.000664 +Eopch: 43 valuation loss = 0.018365, ACC = 0.996167 +Training epoch: 43 completed. +Eopch: 44 train loss = 0.000205 +Eopch: 44 valuation loss = 0.016666, ACC = 0.996500 +Training epoch: 44 completed. +Eopch: 45 train loss = 0.002770 +Eopch: 45 valuation loss = 0.028044, ACC = 0.993333 +Training epoch: 45 completed. +Eopch: 46 train loss = 0.002998 +Eopch: 46 valuation loss = 0.024017, ACC = 0.995333 +Training epoch: 46 completed. +Eopch: 47 train loss = 0.000574 +Eopch: 47 valuation loss = 0.023372, ACC = 0.995167 +Training epoch: 47 completed. +Eopch: 48 train loss = 0.002458 +Eopch: 48 valuation loss = 0.016478, ACC = 0.996333 +Training epoch: 48 completed. +Eopch: 49 train loss = 0.002323 +Eopch: 49 valuation loss = 0.019090, ACC = 0.994500 +Training epoch: 49 completed. +Eopch: 50 train loss = 0.000743 +Eopch: 50 valuation loss = 0.017156, ACC = 0.995833 +Training epoch: 50 completed. +Train Loss Visualization is saved to /userhome/cs2/mingzeng/codes/kmnist/log/ResNet-18/train_loss.png +Validation Accuracy Visualization is saved to /userhome/cs2/mingzeng/codes/kmnist/log/ResNet-18/val_acc.png +Train Loss: +0.11033941893198337,0.01587141261952466,0.010047460450098123,0.009026822771629323,0.007909631987312126,0.007267851441115432,0.005926719363539861,0.004415530911910753,0.004000338429366786,0.004315067960562365,0.005082801300376003,0.0019316070164466302,0.003448847651242608,0.0037018515641234467,0.0018449491861234294,0.0048415431339483015,0.0037093547556111294,0.0020316048078758238,0.002395397147667406,0.0027070195585829683,0.0024358899397665702,0.0024986996267306973,0.001254476565973013,0.003710166680174427,0.0014893779088541088,0.0040947873539281315,0.0007969619218517213,8.857132731382576e-05,0.002590080609492829,0.0032777426852391178,0.001701802904918349,0.0003830976405296501,0.0035213949273196557,0.0008816417004870466,0.0009755956402146495,0.003609171024395415,0.002236520517689965,0.0002444293147632207,6.380908616249523e-05,3.4506582773181185e-05,0.000154251702148673,0.007359382428508738,0.000663765362646506,0.00020465109137051844,0.002769523832936044,0.0029981745003645937,0.0005742085878933566,0.0024575316408295207,0.002322727980387896,0.0007429241633507924 +Validation Accuracy: +0.9915,0.9926666666666667,0.9911666666666666,0.9921666666666666,0.9923333333333333,0.9881666666666666,0.9928333333333333,0.9945,0.9925,0.9903333333333333,0.9925,0.9948333333333333,0.9913333333333333,0.9948333333333333,0.9938333333333333,0.9938333333333333,0.9946666666666667,0.9945,0.9935,0.9946666666666667,0.9925,0.9965,0.9928333333333333,0.993,0.9946666666666667,0.9951666666666666,0.9963333333333333,0.9955,0.9955,0.9946666666666667,0.9961666666666666,0.9946666666666667,0.9955,0.9903333333333333,0.9961666666666666,0.9945,0.9963333333333333,0.9968333333333333,0.9966666666666667,0.9966666666666667,0.993,0.9943333333333333,0.9961666666666666,0.9965,0.9933333333333333,0.9953333333333333,0.9951666666666666,0.9963333333333333,0.9945,0.9958333333333333 +=> loaded model checkpoint: /userhome/cs2/mingzeng/codes/kmnist/models/ResNet-18.pkl +******* begin testing!********* +Test Averaged Loss = 0.061179 +Test Averaged Accuracy = 0.986900 +Confusion Matrix Visualization is saved to /userhome/cs2/mingzeng/codes/kmnist/log/ResNet-18/confusion_matrix.png +Class: 0 Accuracy = 0.996500 Precision = 0.967992 Recall = 0.998000 f-score = 0.982767 +Class: 1 Accuracy = 0.998200 Precision = 0.992972 Recall = 0.989000 f-score = 0.990982 +Class: 2 Accuracy = 0.995500 Precision = 0.984772 Recall = 0.970000 f-score = 0.977330 +Class: 3 Accuracy = 0.997700 Precision = 0.986070 Recall = 0.991000 f-score = 0.988529 +Class: 4 Accuracy = 0.995000 Precision = 0.992739 Recall = 0.957000 f-score = 0.974542 +Class: 5 Accuracy = 0.997200 Precision = 0.988934 Recall = 0.983000 f-score = 0.985958 +Class: 6 Accuracy = 0.997400 Precision = 0.981225 Recall = 0.993000 f-score = 0.987078 +Class: 7 Accuracy = 0.998300 Precision = 0.988083 Recall = 0.995000 f-score = 0.991530 +Class: 8 Accuracy = 0.999000 Precision = 0.993028 Recall = 0.997000 f-score = 0.995010 +Class: 9 Accuracy = 0.999000 Precision = 0.994012 Recall = 0.996000 f-score = 0.995005 diff --git a/log/ResNet-18/train_loss.png b/log/ResNet-18/train_loss.png new file mode 100644 index 0000000..aaf4751 Binary files /dev/null and b/log/ResNet-18/train_loss.png differ diff --git a/log/ResNet-18/val_acc.png b/log/ResNet-18/val_acc.png new file mode 100644 index 0000000..187ed35 Binary files /dev/null and b/log/ResNet-18/val_acc.png differ diff --git a/log/ResNet-34/confusion_matrix.png b/log/ResNet-34/confusion_matrix.png new file mode 100644 index 0000000..9990fda Binary files /dev/null and b/log/ResNet-34/confusion_matrix.png differ diff --git a/log/ResNet-34/log.txt b/log/ResNet-34/log.txt new file mode 100644 index 0000000..465ce6b --- /dev/null +++ b/log/ResNet-34/log.txt @@ -0,0 +1,178 @@ +********************begin training!******************** +Eopch: 1 train loss = 0.106633 +Eopch: 1 valuation loss = 0.036541, ACC = 0.990000 + Epoch: 1 model has been already save! +Training epoch: 1 completed. +Eopch: 2 train loss = 0.023560 +Eopch: 2 valuation loss = 0.029485, ACC = 0.991833 + Epoch: 2 model has been already save! +Training epoch: 2 completed. +Eopch: 3 train loss = 0.014101 +Eopch: 3 valuation loss = 0.026434, ACC = 0.993167 + Epoch: 3 model has been already save! +Training epoch: 3 completed. +Eopch: 4 train loss = 0.011178 +Eopch: 4 valuation loss = 0.033568, ACC = 0.990500 +Training epoch: 4 completed. +Eopch: 5 train loss = 0.011963 +Eopch: 5 valuation loss = 0.028671, ACC = 0.993167 +Training epoch: 5 completed. +Eopch: 6 train loss = 0.008491 +Eopch: 6 valuation loss = 0.021085, ACC = 0.995000 + Epoch: 6 model has been already save! +Training epoch: 6 completed. +Eopch: 7 train loss = 0.008863 +Eopch: 7 valuation loss = 0.023528, ACC = 0.994500 +Training epoch: 7 completed. +Eopch: 8 train loss = 0.007288 +Eopch: 8 valuation loss = 0.026976, ACC = 0.992500 +Training epoch: 8 completed. +Eopch: 9 train loss = 0.005660 +Eopch: 9 valuation loss = 0.043079, ACC = 0.987667 +Training epoch: 9 completed. +Eopch: 10 train loss = 0.005652 +Eopch: 10 valuation loss = 0.032521, ACC = 0.990833 +Training epoch: 10 completed. +Eopch: 11 train loss = 0.007269 +Eopch: 11 valuation loss = 0.022062, ACC = 0.994167 +Training epoch: 11 completed. +Eopch: 12 train loss = 0.004372 +Eopch: 12 valuation loss = 0.036558, ACC = 0.992167 +Training epoch: 12 completed. +Eopch: 13 train loss = 0.005474 +Eopch: 13 valuation loss = 0.023405, ACC = 0.993833 +Training epoch: 13 completed. +Eopch: 14 train loss = 0.003349 +Eopch: 14 valuation loss = 0.027550, ACC = 0.993833 +Training epoch: 14 completed. +Eopch: 15 train loss = 0.004899 +Eopch: 15 valuation loss = 0.023654, ACC = 0.994333 +Training epoch: 15 completed. +Eopch: 16 train loss = 0.003552 +Eopch: 16 valuation loss = 0.022079, ACC = 0.995000 +Training epoch: 16 completed. +Eopch: 17 train loss = 0.004018 +Eopch: 17 valuation loss = 0.021004, ACC = 0.995167 + Epoch: 17 model has been already save! +Training epoch: 17 completed. +Eopch: 18 train loss = 0.004144 +Eopch: 18 valuation loss = 0.017145, ACC = 0.996500 + Epoch: 18 model has been already save! +Training epoch: 18 completed. +Eopch: 19 train loss = 0.002511 +Eopch: 19 valuation loss = 0.022749, ACC = 0.995167 +Training epoch: 19 completed. +Eopch: 20 train loss = 0.002511 +Eopch: 20 valuation loss = 0.033824, ACC = 0.992500 +Training epoch: 20 completed. +Eopch: 21 train loss = 0.005574 +Eopch: 21 valuation loss = 0.018856, ACC = 0.996000 +Training epoch: 21 completed. +Eopch: 22 train loss = 0.002727 +Eopch: 22 valuation loss = 0.019269, ACC = 0.995167 +Training epoch: 22 completed. +Eopch: 23 train loss = 0.002617 +Eopch: 23 valuation loss = 0.023354, ACC = 0.995167 +Training epoch: 23 completed. +Eopch: 24 train loss = 0.002244 +Eopch: 24 valuation loss = 0.018489, ACC = 0.995167 +Training epoch: 24 completed. +Eopch: 25 train loss = 0.002358 +Eopch: 25 valuation loss = 0.019329, ACC = 0.995000 +Training epoch: 25 completed. +Eopch: 26 train loss = 0.003287 +Eopch: 26 valuation loss = 0.021972, ACC = 0.995167 +Training epoch: 26 completed. +Eopch: 27 train loss = 0.004589 +Eopch: 27 valuation loss = 0.024749, ACC = 0.993333 +Training epoch: 27 completed. +Eopch: 28 train loss = 0.000912 +Eopch: 28 valuation loss = 0.021506, ACC = 0.995500 +Training epoch: 28 completed. +Eopch: 29 train loss = 0.002646 +Eopch: 29 valuation loss = 0.021657, ACC = 0.995667 +Training epoch: 29 completed. +Eopch: 30 train loss = 0.001185 +Eopch: 30 valuation loss = 0.027115, ACC = 0.994667 +Training epoch: 30 completed. +Eopch: 31 train loss = 0.004833 +Eopch: 31 valuation loss = 0.021278, ACC = 0.995333 +Training epoch: 31 completed. +Eopch: 32 train loss = 0.001897 +Eopch: 32 valuation loss = 0.022322, ACC = 0.994500 +Training epoch: 32 completed. +Eopch: 33 train loss = 0.003539 +Eopch: 33 valuation loss = 0.018811, ACC = 0.995833 +Training epoch: 33 completed. +Eopch: 34 train loss = 0.000455 +Eopch: 34 valuation loss = 0.021088, ACC = 0.995833 +Training epoch: 34 completed. +Eopch: 35 train loss = 0.002819 +Eopch: 35 valuation loss = 0.027828, ACC = 0.993667 +Training epoch: 35 completed. +Eopch: 36 train loss = 0.002950 +Eopch: 36 valuation loss = 0.022639, ACC = 0.995000 +Training epoch: 36 completed. +Eopch: 37 train loss = 0.002021 +Eopch: 37 valuation loss = 0.014453, ACC = 0.996500 +Training epoch: 37 completed. +Eopch: 38 train loss = 0.000929 +Eopch: 38 valuation loss = 0.020020, ACC = 0.996167 +Training epoch: 38 completed. +Eopch: 39 train loss = 0.003934 +Eopch: 39 valuation loss = 0.021908, ACC = 0.995667 +Training epoch: 39 completed. +Eopch: 40 train loss = 0.002856 +Eopch: 40 valuation loss = 0.022252, ACC = 0.995333 +Training epoch: 40 completed. +Eopch: 41 train loss = 0.001765 +Eopch: 41 valuation loss = 0.019979, ACC = 0.995333 +Training epoch: 41 completed. +Eopch: 42 train loss = 0.002859 +Eopch: 42 valuation loss = 0.019498, ACC = 0.995500 +Training epoch: 42 completed. +Eopch: 43 train loss = 0.001988 +Eopch: 43 valuation loss = 0.020122, ACC = 0.995833 +Training epoch: 43 completed. +Eopch: 44 train loss = 0.001906 +Eopch: 44 valuation loss = 0.015427, ACC = 0.996000 +Training epoch: 44 completed. +Eopch: 45 train loss = 0.000617 +Eopch: 45 valuation loss = 0.023248, ACC = 0.995833 +Training epoch: 45 completed. +Eopch: 46 train loss = 0.002288 +Eopch: 46 valuation loss = 0.024685, ACC = 0.994333 +Training epoch: 46 completed. +Eopch: 47 train loss = 0.002961 +Eopch: 47 valuation loss = 0.029439, ACC = 0.994000 +Training epoch: 47 completed. +Eopch: 48 train loss = 0.002208 +Eopch: 48 valuation loss = 0.027046, ACC = 0.994833 +Training epoch: 48 completed. +Eopch: 49 train loss = 0.000916 +Eopch: 49 valuation loss = 0.020738, ACC = 0.996000 +Training epoch: 49 completed. +Eopch: 50 train loss = 0.000336 +Eopch: 50 valuation loss = 0.021995, ACC = 0.994667 +Training epoch: 50 completed. +Train Loss Visualization is saved to /userhome/cs2/mingzeng/codes/kmnist/log/ResNet-34/train_loss.png +Validation Accuracy Visualization is saved to /userhome/cs2/mingzeng/codes/kmnist/log/ResNet-34/val_acc.png +Train Loss: +0.10663288501016818,0.023560329713782524,0.014101241604837157,0.011177892464259832,0.011962616963440859,0.008491031523219514,0.008862906449553844,0.0072877223670445536,0.005659631196351981,0.0056515910035211045,0.007268956035215386,0.004371746765634369,0.005474027861992732,0.0033488901729209165,0.00489867321382667,0.0035519160211847887,0.0040176142923510315,0.0041439567963489805,0.0025106570646627733,0.0025113279428380718,0.005573548964850264,0.0027273199622795284,0.0026173889789269113,0.0022438014260685135,0.002357653659443726,0.003287135360845659,0.0045886932596953065,0.0009121494066920027,0.002645777721855446,0.001184784014578749,0.004833152677004954,0.0018972325255269696,0.0035389326350548506,0.00045536609219931846,0.002819422161926344,0.00295049640826091,0.0020209123628466595,0.0009293903282567419,0.003933526405523504,0.0028557769325891352,0.001764970838518952,0.0028587067907089748,0.001987672652010898,0.0019058688251576848,0.0006169046146797123,0.0022881885661847966,0.0029608939319989026,0.0022084624991614947,0.0009156058300263428,0.0003358277108293135 +Validation Accuracy: +0.99,0.9918333333333333,0.9931666666666666,0.9905,0.9931666666666666,0.995,0.9945,0.9925,0.9876666666666667,0.9908333333333333,0.9941666666666666,0.9921666666666666,0.9938333333333333,0.9938333333333333,0.9943333333333333,0.995,0.9951666666666666,0.9965,0.9951666666666666,0.9925,0.996,0.9951666666666666,0.9951666666666666,0.9951666666666666,0.995,0.9951666666666666,0.9933333333333333,0.9955,0.9956666666666667,0.9946666666666667,0.9953333333333333,0.9945,0.9958333333333333,0.9958333333333333,0.9936666666666667,0.995,0.9965,0.9961666666666666,0.9956666666666667,0.9953333333333333,0.9953333333333333,0.9955,0.9958333333333333,0.996,0.9958333333333333,0.9943333333333333,0.994,0.9948333333333333,0.996,0.9946666666666667 +=> loaded model checkpoint: /userhome/cs2/mingzeng/codes/kmnist/models/ResNet-34.pkl +******* begin testing!********* +Test Averaged Loss = 0.070697 +Test Averaged Accuracy = 0.986100 +Confusion Matrix Visualization is saved to /userhome/cs2/mingzeng/codes/kmnist/log/ResNet-34/confusion_matrix.png +Class: 0 Accuracy = 0.995800 Precision = 0.965953 Recall = 0.993000 f-score = 0.979290 +Class: 1 Accuracy = 0.998000 Precision = 0.994949 Recall = 0.985000 f-score = 0.989950 +Class: 2 Accuracy = 0.994100 Precision = 0.972864 Recall = 0.968000 f-score = 0.970426 +Class: 3 Accuracy = 0.998000 Precision = 0.984190 Recall = 0.996000 f-score = 0.990060 +Class: 4 Accuracy = 0.995000 Precision = 0.985685 Recall = 0.964000 f-score = 0.974722 +Class: 5 Accuracy = 0.996700 Precision = 0.993871 Recall = 0.973000 f-score = 0.983325 +Class: 6 Accuracy = 0.997700 Precision = 0.983185 Recall = 0.994000 f-score = 0.988563 +Class: 7 Accuracy = 0.999100 Precision = 0.994018 Recall = 0.997000 f-score = 0.995507 +Class: 8 Accuracy = 0.998800 Precision = 0.992032 Recall = 0.996000 f-score = 0.994012 +Class: 9 Accuracy = 0.999000 Precision = 0.995000 Recall = 0.995000 f-score = 0.995000 diff --git a/log/ResNet-34/train_loss.png b/log/ResNet-34/train_loss.png new file mode 100644 index 0000000..29480ff Binary files /dev/null and b/log/ResNet-34/train_loss.png differ diff --git a/log/ResNet-34/val_acc.png b/log/ResNet-34/val_acc.png new file mode 100644 index 0000000..673c719 Binary files /dev/null and b/log/ResNet-34/val_acc.png differ diff --git a/models/.gitignore b/models/.gitignore new file mode 100644 index 0000000..e69de29 diff --git a/nets/ResNet.py b/nets/ResNet.py new file mode 100644 index 0000000..0a24f47 --- /dev/null +++ b/nets/ResNet.py @@ -0,0 +1,213 @@ +import torch.nn as nn +import torchvision +import torch +import math +import torch.utils.model_zoo as model_zoo + +# construct model +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152'] + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000): + self.inplanes = 64 + super(ResNet, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + y = x + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return torch.sigmoid(x), y + + +def resnet18(t_num_classes=46, pretrained=False, **kwargs): + """Constructs a ResNet-18 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) + num_fc_ftr = model.fc.in_features # overwrite the fc layer + model.fc = torch.nn.Linear(num_fc_ftr, t_num_classes) + return model + + +def resnet34(pretrained=False, **kwargs): + """Constructs a ResNet-34 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) + return model + + +def resnet50(t_num_classes=5, pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) + num_fc_ftr = model.fc.in_features # overwrite the fc layer + model.fc = torch.nn.Linear(num_fc_ftr, t_num_classes) + return model + + +def resnet101(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) + return model + + +def resnet152(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) + return model diff --git a/unsupervised.ipynb b/unsupervised.ipynb new file mode 100644 index 0000000..7008a97 --- /dev/null +++ b/unsupervised.ipynb @@ -0,0 +1,1827 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "7ba70409", + "metadata": {}, + "source": [ + "# Unsupervised Models" + ] + }, + { + "cell_type": "markdown", + "id": "dd2ab19d", + "metadata": {}, + "source": [ + "## 1. Data and Clusters' Examples Visualization" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "0422e6d9", + "metadata": {}, + "outputs": [], + "source": [ + "# download the dataset in NumPy format\n", + "import numpy as np\n", + "def load(f):\n", + " return np.load(f)['arr_0']\n", + "\n", + "# Load the data\n", + "x_train = load('kmnist-train-imgs.npz')\n", + "x_test = load('kmnist-test-imgs.npz')\n", + "y_train = load('kmnist-train-labels.npz')\n", + "y_test = load('kmnist-test-labels.npz')" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "4bf07b89", + "metadata": {}, + "outputs": [], + "source": [ + "# Flatten images\n", + "# Each element in x_train and x_test is in the form of 28x28,\n", + "# so reshaping them in the form of 1x784\n", + "x_trainf = x_train.reshape(-1, 784)\n", + "x_testf = x_test.reshape(-1, 784)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "93c6ef34", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
indexcodepointchar
00U+304Aお
11U+304Dき
22U+3059す
33U+3064぀
44U+306Aγͺ
55U+306Fは
66U+307Eま
77U+3084γ‚„
88U+308Cγ‚Œ
99U+3092γ‚’
\n", + "
" + ], + "text/plain": [ + " index codepoint char\n", + "0 0 U+304A お\n", + "1 1 U+304D き\n", + "2 2 U+3059 す\n", + "3 3 U+3064 ぀\n", + "4 4 U+306A γͺ\n", + "5 5 U+306F は\n", + "6 6 U+307E ま\n", + "7 7 U+3084 γ‚„\n", + "8 8 U+308C γ‚Œ\n", + "9 9 U+3092 γ‚’" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "import pandas as pd\n", + "# import the map to see the given 10 clusters\n", + "map = pd.read_csv(\"kmnist_classmap.csv\")\n", + "map" + ] + }, + { + "cell_type": "markdown", + "id": "ca9557a5", + "metadata": {}, + "source": [ + "### visualize characters in each cluster" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "039fead3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# cluster 1\n", + "plt.figure()\n", + "fig, (ax,ax2) = plt.subplots(ncols=2)\n", + "fig.subplots_adjust(wspace=0.01)\n", + "\n", + "sns.heatmap(x_train[2], cmap = \"gist_gray\", cbar = False, ax = ax)\n", + "sns.heatmap(x_train[12], cmap = \"gist_gray\", cbar = False, ax = ax2)\n", + "ax2.yaxis.tick_right()\n", + "ax2.tick_params(rotation=0)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "046a75eb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# cluster 2\n", + "plt.figure()\n", + "fig, (ax,ax2) = plt.subplots(ncols=2)\n", + "fig.subplots_adjust(wspace=0.01)\n", + "sns.heatmap(x_train[3], cmap = \"gist_gray\", cbar = False, ax = ax)\n", + "sns.heatmap(x_train[8], cmap = \"gist_gray\", cbar = False, ax = ax2)\n", + "ax2.yaxis.tick_right()\n", + "ax2.tick_params(rotation=0)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "da357b54", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# cluster 3\n", + "plt.figure()\n", + "fig, (ax,ax2) = plt.subplots(ncols=2)\n", + "fig.subplots_adjust(wspace=0.01)\n", + "sns.heatmap(x_train[5], cmap = \"gist_gray\", cbar = False, ax = ax)\n", + "sns.heatmap(x_train[26], cmap = \"gist_gray\", cbar = False, ax = ax2)\n", + "ax2.yaxis.tick_right()\n", + "ax2.tick_params(rotation=0)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "7286cdbb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# cluster 4\n", + "plt.figure()\n", + "fig, (ax,ax2) = plt.subplots(ncols=2)\n", + "fig.subplots_adjust(wspace=0.01)\n", + "sns.heatmap(x_train[21], cmap = \"gist_gray\", cbar = False, ax = ax)\n", + "sns.heatmap(x_train[54], cmap = \"gist_gray\", cbar = False, ax = ax2)\n", + "ax2.yaxis.tick_right()\n", + "ax2.tick_params(rotation=0)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "38866d42", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# cluster 5 \n", + "plt.figure()\n", + "fig, (ax,ax2) = plt.subplots(ncols=2)\n", + "fig.subplots_adjust(wspace=0.01)\n", + "sns.heatmap(x_train[4], cmap = \"gist_gray\", cbar = False, ax = ax)\n", + "sns.heatmap(x_train[6], cmap = \"gist_gray\", cbar = False, ax = ax2)\n", + "ax2.yaxis.tick_right()\n", + "ax2.tick_params(rotation=0)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "e0c292a2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# cluster 6\n", + "plt.figure()\n", + "fig, (ax,ax2) = plt.subplots(ncols=2)\n", + "fig.subplots_adjust(wspace=0.01)\n", + "sns.heatmap(x_train[10], cmap = \"gist_gray\", cbar = False, ax = ax)\n", + "sns.heatmap(x_train[13], cmap = \"gist_gray\", cbar = False, ax = ax2)\n", + "ax2.yaxis.tick_right()\n", + "ax2.tick_params(rotation=0)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "5c9e5d1f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# cluster 7\n", + "plt.figure()\n", + "fig, (ax,ax2) = plt.subplots(ncols=2)\n", + "fig.subplots_adjust(wspace=0.01)\n", + "sns.heatmap(x_train[24], cmap = \"gist_gray\", cbar = False, ax = ax)\n", + "sns.heatmap(x_train[25], cmap = \"gist_gray\", cbar = False, ax = ax2)\n", + "ax2.yaxis.tick_right()\n", + "ax2.tick_params(rotation=0)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "523d5272", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# cluster 8\n", + "plt.figure()\n", + "fig, (ax,ax2) = plt.subplots(ncols=2)\n", + "fig.subplots_adjust(wspace=0.01)\n", + "sns.heatmap(x_train[1], cmap = \"gist_gray\", cbar = False, ax = ax)\n", + "sns.heatmap(x_train[14], cmap = \"gist_gray\", cbar = False, ax = ax2)\n", + "ax2.yaxis.tick_right()\n", + "ax2.tick_params(rotation=0)\n", + "\n", + "plt.show()\n", + "#sns.heatmap(x_train[17], cmap = \"gist_gray\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "a53a0ed4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# cluster 9\n", + "plt.figure()\n", + "fig, (ax,ax2) = plt.subplots(ncols=2)\n", + "fig.subplots_adjust(wspace=0.01)\n", + "sns.heatmap(x_train[0], cmap = \"gist_gray\", cbar = False, ax = ax)\n", + "sns.heatmap(x_train[7], cmap = \"gist_gray\", cbar = False, ax = ax2)\n", + "ax2.yaxis.tick_right()\n", + "ax2.tick_params(rotation=0)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "e7d36a15", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# cluster 10\n", + "plt.figure()\n", + "fig, (ax,ax2) = plt.subplots(ncols=2)\n", + "fig.subplots_adjust(wspace=0.01)\n", + "sns.heatmap(x_train[19], cmap = \"gist_gray\", cbar = False, ax = ax)\n", + "sns.heatmap(x_train[30], cmap = \"gist_gray\", cbar = False, ax = ax2)\n", + "ax2.yaxis.tick_right()\n", + "ax2.tick_params(rotation=0)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "2786386d", + "metadata": {}, + "source": [ + "## 2. Models" + ] + }, + { + "cell_type": "markdown", + "id": "3fc3e521", + "metadata": {}, + "source": [ + "### 2.1 Kmeans" + ] + }, + { + "cell_type": "markdown", + "id": "9bef3c89", + "metadata": {}, + "source": [ + "###    2.1.1 Fitting Models" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "289d8d9e", + "metadata": {}, + "outputs": [], + "source": [ + "# import KMeans from sklearn \n", + "from sklearn.cluster import KMeans\n", + "\n", + "# assigned 10 clusters \n", + "# fit() to compute clustering for K-Means\n", + "# fit_predict() to make predictions\n", + "kmeans = KMeans(n_clusters = 10).fit(x_trainf)\n", + "y_pred_kmean = kmeans.predict(x_testf)" + ] + }, + { + "cell_type": "markdown", + "id": "67ccad57", + "metadata": {}, + "source": [ + "###    2.2.2 Clustering results evaluations" + ] + }, + { + "cell_type": "markdown", + "id": "f2e23193", + "metadata": {}, + "source": [ + "####     NMI (Normaliszed Mutual Information)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "955c19c9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Test score:0.3075\n" + ] + } + ], + "source": [ + "from sklearn.metrics.cluster import normalized_mutual_info_score\n", + "print(\"Test score:{:.4f}\".format(normalized_mutual_info_score(y_test, y_pred_kmean,average_method='arithmetic')))" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "40e0dea9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Test score:0.3204\n" + ] + } + ], + "source": [ + "from sklearn.metrics.cluster import normalized_mutual_info_score\n", + "print(\"Test score:{:.4f}\".format(normalized_mutual_info_score(y_test, y_pred_kmean,average_method='min')))" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "01329398", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Test score:0.3077\n" + ] + } + ], + "source": [ + "from sklearn.metrics.cluster import normalized_mutual_info_score\n", + "print(\"Test score:{:.4f}\".format(normalized_mutual_info_score(y_test, y_pred_kmean,average_method='geometric')))" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "ac0103e6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Test score:0.2956\n" + ] + } + ], + "source": [ + "from sklearn.metrics.cluster import normalized_mutual_info_score\n", + "print(\"Test score:{:.4f}\".format(normalized_mutual_info_score(y_test, y_pred_kmean,average_method='max')))" + ] + }, + { + "cell_type": "markdown", + "id": "5d7c0620", + "metadata": {}, + "source": [ + "####     ARI (Adjusted Rand Index)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "779c4a21", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Test score:0.1616\n" + ] + } + ], + "source": [ + "from sklearn.metrics.cluster import adjusted_rand_score\n", + "print(\"Test score:{:.4f}\".format(adjusted_rand_score(y_test, y_pred_kmean)))" + ] + }, + { + "cell_type": "markdown", + "id": "eaf905f4", + "metadata": {}, + "source": [ + "####     Confusion Matrix" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "49e50762", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Confusion matrix: \n", + "[[259 106 0 60 7 3 523 21 3 18]\n", + " [170 0 79 2 6 75 1 95 0 572]\n", + " [229 4 137 16 14 355 4 8 14 219]\n", + " [549 108 3 9 3 3 1 9 256 59]\n", + " [ 83 5 68 11 108 38 73 414 2 198]\n", + " [ 62 251 56 1 0 411 6 8 10 195]\n", + " [ 81 0 560 32 4 24 5 106 3 185]\n", + " [ 38 1 14 391 27 116 9 314 19 71]\n", + " [401 19 75 2 1 361 4 6 0 131]\n", + " [222 1 14 10 301 191 1 22 0 238]]\n" + ] + } + ], + "source": [ + "from sklearn.metrics import confusion_matrix\n", + "kmeans_confusion = confusion_matrix(y_test,y_pred_kmean)\n", + "print('Confusion matrix: \\n{}'.format(kmeans_confusion))" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "ba637eb3", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# visualize the confusion matrix, the deeper color represents the greater values\n", + "import seaborn as sns\n", + "sns.heatmap(kmeans_confusion, annot=False, cmap='Blues')\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "ab647317", + "metadata": {}, + "source": [ + "### 2.2 PCA (Principal component analysis)" + ] + }, + { + "cell_type": "markdown", + "id": "06872c16", + "metadata": {}, + "source": [ + "###    2.2.1 Scatter Plots in 2D and 3D" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "64add44b", + "metadata": {}, + "outputs": [], + "source": [ + "# import PCA from sklearn\n", + "from sklearn.decomposition import PCA\n", + "\n", + "# implement PCA and keep the first two principal components only\n", + "pca = PCA(n_components = 2, whiten = True)\n", + "pca.fit(x_trainf)\n", + "\n", + "# transform data to reduce dimensions\n", + "x_train_pca = pca.transform(x_trainf)\n", + "x_test_pca = pca.transform(x_testf)" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "id": "fc524704", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "%matplotlib inline\n", + "plt.scatter(x_train_pca[:,0], x_train_pca[:,1], c = y_train)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "8baaf415", + "metadata": {}, + "outputs": [], + "source": [ + "# implement PCA and keep the first three principal components only\n", + "pca3 = PCA(n_components = 3, whiten = True)\n", + "pca3.fit(x_trainf)\n", + "# transform data\n", + "x_train_pca3 = pca3.transform(x_trainf)\n", + "x_test_pca3 = pca3.transform(x_testf)" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "id": "a7c2b5b8", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "from mpl_toolkits.mplot3d import Axes3D\n", + "plt.figure(figsize=(8,8))\n", + "ax = plt.axes(projection=\"3d\")\n", + "ax.scatter3D(x_train_pca3[:,0], x_train_pca3[:,1],x_train_pca3[:,2], c = y_train)\n", + "ax.view_init(10, 60)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "9f414ded", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0.10869168 0.05363875]\n", + "[0.10869168 0.05363875 0.0409124 ]\n" + ] + } + ], + "source": [ + "print(pca.explained_variance_ratio_)\n", + "print(pca3.explained_variance_ratio_)" + ] + }, + { + "cell_type": "markdown", + "id": "4b55fa60", + "metadata": {}, + "source": [ + "###    2.2.2 KMeans using PCA" + ] + }, + { + "cell_type": "markdown", + "id": "5fbf7294", + "metadata": {}, + "source": [ + "### ARI" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "4b695f26", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Test score:0.0848\n" + ] + } + ], + "source": [ + "from sklearn.cluster import KMeans\n", + "\n", + "# implement KMeans on transformed data\n", + "kmeans = KMeans(n_clusters = 10).fit(x_train_pca)\n", + "y_pred_trans_kmean = kmeans.predict(x_test_pca)\n", + "test_score = adjusted_rand_score(y_test, y_pred_trans_kmean)\n", + "print(\"Test score:{:.4f}\".format(test_score))" + ] + }, + { + "cell_type": "markdown", + "id": "42ef2a8d", + "metadata": {}, + "source": [ + "### NMI" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "981f11be", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Test score:0.1828\n" + ] + } + ], + "source": [ + "test_score = normalized_mutual_info_score(y_test, y_pred_trans_kmean)\n", + "print(\"Test score:{:.4f}\".format(test_score))" + ] + }, + { + "cell_type": "markdown", + "id": "54dc4e2d", + "metadata": {}, + "source": [ + "####     Confusion Matrix" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "57eb2033", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Confusion matrix: \n", + "[[ 1 2 398 105 41 34 103 82 234 0]\n", + " [369 99 23 0 61 15 73 28 18 314]\n", + " [337 39 33 5 146 5 38 114 7 276]\n", + " [ 13 3 207 170 187 7 99 283 31 0]\n", + " [ 77 198 38 1 165 134 188 36 35 128]\n", + " [331 51 55 213 34 0 16 23 9 268]\n", + " [333 129 4 0 66 45 83 18 8 314]\n", + " [ 18 185 32 2 190 209 150 177 15 22]\n", + " [107 104 71 23 148 6 82 161 17 281]\n", + " [119 138 12 1 266 40 124 162 4 134]]\n" + ] + } + ], + "source": [ + "from sklearn.metrics import confusion_matrix\n", + "kmeans_pca_confusion = confusion_matrix(y_test,y_pred_kmean_pca)\n", + "print('Confusion matrix: \\n{}'.format(kmeans_pca_confusion))" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "3eef5fab", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import seaborn as sns\n", + "sns.heatmap(kmeans_pca_confusion, annot=False, cmap='Blues')" + ] + }, + { + "cell_type": "markdown", + "id": "0ec36a9a", + "metadata": {}, + "source": [ + "### 2.3 MiniBatchKMeans" + ] + }, + { + "cell_type": "markdown", + "id": "78ddd888", + "metadata": {}, + "source": [ + "###    2.3.1 Fitting Models" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "c4f713ae", + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.cluster import MiniBatchKMeans\n", + "from sklearn.metrics.cluster import normalized_mutual_info_score\n", + "from sklearn.metrics.cluster import adjusted_rand_score\n", + "\n", + "# implement Mini-Batch K-Means\n", + "# assign 10 clusters \n", + "kmeans = MiniBatchKMeans(n_clusters=10,batch_size=5120).fit(x_train_pca)\n", + "y_pred_mbkm = kmeans.predict(x_test_pca)" + ] + }, + { + "cell_type": "markdown", + "id": "d5af13fa", + "metadata": {}, + "source": [ + "###    2.3.2 Clustering results evaluations" + ] + }, + { + "cell_type": "markdown", + "id": "a14ff10a", + "metadata": {}, + "source": [ + "####     NMI (Normaliszed Mutual Information) and ARI (Adjusted Rand Index)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "178c9a4a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "NMI:0.1729\n", + "ARI:0.0817\n" + ] + } + ], + "source": [ + "print(\"NMI:{:.4f}\".format(normalized_mutual_info_score(y_test, y_pred_mbkm,average_method='arithmetic')))\n", + "print(\"ARI:{:.4f}\".format(adjusted_rand_score(y_test, y_pred_mbkm)))" + ] + }, + { + "cell_type": "markdown", + "id": "05d87e3f", + "metadata": {}, + "source": [ + "####     Confusion Matrix" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "91486cd4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Confusion matrix: \n", + "[[127 27 1 80 238 2 156 310 0 59]\n", + " [ 2 60 459 20 76 75 56 9 229 14]\n", + " [ 2 96 418 9 43 22 158 5 209 38]\n", + " [112 112 11 16 129 2 377 81 0 160]\n", + " [ 2 241 128 156 112 161 60 27 105 8]\n", + " [185 32 416 0 17 26 31 32 201 60]\n", + " [ 0 76 395 55 56 83 28 1 270 36]\n", + " [ 2 179 23 186 101 201 228 8 24 48]\n", + " [ 7 96 172 9 95 56 204 29 258 74]\n", + " [ 0 260 139 45 80 90 167 0 130 89]]\n" + ] + } + ], + "source": [ + "mbkm_confusion = confusion_matrix(y_test,y_pred_mbkm)\n", + "print('Confusion matrix: \\n{}'.format(mbkm_confusion))" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "c9ec58f1", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "sns.heatmap(mbkm_confusion, annot=False, cmap='Blues')\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "e7bb3593", + "metadata": {}, + "source": [ + "### 2.4 Birch" + ] + }, + { + "cell_type": "markdown", + "id": "103a4fad", + "metadata": {}, + "source": [ + "###    2.4.1 Fitting Models" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "296530dd", + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.cluster import Birch\n", + "\n", + "# implement birch\n", + "brc = Birch(n_clusters = 10).fit(x_train_pca)\n", + "y_pred_brc = brc.predict(x_test_pca)" + ] + }, + { + "cell_type": "markdown", + "id": "c21d97b8", + "metadata": {}, + "source": [ + "###    2.4.2 Clustering results evaluations" + ] + }, + { + "cell_type": "markdown", + "id": "cac10047", + "metadata": {}, + "source": [ + "####     NMI (Normaliszed Mutual Information) and ARI (Adjusted Rand Index)" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "857fee12", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "NMI:0.1808\n", + "ARI:0.0910\n" + ] + } + ], + "source": [ + "from sklearn.metrics.cluster import normalized_mutual_info_score\n", + "from sklearn.metrics.cluster import adjusted_rand_score\n", + "print(\"NMI:{:.4f}\".format(normalized_mutual_info_score(y_test, y_pred_brc,average_method='arithmetic')))\n", + "print(\"ARI:{:.4f}\".format(adjusted_rand_score(y_test, y_pred_brc)))" + ] + }, + { + "cell_type": "markdown", + "id": "81466971", + "metadata": {}, + "source": [ + "####     Confusion Matrix" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "b89536eb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Confusion matrix: \n", + "[[ 41 451 48 2 94 0 1 73 289 1]\n", + " [230 39 5 16 0 0 603 35 72 0]\n", + " [171 33 1 2 10 0 545 12 226 0]\n", + " [101 132 1 2 158 0 15 32 559 0]\n", + " [502 63 15 71 2 1 131 116 98 1]\n", + " [137 43 0 0 218 0 538 5 59 0]\n", + " [275 17 10 27 0 1 573 45 50 2]\n", + " [362 43 7 156 3 0 25 91 309 4]\n", + " [250 83 0 6 38 0 321 17 284 1]\n", + " [439 26 2 18 3 0 187 43 282 0]]\n" + ] + } + ], + "source": [ + "brc_confusion = confusion_matrix(y_test,y_pred_brc)\n", + "print('Confusion matrix: \\n{}'.format(brc_confusion))" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "c75904f9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAggAAAGgCAYAAADPW599AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAzuklEQVR4nO3df1hVZb738c+WH1tEIIVkQ6KhMWWCToHHETM8qThO/hqfGTVrstEmTTOJSGOc80Q/ZKud1DlZTlZHScehc03R2Dyl0kxDkVMp5YySqR3phwVRDQEabhTW80dX+8zea4Ns3bA2nvera11XrHWz+GBqX773fa9lMwzDEAAAwD/pYXUAAAAQfCgQAACACQUCAAAwoUAAAAAmFAgAAMCEAgEAAJhQIAAAABMKBAAAYEKBAAAATCgQAACACQUCAABB5NNPP9VNN92k2NhY9erVS9///vdVUVHhvm4YhgoKCpSYmKiIiAiNHTtWlZWVHvdwuVxasmSJ4uLiFBkZqalTp+r48eN+5aBAAAAgSNTV1Wn06NEKCwvTyy+/rPfee0+PPPKILrroIveYNWvWaO3atdqwYYP27t0rh8OhCRMmqLGx0T0mJydHJSUlKi4uVnl5uU6cOKHJkyerpaWlw1lswfKypi8az1gdoU0j79ttdYR27SmYYHWEdl3UK8zqCG36+MtvrI7QrgFxvayO0G3V1J+yOkK7HvvrR1ZHaNeDP7zc6gjt6hnaufePuOqOgN2r6d0NHR5777336o033tDrr7/u87phGEpMTFROTo6WL18u6dtuQXx8vFavXq0FCxaovr5eF198sbZu3apZs2ZJkj777DMlJSXppZde0sSJEzuUhQ4CAADebD0CdrhcLjU0NHgcLpfL55fdsWOHMjIy9NOf/lT9+vXTVVddpSeffNJ9vaqqSjU1NcrOznafs9vtysrK0p49eyRJFRUVOn36tMeYxMREpaamusd0BAUCAACdyOl0KiYmxuNwOp0+xx47dkwbN25USkqKdu3apYULF+rOO+/UM888I0mqqamRJMXHx3t8Xnx8vPtaTU2NwsPD1adPnzbHdEQnN2kAAOiGbLaA3So/P1+5ubke5+x2u8+xra2tysjIUGFhoSTpqquuUmVlpTZu3Kibb775n+J55jMMw3TOW0fG/DM6CAAAeAvgFIPdbld0dLTH0VaBkJCQoCuvvNLj3JAhQ/Txxx9LkhwOhySZOgG1tbXuroLD4VBzc7Pq6uraHNMRFAgAAHiz2QJ3+GH06NE6fPiwx7kjR45o4MCBkqTk5GQ5HA6Vlpa6rzc3N6usrEyZmZmSpPT0dIWFhXmMqa6u1sGDB91jOoIpBgAAgsRdd92lzMxMFRYWaubMmXr77be1adMmbdq0SdK3Uws5OTkqLCxUSkqKUlJSVFhYqF69emnOnDmSpJiYGM2fP1933323YmNj1bdvX+Xl5SktLU3jx4/vcBYKBAAAvNmsabCPGDFCJSUlys/P1wMPPKDk5GStX79eN954o3vMsmXL1NTUpEWLFqmurk4jR47U7t27FRUV5R6zbt06hYaGaubMmWpqatK4ceO0ZcsWhYSEdDgLz0HoAJ6DcH54DsK54zkI547nIJyf//XPQRh5T8Du1fTWwwG7V1diDQIAADBhigEAAG8WTTEEEwoEAAC8BfA5CN0VJRIAADChgwAAgDemGCgQAAAwYYqBKQYAAGBGBwEAAG9MMfhfIBw/flwbN27Unj17VFNTI5vNpvj4eGVmZmrhwoVKSko66z1cLpfpXdiu5pA2X14BAECXYorBvymG8vJyDRkyRCUlJRo+fLhuvvlm3XTTTRo+fLheeOEFDR06VG+88cZZ7+Pr3di/fmT1OX8TAAAEVADf5thd+dVBuOuuu3Trrbdq3bp1bV7PycnR3r17272Pr3djNzR3/PnQAACgc/lVIBw8eFDbtm1r8/qCBQv0m9/85qz3sdvtpukEVxC/iwEA8L9MN/7JP1D8+hVISEjQnj172rz+17/+VQkJCecdCgAAS/WwBe7opvzqIOTl5WnhwoWqqKjQhAkTFB8fL5vNppqaGpWWluqpp57S+vXrOykqAADoKn4VCIsWLVJsbKzWrVunJ554Qi0tLZKkkJAQpaen65lnntHMmTM7JSgAAF2GKQb/tznOmjVLs2bN0unTp/Xll19KkuLi4hQWFhbwcAAAWIJtjuf+oKSwsDDWGwAAcIHiSYoAAHhjioECAQAAE6YYeFkTAAAwo4MAAIA3phgoEAAAMGGKgQIBAAATOgisQQAAAGZ0EAAA8MYUAwUCAAAmTDEwxQAAAMzoIAAA4I0phuApEHYc+szqCG1664FsqyO0K9IeYnWEdp1uabU6QpsGxPWyOkK31mfEHVZHaNPx8vVWR2jX4lEDrY6A9jDFwBQDAAAwC5oOAgAAQYMOAgUCAAAmrEFgigEAAJjRQQAAwBtTDBQIAACYMMVAgQAAgAkdBNYgAAAAMzoIAAB4Y4qBAgEAAG82CgSmGAAAgBkdBAAAvNBBoEAAAMCM+oApBgAAYEYHAQAAL0wxUCAAAGBCgWBRgeByueRyuTzOnW52KSzcbkUcAADgJeBrED755BPNmzev3TFOp1MxMTEex0tFjwc6CgAA58RmswXs6K4CXiD84x//UFFRUbtj8vPzVV9f73H8aO6iQEcBAOCcUCCcwxTDjh072r1+7Nixs97DbrfLbvecTggL/9rfKAAAdI7u+//1gPG7QJg+fbpsNpsMw2hzTHeumAAAwDlMMSQkJOi5555Ta2urz+Odd97pjJwAAHQZq6YYCgoKTJ/vcDjc1w3DUEFBgRITExUREaGxY8eqsrLS4x4ul0tLlixRXFycIiMjNXXqVB0/ftzvXwO/C4T09PR2i4CzdRcAAAh2Vq5BGDp0qKqrq93HgQMH3NfWrFmjtWvXasOGDdq7d68cDocmTJigxsZG95icnByVlJSouLhY5eXlOnHihCZPnqyWlha/cvg9xXDPPffo5MmTbV6/7LLL9Oqrr/p7WwAAICk0NNSja/AdwzC0fv16rVixQjNmzJAkFRUVKT4+Xtu3b9eCBQtUX1+vp59+Wlu3btX48eMlSdu2bVNSUpJeeeUVTZw4scM5/O4gjBkzRj/84Q/bvB4ZGamsrCx/bwsAQNAIZAfB5XKpoaHB4/B+FtA/O3r0qBITE5WcnKzZs2e7F/9XVVWppqZG2dnZ7rF2u11ZWVnas2ePJKmiokKnT5/2GJOYmKjU1FT3mI7iXQwAAHgJZIHg69k/TqfT59cdOXKknnnmGe3atUtPPvmkampqlJmZqa+++ko1NTWSpPj4eI/PiY+Pd1+rqalReHi4+vTp0+aYjuJRywAAdKL8/Hzl5uZ6nPPe6v+dSZMmuf89LS1No0aN0uDBg1VUVKQf/OAHksw7BQ3DOOtah46M8UYHAQAAb7bAHXa7XdHR0R5HWwWCt8jISKWlpeno0aPudQnenYDa2lp3V8HhcKi5uVl1dXVtjukoCgQAALwEy5MUXS6XDh06pISEBCUnJ8vhcKi0tNR9vbm5WWVlZcrMzJT07U7DsLAwjzHV1dU6ePCge0xHMcUAAECQyMvL05QpUzRgwADV1tbqoYceUkNDg+bOnSubzaacnBwVFhYqJSVFKSkpKiwsVK9evTRnzhxJUkxMjObPn6+7775bsbGx6tu3r/Ly8pSWlube1dBRFAgAAHix6onAx48f1w033KAvv/xSF198sX7wgx/ozTff1MCBAyVJy5YtU1NTkxYtWqS6ujqNHDlSu3fvVlRUlPse69atU2hoqGbOnKmmpiaNGzdOW7ZsUUhIiF9ZbEaQPNXo6bc/tjpCm6ZemWh1hHZF2v37j97VWlqD4reYT2EhzLKdjz4j7rA6QpuOl6+3OkK7Gk+dsTpCuxwxPa2O0K6enfzjbb95/xWwe9X+58yA3asr0UEAAMAbrxRikSIAADCjgwAAgBfeSkyBAACACQVCEBUIl0b3sjpCm3oE+e+THkH+G7lHSPDmC+YFlJIUEuy/+RJSrE7QJntocC/effHYZ1ZHaNfsqwZYHQEWC5oCAQCAYEEHgQIBAAATCgR2MQAAAB/oIAAA4I0GAgUCAADemGJgigEAAPhABwEAAC90ECgQAAAwoUCgQAAAwIz6gDUIAADAjA4CAABemGKgQAAAwIQCgSkGAADgAx0EAAC80EGgQAAAwIQCgSkGAADgg98FQlNTk8rLy/Xee++Zrp06dUrPPPPMWe/hcrnU0NDgcTQ3u/yNAgBA57AF8Oim/CoQjhw5oiFDhujaa69VWlqaxo4dq+rqavf1+vp6/fznPz/rfZxOp2JiYjyO3236tf/pAQDoBDabLWBHd+VXgbB8+XKlpaWptrZWhw8fVnR0tEaPHq2PP/7Yry+an5+v+vp6j+OG25b6dQ8AANB5/FqkuGfPHr3yyiuKi4tTXFycduzYocWLF2vMmDF69dVXFRkZ2aH72O122e12j3Ph4c3+RAEAoNN055/8A8WvAqGpqUmhoZ6f8thjj6lHjx7KysrS9u3bAxoOAAArUB/4WSBcccUV2rdvn4YMGeJx/tFHH5VhGJo6dWpAwwEAYAU6CH6uQfjxj3+s3/3udz6vbdiwQTfccIMMwwhIMAAAYB2/CoT8/Hy99NJLbV5//PHH1draet6hAACwks0WuKO74kmKAAB4YYqBJykCAAAf6CAAAOCFBgIFAgAAJj16UCEwxQAAAEzoIAAA4IUpBgoEAABM2MXAFAMAAPCBDgIAAF5oIFAgAABgwhQDBQIAACYUCKxBAAAAPgRNByExJsLqCG368ItvrI7QrqH9o62O0G298+HXVkdo14hBfayO0K7akiVWR2hT85ngfnHc7fnbrY7Qrtk777U6gqVoIARRgQAAQLBgioEpBgAA4AMdBAAAvNBAoEAAAMCEKQamGAAAgA90EAAA8EIDgQIBAAATphiYYgAAAD7QQQAAwAsNBDoIAACY2Gy2gB3nyul0ymazKScnx33OMAwVFBQoMTFRERERGjt2rCorKz0+z+VyacmSJYqLi1NkZKSmTp2q48eP+/31KRAAAPBiswXuOBd79+7Vpk2bNGzYMI/za9as0dq1a7Vhwwbt3btXDodDEyZMUGNjo3tMTk6OSkpKVFxcrPLycp04cUKTJ09WS0uLXxkoEAAA6EQul0sNDQ0eh8vlanP8iRMndOONN+rJJ59Unz7/8z4WwzC0fv16rVixQjNmzFBqaqqKior0zTffaPv2b9/tUV9fr6efflqPPPKIxo8fr6uuukrbtm3TgQMH9Morr/iVmwIBAAAvgZxicDqdiomJ8TicTmebX3vx4sW6/vrrNX78eI/zVVVVqqmpUXZ2tvuc3W5XVlaW9uzZI0mqqKjQ6dOnPcYkJiYqNTXVPaajWKQIAICXQC5SzM/PV25ursc5u93uc2xxcbEqKiq0b98+07WamhpJUnx8vMf5+Ph4ffTRR+4x4eHhHp2H78Z89/kdRYEAAEAnstvtbRYE/+yTTz7R0qVLtXv3bvXs2bPNcd4LHw3DOOtiyI6M8cYUAwAAXqzYxVBRUaHa2lqlp6crNDRUoaGhKisr03/8x38oNDTU3Tnw7gTU1ta6rzkcDjU3N6uurq7NMR3ld4Fw6NAhbd68We+//74k6f3339ftt9+uefPm6c9//nOH7uFrwUZzOws2AADoSlbsYhg3bpwOHDig/fv3u4+MjAzdeOON2r9/vwYNGiSHw6HS0lL35zQ3N6usrEyZmZmSpPT0dIWFhXmMqa6u1sGDB91jOsqvKYadO3dq2rRp6t27t7755huVlJTo5ptv1vDhw2UYhiZOnKhdu3bpuuuua/c+TqdT999/v8e5Rbn5uiNvhV/hAQC4UERFRSk1NdXjXGRkpGJjY93nc3JyVFhYqJSUFKWkpKiwsFC9evXSnDlzJEkxMTGaP3++7r77bsXGxqpv377Ky8tTWlqaadHj2fhVIDzwwAO655579NBDD6m4uFhz5szR7bffrpUrV0qSVqxYoVWrVp21QPC1YKPqH2f8Cg4AQGcJ1ncxLFu2TE1NTVq0aJHq6uo0cuRI7d69W1FRUe4x69atU2hoqGbOnKmmpiaNGzdOW7ZsUUhIiF9fy2YYhtHRwTExMaqoqNBll12m1tZW2e12vfXWW7r66qslSQcPHtT48eP9XikpSYeqT/r9OV3lG5d/D5foakP7R1sdodvae6zu7IMsNGJQn7MPstDpllarI7Tp9JkO/9VmiUumPWx1hHbV7bzX6gjt6tnJS+yvXftGwO71Wu7ogN2rK53zIsUePXqoZ8+euuiii9znoqKiVF9fH4hcAADAQn4VCJdeeqk++OAD98d//etfNWDAAPfHn3zyiRISEgKXDgAAC1j9qOVg4FeT5vbbb/d4lrP3YoqXX375rOsPAAAIdsG6BqEr+VUgLFy4sN3r3y1WBACgO6M+4EFJAADABx61DACAF6YYKBAAADChPmCKAQAA+EAHAQAALz1oIVAgAADgjfqAKQYAAOADHQQAALywi4ECAQAAkx7UBxQIAAB4o4PAGgQAAOADHQQAALzQQAiiAiHznhesjtCmA4/+xOoI7Trd0mp1hHZ9daLZ6ghtGjGoj9URurWj1SesjtCm0JDgbpC+u/VOqyOgHTZRIQT3nyAAAGCJoOkgAAAQLNjFQIEAAIAJuxiYYgAAAD7QQQAAwAsNBAoEAABMeJsjUwwAAMAHOggAAHihgUCBAACACbsYKBAAADChPmANAgAA8IEOAgAAXtjFQIEAAIAJ5QFTDAAAwIeAdBAMw2DFJwDggsH/0wLUQbDb7Tp06FAgbgUAgOV62AJ3dFd+dRByc3N9nm9padGqVasUGxsrSVq7dm2793G5XHK5XB7njJbTsoWE+RMHAAB0Er8KhPXr12v48OG66KKLPM4bhqFDhw4pMjKyQ20Zp9Op+++/3+OcPe3Hihj2f/yJAwBAp2CKwc8CYeXKlXryySf1yCOP6LrrrnOfDwsL05YtW3TllVd26D75+fmmbsSABc/7EwUAgE5DfeDnGoT8/Hw9++yzuv3225WXl6fTp0+f0xe12+2Kjo72OJheAAAgePi9SHHEiBGqqKjQF198ofT0dB04cIBWDADggmKz2QJ2dFfntM2xd+/eKioqUnFxsSZMmKCWlpZA5wIAwDLdefdBoJzXcxBmz56ta665RhUVFRo4cGCgMgEAYKnu/JN/oJz3g5L69++v/v37ByILAAAIEryLAQAAL/QPKBAAADDhbY68rAkAAPhABwEAAC80ECgQAAAwYRcDUwwAAMAHOggAAHihgUCBAACACbsYmGIAAAA+UCAAAODFZgvc4Y+NGzdq2LBh7jcdjxo1Si+//LL7umEYKigoUGJioiIiIjR27FhVVlZ63MPlcmnJkiWKi4tTZGSkpk6dquPHj/v9a0CBAACAF6ve5ti/f3+tWrVK+/bt0759+3Tddddp2rRp7iJgzZo1Wrt2rTZs2KC9e/fK4XBowoQJamxsdN8jJydHJSUlKi4uVnl5uU6cOKHJkyf7/WJFm2EYhl+f0Un2VtVbHaFNA+N6WR2hXdERYVZH6LZ+/zf/q+qu9JPhwf2ekz4/ftzqCG366vnbrY7Qrn+cOG11hHbFRYVbHaFdPTt5Bd2SkkMBu9e//2iQXC6Xxzm73S673d6hz+/bt68efvhhzZs3T4mJicrJydHy5cslfdstiI+P1+rVq7VgwQLV19fr4osv1tatWzVr1ixJ0meffaakpCS99NJLmjhxYodz00EAAKATOZ1OxcTEeBxOp/Osn9fS0qLi4mKdPHlSo0aNUlVVlWpqapSdne0eY7fblZWVpT179kiSKioqdPr0aY8xiYmJSk1NdY/pKHYxAADgJZAPSsrPz1dubq7Hufa6BwcOHNCoUaN06tQp9e7dWyUlJbryyivd/4OPj4/3GB8fH6+PPvpIklRTU6Pw8HD16dPHNKampsav3BQIAAB46RHAXY7+TCdI0uWXX679+/fr66+/1nPPPae5c+eqrKzMfd27eDEM46wFTUfGeGOKAQCAIBIeHq7LLrtMGRkZcjqdGj58uH7961/L4XBIkqkTUFtb6+4qOBwONTc3q66urs0xHUWBAACAlx62wB3nyzAMuVwuJScny+FwqLS01H2tublZZWVlyszMlCSlp6crLCzMY0x1dbUOHjzoHtNRTDEAAODFqpc1/fKXv9SkSZOUlJSkxsZGFRcX6y9/+Yt27twpm82mnJwcFRYWKiUlRSkpKSosLFSvXr00Z84cSVJMTIzmz5+vu+++W7Gxserbt6/y8vKUlpam8ePH+5WFAgEAgCDx+eef62c/+5mqq6sVExOjYcOGaefOnZowYYIkadmyZWpqatKiRYtUV1enkSNHavfu3YqKinLfY926dQoNDdXMmTPV1NSkcePGacuWLQoJCfErC89B6ACeg3Dh4jkI54fnIJw7noNwfjr7OQj3/PFwwO718OTLA3avrkQHAQAAL7yriUWKAADABzoIAAB44XXPFAgAAJjQXqdAAADAhAYCRRIAAPCBDgIAAF5Yg0CBAACACfUBUwwAAMCH8+og1NXVqaioSEePHlVCQoLmzp2rpKSks36ey+WSy+XyONfscincj9dhAgDQWQL5uufuyq8OQmJior766itJUlVVla688kqtXr1aR48e1RNPPKG0tDS9//77Z72P0+lUTEyMx7Fl49pz+w4AAAiwHjZbwI7uyq8CoaamRi0tLZK+fePUFVdcof/+7//W7t279cEHH2jMmDH6t3/7t7PeJz8/X/X19R7HLbfnntt3AAAAAu6cpxjeeustPfXUU+rV69sXGdntdv3qV7/ST37yk7N+rt1ul91rOiH8q6B4ZxQAACxS1DkUCN+9I9vlcik+Pt7jWnx8vL744ovAJAMAwCKsQTiHAmHcuHEKDQ1VQ0ODjhw5oqFDh7qvffzxx4qLiwtoQAAA0PX8KhDuu+8+j4+/m174zosvvqgxY8acfyoAACxkEy2E8yoQvD388MPnFQYAgGDAFANPUgQAwIQCgScpAgAAH+ggAADgxcY+RwoEAAC8McXAFAMAAPCBDgIAAF6YYaBAAADApDu/ZClQmGIAAAAmdBAAAPDCIkUKBAAATJhhYIoBAAD4EDQdhBeP1FodoU33JAy2OkK7WloNqyO0q/lMq9UR2vST4f2tjtC9fV1jdYI2NTadsTpCu0LpYQe1HrysKXgKBAAAggVTDBQIAACY0OBhDQIAAPCBDgIAAF54UBIFAgAAJtQHTDEAAAAf6CAAAOCFKQYKBAAATKgPmGIAAAA+0EEAAMALPz1TIAAAYGJjjoEiCQAAmNFBAADAC/0DCgQAAEzY5kiBAACACeUBaxAAAIAPdBAAAPDCDIOfHYR3331XVVVV7o+3bdum0aNHKykpSddcc42Ki4s7dB+Xy6WGhgaP40xzs3/JAQDoJDabLWBHd+VXgTB//nx9+OGHkqSnnnpKt912mzIyMrRixQqNGDFCv/jFL/Sf//mfZ72P0+lUTEyMx/HGs785p28AAAAEnl9TDIcPH9bgwYMlSY8//rjWr1+v2267zX19xIgRWrlypebNm9fuffLz85Wbm+txblXZJ/5EAQCg07BAz88CISIiQl988YUGDBigTz/9VCNHjvS4PnLkSI8piLbY7XbZ7XbPIOHh/kQBAKDTdOepgUDxq0iaNGmSNm7cKEnKysrS73//e4/r//Vf/6XLLrsscOkAAPhfxOl0asSIEYqKilK/fv00ffp0HT582GOMYRgqKChQYmKiIiIiNHbsWFVWVnqMcblcWrJkieLi4hQZGampU6fq+PHjfmXxq0BYvXq1/vSnPykrK0tJSUl65JFHNGbMGN12223KyspSQUGBVq1a5VcAAACCjS2Ahz/Kysq0ePFivfnmmyotLdWZM2eUnZ2tkydPusesWbNGa9eu1YYNG7R37145HA5NmDBBjY2N7jE5OTkqKSlRcXGxysvLdeLECU2ePFktLS0dzuLXFENiYqLeffddrVq1Si+++KIMw9Dbb7+tTz75RKNHj9Ybb7yhjIwMf24JAEDQsWqKYefOnR4fb968Wf369VNFRYWuvfZaGYah9evXa8WKFZoxY4YkqaioSPHx8dq+fbsWLFig+vp6Pf3009q6davGjx8v6dtdh0lJSXrllVc0ceLEDmXxex3GRRddpFWrVqmyslJNTU1yuVz68MMP9dvf/pbiAAAAL7629rtcrg59bn19vSSpb9++kqSqqirV1NQoOzvbPcZutysrK0t79uyRJFVUVOj06dMeYxITE5Wamuoe0xEs1AQAwEuPAB6+tvY7nc6zZjAMQ7m5ubrmmmuUmpoqSaqpqZEkxcfHe4yNj493X6upqVF4eLj69OnT5piO4EmKAAB4CeQUg6+t/d47+Xy544479Pe//13l5eVnzWcYxlkzd2TMP6ODAACAl0AuUrTb7YqOjvY4zlYgLFmyRDt27NCrr76q/v37u887HA5JMnUCamtr3V0Fh8Oh5uZm1dXVtTmmIygQAAAIEoZh6I477tDzzz+vP//5z0pOTva4npycLIfDodLSUve55uZmlZWVKTMzU5KUnp6usLAwjzHV1dU6ePCge0xHMMUAAIAXq56TtHjxYm3fvl1/+MMfFBUV5e4UxMTEKCIiQjabTTk5OSosLFRKSopSUlJUWFioXr16ac6cOe6x8+fP1913363Y2Fj17dtXeXl5SktLc+9q6AgKBAAAvPTw+wkGgfHdwwjHjh3rcX7z5s265ZZbJEnLli1TU1OTFi1apLq6Oo0cOVK7d+9WVFSUe/y6desUGhqqmTNnqqmpSePGjdOWLVsUEhLS4Sw2wzCM8/6OAuD/7jpqdYQ23ZM12OoI7erRI7gfCdp8ptXqCG2KCO/4HxaY9fnX/2t1hDZ9+P/+zeoI7QqOv3nbdlFkmNUR2tWzk3+8ffHA5wG715S0js/7BxM6CAAAeOFVDBQIAACY2CyaYggm7GIAAAAmdBAAAPDCFEMQFQgVH9adfZBFmjI7/vYrK/Tu7NU656ny0warI7Qp/dI+Zx9koWD/S6p6931WR2hTbUPHnnVvlUder7I6QrvWTxtqdQRLWbWLIZgwxQAAAEyC+0dPAAAsEOzdu65AgQAAgBcKBAoEAABM2ObIGgQAAOADHQQAALwE+RPsuwQFAgAAXphiYIoBAAD4QAcBAAAv7GKgQAAAwIQpBqYYAACAD3QQAADwwi4GCgQAAEyYYmCKAQAA+EAHAQAAL+xioEAAAMCE+oACAQAAkx60EPxbg7BkyRK9/vrr5/1FXS6XGhoaPI6W083nfV8AABAYfhUIjz32mMaOHavvfe97Wr16tWpqas7pizqdTsXExHgc/72z6JzuBQBAoNkCeHRXfu9i2L17t370ox/p3//93zVgwABNmzZNf/zjH9Xa2trhe+Tn56u+vt7jGPzDuf5GAQCgc1Ah+F8gpKWlaf369frss8+0bds2uVwuTZ8+XUlJSVqxYoU++OCDs97DbrcrOjra4wgJCz+nbwAAAATeOT8HISwsTDNnztTOnTt17Ngx/eIXv9Bvf/tbXX755YHMBwBAl7MF8J/uKiAPShowYIAKCgpUVVWlnTt3BuKWAABYxmYL3NFd+VUgDBw4UCEhIW1et9lsmjBhwnmHAgAA1vLrOQhVVVWdlQMAgKDRjX/wDxgelAQAgDcqBF7WBAAAzOggAADgpTvvPggUCgQAALx0590HgUKBAACAF+oD1iAAAAAf6CAAAOCNFgIFAgAA3likyBQDAADwgQ4CAABe2MVAgQAAgAn1QRAVCOOGxFkdoU2NTWesjtCu6IgwqyO06+pLL7I6QptOuIL7v21Uz6D5I+rTXX94z+oIbfqP6UOtjtCuPr2C+88tP0EjuP/2AQDAChRIFAgAAHhjFwO7GAAAgA90EAAA8MIaDAoEAABMqA8oEAAAMKNCYA0CAAAwo4MAAIAXdjHQQQAAwMRmC9zhj9dee01TpkxRYmKibDabXnjhBY/rhmGooKBAiYmJioiI0NixY1VZWekxxuVyacmSJYqLi1NkZKSmTp2q48eP+/1rQIEAAECQOHnypIYPH64NGzb4vL5mzRqtXbtWGzZs0N69e+VwODRhwgQ1Nja6x+Tk5KikpETFxcUqLy/XiRMnNHnyZLW0tPiVhSkGAAC8WDXBMGnSJE2aNMnnNcMwtH79eq1YsUIzZsyQJBUVFSk+Pl7bt2/XggULVF9fr6efflpbt27V+PHjJUnbtm1TUlKSXnnlFU2cOLHDWeggAADgzRa4w+VyqaGhweNwuVx+R6qqqlJNTY2ys7Pd5+x2u7KysrRnzx5JUkVFhU6fPu0xJjExUampqe4xHUWBAABAJ3I6nYqJifE4nE6n3/epqamRJMXHx3ucj4+Pd1+rqalReHi4+vTp0+aYjmKKAQAAL4HcxZCfn6/c3FyPc3a7/ZzvZ/Na+WgYhumct46M8UYHAQAAL4HcxWC32xUdHe1xnEuB4HA4JMnUCaitrXV3FRwOh5qbm1VXV9fmmI6iQAAAoBtITk6Ww+FQaWmp+1xzc7PKysqUmZkpSUpPT1dYWJjHmOrqah08eNA9pqOYYgAAwItVuxhOnDihDz74wP1xVVWV9u/fr759+2rAgAHKyclRYWGhUlJSlJKSosLCQvXq1Utz5syRJMXExGj+/Pm6++67FRsbq759+yovL09paWnuXQ0d5XeB8Oijj2rfvn26/vrrNXPmTG3dulVOp1Otra2aMWOGHnjgAYWGtn9bl8tlWsF5ptml0PBzn5MBACBgLKoQ9u3bp3/91391f/zd2oW5c+dqy5YtWrZsmZqamrRo0SLV1dVp5MiR2r17t6Kiotyfs27dOoWGhmrmzJlqamrSuHHjtGXLFoWEhPiVxa8C4cEHH9TDDz+s7OxsLV26VFVVVXr44Yd11113qUePHlq3bp3CwsJ0//33t3sfp9NpGpN9y52a+POlfoUHAKAzWPWo5bFjx8owjDav22w2FRQUqKCgoM0xPXv21KOPPqpHH330vLL4VSBs2bJFW7Zs0YwZM/S3v/1N6enpKioq0o033ihJuuKKK7Rs2bKzFgi+VnT+5u1P/YwOAAA6i18FQnV1tTIyMiRJw4cPV48ePfT973/fff3qq6/WZ599dtb72O120wrO0PAv/YkCAECn8fcdChciv3YxOBwOvffee5Kko0ePqqWlxf2xJFVWVqpfv36BTQgAQBcL4IMUuy2/Oghz5szRzTffrGnTpulPf/qTli9frry8PH311Vey2WxauXKlfvKTn3RWVgAA0EX8KhDuv/9+RURE6M0339SCBQu0fPlyDRs2TMuWLdM333yjKVOm6MEHH+ysrAAAdI3u/KN/gPhVIISEhGjFihUe52bPnq3Zs2cHNBQAAFayahdDMOFJigAAwIQnKQIA4IVdDBQIAACYUB8wxQAAAHyggwAAgDdaCBQIAAB4YxcDBQIAACYsUmQNAgAA8IEOAgAAXmggUCAAAGDCFANTDAAAwIeg6SA88f+OWh2hTT/PGGh1hG6tpcWwOkKbetuD5o9At/R/UoP39e5Np1usjtAufjoLdrQQ+NsRAAAvTDFQxAIAAB/oIAAA4IUGAgUCAAAmTDEwxQAAAHyggwAAgBfexUCBAACAGfUBBQIAAN6oD1iDAAAAfKCDAACAF3YxUCAAAGDCIkWmGAAAgA90EAAA8EYDgQIBAABv1AdMMQAAAB/oIAAA4IVdDOdQIFRXV2vjxo0qLy9XdXW1QkJClJycrOnTp+uWW25RSEhIZ+QEAKDLsIvBzymGffv2aciQIXrxxRd16tQpHTlyRFdffbUiIyOVl5enMWPGqLGx8az3cblcamho8DhazzSf8zcBAAACy68CIScnR3fddZfeffdd7dmzR0VFRTpy5IiKi4t17NgxNTU16Ve/+tVZ7+N0OhUTE+Nx/OPNZ8/5mwAAIJBstsAd3ZVfBcI777yjn/3sZ+6P58yZo3feeUeff/65+vTpozVr1uj3v//9We+Tn5+v+vp6j6PvD2b5nx4AAHQKv9Yg9OvXT9XV1Ro0aJAk6fPPP9eZM2cUHR0tSUpJSdE//vGPs97HbrfLbrd7nOsRGu5PFAAAOk13/sk/UPzqIEyfPl0LFy7Uzp079eqrr+rGG29UVlaWIiIiJEmHDx/WJZdc0ilBAQBA1/Grg/DQQw+purpaU6ZMUUtLi0aNGqVt27a5r9tsNjmdzoCHBACgK7GLwc8CoXfv3nr22Wd16tQpnTlzRr179/a4np2dHdBwAABYgSmGc3xQUs+ePQOdAwAABBGepAgAgBcaCBQIAACYUSHwsiYAAGBGBwEAAC/sYqBAAADAhF0MTDEAAAAf6CAAAOCFBgIFAgAAZlQITDEAAODNFsB//PX4448rOTlZPXv2VHp6ul5//fVO+A7PjgIBAIAg8eyzzyonJ0crVqzQu+++qzFjxmjSpEn6+OOPuzwLBQIAAF5stsAdLpdLDQ0NHofL5fL5ddeuXav58+fr1ltv1ZAhQ7R+/XolJSVp48aNXfwrIMm4AJ06dcq47777jFOnTlkdxSSYsxkG+c5HMGczDPKdj2DOZhjkC3b33XefIcnjuO+++0zjXC6XERISYjz//PMe5++8807j2muv7aK0/8NmGIbR9WVJ52poaFBMTIzq6+sVHR1tdRwPwZxNIt/5COZsEvnORzBnk8gX7Fwul6ljYLfbZbfbPc599tlnuuSSS/TGG28oMzPTfb6wsFBFRUU6fPhwl+T9DrsYAADoRL6KgfbYvJ7SZBiG6VxXYA0CAABBIC4uTiEhIaqpqfE4X1tbq/j4+C7PQ4EAAEAQCA8PV3p6ukpLSz3Ol5aWekw5dJULcorBbrfrvvvu86ul01WCOZtEvvMRzNkk8p2PYM4mke9Ckpubq5/97GfKyMjQqFGjtGnTJn388cdauHBhl2e5IBcpAgDQXT3++ONas2aNqqurlZqaqnXr1unaa6/t8hwUCAAAwIQ1CAAAwIQCAQAAmFAgAAAAEwoEAABgcsEVCMHymkxvr732mqZMmaLExETZbDa98MILVkdyczqdGjFihKKiotSvXz9Nnz69yx/p2Z6NGzdq2LBhio6OVnR0tEaNGqWXX37Z6lhtcjqdstlsysnJsTqKJKmgoEA2m83jcDgcVsdy+/TTT3XTTTcpNjZWvXr10ve//31VVFRYHUuSdOmll5p+7Ww2mxYvXmx1NEnSmTNn9Ktf/UrJycmKiIjQoEGD9MADD6i1tdXqaJKkxsZG5eTkaODAgYqIiFBmZqb27t1rdSx00AVVIATTazK9nTx5UsOHD9eGDRusjmJSVlamxYsX680331RpaanOnDmj7OxsnTx50upokqT+/ftr1apV2rdvn/bt26frrrtO06ZNU2VlpdXRTPbu3atNmzZp2LBhVkfxMHToUFVXV7uPAwcOWB1JklRXV6fRo0crLCxML7/8st577z098sgjuuiii6yOJunb/57//Ov23QNsfvrTn1qc7FurV6/Wb37zG23YsEGHDh3SmjVr9PDDD+vRRx+1Opok6dZbb1Vpaam2bt2qAwcOKDs7W+PHj9enn35qdTR0RJe/HqoT/cu//IuxcOFCj3NXXHGFce+991qUyDdJRklJidUx2lRbW2tIMsrKyqyO0qY+ffoYTz31lNUxPDQ2NhopKSlGaWmpkZWVZSxdutTqSIZhfPsmueHDh1sdw6fly5cb11xzjdUxOmzp0qXG4MGDjdbWVqujGIZhGNdff70xb948j3MzZswwbrrpJosS/Y9vvvnGCAkJMf74xz96nB8+fLixYsUKi1LBHxdMB6G5uVkVFRXKzs72OJ+dna09e/ZYlKp7qq+vlyT17dvX4iRmLS0tKi4u1smTJzVq1Cir43hYvHixrr/+eo0fP97qKCZHjx5VYmKikpOTNXv2bB07dszqSJKkHTt2KCMjQz/96U/Vr18/XXXVVXryySetjuVTc3Oztm3bpnnz5lny4hxfrrnmGv3pT3/SkSNHJEl/+9vfVF5erh/96EcWJ/t2+qOlpUU9e/b0OB8REaHy8nKLUsEfF8yjlr/88ku1tLSYXmgRHx9vevEF2mYYhnJzc3XNNdcoNTXV6jhuBw4c0KhRo3Tq1Cn17t1bJSUluvLKK62O5VZcXKyKigrt27fP6igmI0eO1DPPPKPvfe97+vzzz/XQQw8pMzNTlZWVio2NtTTbsWPHtHHjRuXm5uqXv/yl3n77bd15552y2+26+eabLc3m7YUXXtDXX3+tW265xeoobsuXL1d9fb2uuOIKhYSEqKWlRStXrtQNN9xgdTRFRUVp1KhRevDBBzVkyBDFx8frd7/7nd566y2lpKRYHQ8dcMEUCN8Jltdkdld33HGH/v73vwddhX/55Zdr//79+vrrr/Xcc89p7ty5KisrC4oi4ZNPPtHSpUu1e/du009LwWDSpEnuf09LS9OoUaM0ePBgFRUVKTc318JkUmtrqzIyMlRYWChJuuqqq1RZWamNGzcGXYHw9NNPa9KkSUpMTLQ6ituzzz6rbdu2afv27Ro6dKj279+vnJwcJSYmau7cuVbH09atWzVv3jxdcsklCgkJ0dVXX605c+bonXfesToaOuCCKRCC7TWZ3dGSJUu0Y8cOvfbaa+rfv7/VcTyEh4frsssukyRlZGRo7969+vWvf60nnnjC4mRSRUWFamtrlZ6e7j7X0tKi1157TRs2bJDL5VJISIiFCT1FRkYqLS1NR48etTqKEhISTEXekCFD9Nxzz1mUyLePPvpIr7zyip5//nmro3i45557dO+992r27NmSvi0AP/roIzmdzqAoEAYPHqyysjKdPHlSDQ0NSkhI0KxZs5ScnGx1NHTABbMGIdhek9mdGIahO+64Q88//7z+/Oc/d4s/vIZhyOVyWR1DkjRu3DgdOHBA+/fvdx8ZGRm68cYbtX///qAqDiTJ5XLp0KFDSkhIsDqKRo8ebdpSe+TIEQ0cONCiRL5t3rxZ/fr10/XXX291FA/ffPONevTw/Gs8JCQkaLY5ficyMlIJCQmqq6vTrl27NG3aNKsjoQMumA6CFFyvyfR24sQJffDBB+6Pq6qqtH//fvXt21cDBgywMNm3i+u2b9+uP/zhD4qKinJ3YWJiYhQREWFpNkn65S9/qUmTJikpKUmNjY0qLi7WX/7yF+3cudPqaJK+nWv1Xq8RGRmp2NjYoFjHkZeXpylTpmjAgAGqra3VQw89pIaGhqD4CfOuu+5SZmamCgsLNXPmTL399tvatGmTNm3aZHU0t9bWVm3evFlz585VaGhw/ZU5ZcoUrVy5UgMGDNDQoUP17rvvau3atZo3b57V0SRJu3btkmEYuvzyy/XBBx/onnvu0eWXX66f//znVkdDR1i6h6ITPPbYY8bAgQON8PBw4+qrrw6arXqvvvqqIcl0zJ071+poPnNJMjZv3mx1NMMwDGPevHnu/6YXX3yxMW7cOGP37t1Wx2pXMG1znDVrlpGQkGCEhYUZiYmJxowZM4zKykqrY7m9+OKLRmpqqmG3240rrrjC2LRpk9WRPOzatcuQZBw+fNjqKCYNDQ3G0qVLjQEDBhg9e/Y0Bg0aZKxYscJwuVxWRzMMwzCeffZZY9CgQUZ4eLjhcDiMxYsXG19//bXVsdBBvO4ZAACYXDBrEAAAQOBQIAAAABMKBAAAYEKBAAAATCgQAACACQUCAAAwoUAAAAAmFAgAAMCEAgEAAJhQIAAAABMKBAAAYPL/ARxspbFfL9QlAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "sns.heatmap(brc_confusion, annot=False, cmap='Blues')" + ] + }, + { + "cell_type": "markdown", + "id": "3a50f7f3", + "metadata": {}, + "source": [ + "### 2.5 Gaussian mixture" + ] + }, + { + "cell_type": "markdown", + "id": "33c3d7e7", + "metadata": {}, + "source": [ + "###    2.5.1 Fitting Models" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "e98317fb", + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.mixture import GaussianMixture\n", + "gm = GaussianMixture(n_components = 10).fit(x_train_pca)\n", + "y_pred_gm = gm.predict(x_test_pca)" + ] + }, + { + "cell_type": "markdown", + "id": "b4aa5c01", + "metadata": {}, + "source": [ + "###    2.5.2 Clustering results evaluations" + ] + }, + { + "cell_type": "markdown", + "id": "547683c2", + "metadata": {}, + "source": [ + "####     NMI (Normaliszed Mutual Information) and ARI (Adjusted Rand Index)" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "52f39cb4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "NMI:0.1823\n", + "ARI:0.0836\n" + ] + } + ], + "source": [ + "print(\"NMI:{:.4f}\".format(normalized_mutual_info_score(y_test, y_pred_gm,average_method='arithmetic')))\n", + "print(\"ARI:{:.4f}\".format(adjusted_rand_score(y_test, y_pred_gm)))" + ] + }, + { + "cell_type": "markdown", + "id": "337975a0", + "metadata": {}, + "source": [ + "####     Confusion Matrix" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "09af68b9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Confusion matrix: \n", + "[[ 1 2 398 105 41 34 103 82 234 0]\n", + " [369 99 23 0 61 15 73 28 18 314]\n", + " [337 39 33 5 146 5 38 114 7 276]\n", + " [ 13 3 207 170 187 7 99 283 31 0]\n", + " [ 77 198 38 1 165 134 188 36 35 128]\n", + " [331 51 55 213 34 0 16 23 9 268]\n", + " [333 129 4 0 66 45 83 18 8 314]\n", + " [ 18 185 32 2 190 209 150 177 15 22]\n", + " [107 104 71 23 148 6 82 161 17 281]\n", + " [119 138 12 1 266 40 124 162 4 134]]\n" + ] + } + ], + "source": [ + "gm_confusion = confusion_matrix(y_test,y_pred_kmean_pca)\n", + "print('Confusion matrix: \\n{}'.format(gm_confusion))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "971194b5", + "metadata": {}, + "outputs": [], + "source": [ + "sns.heatmap(gm_confusion, annot=False, cmap='Blues')\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "3a8306de", + "metadata": {}, + "source": [ + "## 3. Appendix" + ] + }, + { + "cell_type": "markdown", + "id": "ca1adee0", + "metadata": {}, + "source": [ + "### 3.1 KNN" + ] + }, + { + "cell_type": "markdown", + "id": "2efec242", + "metadata": {}, + "source": [ + "#### We also tried other models including KNN and various regression models to see how they perform. Note that these are not unsupervised models, and labels were used to train the models. Therefore, these implementations are not included in the final report. " + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "c27cb6d8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Fitting KNeighborsClassifier(n_jobs=-1, n_neighbors=4, weights='distance')\n", + "Evaluating KNeighborsClassifier(n_jobs=-1, n_neighbors=4, weights='distance')\n" + ] + } + ], + "source": [ + "from sklearn.neighbors import KNeighborsClassifier\n", + "\n", + "clf = KNeighborsClassifier(n_neighbors=4, weights='distance', n_jobs=-1)\n", + "print('Fitting', clf)\n", + "clf.fit(x_trainf, y_train)\n", + "print('Evaluating', clf)\n", + "\n", + "y_pred_knn = clf.predict(x_testf)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "31b937dc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Test accuracy: 0.921\n" + ] + } + ], + "source": [ + "test_score = clf.score(x_testf, y_test)\n", + "print('Test accuracy:', test_score)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "df13d267", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.8231031975085954" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from sklearn.metrics.cluster import normalized_mutual_info_score\n", + "normalized_mutual_info_score(y_test, y_pred_knn,average_method='arithmetic')" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "2226f1b5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.8234272963838932" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from sklearn.metrics.cluster import normalized_mutual_info_score\n", + "normalized_mutual_info_score(y_test, y_pred_knn,average_method='min')" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "09c6c71e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.823103261265742" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from sklearn.metrics.cluster import normalized_mutual_info_score\n", + "normalized_mutual_info_score(y_test, y_pred_knn,average_method='geometric')" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "cb776587", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.8227793536618937" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from sklearn.metrics.cluster import normalized_mutual_info_score\n", + "normalized_mutual_info_score(y_test, y_pred_knn,average_method='max')" + ] + }, + { + "cell_type": "markdown", + "id": "12d38ecb", + "metadata": {}, + "source": [ + "####    Confusion Matrix" + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "id": "4bef3167", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Confusion matrix: \n", + "[[910 0 1 1 3 26 0 15 44 0]\n", + " [ 1 910 27 1 4 3 39 0 8 7]\n", + " [ 9 6 880 45 8 15 14 3 18 2]\n", + " [ 1 1 18 969 0 5 3 1 2 0]\n", + " [ 13 12 11 16 885 10 14 2 28 9]\n", + " [ 1 5 36 8 2 931 12 0 3 2]\n", + " [ 3 2 20 6 8 3 951 2 2 3]\n", + " [ 1 8 11 4 9 6 14 912 21 14]\n", + " [ 0 13 9 6 0 7 12 0 952 1]\n", + " [ 2 24 10 2 3 4 13 5 27 910]]\n" + ] + } + ], + "source": [ + "from sklearn.metrics import confusion_matrix\n", + "knn_confusion = confusion_matrix(y_test,y_pred_knn)\n", + "print('Confusion matrix: \\n{}'.format(knn_confusion))" + ] + }, + { + "cell_type": "code", + "execution_count": 102, + "id": "2bc87615", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 102, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "import seaborn as sns\n", + "sns.heatmap(knn_confusion, annot=True, cmap='Blues')" + ] + }, + { + "cell_type": "markdown", + "id": "98588e59", + "metadata": {}, + "source": [ + "### 3.2 Four Regression Models" + ] + }, + { + "cell_type": "code", + "execution_count": 86, + "id": "84d84800", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training data set score: 0.337\n", + "Test data set score: 0.187\n" + ] + } + ], + "source": [ + "# Linear regression\n", + "from sklearn.linear_model import LinearRegression\n", + "lr = LinearRegression().fit(x_trainf,y_train)\n", + "print('Training data set score: {:.3f}'.format(lr.score(x_trainf, y_train)))\n", + "print('Test data set score: {:.3f}'.format(lr.score(x_testf, y_test))) # overfit" + ] + }, + { + "cell_type": "code", + "execution_count": 87, + "id": "0f2cb297", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training data set score: 0.34\n", + "Test data set score: 0.19\n" + ] + } + ], + "source": [ + "# Ridge Regression\n", + "from sklearn.linear_model import Ridge\n", + "ridge = Ridge().fit(x_trainf, y_train)\n", + "print('Training data set score: {:.2f}'.format(ridge.score(x_trainf, y_train)))\n", + "print('Test data set score: {:.2f}'.format(ridge.score(x_testf, y_test))) # overfit" + ] + }, + { + "cell_type": "code", + "execution_count": 88, + "id": "b3bee2c1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training data set score: 0.34\n", + "Test data set score: 0.19\n" + ] + } + ], + "source": [ + "# Lasso Regression\n", + "from sklearn.linear_model import Lasso\n", + "lasso = Lasso(alpha = 0.1, max_iter = 100000).fit(x_trainf, y_train)\n", + "print('Training data set score: {:.2f}'.format(lasso.score(x_trainf, y_train)))\n", + "print('Test data set score: {:.2f}'.format(lasso.score(x_testf, y_test))) # overfit" + ] + }, + { + "cell_type": "code", + "execution_count": 91, + "id": "1f7c891d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training data set score: 0.838\n", + "Test data set score: 0.691\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "D:\\ANA\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:814: ConvergenceWarning: lbfgs failed to converge (status=1):\n", + "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", + "\n", + "Increase the number of iterations (max_iter) or scale the data as shown in:\n", + " https://scikit-learn.org/stable/modules/preprocessing.html\n", + "Please also refer to the documentation for alternative solver options:\n", + " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", + " n_iter_i = _check_optimize_result(\n" + ] + } + ], + "source": [ + "# Logistic Regression\n", + "from sklearn.linear_model import LogisticRegression\n", + "logreg = LogisticRegression(max_iter=100).fit(x_trainf, y_train)\n", + "print('Training data set score: {:.3f}'.format(logreg.score(x_trainf, y_train)))\n", + "print('Test data set score: {:.3f}'.format(logreg.score(x_testf, y_test))) # overfit" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/__pycache__/__init__.cpython-38.pyc b/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..3160c29 Binary files /dev/null and b/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/utils/evaluation.py b/utils/evaluation.py new file mode 100644 index 0000000..5ab52f7 --- /dev/null +++ b/utils/evaluation.py @@ -0,0 +1,62 @@ +import numpy as np +import matplotlib.pyplot as plt +from sklearn.metrics import recall_score, f1_score, classification_report, confusion_matrix, ConfusionMatrixDisplay + + +def avg_accuracy(gt, pred): + """ + gt: ground truth + pred: prediction + """ + correct_cnt = (pred == gt).sum() + acc = correct_cnt * 1.0 / pred.shape[0] + return acc + + +def class_metric(confusion_matrix, class_id): + """ + confusion_matrix: confusion matrix of multi-class classification + class_id: id of a particular class + """ + confusion_matrix = np.float64(confusion_matrix) + TP = confusion_matrix[class_id, class_id] + FN = np.sum(confusion_matrix[class_id]) - TP + FP = np.sum(confusion_matrix[:, class_id]) - TP + TN = np.sum(confusion_matrix) - TP - FN - FP + + accuracy = (TP + TN) / (TP + FP + FN + TN) + precision = TP / (TP + FP) + recall = TP / (TP + FN) + f_score = 2 * precision * recall / (precision + recall) + return accuracy, precision, recall, f_score + + +# plot train loss +def visualize_train_loss(train_loss, logger, log_img_path): + plt.xlabel('Train Loss') + plt.plot(train_loss) + path = log_img_path + 'train_loss.png' + plt.savefig(path) + plt.clf() + logger.info('Train Loss Visualization is saved to ' + path) + + +# plot validation accuracy +def visualize_val_accuracy(val_acc, logger, log_img_path): + plt.xlabel('Validation Accuracy') + plt.plot(val_acc) + path = log_img_path + 'val_acc.png' + plt.savefig(path) + plt.clf() + logger.info('Validation Accuracy Visualization is saved to ' + path) + + +# plot confusion matrix +def visualize_confusion_matrix(gt, pred, logger, log_img_path): + cm = confusion_matrix(gt, pred) + disp = ConfusionMatrixDisplay(cm).plot() + path = log_img_path + 'confusion_matrix.png' + plt.savefig(path) + plt.clf() + logger.info('Confusion Matrix Visualization is saved to ' + path) + return cm diff --git a/utils/init.py b/utils/init.py new file mode 100644 index 0000000..97c6e26 --- /dev/null +++ b/utils/init.py @@ -0,0 +1,18 @@ +from nets import ResNet +import timm +from config import * + + +# initialize model +def get_model(model): + if model == 'ResNet-18': + return timm.create_model('resnet18', pretrained=True, num_classes=config['N_CLASSES']) + elif model == 'ResNet-34': + return timm.create_model('resnet34', pretrained=True, num_classes=config['N_CLASSES']) + elif model == 'ResMLP-12': + return timm.create_model('resmlp_12_224', pretrained=True, num_classes=config['N_CLASSES']) + elif model == 'ResMLP-24': + return timm.create_model('resmlp_24_224', pretrained=True, num_classes=config['N_CLASSES']) + else: + print('No required model') + return None diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000..2d5084d --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,88 @@ +import sys +import logging +import os + + +def get_logger(exp_dir): + """ + creates logger instance. writing out info to file and to terminal. + :param exp_dir: experiment directory, where exec.log file is stored. + :return: logger instance. + """ + if not os.path.exists(exp_dir): + os.makedirs(exp_dir) + + logger = logging.getLogger('ChestXray_detection') + logger.setLevel(logging.DEBUG) + log_file = os.path.join(exp_dir, 'log.txt') + hdlr = logging.FileHandler(log_file) + print('Logging to {}'.format(log_file)) + logger.addHandler(hdlr) + logger.addHandler(ColorHandler()) + logger.propagate = False + return logger + + +class _AnsiColorizer(object): + """ + A colorizer is an object that loosely wraps around a stream, allowing + callers to write text to the stream in a particular color. + + Colorizer classes must implement C{supported()} and C{write(text, color)}. + """ + _colors = dict(black=30, red=31, green=32, yellow=33, + blue=34, magenta=35, cyan=36, white=37, default=39) + + def __init__(self, stream): + self.stream = stream + + @classmethod + def supported(cls, stream=sys.stdout): + """ + A class method that returns True if the current platform supports + coloring terminal output using this method. Returns False otherwise. + """ + if not stream.isatty(): + return False # auto color only on TTYs + try: + import curses + except ImportError: + return False + else: + try: + try: + return curses.tigetnum("colors") > 2 + except curses.error: + curses.setupterm() + return curses.tigetnum("colors") > 2 + except: + raise + # guess false in case of error + return False + + def write(self, text, color): + """ + Write the given text to the stream in the given color. + + @param text: Text to be written to the stream. + + @param color: A string label for a color. e.g. 'red', 'white'. + """ + color = self._colors[color] + self.stream.write('\x1b[%sm%s\x1b[0m' % (color, text)) + + +class ColorHandler(logging.StreamHandler): + + def __init__(self, stream=sys.stdout): + super(ColorHandler, self).__init__(_AnsiColorizer(stream)) + + def emit(self, record): + msg_colors = { + logging.DEBUG: "green", + logging.INFO: "default", + logging.WARNING: "red", + logging.ERROR: "red" + } + color = msg_colors.get(record.levelno, "blue") + self.stream.write(record.msg + "\n", color)