Pivot tables are very much similar to what we experienced in spreadsheets. The difference between pivot tables and GroupBy function: “Pivot table is essentially a multi-dimensional version of GroupBy aggregation." — that is, you split-apply-combine, but both the split and the combine happen across not a one-dimensional index, but across a two-dimensional grid.
import numpy as np
import pandas as pd
import seaborn as sns
🛳 Titanic dataset for demonstration
# importing dataset for demonstration
titanic = pd.read_csv('data/titanic.csv')
titanic.head()
survived
pclass
sex
age
fare
embarked
who
embark_town
alive
alone
0
0
3
male
22.0
7.2500
S
man
Southampton
no
False
1
1
1
female
38.0
71.2833
C
woman
Cherbourg
yes
False
2
1
3
female
26.0
7.9250
S
woman
Southampton
yes
True
3
1
1
female
35.0
53.1000
S
woman
Southampton
yes
False
4
0
3
male
35.0
8.0500
S
man
Southampton
no
True
1. WHAT IF WE USE GROUPBY
1.1. Finding survival rate by Gender
Essentially:
group(split) by sex,
selectsurvived, and,
applymean
titanic.groupby('sex')['survived'].mean()
sex
female 0.742038
male 0.188908
Name: survived, dtype: float64
sex pclass
female 1 0.968085
2 0.921053
3 0.500000
male 1 0.368852
2 0.157407
3 0.135447
Name: survived, dtype: float64
# unstack the result for better presentation
titanic.groupby(['sex','pclass'])['survived'].mean().unstack()
pclass 1 2 3
sex
female 0.968085 0.921053 0.500000
male 0.368852 0.157407 0.135447
**Conclusion: ** Though we can apply two-dimensional Groupby but the code will start to look long-to-read and understand. Pandas have better tool, pivot_table, to deal with this.
2. USING PIVOT TABLE
The above two-dimensional GroupBy result can be easily derived from following pivot_table code. We will use .pivot_table() constructor, whose default aggfunc is np.mean
pclass 1 2 3
sex
female 0.968085 0.921053 0.500000
male 0.368852 0.157407 0.135447
We can also get same result without mentioning the index and column kwargs
titanic.pivot_table('survived', 'sex', 'pclass')
pclass 1 2 3
sex
female 0.968085 0.921053 0.500000
male 0.368852 0.157407 0.135447
2.1. Multilevel Pivot Table
Let suppose, we want to group by age, sex and get the survivedmean value by each pclass. But instead of a using each age value as separate group, we will make age_groups. To do this, we will first use pd.cut function to make the segment for age column. To make age segments, first let see min and maxage in our dataset:
Now, we will apply pivot_table on sex and age (through newly created age_group) Other variables will stay the same — finding survivedmean value for each pclass
pclass 1 2 3
sex age
female (0, 18] 0.909091 1.000000 0.511628
(18, 80] 0.972973 0.900000 0.423729
male (0, 18] 0.800000 0.600000 0.215686
(18, 80] 0.375000 0.071429 0.133663
2.2. Additional Pivot Table Options
a. Parameters of pivot_table
Paramter
Default
values=
None
index=
None
aggfunc=
‘mean’
margins=
False
dropna=
True
margins_name=
‘all’
b. aggfunc
Let suppose, we want to know the sum of survived and mean of fare columns, in each pclass
titanic.pivot_table(index='sex',columns='pclass', aggfunc={'survived': sum, 'fare': 'mean'})
# omitted the values keyword;
# when you’re specifying a mapping for aggfunc, this is determined automatically.
fare survived
pclass 1 2 3 1 2 3
sex
female 106.125798 21.970121 16.118810 91 70 72
male 67.226127 19.741782 12.661633 45 17 47
c. margins =True
This simple property margins=True computes sum along each column and row
# using matplotlib to draw figure of
# sum of births in each month, across each gender
# magic function (%matplotlib) to make the plot appear and store in notebook
%matplotlib inline
import matplotlib.pyplot as plt
sns.set() # set seaborn styles
births.pivot_table('births', index='month', columns='gender', aggfunc='sum').plot()
plt.ylabel('total births in each month');
2️⃣ Finding sum of births in each decade, across each gender
# adding a decade column
births['decade'] = 10 * (births['year'] // 10 ) # //10 will remove the last digit in year
# creating pivot table for total births, in each decade, along each gender type
print(births.pivot_table('births', index='decade', columns='gender', aggfunc='sum', margins=True))
gender F M All
decade
1960 1753634 1846572 3600206
1970 16263075 17121550 33384625
1980 18310351 19243452 37553803
1990 19479454 20420553 39900007
2000 18229309 19106428 37335737
All 74035823 77738555 151774378
Let’s put this table into figure
# using matplotlib to draw figure of
# sum of births in each decade, across each gender
# magic function (%matplotlib) to make the plot appear and store in notebook
%matplotlib inline
sns.set() # set seaborn styles
births.pivot_table('births', index='year', columns='gender', aggfunc='sum').plot()
plt.ylabel('total births per year');