{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Problem sheet 3\n", "The previous exercises gave an introduction to Python, Numpy and Pandas. Beginning with this exercise, we shift our focus to statistical learning itself. To this end, we will employ the module scikit-learn which offers many functions we will cover over the remaining semester.\n", "\n", "If not already done, please download the file [Advertising.csv](https://www.tu-chemnitz.de/mathematik/numa/lehre/ds-2018/exercises/Advertising.csv) and move it into a subfolder called `datasets`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Exercise 1:\n", "We start this exercise with the Advertising dataset known from the lecture." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We read the dataset using Pandas:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " TV radio newspaper sales\n", "1 230.1 37.8 69.2 22.1\n", "2 44.5 39.3 45.1 10.4\n", "3 17.2 45.9 69.3 9.3\n" ] } ], "source": [ "import pandas as pd\n", "import numpy as np\n", "%matplotlib inline\n", "\n", "adv = pd.read_csv('./datasets/Advertising.csv', index_col=0)\n", "\n", "# Print first entries of adv\n", "print(adv.head(3))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For convenience, we extract the values from this pandas-DataFrame" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "X = adv.values[:,0:3]\n", "tv, radio, newspaper = np.hsplit(X,3)\n", "Y = adv.values[:,3]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Part (a)\n", "Compute for each of the 3 predictor variables **TV**, **radio** and **newspaper** simple (1-dimensional) linear regressions, e.g.\n", "\n", "$$ y^{TV}_i \\approx \\beta_0^{TV} + \\beta_1^{TV} \\, x_i^{TV}$$\n", "\n", "Use the following function:\n", "\n", " from sklearn.linear_model import LinearRegression\n", " \n", "You can use a command similar to\n", "\n", " print('y = %5.4f + %5.4f x TV' % (intercept, lincoef))\n", " \n", "to print your results in a nice fashion." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y = 7.0326 + 0.0475 x TV\n" ] } ], "source": [ "from sklearn.linear_model import LinearRegression\n", "\n", "reg_tv = LinearRegression().fit(tv, Y)\n", "print('y = %5.4f + %5.4f x TV' % (reg_tv.intercept_, reg_tv.coef_[0]))" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y = 9.3116 + 0.2025 x radio\n" ] } ], "source": [ "reg_radio = LinearRegression().fit(radio, Y)\n", "print('y = %5.4f + %5.4f x radio' % (reg_radio.intercept_, reg_radio.coef_[0]))" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y = 12.3514 + 0.0547 x newspaper\n" ] } ], "source": [ "reg_newspaper = LinearRegression().fit(newspaper, Y)\n", "print('y = %5.4f + %5.4f x newspaper' % (reg_newspaper.intercept_, reg_newspaper.coef_[0]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You should observe, that the regression coefficients for **TV** and **newspaper** are very similar.\n", "As you already know from the lecture, it is not satisfying from a mathematical point of view to restrict our investigation to the absolute values of the coefficients.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Part (b)\n", "\n", "In the lecture you learned about different measures for assessing the quality of a linear fit.\n", "In the last exercise, we already implemented a function to compute the mean squared error (MSE).\n", "\n", "This time, we want to compare the $R^2$ scores. You can use the method `score()` of a `LinearRegression` to get the $R^2$ values.\n", "Remember that this value is the proportion of variability in $Y$ explained using **TV**, **radio** or **newspaper** as predictor in a 1-dimensional linear regression fit." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "R^2 for TV: 0.611875050850071\n", "R^2 for radio: 0.33203245544529525\n", "R^2 for newspaper: 0.05212044544430516\n" ] } ], "source": [ "print(\"R^2 for TV: \", reg_tv.score(tv,Y))\n", "print(\"R^2 for radio: \", reg_radio.score(radio,Y))\n", "print(\"R^2 for newspaper: \", reg_newspaper.score(newspaper,Y))\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Part (c)\n", "Now we want to compute the predicted value of sales if we restrict our prediction to one input, i.e. **TV**, **radio** or **newspaper**, resp.\n", "Predict the values $\\hat{y}^{TV}$ $\\hat{y}^{radio}$ and $\\hat{y}^{newspaper}$ using the method `predict()`" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "y_tv = reg_tv.predict(tv)\n", "y_radio = reg_radio.predict(radio)\n", "y_newspaper = reg_newspaper.predict(newspaper)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Part (d)\n", "\n", "Plot the datapoints as well as the corresponding regression line for each of the inputs **TV**, **radio** or **newspaper**.\n", "\n", "You can use the functions `subplots` or `fig.add_subplot` to arrange the plots in one figure." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support.' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '');\n", " var titletext = $(\n", " '');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('')\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.addClass('ui-button ui-widget ui-state-default ui-corner-all ' +\n", " 'ui-button-icon-only');\n", " button.attr('role', 'button');\n", " button.attr('aria-disabled', 'false');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", "\n", " var icon_img = $('');\n", " icon_img.addClass('ui-button-icon-primary ui-icon');\n", " icon_img.addClass(image);\n", " icon_img.addClass('ui-corner-all');\n", "\n", " var tooltip_span = $('');\n", " tooltip_span.addClass('ui-button-text');\n", " tooltip_span.html(tooltip);\n", "\n", " button.append(icon_img);\n", " button.append(tooltip_span);\n", "\n", " nav_element.append(button);\n", " }\n", "\n", " var fmt_picker_span = $('');\n", "\n", " var fmt_picker = $('');\n", " fmt_picker.addClass('mpl-toolbar-option ui-widget ui-widget-content');\n", " fmt_picker_span.append(fmt_picker);\n", " nav_element.append(fmt_picker_span);\n", " this.format_dropdown = fmt_picker[0];\n", "\n", " for (var ind in mpl.extensions) {\n", " var fmt = mpl.extensions[ind];\n", " var option = $(\n", " '', {selected: fmt === mpl.default_extension}).html(fmt);\n", " fmt_picker.append(option)\n", " }\n", "\n", " // Add hover states to the ui-buttons\n", " $( \".ui-button\" ).hover(\n", " function() { $(this).addClass(\"ui-state-hover\");},\n", " function() { $(this).removeClass(\"ui-state-hover\");}\n", " );\n", "\n", " var status_bar = $('