cohort analysis grid visualization with Python, Pandas, matplotlib and Seaborn

I want to share some code I put together to make a cohort grid using Python, Pandas, matplotlib and Seaborn in case it's useful to others.

# SQL output is imported as a pandas dataframe variable called "df"
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

df['visitors_cohort_period']=pd.to_numeric(df['visitors_cohort_period'])
df['visitors_retained_period']=pd.to_numeric(df['visitors_retained_period'])
df['retention_rate']=pd.to_numeric(df['retention_rate'])
df['months_since']=pd.to_numeric(df['months_since'])

df.reset_index(inplace=True)
df.set_index(['months_since','cohort_period'], inplace=True)

test = df['retention_rate'].unstack(0)

sns.set(style='white')

plt.figure(figsize=(20, 8))
plt.title('Cohorts: User Retention')
sns.heatmap(test, mask=test.isnull(), annot=True, fmt='.0%');

# Use Periscope to visualize a dataframe or an image by passing data to periscope.output()
periscope.output(plt)

The visualization ended up looking like this:

 

 

Also thanks to Neha for her help debugging a few data type issues!

Sources:

http://www.gregreda.com/2015/08/23/cohort-analysis-with-python/

2replies Oldest first
  • Oldest first
  • Newest first
  • Active threads
  • Popular
  • Awesome visualization Chris Eldredge !

    Could you share how the data was formatted when output by the SQL code and loaded into Python? Specifically, what columns did you have and what did they represent? Again, really nice chart! 

    Reply Like
  • Thanks! Sure, these are the columns output by my sql query:

    cohort_period -- the base month the user was active
    retained_period -- the month the user was retained; the dataset contains rows for each month after the cohort period/month
    months_since -- date difference in months between cohort_period (first month the user was active) and retained_period (month user returned)
    visitors_cohort_period -- count of visitors in the base month cohort period
    visitors_retained_period -- count of visitors from the cohort period that were retained that month
    retention_rate -- visitors_retained_period divided by visitors_cohort_period

    the exact sql is specific to our logging, but the general structure was:

    SELECT

    all the columns/counts outlined above

    FROM

    (subquery outputting a list of distinct users active each month, grouping by month using the FORMAT_DATE function to change date stamps to YYYYMM or YYYY-MM-01 format)

    LEFT JOIN

    (same subquery)

     ON user_id

    WHERE 
    retained_period >= cohort_period

    Reply Like 1
reply to topic
Like7 Follow
  • 7 Likes
  • 11 days agoLast active
  • 2Replies
  • 216Views
  • 2 Following