{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Tutorial4_Advanced_PyTorch.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "l6QKttSZtnDg"
      },
      "source": [
        "import time\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torchvision.transforms as transforms\n",
        "from torch.utils.data import Dataset, DataLoader\n",
        "\n",
        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "\n",
        "def get_time():\n",
        "  if torch.cuda.is_available():\n",
        "    torch.cuda.synchronize()\n",
        "  return time.time()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x_ZmSrbMtzBN"
      },
      "source": [
        "# Vectorization"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FrosVALmzb3g"
      },
      "source": [
        "## Elementwise Operations\n",
        "\n",
        "Elementwise operations such as `+`, `-`, `*`, `/`, `min` and `max` are vectorized."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "UX5b1PCXzcR0"
      },
      "source": [
        "a = torch.rand(size=(10000, 2000), device=device)\n",
        "b = torch.rand_like(a)\n",
        "\n",
        "c = a - b\n",
        "d = a + b"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TFe9uBYiuCg4"
      },
      "source": [
        "## Linear Algebra\n",
        "\n",
        "[`torch.matmul`](https://pytorch.org/docs/stable/generated/torch.matmul.html): Matrix multiplication with broadcasting."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jtR7M_25y9Vx"
      },
      "source": [
        "# MATRIX-VECTOR\n",
        "N, M = 1000, 2000\n",
        "mat = torch.rand(size=(N, M), device=device)\n",
        "vec = torch.rand(size=(M,), device=device)\n",
        "\n",
        "t0 = get_time()\n",
        "out = mat @ vec\n",
        "t1 = get_time()\n",
        "print('vectorized: %.4f sec' % (t1 - t0,))\n",
        "\n",
        "t0 = get_time()\n",
        "ref = torch.zeros(size=(N,), device=device)\n",
        "for n in range(N):\n",
        "  for m in range(M):\n",
        "    ref[n] += mat[n, m] * vec[m]\n",
        "t1 = get_time()\n",
        "print('naive:      %.4f sec' % (t1 - t0,))\n",
        "\n",
        "try:\n",
        "  torch.testing.assert_allclose(out, ref)\n",
        "except AssertionError:\n",
        "  print('test: failed (out != ref)')\n",
        "else:\n",
        "  print('test: passed (out == ref)')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "txjotf9Ht1Ox"
      },
      "source": [
        "# MATRIX-MATRIX\n",
        "N, M, K = 100, 200, 50\n",
        "mat1 = torch.rand(size=(N, K), device=device)\n",
        "mat2 = torch.rand(size=(K, M), device=device)\n",
        "\n",
        "t0 = get_time()\n",
        "out = mat1 @ mat2\n",
        "t1 = get_time()\n",
        "print('vectorized: %.4f sec' % (t1 - t0,))\n",
        "\n",
        "t0 = get_time()\n",
        "ref = torch.zeros(size=(N, M), device=device)\n",
        "for n in range(N):\n",
        "  for m in range(M):\n",
        "    for k in range(K):\n",
        "      ref[n, m] += mat1[n, k] * mat2[k, m]\n",
        "t1 = get_time()\n",
        "print('naive:      %.4f sec' % (t1 - t0,))\n",
        "\n",
        "try:\n",
        "  torch.testing.assert_allclose(out, ref)\n",
        "except AssertionError:\n",
        "  print('test: failed (out != ref)')\n",
        "else:\n",
        "  print('test: passed (out == ref)')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "CquUnnTQyO4_"
      },
      "source": [
        "# BATCH MATRIX-MATRIX\n",
        "B, N, M, K = 10, 100, 200, 50\n",
        "bmat1 = torch.rand(size=(B, N, K), device=device)\n",
        "bmat2 = torch.rand(size=(B, K, M), device=device)\n",
        "\n",
        "t0 = get_time()\n",
        "out = bmat1 @ bmat2\n",
        "t1 = get_time()\n",
        "print('vectorized: %.4f sec' % (t1 - t0,))\n",
        "\n",
        "t0 = get_time()\n",
        "ref = torch.zeros(size=(B, N, M), device=device)\n",
        "for b in range(B):\n",
        "  for n in range(N):\n",
        "    for m in range(M):\n",
        "      for k in range(K):\n",
        "        ref[b, n, m] += bmat1[b, n, k] * bmat2[b, k, m]\n",
        "t1 = get_time()\n",
        "print('naive:      %.4f sec' % (t1 - t0,))\n",
        "\n",
        "try:\n",
        "  torch.testing.assert_allclose(out, ref)\n",
        "except AssertionError:\n",
        "  print('test: failed (out != ref)')\n",
        "else:\n",
        "  print('test: passed (out == ref)')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YSMY8SwluFVh"
      },
      "source": [
        "## Broadcasting"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZUvlRgpZ4EDY"
      },
      "source": [
        "# ADD A MATRIX AND A ROW\n",
        "N, M = 100, 200\n",
        "\n",
        "a = torch.rand(size=(N, M))          # N×M\n",
        "b = torch.rand(size=(1, M))          # 1×M\n",
        "out = a + b                          # N×M\n",
        "print(out.size())"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hQpnjltsuGtQ"
      },
      "source": [
        "# ADD A COLUMN AND A ROW\n",
        "N, M = 100, 200\n",
        "\n",
        "a = torch.rand(size=(N, 1))          # N×1\n",
        "b = torch.rand(size=(1, M))          # 1×M\n",
        "out = a + b                          # N×M\n",
        "print(out.size())"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "w_0pDWvmGMVe"
      },
      "source": [
        "# MULTIPLY A BATCH OF MATRICES BY A MATRIX\n",
        "B, N, M, K = 10, 100, 200, 50\n",
        "\n",
        "bmat1 = torch.rand(size=(B, N, K))  # B×N×K\n",
        "mat2 = torch.rand(size=(K, M))      # K×M\n",
        "out = bmat1 @ mat2                  # B×N×M\n",
        "print(out.size())"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3nsyGgyhuHBr"
      },
      "source": [
        "## Advanced Tensor Multiplication\n",
        "\n",
        "Sometimes matrix multiplication is not enough, and a more sophosticated summation is needed. In such cases, [`torch.einsum`](https://pytorch.org/docs/stable/generated/torch.einsum.html) might be useful.\n",
        "\n",
        "The first argument is the summation scheme. It's composed of two parts - operands and result, with a `->` separating between them. The operands are separated by `,`. After the summation scheme, the function receives the operands (`*operands`).\n",
        "\n",
        "Each index in the summation scheme is represented by a lower-case english letter. Indices with the same letters are the same. Indices that appears in the operands part but not in the result parts are summed over."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dffPek2A29un"
      },
      "source": [
        "### Matrix-Vector Multiplication\n",
        "\n",
        "$$ \\mathbf{out}[n] = \\sum_m \\mathbf{mat}[n,m] \\cdot \\mathbf{vec}[m]$$"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2p5Zpl_q2-Lf"
      },
      "source": [
        "# MATRIX-VECTOR\n",
        "N, M = 1000, 2000\n",
        "mat = torch.rand(size=(N, M), device=device)\n",
        "vec = torch.rand(size=(M,), device=device)\n",
        "\n",
        "t0 = get_time()\n",
        "out = torch.einsum('ij,j->i', mat, vec)\n",
        "t1 = get_time()\n",
        "print('einsum:     %.4f sec' % (t1 - t0,))\n",
        "\n",
        "t0 = get_time()\n",
        "ref = mat @ vec\n",
        "t1 = get_time()\n",
        "print('vectorized: %.4f sec' % (t1 - t0,))\n",
        "\n",
        "try:\n",
        "  torch.testing.assert_allclose(out, ref)\n",
        "except AssertionError:\n",
        "  print('test: failed (out != ref)')\n",
        "else:\n",
        "  print('test: passed (out == ref)')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "49ZI4B0S2-c-"
      },
      "source": [
        "### Batch Matix-Matrix Multiplication\n",
        "\n",
        "$$ \\mathbf{out}[b,n,m] = \\sum_k \\mathbf{bmat1}[b,n,k] \\cdot \\mathbf{bmat2}[b,k,m]$$"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jedyUz8T2_C0"
      },
      "source": [
        "# BATCH MATRIX-MATRIX\n",
        "B, N, M, K = 10, 100, 200, 50\n",
        "bmat1 = torch.rand(size=(B, N, K), device=device)\n",
        "bmat2 = torch.rand(size=(B, K, M), device=device)\n",
        "t0 = get_time()\n",
        "out = torch.einsum('bnk,bkm->bnm', bmat1, bmat2)\n",
        "t1 = get_time()\n",
        "print('einsum:     %.4f sec' % (t1 - t0,))\n",
        "\n",
        "t0 = get_time()\n",
        "ref = bmat1 @ bmat2\n",
        "t1 = get_time()\n",
        "print('vectorized: %.4f sec' % (t1 - t0,))\n",
        "\n",
        "try:\n",
        "  torch.testing.assert_allclose(out, ref)\n",
        "except AssertionError:\n",
        "  print('test: failed (out != ref)')\n",
        "else:\n",
        "  print('test: passed (out == ref)')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "M4-hnAtv2_Se"
      },
      "source": [
        "### Complex Example\n",
        "\n",
        "$$ \\mathbf{out}[i,j,k] = \\sum_\\ell \\sum_m \\mathbf{x}[i,k,m,\\ell,j] \\cdot \\mathbf{y}[i,\\ell,j,k] \\cdot \\mathbf{z}[i] $$"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "U3txfq2SuL3s"
      },
      "source": [
        "# inputs\n",
        "I, J, K, L, M = 5, 10, 20, 30, 40\n",
        "x = torch.rand(size=(I, K, M, L, J), device=device)\n",
        "y = torch.rand(size=(I, L, J, K), device=device)\n",
        "z = torch.rand(size=(I,), device=device)\n",
        "\n",
        "t0 = get_time()\n",
        "out = torch.einsum('ikmlj,iljk,i->ijk', x, y, z)\n",
        "t1 = get_time()\n",
        "print('einsum:     %.4f sec' % (t1 - t0,))\n",
        "\n",
        "t0 = get_time()\n",
        "ref = torch.zeros(size=(I, J, K), device=device)\n",
        "for i in range(I):\n",
        "  for j in range(J):\n",
        "    for k in range(K):\n",
        "      for l in range(L):\n",
        "        for m in range(M):\n",
        "          ref[i, j, k] += x[i, k, m, l, j] * y[i, l, j, k] * z[i]\n",
        "t1 = get_time()\n",
        "print('naive:      %.4f sec' % (t1 - t0,))\n",
        "\n",
        "try:\n",
        "  torch.testing.assert_allclose(out, ref)\n",
        "except AssertionError:\n",
        "  print('test: failed (out != ref)')\n",
        "else:\n",
        "  print('test: passed (out == ref)')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4z811MWBuMZ0"
      },
      "source": [
        "## Sampling by Index"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_ct204E58HWN"
      },
      "source": [
        "### Gather"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Yg6-9OVVuPZz"
      },
      "source": [
        "dim = 0\n",
        "num_samples = 1\n",
        "src = torch.rand(size=(10, 100, 1000), device=device)\n",
        "index = torch.randint(low=0, high=src.size(dim), size=(num_samples, src.size(1), src.size(2)), device=device)\n",
        "\n",
        "t0 = get_time()\n",
        "out = torch.gather(src, dim, index)\n",
        "t1 = get_time()\n",
        "print('vectorized: %.4f sec' % (t1 - t0,))\n",
        "\n",
        "t0 = get_time()\n",
        "ref = torch.zeros(size=index.size(), device=device)\n",
        "for i in range(index.size(0)):\n",
        "  for j in range(index.size(1)):\n",
        "    for k in range(index.size(2)):\n",
        "      ref[i, j, k] = src[index[i, j, k], j, k]\n",
        "t1 = get_time()\n",
        "print('naive:      %.4f sec' % (t1 - t0,))\n",
        "\n",
        "try:\n",
        "  torch.testing.assert_allclose(out, ref)\n",
        "except AssertionError:\n",
        "  print('test: failed (out != ref)')\n",
        "else:\n",
        "  print('test: passed (out == ref)')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "OM9gz7m-1s41"
      },
      "source": [
        "dim = 1\n",
        "num_samples = 50\n",
        "src = torch.rand(size=(10, 100, 1000), device=device)\n",
        "index = torch.randint(low=0, high=src.size(dim), size=(src.size(0), num_samples, src.size(2)), device=device)\n",
        "\n",
        "t0 = get_time()\n",
        "out = torch.gather(src, dim, index)\n",
        "t1 = get_time()\n",
        "print('vectorized: %.4f sec' % (t1 - t0,))\n",
        "\n",
        "t0 = get_time()\n",
        "ref = torch.zeros(size=index.size(), device=device)\n",
        "for i in range(index.size(0)):\n",
        "  for j in range(index.size(1)):\n",
        "    for k in range(index.size(2)):\n",
        "      ref[i, j, k] = src[i, index[i, j, k], k]\n",
        "t1 = get_time()\n",
        "print('naive:      %.4f sec' % (t1 - t0,))\n",
        "\n",
        "try:\n",
        "  torch.testing.assert_allclose(out, ref)\n",
        "except AssertionError:\n",
        "  print('test: failed (out != ref)')\n",
        "else:\n",
        "  print('test: passed (out == ref)')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "idSSZ3om8gM2"
      },
      "source": [
        "### Scatter"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8QjK-fIp8gnt"
      },
      "source": [
        "dim = 0\n",
        "size = 10\n",
        "src = torch.rand(size=(1, 100, 1000), device=device)\n",
        "index = torch.randint(low=0, high=size, size=src.size(), device=device)\n",
        "\n",
        "t0 = get_time()\n",
        "out = torch.zeros(size=(size, src.size(1), src.size(2)))\n",
        "out.scatter_add_(dim, index, src)\n",
        "t1 = get_time()\n",
        "print('vectorized: %.4f sec' % (t1 - t0,))\n",
        "\n",
        "t0 = get_time()\n",
        "ref = torch.zeros(size=(size, src.size(1), src.size(2)), device=device)\n",
        "for i in range(index.size(0)):\n",
        "  for j in range(index.size(1)):\n",
        "    for k in range(index.size(2)):\n",
        "      ref[index[i, j, k], j, k] += src[i, j, k]\n",
        "t1 = get_time()\n",
        "print('naive:      %.4f sec' % (t1 - t0,))\n",
        "\n",
        "try:\n",
        "  torch.testing.assert_allclose(out, ref)\n",
        "except AssertionError:\n",
        "  print('test: failed (out != ref)')\n",
        "else:\n",
        "  print('test: passed (out == ref)')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "L3gaRGSl8g_I"
      },
      "source": [
        "dim = 1\n",
        "size = 100\n",
        "src = torch.rand(size=(10, 50, 1000), device=device)\n",
        "index = torch.randint(low=0, high=size, size=src.size(), device=device)\n",
        "\n",
        "t0 = get_time()\n",
        "out = torch.zeros(size=(src.size(0), size, src.size(2)))\n",
        "out.scatter_add_(dim, index, src)\n",
        "t1 = get_time()\n",
        "print('vectorized: %.4f sec' % (t1 - t0,))\n",
        "\n",
        "t0 = get_time()\n",
        "ref = torch.zeros(size=(src.size(0), size, src.size(2)), device=device)\n",
        "for i in range(index.size(0)):\n",
        "  for j in range(index.size(1)):\n",
        "    for k in range(index.size(2)):\n",
        "      ref[i, index[i, j, k], k] += src[i, j, k]\n",
        "t1 = get_time()\n",
        "print('naive:      %.4f sec' % (t1 - t0,))\n",
        "\n",
        "try:\n",
        "  torch.testing.assert_allclose(out, ref)\n",
        "except AssertionError:\n",
        "  print('test: failed (out != ref)')\n",
        "else:\n",
        "  print('test: passed (out == ref)')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yzizlEts-jji"
      },
      "source": [
        "### Index Select"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_1TzLkro-kA1"
      },
      "source": [
        "dim = 0\n",
        "num_samples = 50\n",
        "src = torch.rand(size=(10, 100, 1000), device=device)\n",
        "index = torch.randint(low=0, high=src.size(dim), size=(num_samples,), device=device)\n",
        "\n",
        "t0 = get_time()\n",
        "out = torch.index_select(src, dim, index)\n",
        "t1 = get_time()\n",
        "print('vectorized: %.4f sec' % (t1 - t0,))\n",
        "\n",
        "t0 = get_time()\n",
        "ref = torch.zeros(size=(index.size(0), src.size(1), src.size(2)), device=device)\n",
        "for i in range(ref.size(0)):\n",
        "  for j in range(ref.size(1)):\n",
        "    for k in range(ref.size(2)):\n",
        "      ref[i, j, k] = src[index[i], j, k]\n",
        "t1 = get_time()\n",
        "print('naive:      %.4f sec' % (t1 - t0,))\n",
        "\n",
        "try:\n",
        "  torch.testing.assert_allclose(out, ref)\n",
        "except AssertionError:\n",
        "  print('test: failed (out != ref)')\n",
        "else:\n",
        "  print('test: passed (out == ref)')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "23zG2gOb_CAh"
      },
      "source": [
        "dim = 1\n",
        "num_samples = 50\n",
        "src = torch.rand(size=(10, 100, 1000), device=device)\n",
        "index = torch.randint(low=0, high=src.size(dim), size=(num_samples,), device=device)\n",
        "\n",
        "t0 = get_time()\n",
        "out = torch.index_select(src, dim, index)\n",
        "t1 = get_time()\n",
        "print('vectorized: %.4f sec' % (t1 - t0,))\n",
        "\n",
        "t0 = get_time()\n",
        "ref = torch.zeros(size=(src.size(0), index.size(0), src.size(2)), device=device)\n",
        "for i in range(ref.size(0)):\n",
        "  for j in range(ref.size(1)):\n",
        "    for k in range(ref.size(2)):\n",
        "      ref[i, j, k] = src[i, index[j], k]\n",
        "t1 = get_time()\n",
        "print('naive:      %.4f sec' % (t1 - t0,))\n",
        "\n",
        "try:\n",
        "  torch.testing.assert_allclose(out, ref)\n",
        "except AssertionError:\n",
        "  print('test: failed (out != ref)')\n",
        "else:\n",
        "  print('test: passed (out == ref)')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3_iQP62A-kTd"
      },
      "source": [
        "### Index Add"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RnYBCxGZ-krb"
      },
      "source": [
        "dim = 0\n",
        "size = 10\n",
        "src = torch.rand(size=(50, 100, 1000), device=device)\n",
        "index = torch.randint(low=0, high=size, size=(src.size(dim),), device=device)\n",
        "\n",
        "t0 = get_time()\n",
        "out = torch.zeros(size=(size, src.size(1), src.size(2)))\n",
        "out.index_add_(dim, index, src)\n",
        "t1 = get_time()\n",
        "print('vectorized: %.4f sec' % (t1 - t0,))\n",
        "\n",
        "t0 = get_time()\n",
        "ref = torch.zeros(size=(size, src.size(1), src.size(2)), device=device)\n",
        "for i in range(src.size(0)):\n",
        "  for j in range(src.size(1)):\n",
        "    for k in range(src.size(2)):\n",
        "      ref[index[i], j, k] += src[i, j, k]\n",
        "t1 = get_time()\n",
        "print('naive:      %.4f sec' % (t1 - t0,))\n",
        "\n",
        "try:\n",
        "  torch.testing.assert_allclose(out, ref)\n",
        "except AssertionError:\n",
        "  print('test: failed (out != ref)')\n",
        "else:\n",
        "  print('test: passed (out == ref)')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "n5LXkL4S_Q1B"
      },
      "source": [
        "dim = 1\n",
        "size = 100\n",
        "src = torch.rand(size=(10, 50, 1000), device=device)\n",
        "index = torch.randint(low=0, high=size, size=(src.size(dim),), device=device)\n",
        "\n",
        "t0 = get_time()\n",
        "out = torch.zeros(size=(src.size(0), size, src.size(2)))\n",
        "out.index_add_(dim, index, src)\n",
        "t1 = get_time()\n",
        "print('vectorized: %.4f sec' % (t1 - t0,))\n",
        "\n",
        "t0 = get_time()\n",
        "ref = torch.zeros(size=(src.size(0), size, src.size(2)), device=device)\n",
        "for i in range(src.size(0)):\n",
        "  for j in range(src.size(1)):\n",
        "    for k in range(src.size(2)):\n",
        "      ref[i, index[j], k] += src[i, j, k]\n",
        "t1 = get_time()\n",
        "print('naive:      %.4f sec' % (t1 - t0,))\n",
        "\n",
        "try:\n",
        "  torch.testing.assert_allclose(out, ref)\n",
        "except AssertionError:\n",
        "  print('test: failed (out != ref)')\n",
        "else:\n",
        "  print('test: passed (out == ref)')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OFMmWF_it3FC"
      },
      "source": [
        "# Convolution-like Operations"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ODmArmV_uZAX"
      },
      "source": [
        "## `Conv2d` and `MaxPool2d`"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5S8ty3syt5JF"
      },
      "source": [
        "in_channels, out_channels = 3, 16\n",
        "\n",
        "# Simple Conv2d layer\n",
        "conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3)\n",
        "\n",
        "# With stride\n",
        "conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2)\n",
        "\n",
        "# With padding\n",
        "conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)\n",
        "\n",
        "# Different (H, W) dimensions\n",
        "conv4 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 1), stride=(2, 1), padding=(1, 0))\n",
        "\n",
        "# Simple MaxPool2d layer\n",
        "Pool1 = nn.MaxPool2d(kernel_size=2, stride=2)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BCCm1KvbueCK"
      },
      "source": [
        "## `F.unfold`"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jalehbFnwwHA"
      },
      "source": [
        "image = torch.rand(size=(2, 3, 3, 4))\n",
        "\n",
        "# extract blocks\n",
        "# shape: (batch, block_size, num_blocks)\n",
        "blocks = F.unfold(image, kernel_size=2)\n",
        "print(blocks.size())  # blocks.shape == (2, 3 * 2 * 2, 2 * 3)\n",
        "\n",
        "# view channels as a dimension\n",
        "# shape: (batch, channels, kernel_size, num_blocks)\n",
        "blocks = blocks.view(blocks.size(0), image.size(1), -1, blocks.size(-1))\n",
        "print(blocks.size())  # blocks.shape == (2, 3, 2 * 2, 2 * 3)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mOlZT_1Jwt3y"
      },
      "source": [
        "## `F.fold`"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qpxz0vaJuhrl"
      },
      "source": [
        "image = torch.ones(size=(1, 1, 3, 4))\n",
        "output_size = image.shape[2:]\n",
        "\n",
        "# extract blocks\n",
        "# shape: (batch, block_size, num_blocks)\n",
        "blocks = F.unfold(image, kernel_size=2)\n",
        "print(blocks.size())  # blocks.shape == (1, 1 * 2 * 2, 2 * 3)\n",
        "\n",
        "# fold back into an image\n",
        "# shape: (batch, channels, *output_size)\n",
        "output = F.fold(blocks, output_size, kernel_size=2)\n",
        "print(output.size())  # blocks.shape == (1, 1, 3, 4)\n",
        "print(output)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "D4u8JyBpt53p"
      },
      "source": [
        "# Data Loading and Handling"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "o1setzq_BpC4"
      },
      "source": [
        "from PIL import Image\n",
        "import requests\n",
        "\n",
        "def load_image_from_url(url):\n",
        "  return Image.open(requests.get(url, stream=True).raw)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Exb1pUKdt7VV"
      },
      "source": [
        "class MyDataset(Dataset):\n",
        "  def __init__(self, img_urls, transform):\n",
        "    self.img_urls = img_urls\n",
        "    self.transform = transform\n",
        "\n",
        "  def __len__(self):             # == len(dataset)\n",
        "    return len(self.img_urls)\n",
        "\n",
        "  def __getitem__(self, index):  # == dataset[index]\n",
        "    img = load_image_from_url(self.img_urls[index])\n",
        "    img = self.transform(img)\n",
        "    return {\"img\": img}    "
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1EwutDVzBZAt"
      },
      "source": [
        "transform = transforms.Compose([\n",
        "  transforms.CenterCrop(256),\n",
        "  transforms.ToTensor()\n",
        "])\n",
        "img_urls = [\n",
        "  r\"https://3.bp.blogspot.com/-W__wiaHUjwI/Vt3Grd8df0I/AAAAAAAAA78/7xqUNj8ujtY/s1600/image02.png\",\n",
        "  r\"https://3.bp.blogspot.com/-W__wiaHUjwI/Vt3Grd8df0I/AAAAAAAAA78/7xqUNj8ujtY/s1600/image02.png\",\n",
        "  r\"https://3.bp.blogspot.com/-W__wiaHUjwI/Vt3Grd8df0I/AAAAAAAAA78/7xqUNj8ujtY/s1600/image02.png\",\n",
        "  r\"https://3.bp.blogspot.com/-W__wiaHUjwI/Vt3Grd8df0I/AAAAAAAAA78/7xqUNj8ujtY/s1600/image02.png\",\n",
        "  r\"https://3.bp.blogspot.com/-W__wiaHUjwI/Vt3Grd8df0I/AAAAAAAAA78/7xqUNj8ujtY/s1600/image02.png\",\n",
        "]\n",
        "dataset = MyDataset(img_urls=img_urls, transform=transform)\n",
        "\n",
        "dataloader = DataLoader(dataset, batch_size=2)\n",
        "\n",
        "for batch in dataloader:\n",
        "  img = batch[\"img\"]\n",
        "  img = img.to(device=device)\n",
        "  print(type(img), img.dtype, img.size())  # shape: 2, 3, 256, 256\n",
        "  # do something"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4pnsNXkLt8Ac"
      },
      "source": [
        "# Intermediate Results and Hooks"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KshScWSEE_It"
      },
      "source": [
        "class MyNet(nn.Module):\n",
        "  def __init__(self):\n",
        "    super().__init__()\n",
        "    self.conv1 = nn.Conv2d(in_channels=2, out_channels=4, kernel_size=3)\n",
        "    self.conv2 = nn.Conv2d(in_channels=4, out_channels=8, kernel_size=3)\n",
        "  \n",
        "  def forward(self, x):\n",
        "    x = self.conv1(x)\n",
        "    x = F.relu(x)\n",
        "    x = self.conv2(x)\n",
        "    return x"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Pqvu6vwkFXhy"
      },
      "source": [
        "net = MyNet().to(device=device)\n",
        "x = torch.randn(size=(3, 2, 12, 12), device=device)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vVoLMuDRuAcp"
      },
      "source": [
        "def my_hook(module, input, output):\n",
        "  # module: the module being hooked\n",
        "  # input:  a tuple of inputs\n",
        "  # output: a tuple of outputs\n",
        "  print(\"hook!\", input[0].shape, output[0].shape)\n",
        "\n",
        "# register the hook\n",
        "handle = net.conv1.register_forward_hook(my_hook)\n",
        "\n",
        "# use the hook\n",
        "print('with hook')\n",
        "y = net(x)\n",
        "# printed: \"hook! torch.Size([...]) torch.Size([...])\"\n",
        "\n",
        "# remove the hook\n",
        "handle.remove()\n",
        "\n",
        "# after the hook is removed\n",
        "print('without hook')\n",
        "y = net(x)\n",
        "# nothing is printed"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}