{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Homework 10 - Spline Regression" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this homework, you will implement a general spline regression.\n", "Recall that a spline of degree $d$ is a piecewise polynomial of degree $d$ with continuity of derivatives of orders $0, 1, \\ldots, d-1$.\n", "\n", "In the case of a cubic spline ($d = 3$) with $K$ knots, this regression function can be expressed by\n", "\n", "$$ f(x_i) = \\beta_0 + \\beta_1 \\, b_1(x_i) + \\ldots + \\beta_{K+3} \\, b_1(x_i) $$\n", "\n", "using appropriate basis functions.\n", "\n", "From Slide 352, you know that we can start off with monomials $x, x^2, x^3$ and then add for each knot $\\xi$ one **truncated monomial**\n", "\n", "$$ h(x,\\xi) = (x-\\xi)_+^3 = \\begin{cases} (x - \\xi)^3 & \\text{if } x > \\xi, \\\\ 0 & \\text{otherwise.} \\end{cases} $$\n", "\n", "\n", "A one-dimensional example is provided below." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "def y_func(x):\n", " return np.sin(8*x) / np.exp((5*x))\n", "n = 100\n", "eps = 0.1\n", "np.random.seed(1)\n", "\n", "x = np.random.rand(n)\n", "y = y_func(x) + eps * np.random.randn(n)\n", "plt.scatter(x, y, label='Samples')\n", "xlin = np.linspace(0,1,100)\n", "plt.plot(xlin, y_func(xlin),'r--', label='Population line')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Task**: Write a function that implements spline regression of degree $d$ and try to estimate the function `y_func` using **cubic** spline regression with $K$ equidistant knots (use $K=4$ in your tests)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Solution**:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def generateSplineRegressionMatrix(x, xi, d = 3):\n", "\n", " # Get number of samples\n", " n = len(x)\n", " \n", " # Get number of knots\n", " K = len(xi)\n", " \n", " # Initialize matrix with predictor variables 1\n", " X = np.ones((n,1))\n", "\n", " # Append columns for the monomials x^1, ..., x^d\n", " for i in range(d):\n", " X = np.hstack((X,np.power(x,i+1).reshape(n,1)))\n", " \n", " # Append one column for each knot using the \n", " # truncated monomial (x - xi_k)_+^d\n", " for k in range(K):\n", " h = np.power(np.maximum(x-xi[k],0), d)\n", " X = np.hstack((X,h.reshape(n,1)))\n", " return X" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.2 0.4 0.6 0.8]\n" ] } ], "source": [ "K = 4\n", "xi = np.linspace(0,1,K+2)\n", "xi = xi[1:-1]\n", "print(xi)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "d = 3\n", "X = generateSplineRegressionMatrix(x, xi, d)\n", "beta = np.linalg.solve(X.T @ X, X.T @ y)\n", "plt.scatter(x, y, label = 'Samples')\n", "Xspline = generateSplineRegressionMatrix(xlin, xi, d)\n", "plt.plot(xlin, y_func(xlin), 'r--', label='Population line')\n", "plt.plot(xlin, Xspline @ beta, 'g-', label='Regression line')\n", "plt.legend();" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Hint**: If you have no idea how to get started, the following code cells might help you.\n", "They perform a **global cubic regression** and can be adapted to perform a **cubic spline regression**." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def generatePolynomialRegressionMatrix(x, d = 3):\n", " \n", " # Get number of samples\n", " n = len(x)\n", " \n", " # Initialize matrix with predictor variables 1 (for the intercept)\n", " X = np.ones((n,1))\n", "\n", " # Append columns for the monomials x^1, ..., x^d\n", " for i in range(d):\n", " X = np.hstack((X,np.power(x,i+1).reshape(n,1)))\n", " \n", " return X" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Set degree of spline regression\n", "d = 3\n", "\n", "# Generate Spline matrix\n", "X = generatePolynomialRegressionMatrix(x, d)\n", "\n", "# Solve the normal equation, remember the @ sign\n", "# performs the ordinary matrix multiplication for numpy arrays.\n", "# You should also avoid to invert the matrix X^T * X!\n", "\n", "beta = np.linalg.solve(X.T @ X, X.T @ y)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.scatter(x, y, label='Samples')\n", "\n", "Xreg = generatePolynomialRegressionMatrix(xlin,d)\n", "\n", "plt.plot(xlin, Xreg @ beta, 'g-', label = 'Regression line')\n", "plt.plot(xlin, y_func(xlin), 'r--', label='Population line')\n", "plt.legend()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 3.82794677e-02, 8.94074990e+00, -5.46596826e+01, 8.98472825e+01,\n", " -7.81583544e+01, -2.18941451e+01, 9.11085904e+00, 1.75846110e+00])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "beta" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.6" } }, "nbformat": 4, "nbformat_minor": 2 }