Predicting Churn rates for Sparkify

Richard Needham, August 2021.

Link to Github Repository containing additional information and code workbook

1. Introduction

1.1 Overview

This my Capstone Project for the Udacity Data Scientist course, which was written to demonstrate various data analysis, coding and machine-learning techniques to solve a Big Data analytics problem.

It concerns a fictitious online music streaming company called Sparkify, whose users are able to log in and listen to music. The service offers two levels of membership either paid or free; members subscribed to the free level are obliged to listen to advertisements between songs, paid level members on the whole don’t.

Each time the user logs in, plays a song, hears an advert, gives the song a thumbs up for approval, a thumbs down for disapproval, adds a song to their playlist, recommends a song to a friend, recieves an error message, changes settings, upgrades, downgrades or visits the home or help pages a log event is recorded, together with a timestamp. This also includes the event that Sparkify wish to avoid, namely the cancellation of the user’s membership, also known as a churn.

Millions of these log events are recorded into a huge dataset, and they can be used to offer an insight into customer satisfaction. The ultimate measure of dissatisfaction in the churn event, but hearing too many adverts, providing music or songs that get more thumbs down than thumbs up, experiencing difficulties with the platform through error messages or having to regularly change settings or visit are also indicators that a customer is not happy.

To that end it would be helpful if the log events, particularly those pertaining to dissatisfaction could be used to highlight a dissatisfied user, before they churn, and prevent that churn by offering suitable incentives such as discounted paid level membership, or improving the content or functionality of the site thereby enhancing the users experience of the service.

Such a solution could be achieved using techniques such as Data Analytics and Machine Learning Prediction models, given a large data set such as this.

1.2 Problem Description

Customer facing businesses are continually faced with the chellenge of predicting so called churn rates, that is to say the rate at which businesses lose customers through cancellation of subscriptions or contracts.

The motivation of this project is to develop a machine learning model to predict if customers will churn.

Using Spark, it is possible to work efficiently with large datasets in order to produce such models, and the project sets out to study a customer dataset, manipulate and prepare it for machine learning, and then identify effective machine learning models for predicting churn rates using this prepared data.

If accurate, such predictions are incredibly valuable, in helping businesses retain customers and protect revenues.

Sparkify Logo, Source Udacity

For this project the customer data from a ficticious online music streaming platform, Sparkify has been created by Udacity.

Like any other business Sparkify suffers from churning customers (which is reflected in their customer data), so this project will attempt to predict which customers may churn, based on their activity history.

To do this the customer data will be analysed and transformed, useful numeric features will be identified and created, and then applied to various machine learning models to make predictions.

The Sparkify dataset is some 12GB, so a subset it at around 128 MB has been provided to enable a local analysis to be made, which will study, clean and transform the data, so that it can be used with machine learning models to predict which customers will churn.


The aim of the project is to identify effective Machine Learning models for predicting customer churning. For this each model will be assessed, cross validated and tuned.

Predictions will then be made using the tuned models and values calculated for their accuracy. Based on these values, a recommendation will be made as to which model should be used for deployment on the 12GB Dataset.

2. Analysis

The data was loaded into a Jupyter Notebook that was running with Pyspark, Pandas, numpy, matplotlib and seaborn libraries.

2a. Data Exploration

The subset of the fictitious Sparkify Customer data was loaded and explored.

Extract from loaded dataframe showing the contents of columns

Using Pyspark’s filtering, show and describe commands the content of each individual column was studied to get a feel for the datas structure and content.

A summary of statistics for the values in each column was obtained using the describe command:

Extract showing statistcal summary of dataset columns

A dictionary was generated that shows the % of rows containing nulls

Dictionary shown percentage null rows for each column
  • 2% of values in registration
  • 20.4% of values in columns related to songs, artist, length, and song
  • 2.9% of values in columns related to users, firstName, gender, lastName, location, userAgent and reg_date
Extract clarify relationship between registration, ts and userId

Relationships between columns understood were also examined, and some summary statistical information was obtained.

This information was then used to provide a summary of the which was used to describe the Dataframe’s content as in the following summary.

2b. Summary of Data Content


Checking the earliest and latest timestamps shows that data was being collected between 01/10/2018 and 03/12/2018, just over 2 months.

The first registration was showing as 18/03/2018.


The Dataframe contains the following columns:

I: Data relating to users:

  • userId : a unique numeric identifier assigned to each user who is registered with the platform.
  • firstName : users first name.
  • lastName : users surname.
  • gender : their gender.
  • location : the location of their registered address.
  • level : the tier paid or free to which they are subscribed.

II: Data relating to sessions:

Sessions are the periods when a user is logged in to the platform.

  • sessionId : the id number for the session. Each session for a particlular user has its own id. THe same sessionId may appear in another users account.
  • auth : an identifier for the authorisation level for the session, one of four possibilities:- Cancelled, Guest, Logged In, Logged Out. No userIds are associated with Logged Out or Guest.
  • userAgent : the browser / operating system used to access the platform.
  • registration : the timestamp associated with the time when the user registered. Events triggered by individuals who are not logged in return no registration timestamp.

III. Event logs:

These are records of individual interactions between a user and the platform)

  • page : the type of event. There are 22 different page types such as NextSong for playing a song, Add to Playlist for adding a song to a playlist, Home for visiting the platform’s home-page etc.
  • method : User interaction with platform: Put for input comming from user, Get for output recieved by user.
  • itemInSession : sequential number with in a session for an event that was logged.
  • ts : an integer timestamp for recording the time of an event in milliseconds where ts=0 is defined as 01.01.1970 00:00:00.000).

IV. Data relating to songs played

The event type (page) NextSong records that the user is playing a song. For this event song details are recorded:

  • song : the song’s title.
  • length : the length of time for which the song was listened to.
  • artist : the song’s performer.

2c. Data Quality and Cleaning

As well as giving an insight to the dataframes structure, the quality of the data was also checked during this exploration andthe following data quality issues were identified:


  • The minimum value for userId is empty it should be a number.
  • There are missing values in the columns artist, firstName, gender, lastName, length, location, registration, song & userAgent.


  • Create new dataframe event_log_valid
  • Drop nulls and empty strings from userId.

State of Data after Cleaning

  • all columns refering to song information, `song`, `artist`, `length` contain 228108 records.
  • all other columns contain 278154 records.

This is plausible, because songs are only played when the event (page) NextSong is logged. For all other log events (pages) there would be no song infomation.

3. Exploratory Data Analysis

Exploratory data analysis (EDA) was performed on the cleaned subset of the data and by doing basic some basic aggregation manipulations within Spark the following insights can be provided.

3a. Pareto analysis of `Page` types

The column page is one of the most useful for our prediction in that it contains information about the type of log event that occurred. This information can be transformed to provide data about the users activity, which in turn can be used to predict the if a user is going to churn.

  • it can be seen that the page with by far the most event logs is NextSong, which is to be expected, after all the purpose of Sparkify is to play songs.
  • however NextSong is so dominant, it is hard to see the results at the tail end of the Pareto, so it will be re-plotted without NextSong to achieve this.
  • it is interesting to see that Cancel and Cancellation Confirmation, the two events that we wish to predict are virtually negligible compared to the others, and that with this sample data set the number of downgrade and Submit Downgrade events do not match, like as much the numbers of upgrade and Submit Upgrade events.

3b. Analysis of Gender difference for Number of Users and Webpage Usage

Another factor that was studied in the exploratory analysis is the difference between genders, in terms of number of users and amount of usage per user.

Comparison of Number of Users and Rate of Usage by Gender
  • whilst there are clearly more male users than female, it is interesting to see that female users spend more time using Sparkify.

3c. Analysis of Gender difference in terms of Paid Subscription Level

Finally the difference between the genders regarding the proportion of users who at somepoint had Paid for a subscription was measured (including even those who subsequently downgraded).

  • a marginally higher proportion of female users have at some point had a paid subscription, compared to male users.
  • we can conclude from this that gender is an important factor to consider when looking to predict user behaviour.

4. Aggregating the data to create a new dataframe containing one row per UserId

  • the data as supplied is essentially a table of event logs where each row represents a single user event.
  • the purpose of this exercise is to predict user behaviour, so these single user events were aggregated to create a dataframe where each row refers to an individual userId.
  • the columns in this aggregated dataframe represent various features, such as page event type, membership level, gender, number of sessions etc, and contain values for these features for each userId.
  • these values would then be used to create feature vectors for each userId for use in the prediction models.

The aggregated dataframe was built up in several steps as follows (the code can be found here :

  1. Obtain the count of each page type for each userId/gender grouping, and pivot this to obtain a table with userId/gender as the row index and count of page events as columns. To this table an additional column gender_v is added that contains the gender data as a binary variable.
  2. Obtain the earliest (minimum) and latest (maximum) timestamp values and a count of logged events grouping by userId/SessionId. Use the difference between the timestamp values to get the duration online_time for each userId/SessionId. Then , grouping by userId, get the very earliest and very latest timestamps, sum the online_times and the sum counts of logged events for each user. The earliest and last timestamps were then used to determine the value half_time which would be used for trend calulations.
  3. In this step the data is manipulated to provide value for a trend calulation. The earliest timestamp and the count of event logs for each userId/sessionId are aggregated into a dataframe. The half_time value for each userId was joined in from the agg_sessions dataframe created in the previous step. If the userId/sessionId earliest timestamp isbefore (less than) half_time then the event log count is assigned to a new column logs_at_start, otherwise the event log is assigned to a new column logs_at_end.
  4. The dataframes created in steps 1 to 3 are then joined on userId to create the completed aggregated dataframe df_aggregated. As they will no longer be used they can be deleted to free up memory.
  5. A column churn was added to the table df_agrregated to use as the label for the prediction models. This was achieved using the Cancellation Confirmation events to define churn, which applies to both paid and free users. Additionally, a column downgraded was added using the Downgrade events as a definition.

5. Exploring the Aggregated Data

5a. Preparation

Having defined churn, some exploratory data analysis was performed to observe the behaviour for users who stayed vs users who churned.

Aggregated values for these two groups of users were explored, comparing the incidence rate that they experienced for specific events in a given time, in this case per hour:

  1. A dictionary of actions rates_hour was created, with the events as keys and new event rate titles (ending with `_h`) as values.
  2. The dictionary was then looped through to create new feature columns each containing values for the relevant event rate per hour.
  3. Additional aggregations were calculated for
  • average session time per user
  • positivity : the ratio of positive enjoyment indicators / total enjoyment indicators*
  • negativity : the ratio of negative enjoyment indicators / total enjoyment indicators*
  • trend : comparing log_events/hour in the first half of the period the User was active with log_events/hour in the second half

(* An enjoyment indicator is a log event that either reflects or would provoke the users level of enjoyment:- Thumbs Up, Add to Playlist, Add Friend, Upgrade, Thumbs Down, Roll Advert, Error, Help, Save Settings and Downgrade)

5b. Visualisation

Box plots were chosen to compare the incidence rate of these events between the churned users and the users who stayed:

Box plots comparing the distributions of Fatures bewteen users who Churned and users who Stayed

A heatmap was drawn to study the correlation between the incidence rates:

A Barchart was also created comparing number of users who churned with those who stayed:

Data Inbalance owing to significant differend in number of churned users vs number who stayed.

5c. Insights

  • reference to the box plot diagrams shows that the median and quartile values for several of these incidence rates differ for users who have stayed and those who have churned, particluarly for thumbs down / hour, thumbs up / hour, adverts / hour, mean session time / hour, positivity and negativity.
  • however it should be noted that there is also a lot overlap between the distributions, which could lead to uncertainties when making a prediction.
  • regarding the heatmap, some of the features have a very strong correlation bewtween them, which means that they could be considered duplicate and can therefore be discarded
    - positivity and negativity unsurprisingly have a very strong negative correlation, so negativity can be discarded.
    - total_logs_h and songs_h are also positively related, as are total_logs_h and adverts_h, together with total_logs_h and home_page_h. Seeing as Sparkify is about playing songs, total_logs_h and home_page_h could also be dropped.
  • Regarding the ratio of churned to stayed users, the data is not balanced, so Balanced Accuracy should be used when evaluating the prediction models.

6. Feature Engineering

6a. Choosing Features

The features which will used with those which were studied in the Boxplot and Heatmap anaylses minus the “correlated duplicates”.

  • thumbs_down_h: thumbs down / hour
  • adverts_h: adverts / hour
  • errors_h : system errors / hour
  • help_h : visits to help pages / hour
  • settings_h : number of changes to user settings / hour
  • downgrades_h : downgrades / hour
  • thumbs_up_h : thumbs up / hour
  • playlists_h : songs added to playlists / hour
  • freinds_h : recommendations to freinds / hour
  • upgrades_h : upgrades / hour
  • songs_h : songs playyed / hour
  • mean_session_time : average length of session time duration
  • positivity
  • trend: trend for rate of logs / hour

Added to these are the two binary variables

  • gender_v
  • downgraded

6b. Defining Vectors

The following code was used to assemble these features into a vector called features, using the VectorAssembler and Normalizer functions from Pysparks features library:

Code for vector assembly

The numeric features were first assembled and then normalised. The two binary features were then added to thevector, but becuase these have value 0 or 1, a second normalization was not needed on the final assembled vector.

7. Modelling

7a. Defining Pipelines, Default Parameter Values

The aggregated dataset including the vectored features column was copied and split into training and testing sets at a ratio of 90% to 10%.

Three machine learning methods LogisticRegression, RandomForestClassifierand GBTClassifier were tried.

Pipelines were built using the vector assemblers described in step 6b:

Code for Pipeline Building, together with Pipelines using default Parameter Values

The three pipeline models were then trained using the training data, and then the trained models were then used to transform the testing data, using this function:

Code Prediction Function
Applying prediction function to pipeline models

These prediction models were evaluated for accuracy using Pyspark’s BinaryClassificationEvaluator() together with some coded calulations to obtain a confusion matrix, which could be used to evaluate the F1 score and Balanced Accuracy. (BinaryClassificationEvaluator() itself doesn’t give a value for F1). The definitions for the coded calculations can be found in the Precision and Recall article in Wikipedia.

The coding for the evaluation was written in the form of a function which returns a dictionary containing the results for a given prediction model:

Evaluation Function

The results from each model were then collected into a table for comparison:

Accuracy results for models with default parameter values


  • Area under the curves for all models is relatively high thanks to their 100% success rate when predicting true negatives, however this may well be due to the significant inbalance in the data between positives and negatives.
  • Balanced accuracy and F1 score are not so optimistic, owing to the fact that true Positives are not so well predicted. In fact the Random Forest Classifier model failed to predict any true Positives at all.
  • Based on the initial assessment, GBT Classifier has an edge with a higher F1 Score and Balanced Accuracy, since it was able to identify 2 from 4 Postitives.

7b. Optimising Pipelines using Cross-Validation

A function for executing 3-fold cross-validation on a given Pipeline Model was written, which takes the base pipeline model and a parameter grid, cross-validates it using the Spark Tuning Library CrossValidator together with the trainData split to find the most accurate models:

Code for Cross-Validation Function

In order to use this function, parameter grids were defined. First the default and allowed parameter values for the three pipeline models were obtained from the Spark documentation for

Based on this information the following parameter grids were built for the cross-validation :

Table showing Default and Cross Validation Parameter Grid Values for LogisticRegression, RandomForestClassifier and GBTClassifier Models

The function was applied to the LogisticRegression and RandomForestClassifier models (see note below regarding GBTClassifier), and .avgmetrics for the cross validations were obtained, taking the maximum resulting value to identify the best parameter combination:


Cross Validation of Logistic Regression

Selected Parameter Values for tuned LogisticRegression model:

  • maxIter=100, regParam=20.0, elascticNetParam=0.0


Cross Validation of RandomForestClassifier

Selected Parameter Values for tuned RandomForestClassifier model:

  • maxDepth=5, numTrees=10.0


Note: during the analysis it was found that cross-validation for GBTClassifier was very resource intensive, with several hours passing without result when running a cross-validation of 3 values for each of 2 parameters, even when setting the parallism parameter to a value higher than default 1.

To expediate tuning it was decided to abandon the cross-validator method and find the best parameter combination manually the grid search parameter combinations for maxIter and maxDepth.

Tuning Trial Results for GBTClassifier
Validation of GBTClassifier

Selected Parameter Values for tuned GBTClassifier model:

  • maxDepth=3, maxIter=40.0

7c. Assessment of Tuned Models

The pipelines were then reconstructed with these parameters, and the predictions re-run:

Example Code for tuned model (LogisticRegression)

The resulting predictions were then assesed using the prediction_test:

Accuracy results for models with tuned parameter values


  • Area under the curves for LogisticRegression improved when using the best Paramater Set from the CrossValidation. However the F1 Score and Balanced Accuracy decreased.
  • Area under the curves for RandomForestClassifier decreased when using the best Paramater Set from the CrossValidation, with no change to the F1 Score or Balanced Accuracy.
  • Area under the curves for GBTClassifier all increased, with no change to the F1 Score or Balanced Accuracy when using the best Paramater Set from the manual tuning analysis.


The results of the CrossValidation are counter-intuitive, in that the so-called best models are performing less well for some indicators compared to the default values, which were also included in the Parameter Grid. However the with the manual process used on GBTClassifier it was possible to find some parameters that led to an all round improvement.

It should be remembered however that the .avgMetrics that are given are average values measured over the 3 folds of the cross validation.

Each fold uses a different subset of the training data. The data has already been shown to be biased, and depending on how these subsets are selected, the bias could be exaggerated, leading to one or more subsets with an unrepresentative proportion of churn : stayed users, or a difference in this proportion between the subsets. Either way the model will be learning differently for each fold.

If one of the folds then delivers an abnormally low value because of this mismatch, it will effect the average in such a way that a misleading .avgMetric is given, which in turn leads to an incorrect selection of parameter when tuning.

The training data used in the cross validation consists of 90% of the aggregated dataframe approximately 200 rows, but the testing data is much smaller at 10% or only about 25 rows. This is a very small sample size with a very low statistical significance, which could lead to different outputs each time the prediction is run.

The metrics also only apply to the evaluators definitions, in this case BinaryClassificationEvaluator which only considers two metrics, Area under PR and Area under ROC. F1 Score has been measured “manually”.

This could be a reason why the tuned models are performing less well for some indicators but better for others, when compared to the default models.

With the manual process used to optimise GBTClassifier, no such subsetting took place, which is equivalent to only one fold taking place. Whilst this delivers a desirable result with improved accuracy values for the prediction, it does not guarentee that the result is repeatable, particularly when applying it to a prediction based on the full 12GB dataset. This would need to be verified with a trial by running the prediction as a Spark cluster on the cloud using AWS or IBM Cloud.

An additional point to consider is that we are checking our prediction of userIds who will churn against a list of userIds who did churn. It is possible, that some of the false positives in the prediction aren’t actually false. The data in the sample was collected over months, which is a relatively short period of time. It could be that the model has indeed identified a user who will churn, but who has not done so yet, according to the sample data, because of the short collection period. A false positive may not necessarily be a bad thing.

8. Recommended Model and Further Work

Based on the results such that they are the recommendation is to use the GBTClassifier Model, with Hyperparameters maxDepth and maxIter tuned to values of 3 and 40 respectively.
It satisfies the metrics by outperforming both other models in the default settings as well as appearing to improve with tuning.
However, for the reasons discussed above in Interpretations, before deploying the model into production it should first be verified with the full data set using a Spark cluster in the cloud, to prove its reliabilty, if necessary taking steps to tune it further.
Nevertheless it has the potential to deliver a solution for predicting customer churns.

9. Reflection

As ever with Udacity, the project has been interesting and challenging to execute, presenting me with a steep learning curve to climb.

The end result hasn’t provided me with a totally ready solution to the problem, but does provide a basis for further analysis.

One very important learning point for me is the use of some of the machine learning tools, their advantages and disadvantages when it comes to using them with a restricted dataset, particularly on a local machine. The results from using the Cross-Validator for example were particularly disappointing, although understandable considering factors such as small sample size.

A second important learning point is that running these functions, even with a relatively small dataset, consumes a lot of computing resource and takes time. This should be remembered when planning these projects to allow plenty of time for the modelling and optimisation process.

Together with all the previous projects in the Programming for Datascience, Data Analytics and Data Scientist programs, this project has been a good preparation, by providing these both positive and negative experiences, for embarking on a journey towards data science.

Richard Needham

August 2021



Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Richard Needham

Richard Needham

Data Analyst in the Automotive Industry, focusing on Product Quality