{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "s4wplUzDYJcx" }, "source": [ "

Многослойная сеть на PyTorch

" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "J2msuyHTYJcx" }, "source": [ "В этом ноутбке мы научимся писать свои нейросети на фреймворке PyTorch, конкретно - рассмотрим, как написать многослойную полносвязную сеть (Fully-Connected, FC)." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "9xJnMEZrYJcz" }, "source": [ "

Компоненты нейросети

" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "InwacmvIYJc0" }, "source": [ "Здесь самое время напомнить о том, какие вещи играют принципиальную роль в построении любой ***нейронной сети*** (все их мы задаём *руками*, самостоятельно): \n", "\n", "- непосредственно, сама **архитектура** нейросети (сюда входят типы функций активации у каждого нейрона);\n", "- начальная **инициализация** весов каждого слоя;\n", "- метод **оптимизации** нейросети (сюда ещё входит метод изменения `learning_rate`);\n", "- размер **батчей** (`batch_size`);\n", "- количество итераций обучения (`num_epochs`);\n", "- **функция потерь** (`loss`); \n", "- тип **регуляризации** нейросети (для каждого слоя можно свой); \n", "\n", "То, что связано с ***данными и задачей***: \n", "- само **качество** выборки (непротиворечивость, чистота, корректность постановки задачи); \n", "- **размер** выборки; " ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "tXujEOB0YJc1" }, "source": [ "

Многослойная нейронная сеть

" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "mnxH-DajYJc3" }, "source": [ "Как можно понять из названия, многослойная нейросеть состоит из нескольких **слоёв**. Каждый слой состоит из **нейронов**. Ранее мы уже писали свой нейрон на NumPy, вот из таких нейронов и состоит ***MLP (Multi-Layer Perceptron)***. Ещё такую многослойную нейросеть, у которой каждый нейрон на предыдущем уровне соединён с нейроном на следующем уровне, называют ***Fully-Connected-сетью*** (или ***Dense-сетью***). \n", "\n", "Расмотрим их устройство более подробно:" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "onjJUneMYJc5" }, "source": [ "* Вот так выглядит двухслойная нейросеть (первый слой - input layer - не считается, потому что это, по сути, не слой):" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "owRRulLzYJc6" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "tFNxGGBEYJc8" }, "source": [ "* Так выглядит трёхслойная нейросеть:" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "zRaKX35eYJc9" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "6w1FTkO1YJc-" }, "source": [ ".. и так далее для большего случая слоёв." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "8iKV7m5YYJc_" }, "source": [ "**Обратите внимание:** связи есть у нейронов со слоя $L_{i-1}$ и нейронов $L_{i}$, но между нейронами в одном слое связей **нет**." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "URV9qWkfYJdA" }, "source": [ "**Входной слой** -- это данные (матрица $(n, m)$)." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "zK8tWuHHYJdB" }, "source": [ "Слои, которые не являются входными или выходными, называются **скрытыми слоями (hidden layers)**." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "fz9clUlCYJdC" }, "source": [ "При решении ***задачи регрессии*** на **выходном слое** обычно один нейрон, который возвращает предсказанные числа (для каждого объекта по числу). \n", "\n", "В случае ***задачи классификации*** на **выходном слое** обычно один нейрон, если задача бинарной классификации, и $K$ нейронов, если задача $K$-класовой классификации." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "nJblXqY5YJdE" }, "source": [ "#### Forward pass в MLP" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "D87xoAl8YJdF" }, "source": [ "Каждый слой многослойной нейросети - это матрица весов, столбцы которой -- это нейроны (один столбец - один нейрон). То есть один столбец -- это веса одного нейрона." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "collapsed": true, "id": "RyXUqCfVYJdG" }, "source": [ "Допустим, мы решаем задачу $K$-классовой классификации (на последнем слое $K$ нейронов). Рассмотрим, как в таком случае выглядит `forward_pass` нейросети:" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "YO6gHbOjYJdH" }, "source": [ "* Вход: $$X =\n", "\\left(\n", "\\begin{matrix} \n", "x_{11} & ... & x_{1M} \\\\\n", "... & \\ddots & ...\\\\\n", "x_{N1} & ... & x_{NM} \n", "\\end{matrix}\n", "\\right)\n", "$$\n", "\n", "-- матрица $(N, M)$" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "collapsed": true, "id": "XdImZiQkYJdI" }, "source": [ "* Структура сети - много слоёв, в слоях много нейронов. Первый слой (после входного) выглядит так:" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "U2c2M4MJYJdJ" }, "source": [ "$$ W^1 =\n", "\\left(\n", "\\begin{matrix} \n", "w_{11} & ... & w_{1L_1} \\\\\n", "... & \\ddots & ...\\\\\n", "w_{M1} & ... & w_{ML_1} \n", "\\end{matrix}\n", "\\right)\n", "$$\n", "\n", "-- матрица $(M, L_1)$" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "UUCdeLN0YJdK" }, "source": [ "То есть это в точности $L_1$ нейронов, каждый имеет свои собственные веса, их $M$ штук." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "ixHtlMwKYJdL" }, "source": [ "Мы помним, что нейрон - это линейное преобразование и потом нелинейная функция активации от этого преобразования. Однако в многослойных нейростеях часто отделяют `Linear` часть и `Activation`, то есть слоем считаем набор весов нейронов, а следующий слой всегда функция активации (у всех нейронов из слоя она одна и та же, обычно фреймворки не позволяют задавать конкретному нейрону в слое отличную от других нейронов в этом слое функцию активации, однако это легко сделать, объявив слой из одного нейрона)." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "AUt1NgTvYJdN" }, "source": [ "* Другие слои выглядит точно так же, как первый слой. Например, второй слой будет такой:" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "IdtSTvTmYJdN" }, "source": [ "$$ W^2 =\n", "\\left(\n", "\\begin{matrix} \n", "w_{11} & ... & w_{1L_2} \\\\\n", "... & \\ddots & ...\\\\\n", "w_{L_11} & ... & w_{L_1L_2} \n", "\\end{matrix}\n", "\\right)\n", "$$\n", "\n", "-- матрица $(L_1, L_2)$" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "R3nGnHoLYJdP" }, "source": [ "То есть это в точности $L_2$ нейронов, каждый имеет свои собственные веса, их $L_1$ штук." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "bQhOAQjSYJdR" }, "source": [ "* Выходной слой: \n", "\n", "Пусть в нейросети до выходного слоя идут $t$ слоёв. Тогда выходной слой имеет форму:" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "fWvqm-K0YJdT" }, "source": [ "$$ W^{out} =\n", "\\left(\n", "\\begin{matrix} \n", "w_{11} & ... & w_{1K} \\\\\n", "... & \\ddots & ...\\\\\n", "w_{L_t1} & ... & w_{L_tK} \n", "\\end{matrix}\n", "\\right)\n", "$$\n", "\n", "-- матрица $(L_t, K)$, где $L_t$ - количество нейронов в $t$-ом слое, а $K$ -- количество классов." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "2z5tO89NYJdU" }, "source": [ "В итоге *для `forward_pass` нам нужно просто последовтельно перемножить матрицы друг за другом, применяя после каждого умножения соответсвующую функцию активации*." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "IZT4GgsCYJdV" }, "source": [ "*Примечание*: можно думать об умножении на очередную матрицу весов как на переход в **новое признаковое пространство**. Действительно, когда подаём матрицу $X$ и умножаем на матрицу первого слоя, мы получаем матрицу размера $(N, L_1)$, то есть как будто $L_1$ \"новых\" признаков (построенных как линейная комбинация старых до применения функции активации, и уже как нелинейная комбинация после активации). Здесь уместно вспомнить, что Deep Learning является пообластью Representation Learning, то есть позволяет выучивает новые представляения данных." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "4RhJ4fsHYJdW" }, "source": [ "**Backward pass в MLP**" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "MYN043DbYJdX" }, "source": [ "Обучается с помощью метода \"Error Backpropagation\" - [\"Обратное распространение ошибки\"](https://ru.wikipedia.org/wiki/%D0%9C%D0%B5%D1%82%D0%BE%D0%B4_%D0%BE%D0%B1%D1%80%D0%B0%D1%82%D0%BD%D0%BE%D0%B3%D0%BE_%D1%80%D0%B0%D1%81%D0%BF%D1%80%D0%BE%D1%81%D1%82%D1%80%D0%B0%D0%BD%D0%B5%D0%BD%D0%B8%D1%8F_%D0%BE%D1%88%D0%B8%D0%B1%D0%BA%D0%B8), принцип распространения очень похож на то, как мы обучали один нейрон - это градиентный спуск, но по \"всей нейросети\" сразу. " ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "collapsed": true, "id": "oK7Vi4bxYJdZ" }, "source": [ "Backpropagation работает корректно благодаря ***chain rule*** (=правилу взятия производной сложной функции): \n", "\n", "Если $f(x) = f(g(x))$, то: \n", "\n", "$$\\frac{\\partial{f}}{\\partial{x}} = \\frac{\\partial{f}}{\\partial{g}} \\frac{\\partial{g}}{\\partial{x}}$$" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "collapsed": true, "id": "WKvMsaEBYJda" }, "source": [ "Более подробно про backpropagation можно прочитать здесь (на английском): https://mattmazur.com/2015/03/17/a-step-by-step-backpropagation-example/" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "z5dyjPVNYJdc" }, "source": [ "

Многослойная нейросеть на PyTorch

" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "T9ufk3ECYJde" }, "source": [ "Ешё раз напомним про основные компоненты нейросети:\n", "\n", "- непосредственно, сама **архитектура** нейросети (сюда входят типы функций активации у каждого нейрона);\n", "- начальная **инициализация** весов каждого слоя;\n", "- метод **оптимизации** нейросети (сюда ещё входит метод изменения `learning_rate`);\n", "- размер **батчей** (`batch_size`);\n", "- количетсво **эпох** обучения (`num_epochs`);\n", "- **функция потерь** (`loss`); \n", "- тип **регуляризации** нейросети (для каждого слоя можно свой); \n", "\n", "То, что связано с ***данными и задачей***: \n", "- само **качество** выборки (непротиворечивость, чистота, корректность постановки задачи); \n", "- **размер** выборки; " ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "9KrWarqTYJdf" }, "source": [ "Cоздадим двухслойную нейросеть из 100 нейронов:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": {}, "colab_type": "code", "id": "bLjkPg19YJdg" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "import torch" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "iCDVRvQJYJdl" }, "source": [ "Генерация и отрисовка датасета:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": {}, "colab_type": "code", "id": "k0J27RcLYJdm" }, "outputs": [], "source": [ "N = 100\n", "D = 2\n", "K = 3\n", "X = np.zeros((N * K, D))\n", "y = np.zeros(N * K, dtype='uint8')\n", "\n", "for j in range(K):\n", " ix = range(N * j,N * (j + 1))\n", " r = np.linspace(0.0, 1, N)\n", " t = np.linspace(j * 4, (j + 1) * 4,N) + np.random.randn(N) * 0.2 # theta\n", " X[ix] = np.c_[r * np.sin(t), r * np.cos(t)]\n", " y[ix] = j" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "X9FHhqX_YJdp" }, "source": [ "Не забываем оборачивать данные (без этого градиенты не посчитать):" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": {}, "colab_type": "code", "id": "pQINaQqZYJdq" }, "outputs": [], "source": [ "X = torch.autograd.Variable(torch.FloatTensor(X))\n", "y = torch.autograd.Variable(torch.LongTensor(y.astype(np.int64)))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": {}, "colab_type": "code", "id": "Who9mS8oYJdu", "outputId": "1ec6e30a-2cd9-4bd6-d0e0-3b7c936d1b8a" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([300, 2]) torch.Size([300])\n" ] } ], "source": [ "print(X.data.shape, y.data.shape)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "I-9dFW5CYJd0" }, "source": [ "Сама ячейка с нейросетью и обучением:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": {}, "colab_type": "code", "id": "kf-YapleYJd1", "outputId": "d71315ae-670a-48e3-bcf7-b425ebe198fe" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\users\\izakharkin\\anaconda3\\envs\\vision\\lib\\site-packages\\torch\\nn\\functional.py:52: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n", " warnings.warn(warning.format(ret))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "0 330.8737487792969\n", "1 328.732177734375\n", "2 326.73291015625\n", "3 324.8538513183594\n", "4 323.0766296386719\n", "5 321.386962890625\n", "6 319.77215576171875\n", "7 318.22314453125\n", "8 316.7312316894531\n", "9 315.2906799316406\n", "10 313.89508056640625\n", "11 312.5408935546875\n", "12 311.2237854003906\n", "13 309.9412536621094\n", "14 308.69000244140625\n", "15 307.4679260253906\n", "16 306.2733154296875\n", "17 305.10418701171875\n", "18 303.9598083496094\n", "19 302.8381652832031\n", "20 301.7386779785156\n", "21 300.6603088378906\n", "22 299.60205078125\n", "23 298.5636901855469\n", "24 297.5439758300781\n", "25 296.5423583984375\n", "26 295.5587158203125\n", "27 294.5920104980469\n", "28 293.6417541503906\n", "29 292.70819091796875\n", "30 291.79022216796875\n", "31 290.88787841796875\n", "32 290.00054931640625\n", "33 289.1278991699219\n", "34 288.2695007324219\n", "35 287.4254150390625\n", "36 286.59490966796875\n", "37 285.7780456542969\n", "38 284.9742431640625\n", "39 284.18316650390625\n", "40 283.4047546386719\n", "41 282.638671875\n", "42 281.8844909667969\n", "43 281.1424255371094\n", "44 280.4117736816406\n", "45 279.6929016113281\n", "46 278.98492431640625\n", "47 278.2878723144531\n", "48 277.6017150878906\n", "49 276.9258117675781\n", "50 276.2603759765625\n", "51 275.60498046875\n", "52 274.9594421386719\n", "53 274.3237609863281\n", "54 273.69744873046875\n", "55 273.0805358886719\n", "56 272.4727783203125\n", "57 271.8741149902344\n", "58 271.28436279296875\n", "59 270.7031555175781\n", "60 270.13055419921875\n", "61 269.5662536621094\n", "62 269.0101623535156\n", "63 268.46209716796875\n", "64 267.9218444824219\n", "65 267.3895568847656\n", "66 266.8648681640625\n", "67 266.3474426269531\n", "68 265.83758544921875\n", "69 265.3347473144531\n", "70 264.8389892578125\n", "71 264.35028076171875\n", "72 263.8682556152344\n", "73 263.3929443359375\n", "74 262.9244689941406\n", "75 262.462158203125\n", "76 262.0062255859375\n", "77 261.5567321777344\n", "78 261.11297607421875\n", "79 260.6752624511719\n", "80 260.24371337890625\n", "81 259.8177795410156\n", "82 259.39764404296875\n", "83 258.9831237792969\n", "84 258.5740051269531\n", "85 258.17047119140625\n", "86 257.77239990234375\n", "87 257.3795166015625\n", "88 256.9918518066406\n", "89 256.6094055175781\n", "90 256.2320861816406\n", "91 255.85975646972656\n", "92 255.4922332763672\n", "93 255.12928771972656\n", "94 254.7711181640625\n", "95 254.41737365722656\n", "96 254.0686492919922\n", "97 253.7240753173828\n", "98 253.38406372070312\n", "99 253.048095703125\n", "100 252.71652221679688\n", "101 252.38906860351562\n", "102 252.06591796875\n", "103 251.74659729003906\n", "104 251.43118286132812\n", "105 251.1195526123047\n", "106 250.81195068359375\n", "107 250.50804138183594\n", "108 250.20799255371094\n", "109 249.91146850585938\n", "110 249.61854553222656\n", "111 249.32940673828125\n", "112 249.04347229003906\n", "113 248.76107788085938\n", "114 248.4823760986328\n", "115 248.2066192626953\n", "116 247.93441772460938\n", "117 247.66543579101562\n", "118 247.3993377685547\n", "119 247.13636779785156\n", "120 246.8765106201172\n", "121 246.61959838867188\n", "122 246.36578369140625\n", "123 246.11471557617188\n", "124 245.86680603027344\n", "125 245.62159729003906\n", "126 245.3793182373047\n", "127 245.1397705078125\n", "128 244.90293884277344\n", "129 244.668701171875\n", "130 244.4371795654297\n", "131 244.2081298828125\n", "132 243.9817657470703\n", "133 243.7580108642578\n", "134 243.53665161132812\n", "135 243.31785583496094\n", "136 243.10150146484375\n", "137 242.8875274658203\n", "138 242.67599487304688\n", "139 242.46681213378906\n", "140 242.2598876953125\n", "141 242.05548095703125\n", "142 241.85304260253906\n", "143 241.65306091308594\n", "144 241.455322265625\n", "145 241.25962829589844\n", "146 241.06610107421875\n", "147 240.87461853027344\n", "148 240.68505859375\n", "149 240.497802734375\n", "150 240.3124542236328\n", "151 240.12908935546875\n", "152 239.94769287109375\n", "153 239.76808166503906\n", "154 239.5904541015625\n", "155 239.4147491455078\n", "156 239.24102783203125\n", "157 239.06932067871094\n", "158 238.8993377685547\n", "159 238.73109436035156\n", "160 238.56472778320312\n", "161 238.39988708496094\n", "162 238.23672485351562\n", "163 238.07516479492188\n", "164 237.91534423828125\n", "165 237.7571258544922\n", "166 237.60055541992188\n", "167 237.4454345703125\n", "168 237.2916259765625\n", "169 237.1393585205078\n", "170 236.9887237548828\n", "171 236.8394775390625\n", "172 236.69174194335938\n", "173 236.5454864501953\n", "174 236.40061950683594\n", "175 236.25697326660156\n", "176 236.1148223876953\n", "177 235.97398376464844\n", "178 235.83470153808594\n", "179 235.69656372070312\n", "180 235.55975341796875\n", "181 235.4245147705078\n", "182 235.29046630859375\n", "183 235.1576690673828\n", "184 235.02626037597656\n", "185 234.8958282470703\n", "186 234.7667694091797\n", "187 234.6388702392578\n", "188 234.51202392578125\n", "189 234.3862762451172\n", "190 234.26170349121094\n", "191 234.13832092285156\n", "192 234.0160675048828\n", "193 233.89483642578125\n", "194 233.7747344970703\n", "195 233.6556396484375\n", "196 233.53765869140625\n", "197 233.4207763671875\n", "198 233.30471801757812\n", "199 233.1900177001953\n", "200 233.07615661621094\n", "201 232.9632110595703\n", "202 232.85122680664062\n", "203 232.74021911621094\n", "204 232.63035583496094\n", "205 232.5215606689453\n", "206 232.4134521484375\n", "207 232.306396484375\n", "208 232.20028686523438\n", "209 232.09500122070312\n", "210 231.9905242919922\n", "211 231.8870391845703\n", "212 231.78427124023438\n", "213 231.68247985839844\n", "214 231.58154296875\n", "215 231.48130798339844\n", "216 231.38198852539062\n", "217 231.28353881835938\n", "218 231.18572998046875\n", "219 231.08877563476562\n", "220 230.99240112304688\n", "221 230.8970489501953\n", "222 230.80233764648438\n", "223 230.7084503173828\n", "224 230.6151885986328\n", "225 230.5227508544922\n", "226 230.43099975585938\n", "227 230.33998107910156\n", "228 230.24993896484375\n", "229 230.16046142578125\n", "230 230.07156372070312\n", "231 229.98330688476562\n", "232 229.89585876464844\n", "233 229.80921936035156\n", "234 229.72317504882812\n", "235 229.63787841796875\n", "236 229.55319213867188\n", "237 229.46926879882812\n", "238 229.38583374023438\n", "239 229.30311584472656\n", "240 229.22097778320312\n", "241 229.13951110839844\n", "242 229.05859375\n", "243 228.97821044921875\n", "244 228.89834594726562\n", "245 228.81919860839844\n", "246 228.7405242919922\n", "247 228.66250610351562\n", "248 228.5851287841797\n", "249 228.50808715820312\n", "250 228.43177795410156\n", "251 228.35577392578125\n", "252 228.2805633544922\n", "253 228.20558166503906\n", "254 228.1313934326172\n", "255 228.05772399902344\n", "256 227.9844512939453\n", "257 227.9118194580078\n", "258 227.83944702148438\n", "259 227.76771545410156\n", "260 227.6964874267578\n", "261 227.62564086914062\n", "262 227.55517578125\n", "263 227.48536682128906\n", "264 227.41592407226562\n", "265 227.34690856933594\n", "266 227.27841186523438\n", "267 227.21031188964844\n", "268 227.1426239013672\n", "269 227.07553100585938\n", "270 227.00869750976562\n", "271 226.9425048828125\n", "272 226.8767852783203\n", "273 226.81141662597656\n", "274 226.74658203125\n", "275 226.6818389892578\n", "276 226.61788940429688\n", "277 226.55419921875\n", "278 226.4910430908203\n", "279 226.4283447265625\n", "280 226.36598205566406\n", "281 226.30397033691406\n", "282 226.2423553466797\n", "283 226.1811981201172\n", "284 226.12063598632812\n", "285 226.060302734375\n", "286 226.00042724609375\n", "287 225.94081115722656\n", "288 225.88162231445312\n", "289 225.82281494140625\n", "290 225.76438903808594\n", "291 225.7063446044922\n", "292 225.64842224121094\n", "293 225.59112548828125\n", "294 225.5340576171875\n", "295 225.47738647460938\n", "296 225.42083740234375\n", "297 225.36495971679688\n", "298 225.30914306640625\n", "299 225.25381469726562\n", "300 225.19866943359375\n", "301 225.1439666748047\n", "302 225.08958435058594\n", "303 225.03536987304688\n", "304 224.98138427734375\n", "305 224.9280242919922\n", "306 224.874755859375\n", "307 224.8217315673828\n", "308 224.7690887451172\n", "309 224.7168426513672\n", "310 224.66493225097656\n", "311 224.6132354736328\n", "312 224.5619354248047\n", "313 224.5107879638672\n", "314 224.4600372314453\n", "315 224.4093017578125\n", "316 224.35906982421875\n", "317 224.30905151367188\n", "318 224.25927734375\n", "319 224.20980834960938\n", "320 224.16049194335938\n", "321 224.11146545410156\n", "322 224.0626983642578\n", "323 224.0141143798828\n", "324 223.9657440185547\n", "325 223.91778564453125\n", "326 223.87010192871094\n", "327 223.82252502441406\n", "328 223.77523803710938\n", "329 223.72813415527344\n", "330 223.6814727783203\n", "331 223.63462829589844\n", "332 223.58831787109375\n", "333 223.5421600341797\n", "334 223.4960174560547\n", "335 223.45028686523438\n", "336 223.40487670898438\n", "337 223.35946655273438\n", "338 223.31427001953125\n", "339 223.26914978027344\n", "340 223.22439575195312\n", "341 223.1798553466797\n", "342 223.13540649414062\n", "343 223.09117126464844\n", "344 223.0472412109375\n", "345 223.00338745117188\n", "346 222.9597930908203\n", "347 222.91632080078125\n", "348 222.8732147216797\n", "349 222.8301239013672\n", "350 222.7873077392578\n", "351 222.7447509765625\n", "352 222.70228576660156\n", "353 222.659912109375\n", "354 222.61798095703125\n", "355 222.5760498046875\n", "356 222.53427124023438\n", "357 222.49267578125\n", "358 222.45132446289062\n", "359 222.41017150878906\n", "360 222.3690948486328\n", "361 222.32810974121094\n", "362 222.2873077392578\n", "363 222.24671936035156\n", "364 222.20628356933594\n", "365 222.16615295410156\n", "366 222.12588500976562\n", "367 222.08590698242188\n", "368 222.04608154296875\n", "369 222.00643920898438\n", "370 221.96688842773438\n", "371 221.92747497558594\n", "372 221.88816833496094\n", "373 221.84906005859375\n", "374 221.81005859375\n", "375 221.7711944580078\n", "376 221.73239135742188\n", "377 221.6939697265625\n", "378 221.65554809570312\n", "379 221.61721801757812\n", "380 221.57923889160156\n", "381 221.5411834716797\n", "382 221.50323486328125\n", "383 221.46563720703125\n", "384 221.42796325683594\n", "385 221.39044189453125\n", "386 221.35302734375\n", "387 221.31573486328125\n", "388 221.27880859375\n", "389 221.24176025390625\n", "390 221.20477294921875\n", "391 221.16806030273438\n", "392 221.13133239746094\n", "393 221.09475708007812\n", "394 221.05816650390625\n", "395 221.02174377441406\n", "396 220.98556518554688\n", "397 220.94918823242188\n", "398 220.91311645507812\n", "399 220.87728881835938\n", "400 220.84132385253906\n", "401 220.80552673339844\n", "402 220.76998901367188\n", "403 220.7344512939453\n", "404 220.69898986816406\n", "405 220.66366577148438\n", "406 220.6285400390625\n", "407 220.59347534179688\n", "408 220.55848693847656\n", "409 220.52359008789062\n", "410 220.4888153076172\n", "411 220.45404052734375\n", "412 220.41934204101562\n", "413 220.38478088378906\n", "414 220.35031127929688\n", "415 220.31590270996094\n", "416 220.28167724609375\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "417 220.2473907470703\n", "418 220.2133331298828\n", "419 220.1792449951172\n", "420 220.14540100097656\n", "421 220.11141967773438\n", "422 220.0778350830078\n", "423 220.04403686523438\n", "424 220.01052856445312\n", "425 219.97705078125\n", "426 219.943603515625\n", "427 219.910400390625\n", "428 219.87709045410156\n", "429 219.84405517578125\n", "430 219.81103515625\n", "431 219.7781982421875\n", "432 219.7453155517578\n", "433 219.71253967285156\n", "434 219.67999267578125\n", "435 219.6472625732422\n", "436 219.6147003173828\n", "437 219.58233642578125\n", "438 219.5499267578125\n", "439 219.51748657226562\n", "440 219.48532104492188\n", "441 219.45327758789062\n", "442 219.4212646484375\n", "443 219.38925170898438\n", "444 219.35740661621094\n", "445 219.3255157470703\n", "446 219.29376220703125\n", "447 219.26210021972656\n", "448 219.2305450439453\n", "449 219.19888305664062\n", "450 219.1672821044922\n", "451 219.13607788085938\n", "452 219.1046905517578\n", "453 219.07330322265625\n", "454 219.04208374023438\n", "455 219.0108184814453\n", "456 218.97976684570312\n", "457 218.94859313964844\n", "458 218.91758728027344\n", "459 218.88671875\n", "460 218.85562133789062\n", "461 218.8248291015625\n", "462 218.7939453125\n", "463 218.7631072998047\n", "464 218.73252868652344\n", "465 218.7017364501953\n", "466 218.67117309570312\n", "467 218.64051818847656\n", "468 218.61012268066406\n", "469 218.57965087890625\n", "470 218.54920959472656\n", "471 218.51881408691406\n", "472 218.48854064941406\n", "473 218.45822143554688\n", "474 218.4281005859375\n", "475 218.39797973632812\n", "476 218.36767578125\n", "477 218.3377685546875\n", "478 218.3079833984375\n", "479 218.2783203125\n", "480 218.2487335205078\n", "481 218.21925354003906\n", "482 218.1898651123047\n", "483 218.1603546142578\n", "484 218.13104248046875\n", "485 218.10177612304688\n", "486 218.07244873046875\n", "487 218.04319763183594\n", "488 218.01397705078125\n", "489 217.9847869873047\n", "490 217.95562744140625\n", "491 217.92665100097656\n", "492 217.89743041992188\n", "493 217.868408203125\n", "494 217.8395233154297\n", "495 217.81057739257812\n", "496 217.78160095214844\n", "497 217.7528533935547\n", "498 217.7239990234375\n", "499 217.6951446533203\n" ] } ], "source": [ "# N - размер батча (batch_size, нужно для метода оптимизации); \n", "# D_in - размерность входа (количество признаков у объекта);\n", "# H - размерность скрытых слоёв; \n", "# D_out - размерность выходного слоя (суть - количество классов)\n", "N, D_in, H, D_out = 64, 2, 100, 3\n", "\n", "# Use the nn package to define our model and loss function.\n", "two_layer_net = torch.nn.Sequential(\n", " torch.nn.Linear(D_in, H),\n", " torch.nn.ReLU(),\n", " torch.nn.Linear(H, D_out),\n", ")\n", "\n", "loss_fn = torch.nn.CrossEntropyLoss(size_average=False)\n", "\n", "learning_rate = 1e-4\n", "optimizer = torch.optim.SGD(two_layer_net.parameters(), lr=learning_rate)\n", "for t in range(500):\n", " # forward\n", " y_pred = two_layer_net(X)\n", "\n", " # loss\n", " loss = loss_fn(y_pred, y)\n", " print('{} {}'.format(t, loss.data))\n", "\n", " # зануляем градиенты (чтобы не было остатка с редыдущего шага)\n", " optimizer.zero_grad()\n", "\n", " # backward\n", " loss.backward()\n", "\n", " # обновляем\n", " optimizer.step()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "zTQiN0bcYJd6" }, "source": [ "**Обратите внимание:** несмотря на то, что это задача 3-х классовой классификации и столбец $y$ нужно по-хорошему кодировать OneHotEncoding'ом, мы подали просто столбец из 0, 1 и 2 и всё отработало. Дело в том, что PyTorch сам делает OneHot в таком случае." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "apsWWq17YJd8" }, "source": [ "Проверим, насколько хороша наша сеть из 100 нейронов:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": {}, "colab_type": "code", "id": "X0ICB6Z-YJd-" }, "outputs": [], "source": [ "# Обратно в Numpy для отрисовки\n", "X = X.data.numpy()\n", "y = y.data.numpy()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": {}, "colab_type": "code", "id": "JvjfvgbPYJeB", "outputId": "cf5545ff-3315-45c9-c674-af44f63a346b" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "h = 0.02\n", "x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1\n", "y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1\n", "\n", "xx, yy = np.meshgrid(np.arange(x_min, x_max, h),\n", " np.arange(y_min, y_max, h))\n", "grid_tensor = torch.FloatTensor(np.c_[xx.ravel(), yy.ravel()])\n", "\n", "Z = two_layer_net(torch.autograd.Variable(grid_tensor))\n", "Z = Z.data.numpy()\n", "Z = np.argmax(Z, axis=1)\n", "Z = Z.reshape(xx.shape)\n", "\n", "plt.figure(figsize=(10, 8))\n", "\n", "plt.contourf(xx, yy, Z, cmap=plt.cm.rainbow, alpha=0.3)\n", "plt.scatter(X[:, 0], X[:, 1], c=y, s=40, cmap=plt.cm.rainbow)\n", "\n", "plt.xlim(xx.min(), xx.max())\n", "plt.ylim(yy.min(), yy.max())\n", "\n", "plt.title('Спираль', fontsize=15)\n", "plt.xlabel('$x$', fontsize=14)\n", "plt.ylabel('$y$', fontsize=14)\n", "plt.show();" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "G2wNgoh5YJeG" }, "source": [ "Подберём гиперпараметры получше:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "X = torch.autograd.Variable(torch.FloatTensor(X))\n", "y = torch.autograd.Variable(torch.LongTensor(y.astype(np.int64)))" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\users\\izakharkin\\anaconda3\\envs\\vision\\lib\\site-packages\\torch\\nn\\functional.py:52: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n", " warnings.warn(warning.format(ret))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "0 329.6793212890625\n", "1 327.8143615722656\n", "2 326.1004943847656\n", "3 324.5147399902344\n", "4 322.9456787109375\n", "5 321.3118896484375\n", "6 319.5968933105469\n", "7 317.75592041015625\n", "8 315.75390625\n", "9 313.5726318359375\n", "10 311.1824645996094\n", "11 308.5584411621094\n", "12 305.68798828125\n", "13 302.5540771484375\n", "14 299.15460205078125\n", "15 295.4794006347656\n", "16 291.5202331542969\n", "17 287.2996826171875\n", "18 282.85137939453125\n", "19 278.21435546875\n", "20 273.4761657714844\n", "21 268.6789245605469\n", "22 263.89453125\n", "23 259.163818359375\n", "24 254.56350708007812\n", "25 250.1515655517578\n", "26 245.91949462890625\n", "27 241.89158630371094\n", "28 238.07766723632812\n", "29 234.496826171875\n", "30 231.16111755371094\n", "31 228.04730224609375\n", "32 225.1465301513672\n", "33 222.45156860351562\n", "34 219.9502716064453\n", "35 217.6254119873047\n", "36 215.45823669433594\n", "37 213.44068908691406\n", "38 211.57652282714844\n", "39 209.91677856445312\n", "40 208.7138671875\n", "41 208.98919677734375\n", "42 214.6230010986328\n", "43 240.55369567871094\n", "44 314.333984375\n", "45 303.1317443847656\n", "46 260.13043212890625\n", "47 220.28932189941406\n", "48 214.5643768310547\n", "49 212.7597198486328\n", "50 213.45379638671875\n", "51 222.7734375\n", "52 235.19393920898438\n", "53 258.9502868652344\n", "54 235.3881072998047\n", "55 222.33709716796875\n", "56 205.35061645507812\n", "57 201.09136962890625\n", "58 197.60365295410156\n", "59 196.84959411621094\n", "60 196.8997802734375\n", "61 203.0827178955078\n", "62 210.37013244628906\n", "63 233.3351287841797\n", "64 223.5889129638672\n", "65 227.52284240722656\n", "66 199.73655700683594\n", "67 195.20997619628906\n", "68 192.6437225341797\n", "69 197.1069793701172\n", "70 223.88955688476562\n", "71 250.8140106201172\n", "72 302.1340026855469\n", "73 234.83021545410156\n", "74 192.74407958984375\n", "75 184.17343139648438\n", "76 178.7435760498047\n", "77 175.32286071777344\n", "78 173.30526733398438\n", "79 173.79928588867188\n", "80 178.83067321777344\n", "81 195.47805786132812\n", "82 224.01499938964844\n", "83 254.59860229492188\n", "84 197.6997528076172\n", "85 184.15213012695312\n", "86 178.74087524414062\n", "87 177.8366241455078\n", "88 196.80772399902344\n", "89 216.80592346191406\n", "90 267.1333312988281\n", "91 198.99403381347656\n", "92 174.82217407226562\n", "93 172.68698120117188\n", "94 166.46255493164062\n", "95 171.16470336914062\n", "96 170.53395080566406\n", "97 183.16978454589844\n", "98 180.73060607910156\n", "99 191.34986877441406\n", "100 179.35696411132812\n", "101 175.2283172607422\n", "102 172.541748046875\n", "103 171.90914916992188\n", "104 184.98558044433594\n", "105 205.13856506347656\n", "106 205.8452911376953\n", "107 224.54994201660156\n", "108 158.3968505859375\n", "109 156.24256896972656\n", "110 147.74032592773438\n", "111 150.6797637939453\n", "112 150.395751953125\n", "113 159.10755920410156\n", "114 170.22654724121094\n", "115 180.6762237548828\n", "116 191.1690216064453\n", "117 163.69064331054688\n", "118 163.85519409179688\n", "119 139.80825805664062\n", "120 140.6570587158203\n", "121 144.08004760742188\n", "122 158.80630493164062\n", "123 188.1263885498047\n", "124 195.49148559570312\n", "125 192.2193603515625\n", "126 153.9271240234375\n", "127 132.1674041748047\n", "128 124.47053527832031\n", "129 116.3742904663086\n", "130 114.36947631835938\n", "131 110.64471435546875\n", "132 111.86954498291016\n", "133 110.50808715820312\n", "134 116.09949493408203\n", "135 116.46121978759766\n", "136 128.56101989746094\n", "137 129.96461486816406\n", "138 145.54727172851562\n", "139 147.5155487060547\n", "140 151.8369140625\n", "141 163.9568634033203\n", "142 143.47850036621094\n", "143 166.7104949951172\n", "144 153.11297607421875\n", "145 158.97427368164062\n", "146 162.9993896484375\n", "147 142.5881805419922\n", "148 129.16439819335938\n", "149 110.0616683959961\n", "150 99.67420196533203\n", "151 91.19140625\n", "152 85.92176055908203\n", "153 82.2587890625\n", "154 79.37641143798828\n", "155 77.57473754882812\n", "156 75.71873474121094\n", "157 74.76695251464844\n", "158 73.61392974853516\n", "159 73.7373046875\n", "160 73.32074737548828\n", "161 74.79610443115234\n", "162 75.15473937988281\n", "163 78.76358032226562\n", "164 80.19961547851562\n", "165 87.09074401855469\n", "166 91.6589126586914\n", "167 105.99201965332031\n", "168 132.52023315429688\n", "169 199.60317993164062\n", "170 285.26385498046875\n", "171 186.7301025390625\n", "172 85.42977905273438\n", "173 71.33895111083984\n", "174 66.64602661132812\n", "175 63.440467834472656\n", "176 60.96210479736328\n", "177 58.90080261230469\n", "178 57.13398361206055\n", "179 55.62022018432617\n", "180 54.284812927246094\n", "181 53.065895080566406\n", "182 51.94929122924805\n", "183 50.92884063720703\n", "184 50.01615524291992\n", "185 49.2723388671875\n", "186 48.88428497314453\n", "187 49.18997573852539\n", "188 51.66621780395508\n", "189 56.960655212402344\n", "190 74.84833526611328\n", "191 96.08964538574219\n", "192 175.0045623779297\n", "193 258.491455078125\n", "194 184.846435546875\n", "195 83.15020751953125\n", "196 69.83650207519531\n", "197 51.21598815917969\n", "198 48.549232482910156\n", "199 46.50739669799805\n", "200 45.08198547363281\n", "201 43.89152145385742\n", "202 42.86724853515625\n", "203 41.932674407958984\n", "204 41.0839958190918\n", "205 40.29037094116211\n", "206 39.5553092956543\n", "207 38.86370086669922\n", "208 38.21051025390625\n", "209 37.590450286865234\n", "210 37.0007438659668\n", "211 36.43656921386719\n", "212 35.89463806152344\n", "213 35.36964416503906\n", "214 34.8619499206543\n", "215 34.37443542480469\n", "216 33.90838623046875\n", "217 33.463294982910156\n", "218 33.03306579589844\n", "219 32.62254333496094\n", "220 32.2387580871582\n", "221 31.8708438873291\n", "222 31.53622055053711\n", "223 31.22540855407715\n", "224 30.986581802368164\n", "225 30.75267219543457\n", "226 30.672260284423828\n", "227 30.66810417175293\n", "228 31.07099723815918\n", "229 31.388729095458984\n", "230 32.891693115234375\n", "231 33.72759246826172\n", "232 37.617942810058594\n", "233 38.177757263183594\n", "234 45.63814163208008\n", "235 43.17341613769531\n", "236 52.84267807006836\n", "237 45.67348861694336\n", "238 53.69830322265625\n", "239 43.98344039916992\n", "240 47.329803466796875\n", "241 39.6602668762207\n", "242 39.626487731933594\n", "243 35.451988220214844\n", "244 34.44646072387695\n", "245 32.15733337402344\n", "246 31.359241485595703\n", "247 29.95440101623535\n", "248 29.509944915771484\n", "249 28.538223266601562\n", "250 28.30735206604004\n", "251 27.50522804260254\n", "252 27.50213623046875\n", "253 26.80497932434082\n", "254 26.98086929321289\n", "255 26.31241226196289\n", "256 26.672256469726562\n", "257 25.982328414916992\n", "258 26.511144638061523\n", "259 25.812694549560547\n", "260 26.602962493896484\n", "261 25.768213272094727\n", "262 26.81290054321289\n", "263 25.837427139282227\n", "264 27.06724739074707\n", "265 25.927724838256836\n", "266 27.30327606201172\n", "267 25.987638473510742\n", "268 27.363279342651367\n", "269 26.01933479309082\n", "270 27.32394027709961\n", "271 25.930755615234375\n", "272 27.011472702026367\n", "273 25.699975967407227\n", "274 26.50739097595215\n", "275 25.408933639526367\n", "276 25.86797332763672\n", "277 25.097475051879883\n", "278 25.256912231445312\n", "279 24.849056243896484\n", "280 24.65272331237793\n", "281 24.661697387695312\n", "282 24.140518188476562\n", "283 24.619226455688477\n", "284 23.73824119567871\n", "285 24.717910766601562\n", "286 23.54086685180664\n", "287 24.9752197265625\n", "288 23.630443572998047\n", "289 25.48272132873535\n", "290 24.02497100830078\n", "291 26.01576805114746\n", "292 24.780500411987305\n", "293 26.469703674316406\n", "294 25.79985809326172\n", "295 26.6370906829834\n", "296 26.846729278564453\n", "297 26.385406494140625\n", "298 27.688945770263672\n", "299 25.86838150024414\n", "300 28.163406372070312\n", "301 25.150251388549805\n", "302 28.116533279418945\n", "303 24.517921447753906\n", "304 27.596027374267578\n", "305 23.859285354614258\n", "306 26.56470489501953\n", "307 23.3745059967041\n", "308 25.282028198242188\n", "309 22.86705207824707\n", "310 23.85765266418457\n", "311 22.422100067138672\n", "312 22.59556770324707\n", "313 22.051166534423828\n", "314 21.59400177001953\n", "315 21.73813247680664\n", "316 20.74985122680664\n", "317 21.347631454467773\n", "318 20.129060745239258\n", "319 20.965091705322266\n", "320 19.616971969604492\n", "321 20.59461212158203\n", "322 19.323686599731445\n", "323 20.284162521362305\n", "324 19.114233016967773\n", "325 20.032251358032227\n", "326 18.95831871032715\n", "327 19.813901901245117\n", "328 18.809864044189453\n", "329 19.587026596069336\n", "330 18.633268356323242\n", "331 19.3393611907959\n", "332 18.45094108581543\n", "333 19.10856056213379\n", "334 18.29006004333496\n", "335 18.963394165039062\n", "336 18.093769073486328\n", "337 18.76136016845703\n", "338 17.90596580505371\n", "339 18.58025360107422\n", "340 17.717267990112305\n", "341 18.402299880981445\n", "342 17.53641700744629\n", "343 18.24544906616211\n", "344 17.375978469848633\n", "345 18.122867584228516\n", "346 17.215097427368164\n", "347 18.01441192626953\n", "348 17.10610008239746\n", "349 17.929285049438477\n", "350 17.017274856567383\n", "351 17.8439884185791\n", "352 16.94769859313965\n", "353 17.71332550048828\n", "354 16.843292236328125\n", "355 17.566287994384766\n", "356 16.777320861816406\n", "357 17.416345596313477\n", "358 16.69339942932129\n", "359 17.225542068481445\n", "360 16.65900230407715\n", "361 17.003944396972656\n", "362 16.512805938720703\n", "363 16.707984924316406\n", "364 16.380939483642578\n", "365 16.481826782226562\n", "366 16.23741340637207\n", "367 16.224605560302734\n", "368 16.116172790527344\n", "369 16.01961326599121\n", "370 16.0587215423584\n", "371 15.799931526184082\n", "372 15.911197662353516\n", "373 15.567693710327148\n", "374 15.816048622131348\n", "375 15.395733833312988\n", "376 15.712669372558594\n", "377 15.229640007019043\n", "378 15.624187469482422\n", "379 15.078832626342773\n", "380 15.593881607055664\n", "381 14.984262466430664\n", "382 15.602352142333984\n", "383 14.913544654846191\n", "384 15.634354591369629\n", "385 14.903152465820312\n", "386 15.721226692199707\n", "387 14.883771896362305\n", "388 15.738739013671875\n", "389 14.934103012084961\n", "390 15.740025520324707\n", "391 15.051740646362305\n", "392 15.747553825378418\n", "393 15.295642852783203\n", "394 15.794942855834961\n", "395 15.707908630371094\n", "396 15.878894805908203\n", "397 16.251039505004883\n", "398 15.95378303527832\n", "399 16.912405014038086\n", "400 16.121116638183594\n", "401 17.57708740234375\n", "402 16.237180709838867\n", "403 18.057422637939453\n", "404 16.25248146057129\n", "405 18.1812801361084\n", "406 16.16710662841797\n", "407 18.015974044799805\n", "408 16.019521713256836\n", "409 17.67422103881836\n", "410 15.807389259338379\n", "411 17.291345596313477\n", "412 15.551980018615723\n", "413 16.870981216430664\n", "414 15.28683090209961\n", "415 16.45754051208496\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "416 15.046100616455078\n", "417 16.132770538330078\n", "418 14.80198860168457\n", "419 15.807101249694824\n", "420 14.587957382202148\n", "421 15.535334587097168\n", "422 14.381044387817383\n", "423 15.229350090026855\n", "424 14.211679458618164\n", "425 14.969184875488281\n", "426 14.052995681762695\n", "427 14.728175163269043\n", "428 13.93752670288086\n", "429 14.523200988769531\n", "430 13.845670700073242\n", "431 14.32778549194336\n", "432 13.768282890319824\n", "433 14.132892608642578\n", "434 13.77458381652832\n", "435 14.004599571228027\n", "436 13.83255386352539\n", "437 13.862512588500977\n", "438 13.929853439331055\n", "439 13.736422538757324\n", "440 14.140033721923828\n", "441 13.69277572631836\n", "442 14.489221572875977\n", "443 13.784162521362305\n", "444 15.030045509338379\n", "445 14.053020477294922\n", "446 15.616470336914062\n", "447 14.566765785217285\n", "448 16.124311447143555\n", "449 15.121051788330078\n", "450 16.293582916259766\n", "451 15.538890838623047\n", "452 16.134872436523438\n", "453 15.629945755004883\n", "454 15.781015396118164\n", "455 15.652036666870117\n", "456 15.342903137207031\n", "457 15.427617073059082\n", "458 14.877690315246582\n", "459 14.979843139648438\n", "460 14.347681045532227\n", "461 14.468661308288574\n", "462 13.883293151855469\n", "463 14.067682266235352\n", "464 13.53783130645752\n", "465 13.715188026428223\n", "466 13.294391632080078\n", "467 13.479684829711914\n", "468 13.086114883422852\n", "469 13.345418930053711\n", "470 12.981571197509766\n", "471 13.237014770507812\n", "472 12.89077091217041\n", "473 13.044453620910645\n", "474 12.756986618041992\n", "475 13.012262344360352\n", "476 12.732881546020508\n", "477 12.824119567871094\n", "478 12.588905334472656\n", "479 12.721162796020508\n", "480 12.548263549804688\n", "481 12.599344253540039\n", "482 12.476131439208984\n", "483 12.607063293457031\n", "484 12.505374908447266\n", "485 12.477144241333008\n", "486 12.475809097290039\n", "487 12.60508918762207\n", "488 12.559510231018066\n", "489 12.446664810180664\n", "490 12.48794937133789\n", "491 12.326469421386719\n", "492 12.461620330810547\n", "493 12.414360046386719\n", "494 12.549243927001953\n", "495 12.268865585327148\n", "496 12.52327823638916\n", "497 12.245977401733398\n", "498 12.541507720947266\n", "499 12.180692672729492\n", "500 12.569136619567871\n", "501 12.120124816894531\n", "502 12.593598365783691\n", "503 12.06467056274414\n", "504 12.618692398071289\n", "505 12.018089294433594\n", "506 12.653897285461426\n", "507 11.985318183898926\n", "508 12.665456771850586\n", "509 11.875547409057617\n", "510 12.63729476928711\n", "511 11.951704025268555\n", "512 12.78012466430664\n", "513 11.966864585876465\n", "514 12.805120468139648\n", "515 11.917511940002441\n", "516 12.759134292602539\n", "517 11.91922664642334\n", "518 12.74145221710205\n", "519 11.870694160461426\n", "520 12.648548126220703\n", "521 11.858591079711914\n", "522 12.608646392822266\n", "523 11.837157249450684\n", "524 12.503793716430664\n", "525 11.827508926391602\n", "526 12.392881393432617\n", "527 11.836199760437012\n", "528 12.351949691772461\n", "529 11.827587127685547\n", "530 12.291703224182129\n", "531 11.845207214355469\n", "532 12.238033294677734\n", "533 11.851987838745117\n", "534 12.129100799560547\n", "535 11.895650863647461\n", "536 12.006834983825684\n", "537 12.071447372436523\n", "538 11.862838745117188\n", "539 12.36300277709961\n", "540 11.838622093200684\n", "541 12.923316955566406\n", "542 12.066741943359375\n", "543 13.796222686767578\n", "544 12.682493209838867\n", "545 14.618804931640625\n", "546 13.567913055419922\n", "547 15.201546669006348\n", "548 14.514939308166504\n", "549 15.157068252563477\n", "550 14.894901275634766\n", "551 14.82116985321045\n", "552 14.973808288574219\n", "553 14.36568832397461\n", "554 14.864335060119629\n", "555 13.79199504852295\n", "556 14.384416580200195\n", "557 13.285518646240234\n", "558 13.8139009475708\n", "559 12.818103790283203\n", "560 13.309890747070312\n", "561 12.457561492919922\n", "562 12.768548011779785\n", "563 12.109518051147461\n", "564 12.345666885375977\n", "565 11.838518142700195\n", "566 11.98867416381836\n", "567 11.639802932739258\n", "568 11.519742965698242\n", "569 11.370664596557617\n", "570 11.15427017211914\n", "571 11.204188346862793\n", "572 10.994608879089355\n", "573 11.155523300170898\n", "574 10.941986083984375\n", "575 11.204839706420898\n", "576 10.84663200378418\n", "577 11.213382720947266\n", "578 10.813859939575195\n", "579 11.298030853271484\n", "580 10.721686363220215\n", "581 11.294883728027344\n", "582 10.642816543579102\n", "583 11.290977478027344\n", "584 10.621570587158203\n", "585 11.330678939819336\n", "586 10.58786392211914\n", "587 11.338900566101074\n", "588 10.599321365356445\n", "589 11.343742370605469\n", "590 10.636480331420898\n", "591 11.348445892333984\n", "592 10.661396980285645\n", "593 11.305547714233398\n", "594 10.69072437286377\n", "595 11.231929779052734\n", "596 10.729141235351562\n", "597 11.120464324951172\n", "598 10.836509704589844\n", "599 11.02585220336914\n", "600 11.147380828857422\n", "601 11.012079238891602\n", "602 11.700508117675781\n", "603 11.138029098510742\n", "604 12.44657039642334\n", "605 11.493819236755371\n", "606 13.279351234436035\n", "607 11.980254173278809\n", "608 13.828882217407227\n", "609 12.335530281066895\n", "610 13.890178680419922\n", "611 12.377365112304688\n", "612 13.672966003417969\n", "613 12.314437866210938\n", "614 13.351038932800293\n", "615 12.199349403381348\n", "616 13.001585960388184\n", "617 11.737051010131836\n", "618 12.477154731750488\n", "619 11.315387725830078\n", "620 12.10653305053711\n", "621 11.122674942016602\n", "622 11.923060417175293\n", "623 11.009233474731445\n", "624 11.797826766967773\n", "625 10.912956237792969\n", "626 11.667257308959961\n", "627 10.87093734741211\n", "628 11.612075805664062\n", "629 10.753063201904297\n", "630 11.500226020812988\n", "631 10.666036605834961\n", "632 11.389327049255371\n", "633 10.641620635986328\n", "634 11.34609603881836\n", "635 10.538703918457031\n", "636 11.240482330322266\n", "637 10.502617835998535\n", "638 11.186229705810547\n", "639 10.399940490722656\n", "640 11.07362174987793\n", "641 10.29592514038086\n", "642 11.001121520996094\n", "643 10.407917022705078\n", "644 11.066526412963867\n", "645 10.300006866455078\n", "646 10.924904823303223\n", "647 10.182451248168945\n", "648 10.847217559814453\n", "649 10.238099098205566\n", "650 10.836588859558105\n", "651 10.119577407836914\n", "652 10.745954513549805\n", "653 10.167619705200195\n", "654 10.740652084350586\n", "655 10.065475463867188\n", "656 10.654346466064453\n", "657 10.109355926513672\n", "658 10.62813949584961\n", "659 9.988624572753906\n", "660 10.55061149597168\n", "661 10.056044578552246\n", "662 10.53543472290039\n", "663 9.937051773071289\n", "664 10.44735050201416\n", "665 10.042400360107422\n", "666 10.486113548278809\n", "667 9.922344207763672\n", "668 10.333057403564453\n", "669 9.806156158447266\n", "670 10.256221771240234\n", "671 9.865896224975586\n", "672 10.259897232055664\n", "673 9.773486137390137\n", "674 10.196406364440918\n", "675 9.832138061523438\n", "676 10.19542121887207\n", "677 9.742813110351562\n", "678 10.134439468383789\n", "679 9.804479598999023\n", "680 10.11970329284668\n", "681 9.704248428344727\n", "682 10.07077407836914\n", "683 9.851264953613281\n", "684 10.099723815917969\n", "685 9.732878684997559\n", "686 9.96037483215332\n", "687 9.626123428344727\n", "688 9.859709739685059\n", "689 9.604263305664062\n", "690 9.819677352905273\n", "691 9.51850700378418\n", "692 9.78436279296875\n", "693 9.615999221801758\n", "694 9.804388999938965\n", "695 9.610206604003906\n", "696 9.761518478393555\n", "697 9.523723602294922\n", "698 9.675700187683105\n", "699 9.515082359313965\n", "700 9.639348983764648\n", "701 9.448809623718262\n", "702 9.563665390014648\n", "703 9.449939727783203\n", "704 9.554956436157227\n", "705 9.402902603149414\n", "706 9.48503303527832\n", "707 9.410470962524414\n", "708 9.466278076171875\n", "709 9.42030143737793\n", "710 9.462084770202637\n", "711 9.388456344604492\n", "712 9.390506744384766\n", "713 9.411766052246094\n", "714 9.380949020385742\n", "715 9.453214645385742\n", "716 9.358305931091309\n", "717 9.496848106384277\n", "718 9.340471267700195\n", "719 9.500185012817383\n", "720 9.268125534057617\n", "721 9.6624174118042\n", "722 9.285665512084961\n", "723 9.766561508178711\n", "724 9.265359878540039\n", "725 9.91065788269043\n", "726 9.266630172729492\n", "727 10.196619033813477\n", "728 9.33184814453125\n", "729 10.31056022644043\n", "730 9.4459228515625\n", "731 10.629362106323242\n", "732 9.821006774902344\n", "733 11.032565116882324\n", "734 10.545709609985352\n", "735 11.271682739257812\n", "736 11.432579040527344\n", "737 11.432348251342773\n", "738 12.460997581481934\n", "739 11.330101013183594\n", "740 13.059946060180664\n", "741 11.198677062988281\n", "742 13.333955764770508\n", "743 10.929497718811035\n", "744 13.135268211364746\n", "745 10.77116584777832\n", "746 12.882686614990234\n", "747 10.58820915222168\n", "748 12.530218124389648\n", "749 10.332746505737305\n", "750 12.048186302185059\n", "751 10.15908432006836\n", "752 11.673738479614258\n", "753 9.969121932983398\n", "754 11.312966346740723\n", "755 9.854571342468262\n", "756 11.113723754882812\n", "757 9.780363082885742\n", "758 10.999980926513672\n", "759 9.710268020629883\n", "760 10.88806438446045\n", "761 9.661212921142578\n", "762 10.785196304321289\n", "763 9.597941398620605\n", "764 10.69247055053711\n", "765 9.532906532287598\n", "766 10.603074073791504\n", "767 9.484548568725586\n", "768 10.532205581665039\n", "769 9.43923568725586\n", "770 10.460396766662598\n", "771 9.406749725341797\n", "772 10.359094619750977\n", "773 9.35378646850586\n", "774 10.313379287719727\n", "775 9.339227676391602\n", "776 10.270576477050781\n", "777 9.321992874145508\n", "778 10.17398452758789\n", "779 9.28928279876709\n", "780 10.135478973388672\n", "781 9.282571792602539\n", "782 10.083499908447266\n", "783 9.312825202941895\n", "784 10.074797630310059\n", "785 9.29830551147461\n", "786 9.953083038330078\n", "787 9.271284103393555\n", "788 9.887079238891602\n", "789 9.270550727844238\n", "790 9.776187896728516\n", "791 9.274053573608398\n", "792 9.742859840393066\n", "793 9.29629135131836\n", "794 9.691925048828125\n", "795 9.311891555786133\n", "796 9.628150939941406\n", "797 9.317649841308594\n", "798 9.54599666595459\n", "799 9.326900482177734\n", "800 9.46576976776123\n", "801 9.333112716674805\n", "802 9.316043853759766\n", "803 9.245628356933594\n", "804 9.174911499023438\n", "805 9.130691528320312\n", "806 9.026357650756836\n", "807 9.047698974609375\n", "808 8.901134490966797\n", "809 8.962682723999023\n", "810 8.791932106018066\n", "811 8.87315559387207\n", "812 8.687568664550781\n", "813 8.800554275512695\n", "814 8.623148918151855\n", "815 8.716794967651367\n", "816 8.543046951293945\n", "817 8.56474494934082\n", "818 8.397558212280273\n", "819 8.531536102294922\n", "820 8.37232780456543\n", "821 8.428328514099121\n", "822 8.274665832519531\n", "823 8.431962966918945\n", "824 8.270166397094727\n", "825 8.379289627075195\n", "826 8.207696914672852\n", "827 8.401199340820312\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "828 8.20716381072998\n", "829 8.469131469726562\n", "830 8.236584663391113\n", "831 8.424156188964844\n", "832 8.168313980102539\n", "833 8.484416961669922\n", "834 8.172074317932129\n", "835 8.536100387573242\n", "836 8.157014846801758\n", "837 8.581727027893066\n", "838 8.13553524017334\n", "839 8.696174621582031\n", "840 8.188529014587402\n", "841 8.957174301147461\n", "842 8.366853713989258\n", "843 9.087259292602539\n", "844 8.61898422241211\n", "845 9.600666046142578\n", "846 9.278848648071289\n", "847 9.91394329071045\n", "848 10.14834976196289\n", "849 10.423606872558594\n", "850 11.416866302490234\n", "851 10.418106079101562\n", "852 12.281963348388672\n", "853 10.481473922729492\n", "854 12.87403678894043\n", "855 10.396160125732422\n", "856 12.916685104370117\n", "857 10.285785675048828\n", "858 12.771074295043945\n", "859 10.352560997009277\n", "860 12.81513500213623\n", "861 10.118997573852539\n", "862 12.345061302185059\n", "863 10.030735969543457\n", "864 12.146442413330078\n", "865 9.814250946044922\n", "866 11.731513977050781\n", "867 9.791189193725586\n", "868 11.683649063110352\n", "869 9.596960067749023\n", "870 11.304178237915039\n", "871 9.476587295532227\n", "872 11.074976921081543\n", "873 9.46214485168457\n", "874 11.042204856872559\n", "875 9.287127494812012\n", "876 10.719624519348145\n", "877 9.23538589477539\n", "878 10.631698608398438\n", "879 9.11007308959961\n", "880 10.393648147583008\n", "881 9.10506820678711\n", "882 10.377655029296875\n", "883 8.979528427124023\n", "884 10.174203872680664\n", "885 8.9866943359375\n", "886 10.169689178466797\n", "887 8.958577156066895\n", "888 10.132844924926758\n", "889 8.853513717651367\n", "890 9.939226150512695\n", "891 8.794721603393555\n", "892 9.840280532836914\n", "893 8.682388305664062\n", "894 9.676387786865234\n", "895 8.69548225402832\n", "896 9.683582305908203\n", "897 8.684111595153809\n", "898 9.66401195526123\n", "899 8.594156265258789\n", "900 9.512399673461914\n", "901 8.573217391967773\n", "902 9.489818572998047\n", "903 8.463871002197266\n", "904 9.332828521728516\n", "905 8.482060432434082\n", "906 9.340717315673828\n", "907 8.466593742370605\n", "908 9.309955596923828\n", "909 8.376811981201172\n", "910 9.1652250289917\n", "911 8.364412307739258\n", "912 9.147921562194824\n", "913 8.264249801635742\n", "914 9.025073051452637\n", "915 8.291736602783203\n", "916 9.049306869506836\n", "917 8.298337936401367\n", "918 9.050772666931152\n", "919 8.218606948852539\n", "920 8.925788879394531\n", "921 8.221832275390625\n", "922 8.929916381835938\n", "923 8.125205993652344\n", "924 8.825401306152344\n", "925 8.168115615844727\n", "926 8.862433433532715\n", "927 8.189055442810059\n", "928 8.871030807495117\n", "929 8.194257736206055\n", "930 8.863895416259766\n", "931 8.089027404785156\n", "932 8.718398094177246\n", "933 8.125476837158203\n", "934 8.73470687866211\n", "935 8.16848373413086\n", "936 8.743915557861328\n", "937 8.072606086730957\n", "938 8.597414016723633\n", "939 8.140182495117188\n", "940 8.633056640625\n", "941 8.234619140625\n", "942 8.650079727172852\n", "943 8.373706817626953\n", "944 8.667745590209961\n", "945 8.419292449951172\n", "946 8.564188003540039\n", "947 8.71246337890625\n", "948 8.499603271484375\n", "949 8.754026412963867\n", "950 8.378808975219727\n", "951 8.707267761230469\n", "952 8.223587036132812\n", "953 8.973301887512207\n", "954 8.249604225158691\n", "955 9.120798110961914\n", "956 8.240376472473145\n", "957 9.567197799682617\n", "958 8.391382217407227\n", "959 9.934947967529297\n", "960 8.551895141601562\n", "961 9.747421264648438\n", "962 8.641094207763672\n", "963 9.514118194580078\n", "964 8.857434272766113\n", "965 9.758781433105469\n", "966 9.328777313232422\n", "967 9.747819900512695\n", "968 9.824708938598633\n", "969 9.374284744262695\n", "970 10.072290420532227\n", "971 9.379064559936523\n", "972 10.451343536376953\n", "973 9.188913345336914\n", "974 10.514688491821289\n", "975 9.137795448303223\n", "976 10.573251724243164\n", "977 8.986966133117676\n", "978 10.422307968139648\n", "979 8.891181945800781\n", "980 10.281803131103516\n", "981 8.728313446044922\n", "982 10.022692680358887\n", "983 8.668171882629395\n", "984 9.946191787719727\n", "985 8.618570327758789\n", "986 9.84320068359375\n", "987 8.564799308776855\n", "988 9.743132591247559\n", "989 8.508552551269531\n", "990 9.638229370117188\n", "991 8.444400787353516\n", "992 9.526444435119629\n", "993 8.40636157989502\n", "994 9.446041107177734\n", "995 8.32719898223877\n", "996 9.311468124389648\n", "997 8.256400108337402\n", "998 9.157243728637695\n", "999 8.188175201416016\n" ] } ], "source": [ "N, D_in, H, D_out = 64, 2, 100, 3\n", "\n", "better_net = torch.nn.Sequential(\n", " torch.nn.Linear(D_in, H),\n", " torch.nn.ReLU(),\n", " torch.nn.Linear(H, H),\n", " torch.nn.ReLU(),\n", " torch.nn.Linear(H, H),\n", " torch.nn.ReLU(),\n", " torch.nn.Linear(H, D_out),\n", ")\n", "\n", "loss_fn = torch.nn.CrossEntropyLoss(size_average=False)\n", "\n", "learning_rate = 1e-3\n", "optimizer = torch.optim.SGD(better_net.parameters(), lr=learning_rate)\n", "for t in range(1000):\n", " # forward\n", " y_pred = better_net(X)\n", "\n", " # loss\n", " loss = loss_fn(y_pred, y)\n", " print('{} {}'.format(t, loss.data))\n", "\n", " # зануляем градиенты (чтобы не было остатка с редыдущего шага)\n", " optimizer.zero_grad()\n", "\n", " # backward\n", " loss.backward()\n", "\n", " # обновляем\n", " optimizer.step()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "h = 0.02\n", "x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1\n", "y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1\n", "\n", "xx, yy = np.meshgrid(np.arange(x_min, x_max, h),\n", " np.arange(y_min, y_max, h))\n", "grid_tensor = torch.FloatTensor(np.c_[xx.ravel(), yy.ravel()])\n", "\n", "Z = better_net(torch.autograd.Variable(grid_tensor))\n", "Z = Z.data.numpy()\n", "Z = np.argmax(Z, axis=1)\n", "Z = Z.reshape(xx.shape)\n", "\n", "plt.figure(figsize=(10, 8))\n", "\n", "plt.contourf(xx, yy, Z, cmap=plt.cm.rainbow, alpha=0.3)\n", "plt.scatter(X[:, 0], X[:, 1], c=y, s=40, cmap=plt.cm.rainbow)\n", "\n", "plt.xlim(xx.min(), xx.max())\n", "plt.ylim(yy.min(), yy.max())\n", "\n", "plt.title('Спираль', fontsize=15)\n", "plt.xlabel('$x$', fontsize=14)\n", "plt.ylabel('$y$', fontsize=14)\n", "plt.show();" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Предлагается самостоятельно проанализировать то, что было изменено, чтобы улучшить качество модели (и *обратить на это внимание*)." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "6N_9wfvPYJeK" }, "source": [ "---" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "H_thmQJOYJeK" }, "source": [ "

Полезные ссылки

" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "RpSrLf9FYJeL" }, "source": [ "1). *Примеры написания нейросетей на PyTorch (офийиальные туториалы) (на английском): https://pytorch.org/tutorials/beginner/pytorch_with_examples.html#examples \n", "https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html*\n", "\n", "2). Курс Стэнфордского Университета: http://cs231n.github.io/\n", "\n", "3). Практически исчерпывающая информация по основам нейросетей (из cs231n) (на английском): \n", "\n", "http://cs231n.github.io/neural-networks-1/, \n", "http://cs231n.github.io/neural-networks-2/, \n", "http://cs231n.github.io/neural-networks-3/, \n", "http://cs231n.github.io/neural-networks-case-study/#linear\n", "\n", "4). *Хорошие статьи по основам нейросетей (на английском): http://neuralnetworksanddeeplearning.com/chap1.html*\n", "\n", "5). *Наглядная демонстрация того, как обучаются нейросети: https://cs.stanford.edu/people/karpathy/convnetjs/*" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "1Qldb1U5YJeM" }, "source": [ "6). *Подробнее про backpropagation -- статья на Medium: https://medium.com/autonomous-agents/backpropagation-how-neural-networks-learn-complex-behaviors-9572ac161670*" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "g1HSslRhYJeN" }, "source": [ "7). *Статья из интернета по Backprop: http://page.mi.fu-berlin.de/rojas/neural/chapter/K7.pdf*" ] } ], "metadata": { "colab": { "name": "[seminar]mlp_pytorch.ipynb", "provenance": [], "version": "0.3.2" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 1 }