When analyzing data, regression lines can help us understand the trend of data. In this article, we will introduce how to use Seaborn and Plotly Express to plot regression lines.
The complete code can be found in .
Table of Contents
Example Datasets
In the example of this article, we use two datasets, tips and anscombe. Let’s take a look at how to load them first.
Both Seaborn and Plotly Express have built-in tip datasets. The ways to load them are as follows.
import seaborn as sns tips = sns.load_dataset('tips') tips.head()
import plotly.express as px tips = px.data.tips() tips.head()
total_bill | tip | sex | smoker | day | time | size |
---|---|---|---|---|---|---|
16.99 | 1.01 | Female | No | Sun | Dinner | 2 |
10.34 | 1.66 | Male | No | Sun | Dinner | 3 |
21.01 | 3.50 | Male | No | Sun | Dinner | 3 |
23.68 | 3.31 | Male | No | Sun | Dinner | 2 |
24.59 | 3.61 | Female | No | Sun | Dinner | 4 |
Seaborn also has built-in anscombe dataset. The way to load it is as follows.
import seaborn as sns anscombe = sns.load_dataset("anscombe") anscombe.head()
dataset | x | Y |
---|---|---|
I | 10.0 | 8.04 |
I | 8.0 | 6.95 |
I | 13.0 | 7.58 |
I | 9.0 | 8.81 |
I | 11.0 | 8.33 |
Searbon
regplot()
Seaborn provides regplot() to plots regression lines. The following is its declaration. Please refer to the official website for other parameters.
seaborn.regplot(data=None, x=None, y=None, order=1, logistic=False, lowess=False, robust=False, logx=False, color=None, marker='o')
- data: Data.
- x: The data of the x-axis.
- y: The data of the y-axis.
- color: Color.
- marker: The symbol of the points.
- order: When greater than 1, a polynomial regression will be used.
- logistic: Using statsmodels to estimate a logistic regression.
- robust: Using statsmodels to estimate a Robust regression.
- lowess: Using statsmodels to estimate a LOWESS (locally weighted scatterplot smoothing).
Linear Regression
regplot() uses linear regression by default.
ax = sns.regplot(data=tips, x='total_bill', y='tip') ax.set_title('Linear Regression')
regplot() will first draw a scatter plot, and then draw a linear regression line and a 95% confidence interval on it.
The following example shows how to change the color and the symbol of the points on a scatter plot.
ax = sns.regplot(data=tips, x='total_bill', y='tip', color='g', marker='+') ax.set_title('Linear Regression')
Logistic Regression
tips['big_tip'] = (tips['tip'] / tips['total_bill']) > .2 ax = sns.regplot(data=tips, x='total_bill', y='big_tip', logistic=True) ax.set_title('Logistic Regression')
Robust Regression
ax = sns.regplot(data=anscombe[anscombe['dataset']=='III'], x='x', y='y', robust=True) ax.set_title('Robust Regression')
LOWESS (Locally Weighted Scatterplot Smoothing)
ax = sns.regplot(data=tips, x='total_bill', y='tip', lowess=True) ax.set_title('Lowess Model')
Polynomial Regression
In the following example, we use a degree 2 polynomial regression.
ax = sns.regplot(data=anscombe[anscombe['dataset']=='II'], x='x', y='y', order=2) ax.set_title('Polynomial Regression with degree 2')
jointplot() with kind=’reg’
In addition to plotting a main chart, jointplot() can also plot the x-axis and y-axis data on the upper and right sides of the main chart. When the parameter kind is set to reg, it will use regplot() for the main chart. The following is its declaration, please refer to the official website for other parameters.
seaborn.jointplot(x=None, y=None, data=None, kind='scatter', hue=None)
- data: Data.
- x: The data of the x-axis.
- y: The data of the y-axis.
- kind: chart type, can be’scatter’,’kde’,’hex’,’reg’, or’resid’.
- hue: group information.
The following example uses jointplot() to draw a linear regression line.
ax = sns.jointplot(data=tips, x='total_bill', y='tip', kind='reg')
pairplot() with kind=’reg’
pairplot() can draw multiple subplots on a plot for a dataset. This allows us to compare data in different columns of a dataset. When the parameter kind is set to reg, it will use regplot() to draw the chart. The following is its declaration, please refer to the official website for other parameters.
seaborn.pairplot(data, hue=None, x_vars=None, y_vars=None, kind='scatter')
- data: Data.
- x_vars: The x-axis data of each chart. The type is array.
- y_vars: The y-axis data of each chart. The type is array.
- kind: Chart type, it can be’ scatter’, ‘kde’, ‘hex’, or ‘reg’.
- hue: Grouping data.
The following example uses pairplot() to draw a linear regression line.
ax = sns.pairplot(data=tips, x_vars=['total_bill', 'size'], y_vars=['tip'], kind='reg', hue='smoker', height=5)
Plotly Express
Like Seaborn, the regression line will be drawn on a scatter plot, so Plotly Express’s scatter() can draw the regression line directly. Its declaration is as follows, please refer to the official website for other parameters.
plotly.express.scatter(data_frame=None, x=None, y=None, color=None, facet_row=None, facet_col=None, trendline=None, title=None)
- data_frame: Data.
- x: The data of the x-axis.
- y: The data of the y-axis.
- trendline: It can be ‘ols’ or ‘lowess’.
- color: Grouping data.
- facet_row: Grouping data, but it will be drawn on different vertical subplots.
- facet_col: Grouping data, but it will be drawn on different horizontal subplots
- title: Chart title.
OLS (Ordinary Least Squares Regression)
We only need to set the parameter trendline to ols.
px.scatter(tips, x='total_bill', y='tip', trendline='ols', title='Ordinary Least Squares Regression')
LOWESS (Locally Weighted Scatterplot Smoothing)
In addition to OSL, scatter() also supports LOWESS.
px.scatter(tips, x='total_bill', y='tip', trendline='lowess', title='Locally Weighted Scatterplot Smoothing')
Multiple Subplots
Similar to Seaborn’s pairplot(), Plotly Express’s scatter() can also plot multiple subplots on the same chart for a dataset.
In the following example, the x-axis and y-axis are total_bill and tip. The parameter color is set to sex, so it will be marked with different colors on the chart according to different sex. In addition, the parameter facet_col is set to smoker, so it draws the data of smoker=No and smoker=Yes on two subplots. This is convenient for us to compare the data of smoker.
px.scatter(tips, x='total_bill', y='tip', trendline='ols', color='sex', facet_col='smoker', title='Ordinary Least Squares Regression')
The following example uses the parameter facet_row instead.
px.scatter(tips, x='total_bill', y='tip', trendline='ols', color='sex', facet_row='smoker', title='Ordinary Least Squares Regression')
Conclusion
The regression line can show the trend of the data and help us analyze the data. Seaborn and Plotly Express provide quite convenient functions that allow us to easily draw regression lines. In addition, the chart drawn is quite beautiful.