Using Yandex’s CatBoost

Many people won’t have heard of Yandex, but the company is major player in the search space in Russia and the former Soviet Union. Yandex have launched several open source projects, one of the most interesting being CatBoost.

CatBoost is a machine learning library from Yandex which is particularly targeted at classification tasks that deal with categorical data. Many datasets contain lots of information which is categorical in nature and CatBoost allows you to build models without having to encode this data to one hot arrays and the such. The library can also be used with other machine learning libraries such as Keras and Tensorflow. I am going to focus on how the library can be used to build models classifying categorical data.

I highly recommend you watch the above talk from one of the creators of the library where she goes into greater detail about the library and how it can be used in variety of different contexts.

Building A Generic Model

The example in this post is going to use on of the demo datasets included with the CatBoost library. Namely, the titanic dataset which contains information about passengers on the Titanic and allows us to predict whether someone would survive based on a number of different features. While the example code uses the demo dataset, it should be generic enough to replace with your own dataset with only minor modifications.

We begin by initiating our CatTrainer class. We simply pass in the Pandas data frame which we are interested in using to train our model. We also initialise several other variables which for the time being we set to none. These will be used later in our code when preparing and training our model.

Next our protected replace null values method is a simple helper function that replaces any null values with the value -999. This value can be overridden with another value should the user have a more appropriate default in mind.

We then write our preparation method which prepares our X and Y values. We pass in our label and default null value should we wish to use one. This then creates our X and Y values without much overhead on our part.

We then come on to the task of creating our model. For this we are going to write two functions. Our first function is a simple function which either returns some sane defaults or overrides the values with the user’s input should the user want to specify specific aspects of the model. We then simply use these values to create a model and assign it to the self.model variable.

We can then write the function that trains our model. We then again pass in several default arguments to this function with some relatively sane defaults. We then split our X and Y values into training and test data. If the user has not chosen which indexes our categorical data, then we automatically try to determine this by checking the value of the common. Should we have not already created or loaded an external model we then call the create_model function. Finally, we call the fit method using all the relevant information.

We also write a quick cross validation function which allows us to simply verify how accurate our trained model actually is. This allows to quickly benchmark the performance of the model we wish to train.

Two additional functions allow for the saving and loading of models. Which simply wrap around functions contained in the CatBoost library.

Finally, we have a simple method that allows us to predict the labels of a passed in data frame. We simply replace the missing values in the same way as before. Before then calling the internal predict function with the passed data frame and returning the value of the predictions.

Using the code

Above we have an example of how this class can be used with the included Titanic dataset. With minor changes it should be possible to use the class with other datasets. This shows us just how easy it is to produce powerful models with relatively little code with the help of CatBoost. You can find the full code on Github and feel free to ask any questions below in the comments.

Leave a Reply

Your email address will not be published. Required fields are marked *