diff --git a/.github/.keep b/.github/.keep new file mode 100644 index 00000000..e69de29b diff --git a/.github/workflows/classroom.yml b/.github/workflows/classroom.yml new file mode 100644 index 00000000..e7879f7d --- /dev/null +++ b/.github/workflows/classroom.yml @@ -0,0 +1,69 @@ +name: Autograding Tests +'on': +- workflow_dispatch +- repository_dispatch +permissions: + checks: write + actions: read + contents: read +jobs: + run-autograding-tests: + runs-on: ubuntu-latest + if: github.actor != 'github-classroom[bot]' + steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Setup + id: setup + uses: classroom-resources/autograding-command-grader@v1 + with: + test-name: Setup + setup-command: sudo -H pip3 install -qr requirements.txt; sudo -H pip3 install + flake8==5.0.4 + command: flake8 --ignore "N801, E203, E266, E501, W503, F812, E741, N803, + N802, N806" minitorch/ tests/ project/; mypy minitorch/* + timeout: 10 + max-score: 10 + - name: Task 1.1 + id: task-1-1 + uses: classroom-resources/autograding-command-grader@v1 + with: + test-name: Task 1.1 + setup-command: sudo -H pip3 install -qr requirements.txt + command: pytest -m task1_1 + timeout: 10 + max-score: 10 + - name: Task 1.2 + id: task-1-2 + uses: classroom-resources/autograding-command-grader@v1 + with: + test-name: Task 1.2 + setup-command: sudo -H pip3 install -qr requirements.txt + command: pytest -m task1_2 + timeout: 10 + - name: task 1.3 + id: task-1-3 + uses: classroom-resources/autograding-command-grader@v1 + with: + test-name: task 1.3 + setup-command: sudo -H pip3 install -qr requirements.txt + command: pytest -m task1_3 + timeout: 10 + - name: Task 1.4 + id: task-1-4 + uses: classroom-resources/autograding-command-grader@v1 + with: + test-name: Task 1.4 + setup-command: sudo -H pip3 install -qr requirements.txt + command: pytest -m task1_4 + timeout: 10 + - name: Autograding Reporter + uses: classroom-resources/autograding-grading-reporter@v1 + env: + SETUP_RESULTS: "${{steps.setup.outputs.result}}" + TASK-1-1_RESULTS: "${{steps.task-1-1.outputs.result}}" + TASK-1-2_RESULTS: "${{steps.task-1-2.outputs.result}}" + TASK-1-3_RESULTS: "${{steps.task-1-3.outputs.result}}" + TASK-1-4_RESULTS: "${{steps.task-1-4.outputs.result}}" + with: + runners: setup,task-1-1,task-1-2,task-1-3,task-1-4 diff --git a/.github/workflows/minitorch.yml b/.github/workflows/minitorch.yml new file mode 100644 index 00000000..776051ae --- /dev/null +++ b/.github/workflows/minitorch.yml @@ -0,0 +1,45 @@ +name: CI (Module 1) + +on: + push: + pull_request: + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10"] + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: pip + + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y graphviz + python -m pip install --upgrade pip + pip install flake8 pytest pep8-naming + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + # на всякий случай — установить сам пакет + pip install -e . + + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 --ignore "N801,E203,E266,E501,W503,W504,F812,F401,F841,E741,N803,N802,N806,E128,E302" minitorch/ tests/ project/ + + - name: Test with pytest (Module 1) + run: | + echo "Module 1" + pytest tests -q -x -m task1_1 + pytest tests -q -x -m task1_2 + pytest tests -q -x -m task1_3 + pytest tests -q -x -m task1_4 diff --git a/README.md b/README.md index 46933775..933b4849 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,4 @@ +[![Open in Visual Studio Code](https://classroom.github.com/assets/open-in-vscode-2e0aaae1b6195c2367325f4f02e2d04e9abb55f0b24a779b69b11b9e10269abc.svg)](https://classroom.github.com/online_ide?assignment_repo_id=20763289&assignment_repo_type=AssignmentRepo) # MiniTorch Module 1 @@ -15,3 +16,664 @@ python sync_previous_module.py previous-module-dir current-module-dir The files that will be synced are: minitorch/operators.py minitorch/module.py tests/test_module.py tests/test_operators.py project/run_manual.py + + +# Scalar Training + +Ниже результаты обучения сети на разных наборах данных: + +## Simple +**Config:** `PTS=50, HIDDEN=2, RATE=0.5` + +**Logs (excerpt):** + +Epoch 10 loss 32.99451624777131 correct 29 +Epoch 20 loss 29.67106517619935 correct 44 +Epoch 30 loss 20.505422569195055 correct 49 +Epoch 40 loss 12.040371792254232 correct 49 +Epoch 50 loss 8.05779782525025 correct 50 +Epoch 60 loss 5.776877152171056 correct 50 +Epoch 70 loss 4.32599156373445 correct 50 +Epoch 80 loss 3.391007016923641 correct 50 +Epoch 90 loss 2.7390332746561117 correct 50 +Epoch 100 loss 2.2638205155347118 correct 50 +Epoch 110 loss 1.9058296475262912 correct 50 +Epoch 130 loss 1.4126621296878168 correct 50 +Epoch 140 loss 1.2384903100021274 correct 50 +Epoch 140 loss 1.2384903100021274 correct 50 +Epoch 150 loss 1.0965209310276882 correct 50 +Epoch 150 loss 1.0965209310276882 correct 50 +Epoch 160 loss 0.9791433120667529 correct 50 +Epoch 170 loss 0.8810811523651478 correct 50 +Epoch 160 loss 0.9791433120667529 correct 50 +Epoch 170 loss 0.8810811523651478 correct 50 +Epoch 180 loss 0.798041390403876 correct 50 +Epoch 190 loss 0.7270314986517517 correct 50 +Epoch 180 loss 0.798041390403876 correct 50 +Epoch 190 loss 0.7270314986517517 correct 50 +Epoch 200 loss 0.6657803100113434 correct 50 +Epoch 210 loss 0.6127190836144141 correct 50 +Epoch 200 loss 0.6657803100113434 correct 50 +Epoch 210 loss 0.6127190836144141 correct 50 +Epoch 220 loss 0.5666541785345667 correct 50 +Epoch 220 loss 0.5666541785345667 correct 50 +Epoch 230 loss 0.5259841239283145 correct 50 +Epoch 240 loss 0.4899498393389119 correct 50 +Epoch 230 loss 0.5259841239283145 correct 50 +Epoch 240 loss 0.4899498393389119 correct 50 +Epoch 250 loss 0.4580529051584932 correct 50 +Epoch 260 loss 0.42947990336786396 correct 50 +Epoch 250 loss 0.4580529051584932 correct 50 +Epoch 260 loss 0.42947990336786396 correct 50 +Epoch 270 loss 0.40371825574841774 correct 50 +Epoch 280 loss 0.3803927096251919 correct 50 +Epoch 270 loss 0.40371825574841774 correct 50 +Epoch 280 loss 0.3803927096251919 correct 50 +Epoch 290 loss 0.3591923758308195 correct 50 +Epoch 300 loss 0.3398567923406658 correct 50 +Epoch 290 loss 0.3591923758308195 correct 50 +Epoch 300 loss 0.3398567923406658 correct 50 +Epoch 310 loss 0.32216577622298964 correct 50 +Epoch 320 loss 0.305931791725392 correct 50 +Epoch 310 loss 0.32216577622298964 correct 50 +Epoch 320 loss 0.305931791725392 correct 50 +Epoch 330 loss 0.29099407053442744 correct 50 +Epoch 330 loss 0.29099407053442744 correct 50 +Epoch 340 loss 0.2772139984633007 correct 50 +Epoch 350 loss 0.2644714439861456 correct 50 +Epoch 340 loss 0.2772139984633007 correct 50 +Epoch 350 loss 0.2644714439861456 correct 50 +Epoch 360 loss 0.25266180270088984 correct 50 +Epoch 350 loss 0.2644714439861456 correct 50 +Epoch 360 loss 0.25266180270088984 correct 50 +Epoch 360 loss 0.25266180270088984 correct 50 +Epoch 370 loss 0.2416935953168232 correct 50 +Epoch 380 loss 0.23148649944459582 correct 50 +Epoch 390 loss 0.22196972517478103 correct 50 +Epoch 390 loss 0.22196972517478103 correct 50 +Epoch 400 loss 0.2130806656978178 correct 50 +Epoch 400 loss 0.2130806656978178 correct 50 +Epoch 410 loss 0.20476376978405275 correct 50 +Epoch 410 loss 0.20476376978405275 correct 50 +Epoch 420 loss 0.19696959454012272 correct 50 +Epoch 420 loss 0.19696959454012272 correct 50 +Epoch 430 loss 0.18965400562421794 correct 50 +Epoch 430 loss 0.18965400562421794 correct 50 +Epoch 440 loss 0.18277749880928493 correct 50 +Epoch 440 loss 0.18277749880928493 correct 50 +Epoch 450 loss 0.17630462196666688 correct 50 +Epoch 460 loss 0.17020348058519222 correct 50 +Epoch 470 loss 0.16444531311832322 correct 50 +Epoch 450 loss 0.17630462196666688 correct 50 +Epoch 460 loss 0.17020348058519222 correct 50 +Epoch 470 loss 0.16444531311832322 correct 50 +Epoch 460 loss 0.17020348058519222 correct 50 +Epoch 470 loss 0.16444531311832322 correct 50 +Epoch 470 loss 0.16444531311832322 correct 50 +Epoch 480 loss 0.15900412496832433 correct 50 +Epoch 490 loss 0.15385637192186644 correct 50 +Epoch 500 loss 0.14898068545998813 correct 50 + + +## Diag +**Config:** `PTS=50, HIDDEN=3, RATE=0.5, EPOCHS=800` + +**Logs:** + +Epoch 10 loss 18.41767648121497 correct 44 +Epoch 20 loss 18.241374745482055 correct 44 +Epoch 30 loss 18.077694322397104 correct 44 +Epoch 40 loss 17.787496076171966 correct 44 +Epoch 50 loss 17.22622063447276 correct 44 +Epoch 60 loss 16.218299532974434 correct 44 +Epoch 70 loss 14.60746253340782 correct 44 +Epoch 80 loss 12.23724317791127 correct 44 +Epoch 90 loss 9.544850994180615 correct 46 +Epoch 100 loss 7.562065011671856 correct 46 +Epoch 110 loss 6.112602137050321 correct 48 +Epoch 120 loss 5.0048459933990905 correct 48 +Epoch 130 loss 4.162684014938122 correct 48 +Epoch 140 loss 3.52155744883408 correct 50 +Epoch 150 loss 3.0258024023642136 correct 50 +Epoch 160 loss 2.7067221081928436 correct 50 +Epoch 170 loss 2.982213947123389 correct 50 +Epoch 180 loss 3.571243230756708 correct 49 +Epoch 190 loss 3.085185630314938 correct 50 +Epoch 200 loss 2.6323532616598326 correct 50 +Epoch 210 loss 2.2357765540748717 correct 50 +Epoch 220 loss 1.7922072530790387 correct 50 +Epoch 230 loss 1.4067375619879097 correct 50 +Epoch 240 loss 1.2235048686696837 correct 50 +Epoch 250 loss 1.1016726862974993 correct 50 +Epoch 260 loss 1.0035961711498014 correct 50 +Epoch 270 loss 0.9205055811305474 correct 50 +Epoch 280 loss 0.8454999986534688 correct 50 +Epoch 290 loss 0.7841042944912567 correct 50 +Epoch 300 loss 0.729316067557551 correct 50 +Epoch 310 loss 0.6798828131913032 correct 50 +Epoch 320 loss 0.6351716025665827 correct 50 +Epoch 330 loss 0.5946362524099438 correct 50 +Epoch 340 loss 0.5578028898126859 correct 50 +Epoch 350 loss 0.5242590057316765 correct 50 +Epoch 360 loss 0.4936445178733028 correct 50 +Epoch 370 loss 0.46564435582316355 correct 50 +Epoch 380 loss 0.4399822374613626 correct 50 +Epoch 390 loss 0.41641540204658256 correct 50 +Epoch 400 loss 0.39473012676172964 correct 50 +Epoch 410 loss 0.37473789411177993 correct 50 +Epoch 420 loss 0.3562721055229054 correct 50 +Epoch 430 loss 0.3391852565769399 correct 50 +Epoch 440 loss 0.3233465043336502 correct 50 +Epoch 450 loss 0.3086395688231362 correct 50 +Epoch 460 loss 0.29496092005563224 correct 50 +Epoch 470 loss 0.28221820943708226 correct 50 +Epoch 480 loss 0.2703289107133867 correct 50 +Epoch 490 loss 0.2592191407757518 correct 50 +Epoch 500 loss 0.24882263504409974 correct 50 +Epoch 510 loss 0.23907985585290037 correct 50 +Epoch 520 loss 0.22993721540795956 correct 50 +Epoch 530 loss 0.22134639755421673 correct 50 +Epoch 540 loss 0.21326376486729373 correct 50 +Epoch 550 loss 0.20564983951648994 correct 50 +Epoch 560 loss 0.19846884799540454 correct 50 +Epoch 570 loss 0.19168832122143856 correct 50 +Epoch 580 loss 0.1852787427037586 correct 50 +Epoch 590 loss 0.1792132385019649 correct 50 +Epoch 600 loss 0.1734673035708013 correct 50 +Epoch 610 loss 0.16801855983232233 correct 50 +Epoch 620 loss 0.16284654195522322 correct 50 +Epoch 630 loss 0.15793250736690068 correct 50 +Epoch 640 loss 0.15325926749207958 correct 50 +Epoch 650 loss 0.14881103761299838 correct 50 +Epoch 660 loss 0.14457330309083513 correct 50 +Epoch 670 loss 0.14053269998435156 correct 50 +Epoch 680 loss 0.13667690835659957 correct 50 +Epoch 690 loss 0.13299455678041916 correct 50 +Epoch 700 loss 0.1294751367428776 correct 50 +Epoch 710 loss 0.12610892581277108 correct 50 +Epoch 720 loss 0.12288691857692009 correct 50 +Epoch 730 loss 0.11980076447385773 correct 50 +Epoch 740 loss 0.11684271175994544 correct 50 +Epoch 750 loss 0.11400555693557328 correct 50 +Epoch 760 loss 0.11128259903951022 correct 50 +Epoch 770 loss 0.10866759828968314 correct 50 +Epoch 780 loss 0.10615473860975419 correct 50 +Epoch 790 loss 0.10373859363438033 correct 50 +Epoch 800 loss 0.10141409583270015 correct 50 + + +## Split +**Config:** `PTS=50, HIDDEN=6, RATE=0.5, EPOCHS=800` + +**Logs:** + +Epoch 10 loss 33.75235919418314 correct 27 +Epoch 20 loss 32.85948268142515 correct 30 +Epoch 30 loss 31.5066623771144 correct 42 +Epoch 40 loss 29.65750160610111 correct 39 +Epoch 50 loss 27.80131987403149 correct 38 +Epoch 60 loss 26.167785037919668 correct 38 +Epoch 70 loss 24.66200397481523 correct 38 +Epoch 80 loss 22.884712684200203 correct 39 +Epoch 90 loss 20.657544351214334 correct 41 +Epoch 100 loss 19.112345931137142 correct 41 +Epoch 110 loss 19.459879663637146 correct 40 +Epoch 120 loss 17.83294384764803 correct 40 +Epoch 130 loss 16.679996723418938 correct 40 +Epoch 140 loss 15.719071510137669 correct 40 +Epoch 150 loss 15.703546633170133 correct 40 +Epoch 160 loss 14.804445623923408 correct 40 +Epoch 170 loss 13.066678291825822 correct 40 +Epoch 180 loss 11.403197834137822 correct 43 +Epoch 190 loss 9.605408576185999 correct 44 +Epoch 200 loss 7.285074421009864 correct 48 +Epoch 210 loss 4.304000386495314 correct 49 +Epoch 220 loss 3.4198006693554395 correct 50 +Epoch 230 loss 3.0136028308413305 correct 50 +Epoch 240 loss 2.711368668392037 correct 50 +Epoch 250 loss 2.467331442167649 correct 50 +Epoch 260 loss 2.263274347933186 correct 50 +Epoch 270 loss 2.0900745732594412 correct 50 +Epoch 280 loss 1.9396301277348078 correct 50 +Epoch 290 loss 1.8091389841229173 correct 50 +Epoch 300 loss 1.6932449867147714 correct 50 +Epoch 310 loss 1.5903559022792013 correct 50 +Epoch 320 loss 1.497552606799441 correct 50 +Epoch 330 loss 1.4144004570703967 correct 50 +Epoch 340 loss 1.3385405082291562 correct 50 +Epoch 350 loss 1.2697716062501032 correct 50 +Epoch 360 loss 1.2069461436387507 correct 50 +Epoch 370 loss 1.1493443884652235 correct 50 +Epoch 380 loss 1.0967117952245755 correct 50 +Epoch 390 loss 1.0476488254983098 correct 50 +Epoch 400 loss 1.0020368204487586 correct 50 +Epoch 410 loss 0.9598708689621344 correct 50 +Epoch 420 loss 0.9206122736420168 correct 50 +Epoch 430 loss 0.8839076650282672 correct 50 +Epoch 440 loss 0.8495254396937801 correct 50 +Epoch 450 loss 0.8172635150111653 correct 50 +Epoch 460 loss 0.7869203350795496 correct 50 +Epoch 470 loss 0.7583766147677934 correct 50 +Epoch 480 loss 0.7314697635123658 correct 50 +Epoch 490 loss 0.7061127014776227 correct 50 +Epoch 500 loss 0.6821342136253884 correct 50 +Epoch 510 loss 0.6594411318814467 correct 50 +Epoch 520 loss 0.6379700071761089 correct 50 +Epoch 530 loss 0.6175500667433639 correct 50 +Epoch 540 loss 0.5981613319799653 correct 50 +Epoch 550 loss 0.5797782059585791 correct 50 +Epoch 560 loss 0.5622985183204681 correct 50 +Epoch 570 loss 0.5456702017959787 correct 50 +Epoch 580 loss 0.5297827226238305 correct 50 +Epoch 590 loss 0.5146217885303698 correct 50 +Epoch 600 loss 0.5001534644463455 correct 50 +Epoch 610 loss 0.4863190738562789 correct 50 +Epoch 620 loss 0.4730769346467114 correct 50 +Epoch 630 loss 0.46042134263569323 correct 50 +Epoch 640 loss 0.44827767350160963 correct 50 +Epoch 650 loss 0.4366712480135445 correct 50 +Epoch 660 loss 0.42552859095855566 correct 50 +Epoch 670 loss 0.4148430557641768 correct 50 +Epoch 680 loss 0.404564552057625 correct 50 +Epoch 690 loss 0.3947008701407381 correct 50 +Epoch 700 loss 0.3852010314076847 correct 50 +Epoch 710 loss 0.37607071986325724 correct 50 +Epoch 720 loss 0.36728463323116706 correct 50 +Epoch 730 loss 0.3588245684683126 correct 50 +Epoch 740 loss 0.3506886629139211 correct 50 +Epoch 750 loss 0.34284684715298863 correct 50 +Epoch 760 loss 0.33527444467360507 correct 50 +Epoch 770 loss 0.3279736730074926 correct 50 +Epoch 780 loss 0.32092303189405386 correct 50 +Epoch 790 loss 0.31410711315851425 correct 50 +Epoch 800 loss 0.30752631825797144 correct 50 + + +## XOR +**Config:** `PTS=50, HIDDEN=10, RATE=0.5, EPOCHS=1000` + +**Logs:** + +Epoch 10 loss 31.710840814280708 correct 37 +Epoch 20 loss 27.619434727361654 correct 44 +Epoch 30 loss 35.15578215814027 correct 27 +Epoch 40 loss 19.65814036712436 correct 43 +Epoch 50 loss 19.86269657424437 correct 41 +Epoch 60 loss 15.914213912439719 correct 44 +Epoch 70 loss 14.545861680601343 correct 45 +Epoch 80 loss 14.912667936081437 correct 45 +Epoch 90 loss 15.580266441239873 correct 44 +Epoch 100 loss 13.417874983848856 correct 45 +Epoch 110 loss 12.252437753240343 correct 45 +Epoch 120 loss 11.6841781074567 correct 45 +Epoch 130 loss 10.916858638621648 correct 46 +Epoch 140 loss 10.28385039442434 correct 47 +Epoch 150 loss 9.707644708218638 correct 47 +Epoch 160 loss 9.235796909717186 correct 48 +Epoch 170 loss 8.772067053248017 correct 48 +Epoch 180 loss 8.139314506793943 correct 48 +Epoch 190 loss 7.460210138703441 correct 48 +Epoch 200 loss 6.927219607486519 correct 48 +Epoch 210 loss 6.376593067036393 correct 48 +Epoch 220 loss 5.913140927983057 correct 49 +Epoch 230 loss 5.483373888652583 correct 49 +Epoch 240 loss 5.061634295028014 correct 49 +Epoch 250 loss 4.682677816397556 correct 49 +Epoch 260 loss 4.335117380899207 correct 49 +Epoch 270 loss 3.99520944553466 correct 49 +Epoch 280 loss 3.683357234181038 correct 49 +Epoch 290 loss 3.416203378252851 correct 49 +Epoch 300 loss 3.1486328702689423 correct 49 +Epoch 310 loss 2.9083076827214844 correct 49 +Epoch 320 loss 2.6780420929150224 correct 49 +Epoch 330 loss 2.487639668979288 correct 49 +Epoch 340 loss 2.3241246076104614 correct 49 +Epoch 350 loss 2.1616183847637296 correct 50 +Epoch 360 loss 2.0043098043445395 correct 50 +Epoch 370 loss 1.880820235858625 correct 50 +Epoch 380 loss 1.7791692991943602 correct 50 +Epoch 390 loss 1.6651766645908703 correct 50 +Epoch 400 loss 1.557542048850534 correct 50 +Epoch 410 loss 1.4744329886289733 correct 50 +Epoch 420 loss 1.385967388946977 correct 50 +Epoch 430 loss 1.3124665776994984 correct 50 +Epoch 440 loss 1.2241953077783376 correct 50 +Epoch 450 loss 1.163476776544279 correct 50 +Epoch 460 loss 1.1045837738465025 correct 50 +Epoch 470 loss 1.0523798442926309 correct 50 +Epoch 480 loss 1.0028078975391308 correct 50 +Epoch 490 loss 0.9589416345016335 correct 50 +Epoch 500 loss 0.9115663638093799 correct 50 +Epoch 510 loss 0.8741984577090567 correct 50 +Epoch 520 loss 0.8307487866205087 correct 50 +Epoch 530 loss 0.8015547549050164 correct 50 +Epoch 540 loss 0.7668909377698404 correct 50 +Epoch 550 loss 0.7407965319031712 correct 50 +Epoch 560 loss 0.7137931412422537 correct 50 +Epoch 570 loss 0.6839435648496994 correct 50 +Epoch 580 loss 0.6620958530429077 correct 50 +Epoch 590 loss 0.6424177379021753 correct 50 +Epoch 600 loss 0.6222351333348232 correct 50 +Epoch 610 loss 0.5984583263793786 correct 50 +Epoch 620 loss 0.5755716347010625 correct 50 +Epoch 630 loss 0.5694832067543669 correct 50 +Epoch 640 loss 0.5463241229098806 correct 50 +Epoch 650 loss 0.5332457454154678 correct 50 +Epoch 660 loss 0.5172149217294235 correct 50 +Epoch 670 loss 0.4966068054997729 correct 50 +Epoch 680 loss 0.48983871229031295 correct 50 +Epoch 690 loss 0.468972200777932 correct 50 +Epoch 700 loss 0.45349750377209086 correct 50 +Epoch 710 loss 0.44426382343021786 correct 50 +Epoch 720 loss 0.4352274800525545 correct 50 +Epoch 730 loss 0.4266537576558136 correct 50 +Epoch 740 loss 0.4178681085078735 correct 50 +Epoch 750 loss 0.4010334948330733 correct 50 +Epoch 760 loss 0.3944971270018335 correct 50 +Epoch 770 loss 0.38194219976990984 correct 50 +Epoch 780 loss 0.37775906698736694 correct 50 +Epoch 790 loss 0.3663828777720555 correct 50 +Epoch 800 loss 0.3606369644047365 correct 50 +Epoch 810 loss 0.345528962682325 correct 50 +Epoch 820 loss 0.3403264929221114 correct 50 +Epoch 830 loss 0.3379484035700957 correct 50 +Epoch 840 loss 0.3235983903768161 correct 50 +Epoch 850 loss 0.320870540554103 correct 50 +Epoch 860 loss 0.3172260906216512 correct 50 +Epoch 870 loss 0.3052846719025846 correct 50 +Epoch 880 loss 0.30153148360617515 correct 50 +Epoch 890 loss 0.296975810407801 correct 50 +Epoch 900 loss 0.29088183227443853 correct 50 +Epoch 910 loss 0.28702131065022113 correct 50 +Epoch 920 loss 0.28158399140948914 correct 50 +Epoch 930 loss 0.2799092483854836 correct 50 +Epoch 940 loss 0.2767137545306454 correct 50 +Epoch 950 loss 0.268825390007598 correct 50 +Epoch 960 loss 0.25776259786465877 correct 50 +Epoch 970 loss 0.26017808392686775 correct 50 +Epoch 980 loss 0.25749061311077026 correct 50 +Epoch 990 loss 0.25034927599141327 correct 50 +Epoch 1000 loss 0.24044277636451425 correct 50 + + +## Circle +**Config:** `PTS=50, HIDDEN=12, RATE=0.5, EPOCHS=1200` + +**Logs:** + +Epoch 10 loss 29.873557603025002 correct 34 +Epoch 20 loss 28.637278694749938 correct 34 +Epoch 30 loss 26.693378698986294 correct 34 +Epoch 40 loss 22.641189507191495 correct 38 +Epoch 50 loss 19.450614781242184 correct 42 +Epoch 60 loss 22.21288506941197 correct 36 +Epoch 70 loss 20.02815629412635 correct 39 +Epoch 80 loss 18.485933278115393 correct 41 +Epoch 90 loss 18.938354344463555 correct 40 +Epoch 100 loss 15.801384177517944 correct 42 +Epoch 110 loss 20.54504173116277 correct 38 +Epoch 120 loss 12.0860808864857 correct 47 +Epoch 130 loss 9.89058856137845 correct 46 +Epoch 140 loss 9.747368374017055 correct 47 +Epoch 150 loss 17.02054927862798 correct 41 +Epoch 160 loss 8.417497986638514 correct 48 +Epoch 170 loss 7.135484223361151 correct 49 +Epoch 180 loss 6.527563031251834 correct 49 +Epoch 190 loss 6.861019469954264 correct 47 +Epoch 200 loss 10.417714715506913 correct 45 +Epoch 210 loss 7.005945422699021 correct 46 +Epoch 220 loss 6.260572947445757 correct 46 +Epoch 230 loss 7.410931169681428 correct 45 +Epoch 240 loss 5.539103577833562 correct 46 +Epoch 250 loss 6.62250431049139 correct 46 +Epoch 260 loss 5.598421674830261 correct 46 +Epoch 270 loss 7.864933680599687 correct 46 +Epoch 280 loss 4.997206273167954 correct 47 +Epoch 290 loss 7.574590575558654 correct 45 +Epoch 300 loss 6.562601431616851 correct 45 +Epoch 310 loss 4.87326656356385 correct 46 +Epoch 320 loss 2.6038494048825527 correct 50 +Epoch 330 loss 2.315534036694687 correct 50 +Epoch 340 loss 8.82038939619125 correct 46 +Epoch 350 loss 2.9707709095829307 correct 50 +Epoch 360 loss 2.0873986412435075 correct 50 +Epoch 370 loss 1.8453401831690723 correct 50 +Epoch 380 loss 1.7092717007677 correct 50 +Epoch 390 loss 1.6084180871732447 correct 50 +Epoch 400 loss 1.5280966883231553 correct 50 +Epoch 410 loss 1.3874388714658201 correct 50 +Epoch 420 loss 1.3605726791706867 correct 50 +Epoch 430 loss 1.2963766574678814 correct 50 +Epoch 440 loss 1.2152162535923534 correct 50 +Epoch 450 loss 1.1423989193296675 correct 50 +Epoch 460 loss 1.0568357842479614 correct 50 +Epoch 470 loss 1.019782761902968 correct 50 +Epoch 480 loss 0.9570502900552107 correct 50 +Epoch 490 loss 0.9173540575276944 correct 50 +Epoch 500 loss 0.8703217240424889 correct 50 +Epoch 510 loss 0.836076807034688 correct 50 +Epoch 520 loss 0.7989480024734267 correct 50 +Epoch 530 loss 0.7616411283425616 correct 50 +Epoch 540 loss 0.7322332364904421 correct 50 +Epoch 550 loss 0.7026049137321023 correct 50 +Epoch 560 loss 0.6734059435788355 correct 50 +Epoch 570 loss 0.6514508905857581 correct 50 +Epoch 580 loss 0.6253657491441774 correct 50 +Epoch 590 loss 0.6016419327474636 correct 50 +Epoch 600 loss 0.5808687869112696 correct 50 +Epoch 610 loss 0.5612780091994775 correct 50 +Epoch 620 loss 0.5418942382173502 correct 50 +Epoch 630 loss 0.5236758362305461 correct 50 +Epoch 640 loss 0.5075390663097562 correct 50 +Epoch 650 loss 0.4911208096803893 correct 50 +Epoch 660 loss 0.476981659802743 correct 50 +Epoch 670 loss 0.46137354360866134 correct 50 +Epoch 680 loss 0.44798499696600164 correct 50 +Epoch 690 loss 0.43528270400602687 correct 50 +Epoch 700 loss 0.4222083739843732 correct 50 +Epoch 710 loss 0.4104886183463712 correct 50 +Epoch 720 loss 0.39895653616710464 correct 50 +Epoch 730 loss 0.38878342428216206 correct 50 +Epoch 740 loss 0.37886167579833563 correct 50 +Epoch 750 loss 0.368570697147489 correct 50 +Epoch 760 loss 0.3589496718668607 correct 50 +Epoch 770 loss 0.3499901651590006 correct 50 +Epoch 780 loss 0.3413641537631822 correct 50 +Epoch 790 loss 0.3332250243507555 correct 50 +Epoch 800 loss 0.3251315015013947 correct 50 +Epoch 810 loss 0.3178095612824394 correct 50 +Epoch 820 loss 0.31041912002011607 correct 50 +Epoch 830 loss 0.30319841873983694 correct 50 +Epoch 840 loss 0.2964133607017596 correct 50 +Epoch 850 loss 0.29008566506978545 correct 50 +Epoch 860 loss 0.2836831216252566 correct 50 +Epoch 870 loss 0.27764790133186296 correct 50 +Epoch 880 loss 0.2718419682640759 correct 50 +Epoch 890 loss 0.26620205028629884 correct 50 +Epoch 900 loss 0.26070236509732264 correct 50 +Epoch 910 loss 0.2555126126082473 correct 50 +Epoch 920 loss 0.25049285536606997 correct 50 +Epoch 930 loss 0.2454331491339338 correct 50 +Epoch 940 loss 0.2407241228091517 correct 50 +Epoch 950 loss 0.23600127713997474 correct 50 +Epoch 960 loss 0.2315757283112603 correct 50 +Epoch 970 loss 0.22730738692423827 correct 50 +Epoch 980 loss 0.22314738218284338 correct 50 +Epoch 990 loss 0.21922294110060087 correct 50 +Epoch 1000 loss 0.21508616884109666 correct 50 +Epoch 1010 loss 0.21131815650591748 correct 50 +Epoch 1020 loss 0.20752792859416885 correct 50 +Epoch 1030 loss 0.20404733796003943 correct 50 +Epoch 1040 loss 0.20054067812529294 correct 50 +Epoch 1050 loss 0.19711702710764678 correct 50 +Epoch 1060 loss 0.19374560749859682 correct 50 +Epoch 1070 loss 0.19050864188466443 correct 50 +Epoch 1080 loss 0.18745335981221015 correct 50 +Epoch 1090 loss 0.18441274555208223 correct 50 +Epoch 1100 loss 0.18148126665642772 correct 50 +Epoch 1110 loss 0.17869402481144853 correct 50 +Epoch 1120 loss 0.1758290920335748 correct 50 +Epoch 1130 loss 0.17310900893611622 correct 50 +Epoch 1140 loss 0.17041855936055722 correct 50 +Epoch 1150 loss 0.16786323307966283 correct 50 +Epoch 1160 loss 0.16529510067226796 correct 50 +Epoch 1170 loss 0.16290304693243238 correct 50 +Epoch 1180 loss 0.16048035852185277 correct 50 +Epoch 1190 loss 0.15815382883096868 correct 50 +Epoch 1200 loss 0.15586020629583122 correct 50 + + +## Spiral +**Config:** `PTS=100, HIDDEN=20, RATE=0.5, EPOCHS=1500` + +**Logs:** + +Epoch 10 loss 68.71511608037055 correct 61 +Epoch 20 loss 68.30588930020942 correct 65 +Epoch 30 loss 68.00332527042707 correct 61 +Epoch 40 loss 67.70521657149985 correct 58 +Epoch 50 loss 67.47138855991601 correct 58 +Epoch 60 loss 67.26976357327436 correct 59 +Epoch 70 loss 67.12410826307537 correct 58 +Epoch 80 loss 67.00707320601778 correct 58 +Epoch 90 loss 66.91497041301285 correct 58 +Epoch 100 loss 66.83572070894556 correct 59 +Epoch 110 loss 66.76864379993027 correct 59 +Epoch 120 loss 66.7022574516325 correct 59 +Epoch 130 loss 66.62232360866038 correct 59 +Epoch 140 loss 66.54962864111374 correct 59 +Epoch 150 loss 66.48262495947542 correct 58 +Epoch 160 loss 66.41083748439603 correct 59 +Epoch 170 loss 66.3664138202179 correct 59 +Epoch 180 loss 66.31367467459017 correct 59 +Epoch 190 loss 66.2907758410631 correct 63 +Epoch 200 loss 66.98089399261403 correct 52 +Epoch 210 loss 67.06177733008565 correct 51 +Epoch 220 loss 66.33192033937969 correct 60 +Epoch 230 loss 66.15057941730524 correct 61 +Epoch 240 loss 66.10070291524839 correct 61 +Epoch 250 loss 66.01192857278852 correct 61 +Epoch 260 loss 65.97473727265219 correct 61 +Epoch 270 loss 65.92903057040971 correct 60 +Epoch 280 loss 65.89971517954437 correct 59 +Epoch 290 loss 65.89941068415979 correct 57 +Epoch 300 loss 65.87273454966903 correct 56 +Epoch 310 loss 65.79172252179588 correct 56 +Epoch 320 loss 65.61825667056381 correct 59 +Epoch 330 loss 65.49223403247998 correct 61 +Epoch 340 loss 65.45810445493414 correct 59 +Epoch 350 loss 65.4260714391121 correct 58 +Epoch 360 loss 65.42439498679903 correct 57 +Epoch 370 loss 65.12558840760677 correct 59 +Epoch 380 loss 65.12263175748602 correct 59 +Epoch 390 loss 65.14940130288834 correct 58 +Epoch 400 loss 65.00537990021363 correct 57 +Epoch 410 loss 64.89748060567388 correct 57 +Epoch 420 loss 64.7885364700252 correct 57 +Epoch 430 loss 64.84280871319638 correct 54 +Epoch 440 loss 64.81893998341585 correct 53 +Epoch 450 loss 64.61214677100281 correct 54 +Epoch 460 loss 64.39150353735262 correct 55 +Epoch 470 loss 64.23123804962226 correct 59 +Epoch 480 loss 64.29466188549844 correct 56 +Epoch 490 loss 64.47516768150778 correct 51 +Epoch 500 loss 64.48738324125286 correct 53 +Epoch 510 loss 64.32293607711219 correct 51 +Epoch 520 loss 64.00080281055355 correct 56 +Epoch 530 loss 63.961607833575904 correct 55 +Epoch 540 loss 63.93100987939193 correct 52 +Epoch 550 loss 63.82236663127581 correct 55 +Epoch 560 loss 63.74195256644413 correct 52 +Epoch 570 loss 63.64833957025323 correct 52 +Epoch 580 loss 63.48868699116711 correct 52 +Epoch 590 loss 63.3862750280804 correct 51 +Epoch 600 loss 63.28512438287466 correct 51 +Epoch 610 loss 63.22309156486676 correct 51 +Epoch 620 loss 63.113414011986976 correct 55 +Epoch 630 loss 62.97173427564117 correct 53 +Epoch 640 loss 62.986479652566075 correct 51 +Epoch 650 loss 63.09358328696164 correct 52 +Epoch 660 loss 63.12407676671131 correct 53 +Epoch 670 loss 62.948836110456575 correct 53 +Epoch 680 loss 62.83578457174591 correct 53 +Epoch 690 loss 62.69799755893202 correct 52 +Epoch 700 loss 62.74274934766089 correct 55 +Epoch 710 loss 62.682703581533374 correct 54 +Epoch 720 loss 62.58973462419676 correct 54 +Epoch 730 loss 62.29942212185198 correct 54 +Epoch 740 loss 62.256309115529454 correct 55 +Epoch 750 loss 62.67994497845996 correct 57 +Epoch 760 loss 62.35083030795144 correct 55 +Epoch 770 loss 62.25955261906933 correct 56 +Epoch 780 loss 62.15185472664485 correct 58 +Epoch 790 loss 61.83685263176922 correct 55 +Epoch 800 loss 61.66471757719838 correct 56 +Epoch 810 loss 62.56581629317485 correct 62 +Epoch 820 loss 61.89923803491597 correct 59 +Epoch 830 loss 61.53355410830513 correct 58 +Epoch 840 loss 61.80143358072765 correct 62 +Epoch 850 loss 61.82515412052376 correct 58 +Epoch 860 loss 61.68465874367569 correct 60 +Epoch 870 loss 61.281307286153734 correct 59 +Epoch 880 loss 60.74353654336668 correct 56 +Epoch 890 loss 61.276936859656935 correct 61 +Epoch 900 loss 61.23453721339674 correct 60 +Epoch 910 loss 60.68750857246043 correct 60 +Epoch 920 loss 60.52102737524052 correct 60 +Epoch 930 loss 60.92190154409212 correct 61 +Epoch 940 loss 60.80069844809939 correct 64 +Epoch 950 loss 60.551295478144404 correct 61 +Epoch 960 loss 61.35047329759586 correct 62 +Epoch 970 loss 61.201626802506546 correct 61 +Epoch 980 loss 60.276827760684796 correct 62 +Epoch 990 loss 59.539994015615534 correct 64 +Epoch 1000 loss 60.80067748852008 correct 61 +Epoch 1010 loss 61.67082447175194 correct 61 +Epoch 1020 loss 60.43011429772951 correct 62 +Epoch 1030 loss 60.24125283261178 correct 61 +Epoch 1040 loss 60.216744151713804 correct 64 +Epoch 1050 loss 59.89960620304444 correct 63 +Epoch 1060 loss 60.3740579607568 correct 63 +Epoch 1070 loss 60.07679982710797 correct 64 +Epoch 1080 loss 60.00566934436829 correct 65 +Epoch 1090 loss 59.70365346169148 correct 62 +Epoch 1100 loss 59.91358188159955 correct 64 +Epoch 1110 loss 59.473717162404114 correct 65 +Epoch 1120 loss 59.78432820756897 correct 62 +Epoch 1130 loss 59.47552114117245 correct 68 +Epoch 1140 loss 59.10096301256793 correct 63 +Epoch 1150 loss 59.630560045185945 correct 65 +Epoch 1160 loss 59.28043412404348 correct 67 +Epoch 1170 loss 59.073255536559856 correct 62 +Epoch 1180 loss 59.16488798942276 correct 62 +Epoch 1190 loss 58.8962009236161 correct 66 +Epoch 1200 loss 58.56811011623168 correct 66 +Epoch 1210 loss 58.86981297452959 correct 66 +Epoch 1220 loss 58.93937821166227 correct 67 +Epoch 1230 loss 58.536200435628366 correct 66 +Epoch 1240 loss 58.91699684381775 correct 68 +Epoch 1250 loss 58.85785353833471 correct 66 +Epoch 1260 loss 58.543703602179086 correct 66 +Epoch 1270 loss 58.030444406997006 correct 66 +Epoch 1290 loss 57.81490957159442 correct 67 +Epoch 1300 loss 58.00140726690014 correct 66 +Epoch 1310 loss 57.934715432543896 correct 66 +Epoch 1320 loss 57.99325114718681 correct 65 +Epoch 1330 loss 58.607947056997354 correct 67 +Epoch 1340 loss 56.50163935605804 correct 68 +Epoch 1350 loss 57.565455468309366 correct 65 +Epoch 1360 loss 58.327873310157486 correct 63 +Epoch 1370 loss 56.52132160696343 correct 70 +Epoch 1380 loss 55.87205126197996 correct 68 +Epoch 1390 loss 57.31080309061276 correct 68 +Epoch 1400 loss 57.35315618763012 correct 67 +Epoch 1410 loss 59.84551149331509 correct 65 +Epoch 1420 loss 56.448622670255965 correct 63 +Epoch 1430 loss 59.42861474237802 correct 58 +Epoch 1440 loss 55.098119209084246 correct 70 +Epoch 1450 loss 53.33741456176819 correct 70 +Epoch 1460 loss 57.80002625049885 correct 66 +Epoch 1470 loss 61.88064947510794 correct 59 +Epoch 1480 loss 55.72888569952491 correct 66 +Epoch 1490 loss 58.11701591073116 correct 59 +Epoch 1500 loss 56.503936394488356 correct 63 \ No newline at end of file diff --git a/logs/circle.txt b/logs/circle.txt new file mode 100644 index 00000000..0fb91ea0 Binary files /dev/null and b/logs/circle.txt differ diff --git a/logs/diag.txt b/logs/diag.txt new file mode 100644 index 00000000..7d983a43 Binary files /dev/null and b/logs/diag.txt differ diff --git a/logs/simple.txt b/logs/simple.txt new file mode 100644 index 00000000..ee4031ce Binary files /dev/null and b/logs/simple.txt differ diff --git a/logs/spiral.txt b/logs/spiral.txt new file mode 100644 index 00000000..fc015803 Binary files /dev/null and b/logs/spiral.txt differ diff --git a/logs/split.txt b/logs/split.txt new file mode 100644 index 00000000..87a4a538 Binary files /dev/null and b/logs/split.txt differ diff --git a/logs/xor.txt b/logs/xor.txt new file mode 100644 index 00000000..6491adbe Binary files /dev/null and b/logs/xor.txt differ diff --git a/minitorch/autodiff.py b/minitorch/autodiff.py index 2b69873b..7f90bf24 100644 --- a/minitorch/autodiff.py +++ b/minitorch/autodiff.py @@ -22,8 +22,16 @@ def central_difference(f: Any, *vals: Any, arg: int = 0, epsilon: float = 1e-6) Returns: An approximation of $f'_i(x_0, \ldots, x_{n-1})$ """ - # TODO: Implement for Task 1.1. - raise NotImplementedError("Need to implement for Task 1.1") + vals_pos = list(vals) + vals_neg = list(vals) + + vals_pos[arg] = vals[arg] + epsilon + vals_neg[arg] = vals[arg] - epsilon + + f_pos = f(*vals_pos) + f_neg = f(*vals_neg) + + return (f_pos - f_neg) / (2.0 * epsilon) variable_count = 1 @@ -61,8 +69,22 @@ def topological_sort(variable: Variable) -> Iterable[Variable]: Returns: Non-constant Variables in topological order starting from the right. """ - # TODO: Implement for Task 1.4. - raise NotImplementedError("Need to implement for Task 1.4") + visited = set() + post: List[Variable] = [] + + def dfs(v: Variable) -> None: + uid = v.unique_id + if uid in visited: + return + visited.add(uid) + if v.is_constant(): + return + for p in v.parents: + dfs(p) + post.append(v) + + dfs(variable) + return list(reversed(post)) def backpropagate(variable: Variable, deriv: Any) -> None: @@ -76,8 +98,20 @@ def backpropagate(variable: Variable, deriv: Any) -> None: No return. Should write to its results to the derivative values of each leaf through `accumulate_derivative`. """ - # TODO: Implement for Task 1.4. - raise NotImplementedError("Need to implement for Task 1.4") + topo = list(topological_sort(variable)) + + grads: dict[int, float] = {} + grads[variable.unique_id] = float(deriv) + + for v in topo: + g_out = float(grads.get(v.unique_id, 0.0)) + if v.is_leaf(): + v.accumulate_derivative(g_out) + continue + + for parent, g_local in v.chain_rule(g_out): + pid = parent.unique_id + grads[pid] = grads.get(pid, 0.0) + float(g_local) @dataclass diff --git a/minitorch/module.py b/minitorch/module.py index 11fc1f39..dcdcffab 100644 --- a/minitorch/module.py +++ b/minitorch/module.py @@ -1,14 +1,14 @@ from __future__ import annotations -from typing import Any, Dict, Optional, Sequence, Tuple +from typing import Any, Dict, Optional, Sequence, Tuple, List class Module: - """ - Modules form a tree that store parameters and other + """Modules form a tree that store parameters and other submodules. They make up the basis of neural network stacks. - Attributes: + Attributes + ---------- _modules : Storage of the child modules _parameters : Storage of the module's parameters training : Whether the module is in training mode or evaluation mode @@ -25,42 +25,57 @@ def __init__(self) -> None: self.training = True def modules(self) -> Sequence[Module]: - "Return the direct child modules of this module." + """Return the direct child modules of this module.""" m: Dict[str, Module] = self.__dict__["_modules"] return list(m.values()) def train(self) -> None: - "Set the mode of this module and all descendent modules to `train`." - raise NotImplementedError("Need to include this file from past assignment.") + """Set the mode of this module and all descendent modules to `train`.""" + self.training = True + for child in self._modules.values(): + child.train() def eval(self) -> None: - "Set the mode of this module and all descendent modules to `eval`." - raise NotImplementedError("Need to include this file from past assignment.") + """Set the mode of this module and all descendent modules to `eval`.""" + self.training = False + for child in self._modules.values(): + child.eval() def named_parameters(self) -> Sequence[Tuple[str, Parameter]]: - """ - Collect all the parameters of this module and its descendents. + """Collect all the parameters of this module and its descendents. - - Returns: + Returns + ------- The name and `Parameter` of each ancestor parameter. + """ - raise NotImplementedError("Need to include this file from past assignment.") + out: List[Tuple[str, Parameter]] = [] + + for name, p in self._parameters.items(): + out.append((name, p)) + + for child_name, child_mod in self._modules.items(): + for sub_name, p in child_mod.named_parameters(): + out.append((f"{child_name}.{sub_name}", p)) + + return out def parameters(self) -> Sequence[Parameter]: - "Enumerate over all the parameters of this module and its descendents." - raise NotImplementedError("Need to include this file from past assignment.") + """Enumerate over all the parameters of this module and its descendents.""" + return [p for _, p in self.named_parameters()] def add_parameter(self, k: str, v: Any) -> Parameter: - """ - Manually add a parameter. Useful helper for scalar parameters. + """Manually add a parameter. Useful helper for scalar parameters. Args: + ---- k: Local name of the parameter. v: Value for the parameter. Returns: + ------- Newly created parameter. + """ val = Parameter(v, k) self.__dict__["_parameters"][k] = val @@ -114,8 +129,7 @@ def _addindent(s_: str, numSpaces: int) -> str: class Parameter: - """ - A Parameter is a special container stored in a `Module`. + """A Parameter is a special container stored in a `Module`. It is designed to hold a `Variable`, but we allow it to hold any value for testing. @@ -130,7 +144,7 @@ def __init__(self, x: Any, name: Optional[str] = None) -> None: self.value.name = self.name def update(self, x: Any) -> None: - "Update the parameter value." + """Update the parameter value.""" self.value = x if hasattr(x, "requires_grad_"): self.value.requires_grad_(True) diff --git a/minitorch/operators.py b/minitorch/operators.py index 895ae82d..efba54d9 100644 --- a/minitorch/operators.py +++ b/minitorch/operators.py @@ -1,185 +1,164 @@ -""" -Collection of the core mathematical operators used throughout the code base. -""" +"""Collection of the core mathematical operators used throughout the code base.""" import math -from typing import Callable, Iterable # ## Task 0.1 +from typing import Callable, Iterable, Iterator, List, TypeVar, Sequence + # # Implementation of a prelude of elementary functions. +# Mathematical functions: +# - mul +# - id +# - add +# - neg +# - lt +# - eq +# - max +# - is_close +# - sigmoid +# - relu +# - log +# - exp +# - log_back +# - inv +# - inv_back +# - relu_back +# +# For sigmoid calculate as: +# $f(x) = \frac{1.0}{(1.0 + e^{-x})}$ if x >=0 else $\frac{e^x}{(1.0 + e^{x})}$ +# For is_close: +# $f(x) = |x - y| < 1e-2$ + def mul(x: float, y: float) -> float: - "$f(x, y) = x * y$" - raise NotImplementedError("Need to include this file from past assignment.") + return x * y def id(x: float) -> float: - "$f(x) = x$" - raise NotImplementedError("Need to include this file from past assignment.") + return x def add(x: float, y: float) -> float: - "$f(x, y) = x + y$" - raise NotImplementedError("Need to include this file from past assignment.") + return x + y def neg(x: float) -> float: - "$f(x) = -x$" - raise NotImplementedError("Need to include this file from past assignment.") + return -x def lt(x: float, y: float) -> float: - "$f(x) =$ 1.0 if x is less than y else 0.0" - raise NotImplementedError("Need to include this file from past assignment.") + return 1.0 if x < y else 0.0 def eq(x: float, y: float) -> float: - "$f(x) =$ 1.0 if x is equal to y else 0.0" - raise NotImplementedError("Need to include this file from past assignment.") + return 1.0 if x == y else 0.0 def max(x: float, y: float) -> float: - "$f(x) =$ x if x is greater than y else y" - raise NotImplementedError("Need to include this file from past assignment.") + return x if x > y else y -def is_close(x: float, y: float) -> float: - "$f(x) = |x - y| < 1e-2$" - raise NotImplementedError("Need to include this file from past assignment.") +def is_close(x: float, y: float) -> bool: + return abs(x - y) < 1e-2 def sigmoid(x: float) -> float: - r""" - $f(x) = \frac{1.0}{(1.0 + e^{-x})}$ - - (See https://en.wikipedia.org/wiki/Sigmoid_function ) - - Calculate as - - $f(x) = \frac{1.0}{(1.0 + e^{-x})}$ if x >=0 else $\frac{e^x}{(1.0 + e^{x})}$ - - for stability. - """ - raise NotImplementedError("Need to include this file from past assignment.") + if x >= 0.0: + z = math.exp(-x) + return 1.0 / (1.0 + z) + else: + z = math.exp(x) + return z / (1.0 + z) def relu(x: float) -> float: - """ - $f(x) =$ x if x is greater than 0, else 0 - - (See https://en.wikipedia.org/wiki/Rectifier_(neural_networks) .) - """ - raise NotImplementedError("Need to include this file from past assignment.") - - -EPS = 1e-6 + return x if x > 0.0 else 0.0 def log(x: float) -> float: - "$f(x) = log(x)$" - return math.log(x + EPS) + return math.log(x) def exp(x: float) -> float: - "$f(x) = e^{x}$" return math.exp(x) -def log_back(x: float, d: float) -> float: - r"If $f = log$ as above, compute $d \times f'(x)$" - raise NotImplementedError("Need to include this file from past assignment.") +def inv(x: float) -> float: + return 1.0 / x -def inv(x: float) -> float: - "$f(x) = 1/x$" - raise NotImplementedError("Need to include this file from past assignment.") +def log_back(a: float, b: float) -> float: + return b / a -def inv_back(x: float, d: float) -> float: - r"If $f(x) = 1/x$ compute $d \times f'(x)$" - raise NotImplementedError("Need to include this file from past assignment.") +def inv_back(a: float, b: float) -> float: + return -b / (a * a) -def relu_back(x: float, d: float) -> float: - r"If $f = relu$ compute $d \times f'(x)$" - raise NotImplementedError("Need to include this file from past assignment.") +def relu_back(a: float, b: float) -> float: + return b if a > 0.0 else 0.0 # ## Task 0.3 # Small practice library of elementary higher-order functions. +# Implement the following core functions +# - map +# - zipWith +# - reduce +# +# Use these to implement +# - negList : negate a list +# - addLists : add two lists together +# - sum: sum lists +# - prod: take the product of lists -def map(fn: Callable[[float], float]) -> Callable[[Iterable[float]], Iterable[float]]: - """ - Higher-order map. - - See https://en.wikipedia.org/wiki/Map_(higher-order_function) - - Args: - fn: Function from one value to one value. - - Returns: - A function that takes a list, applies `fn` to each element, and returns a - new list - """ - raise NotImplementedError("Need to include this file from past assignment.") - - -def negList(ls: Iterable[float]) -> Iterable[float]: - "Use `map` and `neg` to negate each element in `ls`" - raise NotImplementedError("Need to include this file from past assignment.") - - -def zipWith( - fn: Callable[[float, float], float] -) -> Callable[[Iterable[float], Iterable[float]], Iterable[float]]: - """ - Higher-order zipwith (or map2). - See https://en.wikipedia.org/wiki/Map_(higher-order_function) +T = TypeVar("T") +U = TypeVar("U") - Args: - fn: combine two values - Returns: - Function that takes two equally sized lists `ls1` and `ls2`, produce a new list by - applying fn(x, y) on each pair of elements. +def map(fn: Callable[[T], U], it: Iterable[T]) -> List[U]: + out: List[U] = [] + for v in it: + out.append(fn(v)) + return out - """ - raise NotImplementedError("Need to include this file from past assignment.") +def zipWith(fn: Callable[[T, U], T], a: Iterable[T], b: Iterable[U]) -> List[T]: + out: List[T] = [] + ia = iter(a) + ib = iter(b) + while True: + try: + va = next(ia) + vb = next(ib) + except StopIteration: + break + out.append(fn(va, vb)) + return out -def addLists(ls1: Iterable[float], ls2: Iterable[float]) -> Iterable[float]: - "Add the elements of `ls1` and `ls2` using `zipWith` and `add`" - raise NotImplementedError("Need to include this file from past assignment.") +def reduce(fn: Callable[[T, T], T], it: Iterable[T], start: T) -> T: + acc: T = start + for v in it: + acc = fn(acc, v) + return acc -def reduce( - fn: Callable[[float, float], float], start: float -) -> Callable[[Iterable[float]], float]: - r""" - Higher-order reduce. +def negList(ls: Iterable[float]) -> List[float]: + return map(neg, ls) - Args: - fn: combine two values - start: start value $x_0$ - Returns: - Function that takes a list `ls` of elements - $x_1 \ldots x_n$ and computes the reduction :math:`fn(x_3, fn(x_2, - fn(x_1, x_0)))` - """ - raise NotImplementedError("Need to include this file from past assignment.") +def addLists(a: Iterable[float], b: Iterable[float]) -> List[float]: + return zipWith(add, a, b) def sum(ls: Iterable[float]) -> float: - "Sum up a list using `reduce` and `add`." - raise NotImplementedError("Need to include this file from past assignment.") + return reduce(add, ls, 0.0) def prod(ls: Iterable[float]) -> float: - "Product of a list using `reduce` and `mul`." - raise NotImplementedError("Need to include this file from past assignment.") + return reduce(mul, ls, 1.0) diff --git a/minitorch/scalar.py b/minitorch/scalar.py index f5abbe9e..6e4b1ec6 100644 --- a/minitorch/scalar.py +++ b/minitorch/scalar.py @@ -92,31 +92,25 @@ def __rtruediv__(self, b: ScalarLike) -> Scalar: return Mul.apply(b, Inv.apply(self)) def __add__(self, b: ScalarLike) -> Scalar: - # TODO: Implement for Task 1.2. - raise NotImplementedError("Need to implement for Task 1.2") + return Add.apply(self, b) def __bool__(self) -> bool: return bool(self.data) def __lt__(self, b: ScalarLike) -> Scalar: - # TODO: Implement for Task 1.2. - raise NotImplementedError("Need to implement for Task 1.2") + return LT.apply(self, b) def __gt__(self, b: ScalarLike) -> Scalar: - # TODO: Implement for Task 1.2. - raise NotImplementedError("Need to implement for Task 1.2") + return LT.apply(b, self) def __eq__(self, b: ScalarLike) -> Scalar: # type: ignore[override] - # TODO: Implement for Task 1.2. - raise NotImplementedError("Need to implement for Task 1.2") + return EQ.apply(self, b) def __sub__(self, b: ScalarLike) -> Scalar: - # TODO: Implement for Task 1.2. - raise NotImplementedError("Need to implement for Task 1.2") + return Add.apply(self, Neg.apply(b)) def __neg__(self) -> Scalar: - # TODO: Implement for Task 1.2. - raise NotImplementedError("Need to implement for Task 1.2") + return Neg.apply(self) def __radd__(self, b: ScalarLike) -> Scalar: return self + b @@ -125,20 +119,16 @@ def __rmul__(self, b: ScalarLike) -> Scalar: return self * b def log(self) -> Scalar: - # TODO: Implement for Task 1.2. - raise NotImplementedError("Need to implement for Task 1.2") + return Log.apply(self) def exp(self) -> Scalar: - # TODO: Implement for Task 1.2. - raise NotImplementedError("Need to implement for Task 1.2") + return Exp.apply(self) def sigmoid(self) -> Scalar: - # TODO: Implement for Task 1.2. - raise NotImplementedError("Need to implement for Task 1.2") + return Sigmoid.apply(self) def relu(self) -> Scalar: - # TODO: Implement for Task 1.2. - raise NotImplementedError("Need to implement for Task 1.2") + return ReLU.apply(self) # Variable elements for backprop @@ -173,8 +163,13 @@ def chain_rule(self, d_output: Any) -> Iterable[Tuple[Variable, Any]]: assert h.last_fn is not None assert h.ctx is not None - # TODO: Implement for Task 1.3. - raise NotImplementedError("Need to implement for Task 1.3") + local_grads = h.last_fn._backward(h.ctx, d_output) + + out: list[tuple[Variable, Any]] = [] + for inp, g in zip(h.inputs, local_grads): + if not inp.is_constant(): + out.append((inp, g)) + return out def backward(self, d_output: Optional[float] = None) -> None: """ diff --git a/minitorch/scalar_functions.py b/minitorch/scalar_functions.py index d8d2307b..92ee9ea2 100644 --- a/minitorch/scalar_functions.py +++ b/minitorch/scalar_functions.py @@ -103,13 +103,13 @@ class Mul(ScalarFunction): @staticmethod def forward(ctx: Context, a: float, b: float) -> float: - # TODO: Implement for Task 1.2. - raise NotImplementedError("Need to implement for Task 1.2") + ctx.save_for_backward(a, b) + return float(operators.mul(a, b)) @staticmethod def backward(ctx: Context, d_output: float) -> Tuple[float, float]: - # TODO: Implement for Task 1.4. - raise NotImplementedError("Need to implement for Task 1.4") + a, b = ctx.saved_values + return float(d_output * b), float(d_output * a) class Inv(ScalarFunction): @@ -117,13 +117,13 @@ class Inv(ScalarFunction): @staticmethod def forward(ctx: Context, a: float) -> float: - # TODO: Implement for Task 1.2. - raise NotImplementedError("Need to implement for Task 1.2") + ctx.save_for_backward(a) + return float(operators.inv(a)) @staticmethod def backward(ctx: Context, d_output: float) -> float: - # TODO: Implement for Task 1.4. - raise NotImplementedError("Need to implement for Task 1.4") + (a,) = ctx.saved_values + return float(operators.inv_back(a, d_output)) class Neg(ScalarFunction): @@ -131,13 +131,11 @@ class Neg(ScalarFunction): @staticmethod def forward(ctx: Context, a: float) -> float: - # TODO: Implement for Task 1.2. - raise NotImplementedError("Need to implement for Task 1.2") + return float(operators.neg(a)) @staticmethod def backward(ctx: Context, d_output: float) -> float: - # TODO: Implement for Task 1.4. - raise NotImplementedError("Need to implement for Task 1.4") + return float(-d_output) class Sigmoid(ScalarFunction): @@ -145,13 +143,14 @@ class Sigmoid(ScalarFunction): @staticmethod def forward(ctx: Context, a: float) -> float: - # TODO: Implement for Task 1.2. - raise NotImplementedError("Need to implement for Task 1.2") + s = operators.sigmoid(a) + ctx.save_for_backward(s) + return float(s) @staticmethod def backward(ctx: Context, d_output: float) -> float: - # TODO: Implement for Task 1.4. - raise NotImplementedError("Need to implement for Task 1.4") + (s,) = ctx.saved_values + return float(d_output * s * (1.0 - s)) class ReLU(ScalarFunction): @@ -159,13 +158,13 @@ class ReLU(ScalarFunction): @staticmethod def forward(ctx: Context, a: float) -> float: - # TODO: Implement for Task 1.2. - raise NotImplementedError("Need to implement for Task 1.2") + ctx.save_for_backward(a) + return float(operators.relu(a)) @staticmethod def backward(ctx: Context, d_output: float) -> float: - # TODO: Implement for Task 1.4. - raise NotImplementedError("Need to implement for Task 1.4") + (a,) = ctx.saved_values + return float(operators.relu_back(a, d_output)) class Exp(ScalarFunction): @@ -173,13 +172,14 @@ class Exp(ScalarFunction): @staticmethod def forward(ctx: Context, a: float) -> float: - # TODO: Implement for Task 1.2. - raise NotImplementedError("Need to implement for Task 1.2") + e = operators.exp(a) + ctx.save_for_backward(e) + return float(e) @staticmethod def backward(ctx: Context, d_output: float) -> float: - # TODO: Implement for Task 1.4. - raise NotImplementedError("Need to implement for Task 1.4") + (e,) = ctx.saved_values + return float(d_output * e) class LT(ScalarFunction): @@ -187,13 +187,11 @@ class LT(ScalarFunction): @staticmethod def forward(ctx: Context, a: float, b: float) -> float: - # TODO: Implement for Task 1.2. - raise NotImplementedError("Need to implement for Task 1.2") + return float(operators.lt(a, b)) @staticmethod def backward(ctx: Context, d_output: float) -> Tuple[float, float]: - # TODO: Implement for Task 1.4. - raise NotImplementedError("Need to implement for Task 1.4") + return 0.0, 0.0 class EQ(ScalarFunction): @@ -201,10 +199,8 @@ class EQ(ScalarFunction): @staticmethod def forward(ctx: Context, a: float, b: float) -> float: - # TODO: Implement for Task 1.2. - raise NotImplementedError("Need to implement for Task 1.2") + return float(operators.eq(a, b)) @staticmethod def backward(ctx: Context, d_output: float) -> Tuple[float, float]: - # TODO: Implement for Task 1.4. - raise NotImplementedError("Need to implement for Task 1.4") + return 0.0, 0.0 diff --git a/project/run_scalar.py b/project/run_scalar.py index 7ce5207b..26703cb0 100644 --- a/project/run_scalar.py +++ b/project/run_scalar.py @@ -10,8 +10,9 @@ class Network(minitorch.Module): def __init__(self, hidden_layers): super().__init__() - # TODO: Implement for Task 1.5. - raise NotImplementedError("Need to implement for Task 1.5") + self.layer1 = Linear(in_size=2, out_size=hidden_layers) + self.layer2 = Linear(in_size=hidden_layers, out_size=hidden_layers) + self.layer3 = Linear(in_size=hidden_layers, out_size=1) def forward(self, x): middle = [h.relu() for h in self.layer1.forward(x)] @@ -40,9 +41,15 @@ def __init__(self, in_size, out_size): ) def forward(self, inputs): - # TODO: Implement for Task 1.5. - raise NotImplementedError("Need to implement for Task 1.5") - + xs = list(inputs) + out = [] + out_size = len(self.bias) + for j in range(out_size): + s = self.bias[j].value + for i, x in enumerate(xs): + s = s + x * self.weights[i][j].value + out.append(s) + return out def default_log_fn(epoch, total_loss, correct, losses): print("Epoch ", epoch, " loss ", total_loss, "correct", correct) @@ -101,7 +108,7 @@ def train(self, data, learning_rate, max_epochs=500, log_fn=default_log_fn): if __name__ == "__main__": PTS = 50 - HIDDEN = 2 + HIDDEN = 6 RATE = 0.5 - data = minitorch.datasets["Simple"](PTS) - ScalarTrain(HIDDEN).train(data, RATE) + data = minitorch.datasets["Split"](PTS) + ScalarTrain(HIDDEN).train(data, RATE, max_epochs=800)