Syllabus Lesson 84 of 239 · Your First Machine Learning Models
Your First Machine Learning Models

Looking at Your Data with matplotlib

Before you model anything, look at it. A plot reveals shape, spread, and outliers that a table of numbers hides. The standard tool is matplotlib, imported as plt.

The modern way to make a plot is to ask for a figure and one or more axes, then draw onto the axes. An axes is a single plotting area:

import matplotlib.pyplot as plt

fig, ax = plt.subplots()       # one figure, one axes
ax.plot([1, 2, 3], [4, 5, 6])
ax.set_title("My chart")

You can ask for a grid of axes too. plt.subplots(1, 2) gives one row of two, which you can unpack:

fig, (left, right) = plt.subplots(1, 2)

Two workhorse plots

A histogram shows the distribution of one variable by bucketing values into bins and counting how many fall in each:

ax.hist(values, bins=4)

A scatter plots one variable against another, one dot per row, to reveal relationships:

ax.scatter(x_values, y_values)

One rule for this course

Never call plt.show() here. There is no window to pop up in the grader, and your code is checked by inspecting the figure's data, not its pixels. For example, a histogram with 4 bins creates 4 bar patches, which you can count with len(ax.patches). So build the figure, set titles, and record what you need as plain numbers.

Your turn

Two arrays are given: ages and spend (12 values each). Make a figure with one row of two axes via fig, (ax_hist, ax_scatter) = plt.subplots(1, 2). On ax_hist draw a histogram of ages with bins=4 and give it a title. On ax_scatter draw a scatter of ages (x) vs spend (y) and give it a title. Finally set n_bars = len(ax_hist.patches) and n_points to the number of plotted scatter points.

Spotted a problem in this lesson? Report it

Code · runs in your browser
Output