TFDF speed up training time

I find Tensorflow decision forest has nearly 2x slower training time than alternative libraries when I call .fit from within a Colab notebook. It also seems to hang for a long time after the Keras progress bar reaches 100%.

In this example I get the following results on 250k rows, and 200 features:
tfdf 5min 49s
xgboost (tree_method = ‘approx’) 3min 11s
lightgbm 46.1 s
xgboost (tree_method = ‘histogram’) 30.4 s

(1) Does anyone here have an understanding of why tfdf is significantly slower than xgboost on the same dataset, with similar parameters, and any tips on steps that can be taken to debug or speed up? I’m posting the Colab example because it’s easiest to share but I get similar magnitudes of differences internally, on larger datasets.

(2) Is there any plan to offer histogram-based-binning options for TFDF? It does seem to dramatically reduce training time which would be very helpful to potential users with large datasets.

@Mathieu might know this one

1 Like

Hi Olivier,

Thanks for the detailed colab :).

To answer your first question: The training of TF-DF models in your colabs are slower than the XGBoost or LightGBM mainly because the hyper-parameter / logic of the learning algorithms are configured differently. See the details below:

About your second question: Histogram splits are already supported by Yggdrasil Decision Forests (and yes, they are very fast, often >=5x ; but also sometimes lead to worse models) and we are planning to surface the API to TF-DF soon. Note that in the case of distributed training, this is already done and will be available in the next release of TF-DF.

Let me go into details for your first question.

I re-ran your colab in a public colab instance. The training of TF-DF took 6min 40s.

The Keras loading bar that you see when training a TF-DF model is actually the TF dataset reading. No training is done during this time. In your case, the loading of the dataset takes ~1 minute (on a public colab).

There are two optimizations in your case:

  • Instead of feeding a tensor for each feature, feed a single multi-dimensional feature. e.g.
ds = tf.data.Dataset.from_tensor_slices((X,y)).batch(64)

This will decrease the reading time from 1m to 10s.

  • Increase the batch-size.
ds = tf.data.Dataset.from_tensor_slices((X,y)).batch(1000)

This will decrease the reading time from 10s to 1s.

After those two optimizations, the training took 5min 25s.

Note: I’ll update the documentation accordingly.

The number of threads used for training is controlled by the "num_threads’’ argument. This argument defaults to 6. Public colabs are running on 2 processors (see “!cat /proc/cpuinfo”). Therefore, the training threads are fighting each other, making the training slow.

tfdf_clf = tfdf.keras.GradientBoostedTreesModel(shrinkage=0.02,num_trees=50, num_threads=2)
tfdf_clf.fit(ds)

The training now takes 4min 20s.

Note: I’ll make the number of threads auto by default.

  1. Difference with XGBoost

XGBoost takes 3min 41s to train (in my run).

XGBoost is configured differently than TF-DF.
TF-DF is essentially equivalent to tree_method='exact' in XGBoost.
Also, by default, XGBoost trains with a maximum depth of 3, while TF-DF trains with a maximum depth of 6.

Training XGBoost similarly as TF-DF trains in 7min 48s

xgboost.XGBClassifier(learning_rate=0.02, n_estimators=50, tree_method='exact', max_depth=6)
  1. Difference with LightGBM

LGBM takes 36.5 s to train (in my run).

LGBM is configured differently than TF-DF.
LGBM has a default maximum number of nodes of 31, while TF-DF has a default maximum number of nodes of 2^6 = 64.

However, the bigger difference is the histogram splits. LGBM (only; I think) supports global histogram splits. The number of bins is max_bin=255 by default. Histogram splits are much cheaper than exact splits. This is the reason for the difference in training time.

In addition, in the dataset you generated, the features values are mostly unique (because you sampled from a uniform float distribution). Having a lot of unique values is more expensive for exact splits than for histogram (global or local) splits. However in practice, I think this is not the case. For large real world datasets, it is rare that each numerical value is unique.

If you reduce the number of unique values, you can observe a large speed-up.

X_bin = np.round(X, decimals=2)

ds_bin = tf.data.Dataset.from_tensor_slices((X_bin,y))
ds_bin = ds_bin.batch(1000)

TF-DF will now train in 3min 25s.

Thanks again for the colab :slight_smile:

5 Likes

Thanks very much for the detailed advice, this is brilliant and will save me many hours in waiting for models to fit!

I will keep an eye out for Histogram splits in TFDF, it does seem so much faster that it would be very practical for some use-cases. One day having this within Keras/Tensorflow would be amazing but I do appreciate that TFDF is still a relatively new project! Often float values will be distinct so the coursening tip you have could be useful to many potential users.

When i add the batch size to the dataset

ds = tf.data.Dataset.from_tensor_slices((X,y)).batch(1000)

I get this error


Can not squeeze dim[1], expected a dimension of 1, got 1000
         [[{{node Squeeze}}]] [Op:__inference__consumes_training_examples_until_eof_1011]

This is my code

train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_df, label=feature_spec.get_label_col()).batch(BATCH_SIZE)
    valid_ds = tfdf.keras.pd_dataframe_to_tf_dataset(valid_df, label=feature_spec.get_label_col()).batch(BATCH_SIZE)
    model = tfdf.keras.GradientBoostedTreesModel(
        task=tfdf.keras.Task.CLASSIFICATION,
        **model_config.gradient_boosted_trees_config.model_dump(),
        features=[
            f for f in feature_spec.get_all_tffeature() if f.name not in feature_spec.get_blacklist_feature_names()
        ],
        exclude_non_specified_features=True,
        num_threads=12,
    )

    tensorboard_callback = keras.callbacks.TensorBoard(log_dir=model_config.tensorboard_log_dir + "/{}")
    model.fit(train_ds, validation_data=valid_ds, callbacks=[tensorboard_callback])