{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Using Machine Learning to Classify Handwritten Digits\n", "\n", "Things we know about the data:\n", "#### - We have 10 digits.\n", "#### - The images are black and white.\n", "#### - Image size is 28 pixels by 28 pixels.\n", "\n", "![Handwritten Digits](resources/MnistExamples.png \"Handwritten Digits Example\")

Figure 1: Handwritten Digits Example

\n", "\n", "#### Let us start by downloading the data.\n", "\n", "Change `` to the username you created when registering for SciServer.\n", "\n", "For example, my username is `'adi'`, so I will type `USERNAME = 'adi'`." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torchvision\n", "\n", "# change USERNAME to be your SciServer username.\n", "USERNAME = 'adi'\n", "\n", "# DO NOT CHANGE!\n", "DOWNLOAD_FOLDER = '/home/idies/workspace/Temporary/' + USERNAME + '/scratch'\n", "\n", "# MNIST is the name of the handwritten dataset.\n", "dataset = torchvision.datasets.MNIST(DOWNLOAD_FOLDER, download=True, transform=torchvision.transforms.ToTensor())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that the dataset is downloaded, we can examine some of the images.\n", "\n", "To do this, we will need to import a visualization library. We will use `matplotlib` in this tutorial." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([10, 1, 28, 28])\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAagAAACzCAYAAADc1IgsAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3deXQV9f3/8edHVhHqQgK4hFRBoSxFC6dGEY24YLUIsSQoHKzYotQqsklFUakoWkEQOVrqBpXaSkCB4lFJPV83iiDgrxWkSqmyKMhiRXayfX5/DDO9uQlwk8zcmXvzepwzh5u5kzufvLjJ+85nPvMZY61FREQkao4LuwEiIiJVUYESEZFIUoESEZFIUoESEZFIUoESEZFIUoESEZFIUoESEZFI8rVAGWNuM8Z8YYw5aIxZZYzp4efr10XK1H/K1H/KNBh1PVffCpQxpj8wDZgInAcsBd4wxrT2ax91jTL1nzL1nzINhnIF49dMEsaY5cDH1tohMev+Dcyz1o71ZSd1jDL1nzL1nzINhnL16QjKGNMQ6AoUxT1VBFzoxz7qGmXqP2XqP2UaDOXqqO/T62QA9YBtceu3AZfHb2yMuQW45fCXXX1qQ1B2WmszQ9ivMvWfMvWfMg1Gwrmmc6Z+FShXfH+hqWId1tpngGcAjDFRn612Y8j7V6b+U6b+U6bBOGau6ZypX4MkdgJlQKu49S2o/AlAEqNM/adM/adMg6Fc8alAWWuLgVXAFXFPXYEz8iQw3bt3Z/ny5Sxfvpzy8vIKS8uWLYPcdaDCzDRdKVP/KdNgKFeHn118U4DZxpgPgb8DQ4HTgBk+7qOuUab+U6b+U6bBUK7WWt8W4DZgA3AIp/pfnMD32JosAwcOtAMHDrTffvutLSsr85Z9+/Z5jwcPHlyj145bVvqZUZQzTeJSJzOdOHGinThxov3kk09s//79lWn478NIZ1qTXCOQma+Z+nYdVE3V5KRes2bNWLZsGQA/+MEPWL16NQD33Xcfubm53HnnnQD06tWLt956q7ZNXGWt7VbbF0mmoE+U5uTkADBy5EjOOOMMAPr378/mzZsTfYk6mWlRkTNi+PLLnUFYxx3n60QudTLTgClT/1UrU83FJyIikeT3MPOkmDBhAu3btwdgw4YNXHXVVQBs3bqVdevWMWSIc+F1Kg+SiJqsrCwA5syZwwUXXOCtnzt3LkB1jp7ksAULFgBw2223sWXLlpBbI+LIzs4GYOrUqfTp04eJEycC8Oijj7Jv376ktiWlCpQb3ODBg711b7/9Nlu3bvW+/vTTT7npppsA2LatzozGDExWVlalovTBBx8AThef29UqR9eiRQvv/evq06cPAPPmzeNPf/pTGM1KeU2bNgXgpptuokOHDnz/+98HoEePHuzYsQNwuqS3b98eVhNTSlZWFq+//joA7dq1w1rLPffcA8AJJ5zAyJEjk9qelCpQ06ZNA5xzUO+//z4Av/jFLypt98orryS1XekmKyuLfv36ATBixAiysrJUlGpp+/btbNzoXKN49tlnc/PNN/PGG28A8N1334XZtJTVtWtXHnvsMQB69uxZ6Xm3eA0cOJCpU6cmtW2ppkWLFgC8+eabtGvXrspthg4dyg9/+EMAOnXqxJo1awBYsmQJCxYs4OOPPwagvLzct3bpHJSIiERSSh1BuSMOrbWsW7cu5Nakn4KCAgAmT57snXPavHkzI0eO1CfQWurSpQtXXOFcc/nRRx8xd+7cpPfnp4srr7wScD7tG2MAKCsr45NPPuGTTz4B4OOPP/by3bBhQyjtTCV33XUXwBGPngAaNmxIbm6u97X7+NJLL2XcuHF07epMA/jPf/7Tt3alVIGS4BQUFDBnzhzv6ylTpgAwatSosJqUVg4cOMCePXsAOOmkk6hfX796NdGsWTOeffZZAIwxzJw5E4DHHnuMTz/9lJdeeglwuk1///vfh9bOVDJhwgTv3NLRLjtauHCh9+Hg+OOPr/S8e07VzwKlLj4REYmklPkYV79+fe+kJ+jEst/cQRHgDB3XkZO/1q1b5w0uufzyy2ncuLHewzXQs2dPWrd2big7adIk7r77bsAZEDFx4kSvm7ply5Y6gjoGtxv/+uuv97pKY+3fv5/XXnuNhx9+GIA1a9Z4R/7t2rXzLpNo06YNu3fv5p133vG9jSlToE466aQKI3XcESSJys7OrvD9f//733UeK8aUKVO8GSLy8/O9Q/25c+cyb948CgsLw2yeCADnnXceBw8eBOC5555j9OjRAJx++ukMGzbM265t27Y0bNgQgOLi4uQ3NOIaN27Myy+/DMCZZ55Z4fz+2rVrAXjggQeYP39+he/r1KkT4FwPeeaZZ3rfs379elasWOF7O9XFJyIikZQyR1B79+71Tr6de+65dO/eHYBZs2Yd8Xvatm3L008/Dfxv/jPX9u3b6dWrF+DvSb1UtWzZMq/rpKCgwOvyy8/PJz8/n8mTJwPO1eUa0Sdhaty4MeCM4nM/xcfLzs6mQYMGgI6gqtKyZUvOP//8Sut37drFmDFjACffWI0bN/a6+9q0aeOt37t3L7Nnz+bAgQO+tzNlCtTBgwe9Cx27dOlCo0aNqtyuV69ejB071tvOPW+1atUqb+LYc845h7y8PO+C3g4dOuhNHKOwsLBCl547mwQ4XYHurBJuf78kxu3nr6q/XxLz1ltveZNBZ2Zmsnv3bgBmzJjB2rVr+d3vfgdomrNjOfXUUyt87Xbl/epXv/Jm4HCdcMIJAIwePdr7UB9r0qRJPPnkk8E0NALTySc8Vfu0adPstGnTKtxeIzs72zZq1Mg2atTI3nHHHfbAgQPeczt27LDjxo2z48aNq/RaRUVF3nbXXnttpKfcDzLT6i6FhYXWVVBQkLK3MQgj06KiIltUVGTLy8ttp06dQr2NQRSWmv6szZs3t82bN7eZmZne4j7n/k5ba+1ZZ51lzzrrLGVaxdKsWTM7ffp0O336dNu9e/ejbvvggw/aBx980JaWllZYZs+ebWfPnh1opjoHJSIi0ZQqFR+wbdu2tW3btrV79uzxPik9/fTTNiMjw2ZkZHjrFi9ebBcvXmy7det2xNd67rnnvO1nzJhRJz9F1XRxLV26VJ9Mq7HEHkE99thjfv+/1MlM45fYI6j8/Hybn5+vTGuxrF692rpie67mzZuXlExT5hwUwPr16wFn9ucbb7wRgAEDBvDuu+962+zatYtf//rXFbavyjXXXOM9dmfvleqJneFcquemm27yTkaLP9q3b6/zez4aP3482dnZ3uSv1loOHToE4A2aCpq6+EREJJJS6gjKNX78ePLy8gA48cQT+ctf/uI9d/LJJzN06FAA7yK+eC1btqRVq1beyL3//ve/Abc4PekmhdUTe8lDRkZGyK1JP61bt65wBBXEhaN1gTtit3fv3t7FzuCMhL7uuusA+Oqrr5LSlpQsUBs3bvSGkj/11FNu3ysbN25kypQp/OMf/6j0PY0aNeKMM84AnC5Cay1FRUWAcz8TSYw7PQqg+0JV065duwBn5u369et7M0hPmjQpzGaljdji9MUXXyTtj2g66devH7179wbwitOHH34IwLXXXltpCHrQ1MUnIiKRlJJHUIA3EeT3vvc9Jk6cCDhXj996663eRWPNmzf3tu/WrZt36+L333+fMWPGMGPGjCS3OvU9/vjj3mNNKFs97mSa7777Lj179iQzMxNw7gx72WWXAXh3iJXqi+2OWr9+PSUlJSG2JrX07dsXgBdeeKFCjgAvvvgiQNKPniCFC5RrypQp3lRFzzzzDB06dPAKjzHG6/7bs2ePNz3HV199pTdvNbmzRuTn5zN37lxA56Bq6o9//CM9e/b0ZkRo166dN32PClTNxU5n5s4aI8f28ssve7/f7t9LV+fOnb3JY8OQ8gWqpKTEmzPq3HPPZdCgQd6Muz169PDmh3r++ed1Z80aevzxx70bmm3evFlTHNXS559/zsGDB72bvvXu3Zu//e1vIbcq9cXeBHLr1q0htiT6GjduzPTp0wG4+uqrvcIUO5T89ttv96aXC4vOQYmISDQlcGXyWGAFsBvYASwCOsVtM4vKVwwvC+tqcp8X368mj2qmOTk5Nicnx44YMcIWFBTYgoICu3TpUmuttZs2bbKbNm2yWVlZytSH5f777/eu0C8vL7cTJkywEyZMiFSmQefqZ55ZWVn2wIEDtri42BYXF9u2bdtG8n0alUznzp1bYV49N7etW7favn372r59+/r6fq9ppomEuRgYDHQCOgPzga+BU+LC/BvQKmY5JaEGBBNCaIEm+DNHMtOsrCyblZVlY23atMmOGDFCmYb/Pkx6pkHn6ufPf84553jv102bNinTYyxff/11hQLlTg8XtffpMc9BWWsrzK9ujBkEfAd0x6n8rkPW2q+P9XqiTIOgTIOhXP2nTBNXk0ESzXDOXX0bt/4iY8x2YBfwLnCvtXZ7LdtXV0QiU3dUXprMZxaJTNNQpHNdvnx5snfph6RnumbNGnJzcwG47777+MMf/uDHy/quJgVqGvAP4IOYdW8CrwJfAN8HHgL+zxjT1Vp7KP4FjDG3ALfUYN/pSpn6T5kGo1a5BpVpx44dAfjmm2/8fulkSHqm8XcYj6xq9p1OAbYAZx1ju9OAEuA69e0rU2Wa+pkGkaufP39eXp611toZM2Yc6/Y5yjSF3qcJH0EZY6YC1wOXWms/P9q21totxpgvgbMTff26SJn6T5kGI+q5zp8/P+W6pqOeaRQkVKCMMdNwgsy11n6awPYZwOlAIlfL7QT2Hf43CjKo2JbsIHaiTP2nTIMRYK5RyxQq5qpM/VHzTBM4BH0KZ7x+TyoOeWx6+PmmwGTgApy+0lycvtQvgWYJHuZG5s6VyWiLMlWmqZBpMnKNUqbJao8yrcb3JvDiR+pLHH/4+eNxxvVvB4qBjThj+LOi9KaIUluUqTJNhUyTkWuUMk1We5Rp4ksi10EdtWPXWnsA6HW0baQiZeo/ZRoM5eo/ZZq4qMzF90zYDYgRpbbURpR+jii1pTai9HNEqS21EbWfI2rtqYmo/Qw1bo85fAgmIiISKVE5ghIREalABUpERCJJBUpERCIp1AJljLnNGPOFMeagMWaVMaZHEvY51hizwhiz2xizwxizyBjTKW6bWcYYG7csC7ptfggj08P7TdtclWkw9Pvvv3TLNLQCZYzpjzNJ4kTgPGAp8IYxpnXAu84FngYuxLlQrhR4yxhzStx2bwGnxixXB9yuWgsxU0jTXJVpMPT777+0zDTEi7eWA8/Grfs38EiS29EUKAN6x6ybBbwW9gVuqZppOuWqTNM7V2Ua7UxDOYIyxjQEugJFcU8V4VThZDrqvViMMeuMMc8aY1okuV3VErFMIQ1yVabBiFiuytR/vmUaVhdfBlAP2Ba3fhvOnFTJdKR7sdwIXAaMAn6Mcy+WRkluW3VEKVNIj1yVaTCilKsy9Z9vmdbkhoV+ir9K2FSxLjDGmCnARcBF1toyr1HWvhyz2WpjzCqc+bCuwbmJWJSFmimkZa7KNBj6/fdfWmUa1hHUTpw+yvjK3oLKnwACYZx7sdwA9LQJ3IsFZybhKN+LJfRMIe1yVabBCD1XZeq/IDINpUBZa4uBVcAVcU9dgTPyJFDGuRfLAJwg/b5vUCjCzhTSL1dlGoywc1Wm/gss0xBHnPTHmUr+l8APcPot9wLZAe838PsG1bVM0zlXZZpeuSrT1Mo07EBvAzYAh3Cq/8VJ2Gfg9w2qa5mme67KNH1yVaaplalmMxcRkUjSXHwiIhJJKlAiIhJJKlAiIhJJKlAiIhJJKlAiIhJJKlAiIhJJKlAiIhJJKlAiIhJJKlAiIhJJKlAiIhJJKlAiIhJJKlAiIhJJKlAiIhJJKlAiIhJJKlAiIhJJKlAiIhJJKlAiIhJJKlAiIhJJKlAiIhJJKlAiIhJJKlAiIhJJKlAiIhJJKlAiIhJJKlAiIhJJKlAiIhJJKlAiIhJJKlAiIhJJKlAiIhJJKlAiIhJJKlAiIhJJKlAiIhJJKlAiIhJJKlAiIhJJKlAiIhJJKlAiIhJJKlAiIhJJKlAiIhJJvhYoY8xtxpgvjDEHjTGrjDE9/Hz9ukiZ+k+Z+k+ZBqOu5+pbgTLG9AemAROB84ClwBvGmNZ+7aOuUab+U6b+U6bBUK5grLX+vJAxy4GPrbVDYtb9G5hnrR3ry07qGGXqP2XqP2UaDOUK9f14EWNMQ6ArMDnuqSLgwmN8rz8VMjg7rbWZyd6pMvWfMvWfMg1GTXNNt0z96uLLAOoB2+LWbwNaxW9sjLnFGLPSGLPSp/0HaWNI+1Wm/lOm/lOmwUg413TO1JcjqBjx1dtUsQ5r7TPAM5ASFT9sytR/ytR/yjQYx8w1nTP16whqJ1BG5U9MLaj8CUASo0z9p0z9F8lMu3btyrZt29i2bRu33HJLWM2ojUjmmmy+FChrbTGwCrgi7qkrcEaeSDUpU/8pU/9FLdO8vDzy8vJ4/fXXad68Oc2bN2fIkCFkZGSQkZGR7ObUWNRyDYufXXxTgNnGmA+BvwNDgdOAGT7uo65Rpv5Tpv5TpsGo87n6VqCstXOMMc2BccCpwBrgamttmCcaU1rUMz3hhBN4++23AWjQoAHnnXdeyC06tqhnmoqikumECRO45557ADDG4F5C0759e+/oaefOnclsUq1EJddQWWtDXXBO+CV1+eUvf2mttXbMmDF2zJgxx9p+ZdgZRTXT0aNH27KyMltWVmY/+uij6nyvMvV/qdOZ3nvvvbasrMyWlpba0tLSCo/Hjh2rTKOzVCtTzcUnIiKR5Pcw80g47jin7p5++un8+Mc/9ta99tprABQUFADQokWLcBqYJlq3rjMzrvguJyfHy89aS2FhIQDl5eUYYwB4/PHHueuuu0JrYyrIzs4GYNiwYV5u4HTxzZ8/H4BHHnkklLZJ7aVdgTrllFN4/vnnAejbt6+3vry8nNWrVwPQpUsX1q1bx+TJ8RdpSyJ++tOfAlQYvvviiy+G1ZyUkJOTw/DhwwHnj+f5559PVlYW4Lw3y8vLvcfuB6zhw4ezZcsWpk6dGk6jU0CPHs7cqc2bN4/t5mLnzp2MHDkyzKaljeOPPx6Afv36Vfn89u3bWbx4cSD7TosC1aBBAwBGjx7NhAkTqFevnvfcmjVrAOjUqRNdunTx1hcWFvL1118nt6FpYuxYZxqw+vXrs3//fgDee++9MJsUKveI3Fpb4eR8YWGhV3hi17uP3U/8xx133BEf5+TkJPVnSTUXXXQR4GRqjPEGQVxyySVs2rQpzKalhbZt2zJ69GgAhgwZUuU23333Hddffz1FRUW+71/noEREJJJS/giqTZs2PPzwwwD079+f0tJS73Bz2LBhnHLKKQB88MEH3vc89NBDTJw4MfmNTQN5eXkVPtVv374dgI8++iisJoVqxIgRXlex2z0X213nPo5d7z52u/KO9tg96pKq5eXlAc7R686dO/nJT34CwKeffhpms1Ka23V/9tlnM3ToUJo0aXLU7U888cTAzuenfIEaNWoU/fv3B+DLL79k0KBBvPPOO97zH3/8sfd47ty5gFOgDh06lNR2pjr3hP6TTz7p/dHcsWMHvXv3DrNZoXG79SZPnuwVFOCIj40x3tfu40S6+GJP/EtFK1asIDPTmRjbWsuTTz5ZZz8o1UazZs246qqrAGeKqBEjRgBOFz7AkiVLANi6dSvXXHMNQIWitWPHDjZs2BBI29TFJyIikZTyR1AAe/fuBSA3N5f//Oc/3qiTwsJCOnfu7D0eMGAAAGVlZeE0NEU1atSI++67D4BTTz3VW//iiy+ydu3asJoVqqVLl3r/Xnihc3ue2K4919G6+G644QYA7rzzzgqvoS6+o3O79dq3b+/l8+qrr3pd/VI9nTt35uWXX660fteuXbz33ntel19JSYl32cPdd9/tbbdkyRLvKMtvaVGgSkpKANi8eTMtWrTwrn+48MILmTNnDgCDBw9WYaqhO++8k8GDB1daX5eHln/55ZeA84fRHTVaVlbmjSoDmDdvnvcHtKCggGXLlnnfM3XqVPLz8wHnfaouvsS98sorgNOt544i/fOf/1xpO7f7b8eOHclrXAo6//zzq1w/ZcqUCkV/3LhxFQqTy51eKgjq4hMRkUhK+SOojRs3cvLJJwPw0ksv8c0333jdJR9++CE33ngjAMXFxaG1MZV169bNGxnlcj9F/etf/wqjSZEydepUb/ANOBfkukdN7id9cI6m3CMo9+jLvXD3aKP4dJFuRbHdetZabzTu/PnzyczM9K7Ry8zM9I5mR44c6fWqiDPq7pJLLvG+PnjwoPd45cqV3HzzzQBs2LCBXr16eRc8X3DBBd52+/fv9wZIff7559SrV88bQHH//fdz+umnAxVPCdRIqk9uWK9ePTtr1iw7a9Ys69q9e7fdvXu3bdKkiSbhrOWyaNEia631JoWdMWOGzczMtJmZmcq0FsucOXNseXm5LS8vt2VlZVU+fv/995Vp3HLvvfdWyMpdn5eXZ9euXeu9T93ny8rK7Nq1a22TJk1q8vcgrTJ1M5g5c6aXTVXL0qVL7dKlS+3ChQvt/v37j7jdunXr7Lp16+zChQvtokWLqtymtpmm/BFUw4YNj9hP756bkupzT0Tn5uZSXl7uXZX/29/+Vn36Poj5g3LEI6gnnngitPZFVd++fb3cXn31VWbPnu2tb9Kkifcc4D1u164d7du3B+ru9Xrg9DABXHvttUfd7kjnpOK1adOmwr+uLVu28MILL9SghZXpHJSIiERSyh9BFRYWepOXfvjhh8yaNYtHH30UcEah3HHHHWE2LyV17dqVmTNnAs5EkcXFxTz44IOAc7Ge1E5+fj75+flHHLn31VdfAXj/inMuFOBHP/qRl9V1113nPbbWVprNvKrHdVmfPn0AKhxl+mX16tXe34Zx48axatUqX143JQtUgwYN+Otf/wrAVVdd5U1S2Lt3b4qLi7n11lsBuPrqq70T+vv27QunsSlo+PDhNG3a1Pv6iSee8AqW1N7w4cOPOjDCvcbKHVQh//ujWlUXXiLPS+Lcv5WlpaWMGjWKPXv2HPN7VqxYwcaN/t/oV118IiISSSl5BJWbm+vNHbV161ZGjRoFOEPJO3bsyDnnnAPAtm3bQmtjKpo+fToAP/vZz7x1K1eu5KGHHgqrSWnFvWj8ggsuqNAlFdvFd8MNNzBv3rzQ2hhVbj6xcxpWdRRa1cwd06ZNq9ODI1yx3aHx3PlLV65cybPPPgvA+vXrk9a2I0mpAuVe0/TUU09593K69NJL+eyzzwBo1aoVhYWF3lRHb775prr2EtSqVStyc3MBZ2oj99qI3r17K0OfxHZDHamLT91SVXNHju7YsaPCBLFuEXIfx46MjB3tJ85oRoBBgwZRWlpK48aNAZg5c6aX77fffhta+6qiLj4REYmmKF9YFrv069fPuvbs2WM7duxoO3bsaAF78cUX24svvtiuXr3aWmvtggUL7IIFC2zjxo39uKgyrS7WO9KyevVqW1paaktLS+2uXbtsnz59bJ8+fXy9OLWuZRq75OfnV7qINPaC0zlz5tg5c+Yo02Msb7zxhnWVlZVVeuxmvGfPHtu+fXvbvn17ZRqtJf0u1G3atCnDhg3zvr7xxhu9br28vDxvOpjs7GwWLVrEwIEDgYpTeEhlp512Gj//+c8B6NChgzdaZ/DgwSxcuDDMpqUdd+QeVH3DQk1plJiHH36YK6+8EnC67tzzKn379mXGjBnedtOmTdNNC9NAShSok08+mR49enhf79u3z7tDbrdu3bwZI1544QWGDh2qGSQSdPvttzNmzBjv6/HjxwNo3jIfuXcfdgdGQOUbFt5www0aUp6gJUuWeLPHS/o75jkoY8xYY8wKY8xuY8wOY8wiY0ynuG1mGWNs3KLfuCNQpv5TpsFQrv5TpolL5AgqF3gaWAEY4EHgLWNMB2vtf2O2ewsYFPO1b9OHl5WVcfDgQW/Uyfz5871bDn/22Wf85je/AUilbqlcQs4U/ne/HIAJEyZU6CJJQblEINN47ozl1tojdvG5R1YRlUsEc01xuSjThByzQFlre8V+bYwZBHwHdAcWxTx1yFr7tb/Nc2zZsoUBAwZ4ty8wxvDAAw8A8Nxzz7Fly5YgdhuYKGQKMGTIEIYMGRLUyydVVDKNt3z5csC5YaFbkNwuvlSY0iiquaYyZZq4mgwzb3b4++IHzF9kjNlujFlnjHnWGNOi9s2rM5Sp/5RpMJSr/5TpkdRgGGMh8P+AejHrrgeuBToDvYF/AmuARkd4jVuAlYeXsIc9hj7UVJmmf6aTJk2yJSUltqSkxJaVldmSkhKbk5Njc3JyUiZTP3LV+1SZViubagY5BdgCnHWM7U4DSoDrEnjNsAML9U2qTJVpKmQaRK4RyEyZRjzThLv4jDFTgRuAntbaz4+2rbV2C/AlcHair18XKVP/KdNgKFf/KdNjS+g6KGPMNJxDzlxr7TGvfjPGZACnA7p50BEoU/8p02AoV/8p08Qcs0AZY57CGerYF/jWGNPq8FN7rbV7jTFNgfHAKzjhfR94BNgOJHLF505g3+F/oyCDim3J9nsHylSZ+iDwTCHwXKOWKVTMVZn6o+aZ1qLvffzh548HFuOEVwxsBGYBWdXoh43MnFfJaIsyVaapkGkyco1SpslqjzJNfEnkOqij3i/ZWnsA6HW0baQiZeo/ZRoM5eo/ZZo43W5DREQiKSoF6pmwGxAjSm2pjSj9HFFqS21E6eeIUltqI2o/R9TaUxNR+xlq3B5zuI9QREQkUqJyBCUiIlKBCpSIiERSqAXKGHObMeYLY8xBY8wqY0yPY39XrfeZ1vdiCSPTw/tN21yVaTD0+++/dMs0tAJljOkPTAMmAucBS4E3jDGtA951Ls69WC4EegKlOPdiOSVuu7eAU2OWqwNuV62FmCmkaa7KNBj6/fdfWmYa4sVby4Fn49b9G3gkye1oCpQBvWPWzQJeC/sCt1TNNJ1yVabpnasyjXamoRxBGWMaAl2BorininCqcDKlxb1YIpYppEGuyjQYEctVmfrPt0zD6uLLAOoB2+LWbwNaVd48UNOAf0hJGdIAAAFZSURBVAAfxKx7E7gRuAwYBfwY+D9jTKMkt606opQppEeuyjQYUcpVmfrPt0wTms08QPEXYZkq1gXGGDMFuAi4yFpb5jXK2pdjNlttjFmFMx/WNcCryWpfDYWaKaRlrso0GPr9919aZRrWEdROnD7K+MregsqfAAJh0u9eLKFnCmmXqzINRui5KlP/BZFpKAXKWlsMrAKuiHvqCpyRJ4Eyzr1YBuAEmRb3Ygk7U0i/XJVpMMLOVZn6L7BMQxxx0h9nKvlfAj/A6bfcC2QHvN+ngN04wyFbxSxNY0agTAYuwLkPSy5OX+qXQLOw8opypumcqzJNr1yVaWplGnagtwEbgEM41f/iJOwz8PsG1bVM0z1XZZo+uSrT1MpUk8WKiEgkaS4+ERGJJBUoERGJJBUoERGJJBUoERGJJBUoERGJJBUoERGJJBUoERGJJBUoERGJJBUoERGJpP8Pt65bT11uRlcAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# import the visualization package.\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "\n", "# Create a DataLoader that will read the images we downloaded in the previous step.\n", "data_loader = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=True)\n", "\n", "# Read 10 images from the data_loader.\n", "examples = enumerate(data_loader)\n", "_, (example_data, example_labels) = next(examples)\n", "\n", "# Print some information about the example_data we just read.\n", "print(example_data.shape)\n", "\n", "# Create a figure.\n", "figure = plt.figure()\n", "\n", "# Plot the 10 images contained in example_data.\n", "for i in range(10):\n", " plt.subplot(3, 5, i+1)\n", " plt.tight_layout()\n", " plt.imshow(example_data[i][0], cmap='gray')\n", "plt.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Setting up the machine learning model" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn # This is the Neural Networks package. We use this to create our Machine Learning Models.\n", "\n", "class SimpleModel(nn.Module):\n", " def __init__(self, input_dimension, output_dimension):\n", " super(SimpleModel, self).__init__()\n", " \n", " self.input_dimension = input_dimension\n", " self.linear_layer = nn.Linear(input_dimension, output_dimension)\n", " \n", " def forward(self, input_data):\n", " \n", " output = self.linear_layer(input_data.view(-1, self.input_dimension))\n", " #output = nn.functional.softmax(output, dim=-1)\n", " return output" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SimpleModel(\n", " (linear_layer): Linear(in_features=784, out_features=10, bias=True)\n", ")\n" ] } ], "source": [ "# Set the dimensions\n", "input_dimension = 28*28\n", "output_dimension = 10\n", "\n", "# Create the SimpleModel\n", "model = None\n", "model = SimpleModel(input_dimension, output_dimension)\n", "\n", "# Print the model\n", "print(model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "*

Figure 2: How the Simple Model Works

*" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Now we are ready to train a simple machine learning model!\n", "\n", "\n", "\n", "##### Exercises\n", "1. Train the SimpleModel with the default settings. What can you say about the loss and accuracy?\n", "\n", "\n", "2. Change the `BATCH_SIZE`. How does it affect loss, accuracy, and running time?\n", "\n", "| Batch Size | Effect on Loss | Effect on Accuracy | Effect on Running Time |\n", ":------------:|:--------------:|:-------------------:|:----------------------:|\n", "10 (Default) | \n", "1 |\n", "5 |\n", "20 |\n", "40 |\n", "\n", "3. Change the `LEARNING_RATE`. How does it affect loss, accuracy, and running time?\n", "\n", "| Learning Rate | Effect on Loss | Effect on Accuracy | Effect on Running Time |\n", ":---------------:|:--------------:|:-------------------:|:----------------------:|\n", "0.001 (Default) | \n", "0.01 |\n", "0.1 |\n", "1 |\n", "0.0001 |\n" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support. ' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
');\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", " nav_element.append(button);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = $('');\n", " nav_element.append(status_bar);\n", " this.message = status_bar[0];\n", "\n", " // Add the close button to the window.\n", " var buttongrp = $('
');\n", " var button = $('');\n", " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", " buttongrp.append(button);\n", " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", " titlebar.prepend(buttongrp);\n", "}\n", "\n", "mpl.figure.prototype._root_extra_style = function(el){\n", " var fig = this\n", " el.on(\"remove\", function(){\n", "\tfig.close_ws(fig, {});\n", " });\n", "}\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(el){\n", " // this is important to make the div 'focusable\n", " el.attr('tabindex', 0)\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " }\n", " else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "\n", "}\n", "\n", "mpl.figure.prototype._key_event_extra = function(event, name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager)\n", " manager = IPython.keyboard_manager;\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which == 13) {\n", " this.canvas_div.blur();\n", " // select the cell after this one\n", " var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n", " IPython.notebook.select(index + 1);\n", " }\n", "}\n", "\n", "mpl.figure.prototype.handle_save = function(fig, msg) {\n", " fig.ondownload(fig, null);\n", "}\n", "\n", "\n", "mpl.find_output_cell = function(html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] == html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "}\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel != null) {\n", " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support. ' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
');\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", " nav_element.append(button);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = $('');\n", " nav_element.append(status_bar);\n", " this.message = status_bar[0];\n", "\n", " // Add the close button to the window.\n", " var buttongrp = $('
');\n", " var button = $('');\n", " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", " buttongrp.append(button);\n", " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", " titlebar.prepend(buttongrp);\n", "}\n", "\n", "mpl.figure.prototype._root_extra_style = function(el){\n", " var fig = this\n", " el.on(\"remove\", function(){\n", "\tfig.close_ws(fig, {});\n", " });\n", "}\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(el){\n", " // this is important to make the div 'focusable\n", " el.attr('tabindex', 0)\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " }\n", " else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "\n", "}\n", "\n", "mpl.figure.prototype._key_event_extra = function(event, name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager)\n", " manager = IPython.keyboard_manager;\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which == 13) {\n", " this.canvas_div.blur();\n", " // select the cell after this one\n", " var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n", " IPython.notebook.select(index + 1);\n", " }\n", "}\n", "\n", "mpl.figure.prototype.handle_save = function(fig, msg) {\n", " fig.ondownload(fig, null);\n", "}\n", "\n", "\n", "mpl.find_output_cell = function(html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] == html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "}\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel != null) {\n", " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support. ' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
');\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", " nav_element.append(button);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = $('');\n", " nav_element.append(status_bar);\n", " this.message = status_bar[0];\n", "\n", " // Add the close button to the window.\n", " var buttongrp = $('
');\n", " var button = $('');\n", " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", " buttongrp.append(button);\n", " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", " titlebar.prepend(buttongrp);\n", "}\n", "\n", "mpl.figure.prototype._root_extra_style = function(el){\n", " var fig = this\n", " el.on(\"remove\", function(){\n", "\tfig.close_ws(fig, {});\n", " });\n", "}\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(el){\n", " // this is important to make the div 'focusable\n", " el.attr('tabindex', 0)\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " }\n", " else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "\n", "}\n", "\n", "mpl.figure.prototype._key_event_extra = function(event, name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager)\n", " manager = IPython.keyboard_manager;\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which == 13) {\n", " this.canvas_div.blur();\n", " // select the cell after this one\n", " var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n", " IPython.notebook.select(index + 1);\n", " }\n", "}\n", "\n", "mpl.figure.prototype.handle_save = function(fig, msg) {\n", " fig.ondownload(fig, null);\n", "}\n", "\n", "\n", "mpl.find_output_cell = function(html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] == html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "}\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel != null) {\n", " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Elapsed Time: 12.11 sec, Iteration: 500, Loss: 1.535986, Accuracy: 55.35\n", "Elapsed Time: 35.32 sec, Iteration: 1000, Loss: 1.273219, Accuracy: 60.71\n", "Elapsed Time: 70.83 sec, Iteration: 1500, Loss: 0.955550, Accuracy: 73.40\n", "Elapsed Time: 117.70 sec, Iteration: 2000, Loss: 0.878937, Accuracy: 63.17\n", "Elapsed Time: 177.48 sec, Iteration: 2500, Loss: 0.861457, Accuracy: 72.68\n", "Elapsed Time: 251.17 sec, Iteration: 3000, Loss: 0.804850, Accuracy: 71.46\n", "Elapsed Time: 338.51 sec, Iteration: 3500, Loss: 0.757068, Accuracy: 72.93\n", "Elapsed Time: 437.80 sec, Iteration: 4000, Loss: 0.766040, Accuracy: 71.48\n", "Elapsed Time: 549.39 sec, Iteration: 4500, Loss: 0.772797, Accuracy: 73.08\n", "Elapsed Time: 675.82 sec, Iteration: 5000, Loss: 0.720864, Accuracy: 74.24\n", "Elapsed Time: 817.01 sec, Iteration: 5500, Loss: 0.711215, Accuracy: 84.04\n", "Elapsed Time: 971.54 sec, Iteration: 6000, Loss: 0.572375, Accuracy: 82.40\n", "\n", "Elapsed Time: 1138.08 sec, Iteration: 6500, Loss: 0.541466, Accuracy: 83.72\n", "Elapsed Time: 1317.88 sec, Iteration: 7000, Loss: 0.516570, Accuracy: 82.87\n", "Elapsed Time: 1511.88 sec, Iteration: 7500, Loss: 0.541272, Accuracy: 81.82\n", "Elapsed Time: 1720.43 sec, Iteration: 8000, Loss: 0.513905, Accuracy: 84.16\n", "Elapsed Time: 1943.13 sec, Iteration: 8500, Loss: 0.474660, Accuracy: 83.48\n", "Elapsed Time: 2179.50 sec, Iteration: 9000, Loss: 0.518962, Accuracy: 80.86\n", "Elapsed Time: 2429.46 sec, Iteration: 9500, Loss: 0.490592, Accuracy: 83.49\n", "Elapsed Time: 2691.77 sec, Iteration: 10000, Loss: 0.501373, Accuracy: 83.75\n", "Elapsed Time: 2966.99 sec, Iteration: 10500, Loss: 0.483805, Accuracy: 83.59\n", "Elapsed Time: 3255.63 sec, Iteration: 11000, Loss: 0.506277, Accuracy: 83.44\n", "Elapsed Time: 3557.19 sec, Iteration: 11500, Loss: 0.487032, Accuracy: 81.93\n", "Elapsed Time: 3872.74 sec, Iteration: 12000, Loss: 0.470573, Accuracy: 84.98\n", "\n", "Elapsed Time: 4201.61 sec, Iteration: 12500, Loss: 0.483386, Accuracy: 83.82\n", "Elapsed Time: 4544.19 sec, Iteration: 13000, Loss: 0.444444, Accuracy: 84.85\n", "Elapsed Time: 4899.57 sec, Iteration: 13500, Loss: 0.454052, Accuracy: 84.50\n", "Elapsed Time: 5266.12 sec, Iteration: 14000, Loss: 0.459888, Accuracy: 84.63\n", "Elapsed Time: 5645.00 sec, Iteration: 14500, Loss: 0.452329, Accuracy: 84.38\n", "Elapsed Time: 6038.67 sec, Iteration: 15000, Loss: 0.446735, Accuracy: 85.48\n", "Elapsed Time: 6446.14 sec, Iteration: 15500, Loss: 0.464642, Accuracy: 84.41\n", "Elapsed Time: 6867.68 sec, Iteration: 16000, Loss: 0.396820, Accuracy: 94.39\n", "Elapsed Time: 7300.95 sec, Iteration: 16500, Loss: 0.304140, Accuracy: 93.91\n", "Elapsed Time: 7747.63 sec, Iteration: 17000, Loss: 0.246138, Accuracy: 93.47\n", "Elapsed Time: 8206.81 sec, Iteration: 17500, Loss: 0.243343, Accuracy: 93.21\n", "Elapsed Time: 8678.98 sec, Iteration: 18000, Loss: 0.281900, Accuracy: 92.72\n", "\n", "Elapsed Time: 9162.78 sec, Iteration: 18500, Loss: 0.232396, Accuracy: 94.04\n", "Elapsed Time: 9658.83 sec, Iteration: 19000, Loss: 0.236963, Accuracy: 94.14\n", "Elapsed Time: 10168.35 sec, Iteration: 19500, Loss: 0.233781, Accuracy: 94.16\n", "Elapsed Time: 10690.40 sec, Iteration: 20000, Loss: 0.246233, Accuracy: 92.18\n", "Elapsed Time: 11224.67 sec, Iteration: 20500, Loss: 0.258796, Accuracy: 94.46\n", "Elapsed Time: 11771.90 sec, Iteration: 21000, Loss: 0.215138, Accuracy: 94.10\n", "Elapsed Time: 12332.56 sec, Iteration: 21500, Loss: 0.244385, Accuracy: 94.47\n", "Elapsed Time: 12904.76 sec, Iteration: 22000, Loss: 0.243434, Accuracy: 92.84\n", "Elapsed Time: 13489.77 sec, Iteration: 22500, Loss: 0.231695, Accuracy: 93.53\n", "Elapsed Time: 14086.80 sec, Iteration: 23000, Loss: 0.233429, Accuracy: 93.95\n", "Elapsed Time: 14695.48 sec, Iteration: 23500, Loss: 0.226609, Accuracy: 94.42\n", "Elapsed Time: 15317.17 sec, Iteration: 24000, Loss: 0.232742, Accuracy: 93.93\n", "\n", "Elapsed Time: 15951.81 sec, Iteration: 24500, Loss: 0.228075, Accuracy: 93.16\n", "Elapsed Time: 16598.18 sec, Iteration: 25000, Loss: 0.248222, Accuracy: 93.10\n", "Elapsed Time: 17256.05 sec, Iteration: 25500, Loss: 0.244612, Accuracy: 91.86\n", "Elapsed Time: 17925.64 sec, Iteration: 26000, Loss: 0.214745, Accuracy: 94.58\n", "Elapsed Time: 18609.42 sec, Iteration: 26500, Loss: 0.208128, Accuracy: 94.68\n", "Elapsed Time: 19306.60 sec, Iteration: 27000, Loss: 0.197430, Accuracy: 94.31\n", "Elapsed Time: 20015.81 sec, Iteration: 27500, Loss: 0.214973, Accuracy: 91.31\n", "Elapsed Time: 20737.47 sec, Iteration: 28000, Loss: 0.244899, Accuracy: 94.05\n", "Elapsed Time: 21471.54 sec, Iteration: 28500, Loss: 0.231276, Accuracy: 94.60\n", "Elapsed Time: 22218.03 sec, Iteration: 29000, Loss: 0.202793, Accuracy: 93.38\n", "Elapsed Time: 22976.35 sec, Iteration: 29500, Loss: 0.225060, Accuracy: 94.54\n", "Elapsed Time: 23746.41 sec, Iteration: 30000, Loss: 0.240444, Accuracy: 94.25\n", "\n", "Best Accuracy: 94.68\n", "Best Loss: 0.197430\n", "Elapsed Running Time: 23746.41\n" ] } ], "source": [ "import time\n", "import torch\n", "import torchvision\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "%matplotlib notebook\n", "\n", "BATCH_SIZE = 10\n", "LEARNING_RATE = 0.01\n", "ITERATIONS = 5\n", "transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),\n", " torchvision.transforms.Normalize((0.5,), (0.5,))])\n", "\n", "# The SimpleNeuralNetworkModel needs to be trained on data so that it can learn the different digits.\n", "training_dataset = torchvision.datasets.MNIST(DOWNLOAD_FOLDER, train=True, transform=transforms)\n", "training_loader = torch.utils.data.DataLoader(training_dataset, batch_size=BATCH_SIZE, shuffle=True)\n", "\n", "# Testing data will help us see how good the SimpleModel is at classifying digits. \n", "testing_dataset = torchvision.datasets.MNIST(DOWNLOAD_FOLDER, train=False, transform=transforms)\n", "testing_loader = torch.utils.data.DataLoader(testing_dataset, batch_size = 1000, shuffle=False)\n", "\n", "criterion = nn.CrossEntropyLoss()\n", "optimizer = torch.optim.SGD(model.parameters(), lr = LEARNING_RATE, momentum=0.9)\n", "\n", "iters = 0\n", "cum_loss = 0\n", "print_frequency = 500\n", "elapsed_time = 0\n", "\n", "loss_history = []\n", "accuracy_history = []\n", "\n", "fig,ax = plt.subplots(1,1)\n", "ax.set_xlabel('Iterations')\n", "ax.set_ylabel('Loss')\n", "ax.set_xlim(0,len(training_loader)*ITERATIONS)\n", "ax.set_ylim(0, 1.5)\n", "\n", "fig2, ax2 = plt.subplots(1,1)\n", "ax2.set_xlabel('Iterations')\n", "ax2.set_ylabel('Accuracy')\n", "ax2.set_xlim(0,len(training_loader)*ITERATIONS)\n", "ax2.set_ylim(65, 100)\n", "\n", "fig3, ax3 = plt.subplots(2,2, figsize=(8,6))\n", "fig3.tight_layout()\n", "\n", "start = time.time()\n", "for i in range(ITERATIONS):\n", " total = 0\n", " correct = 0\n", " for j , (images, labels) in enumerate(training_loader):\n", " model.train()\n", " optimizer.zero_grad()\n", " outputs = model(images)\n", " \n", " loss = criterion(outputs, labels)\n", " \n", " loss.backward()\n", " cum_loss += loss.item()\n", " \n", " optimizer.step()\n", " iters += 1\n", "\n", " if iters % print_frequency == 0:\n", " plot_and_print()\n", " cum_loss = 0\n", " print('')\n", " \n", "print('Best Accuracy: %.2f' % max(accuracy_history))\n", "print('Best Loss: %f' % min(loss_history))\n", "print('Elapsed Running Time: %.2f' % elapsed_time)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3.7 (py37)", "language": "python", "name": "py37" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 2 }