Tutorial: Differentiable Information Imbalance (JAX implementation)

The Differentiable Information Imbalance (DII) is a tool to automatically learn the optimal distance function A to predict close pair of points in a target distance space B.

This notebook shows a brief tutorial of the JAX implementation of the DII, available in the ‘DiffImbalance’ class. For more information, please refer to the diff_imbalance module in the DADApy documentation: https://dadapy.readthedocs.io/en/latest/modules.html

[1]:
from dadapy import DiffImbalance
import matplotlib.pyplot as plt
import numpy as np
import os
import jax

jax.config.update("jax_platform_name", "cpu")  # can run on 'cpu' or 'gpu'; restart the
                                               # notebook kernel to make this change effective
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"  # avoid jax memory preallocation
[2]:
%load_ext autoreload
%autoreload 2

Optimization on a 5D anisotropic Gaussian dataset

The differentiable Information Imbalance is computed here as \begin{equation} DII(d^A(\boldsymbol{w})\rightarrow B) = \frac{2}{N^2} \sum_{i,j=1}^N\, c_{ij}^A\, r_{ij}^B\, \hspace{1cm} \Bigg( c_{ij}^A = \frac{\exp(-d_{ij}^A(\boldsymbol{w})^2/\lambda)}{\sum_{k(\neq i)}{\exp(-d_{ik}^A(\boldsymbol{w})^2/\lambda)}}\Bigg). \end{equation} Qualitatively, parameter \(\lambda\) decides how many neighbors are considered - for very small \(\lambda\), only the neigherest neighbor receives a weight of \(~1\), and all other neighbors receive weights close to \(0\).

In this example we show how to assign the optimal weights \(\boldsymbol{w} = \{w_\alpha\}\) to the features \(\{X_\alpha\}\) \((\alpha=1,...,D)\), which define space A, in order to optimize the prediction of distances in a target space B. We will construct space A using a 5-dimensional isotropic Gaussian distribution, and space B by reweighting its coordinates, resulting in a 5-dimensional anisotropic Gaussian.

The weights appear in the distance function as \begin{equation} d_{ij}^A(\boldsymbol{w}) = \Bigg[\sum_{\alpha=1}^D\,(w_{\alpha}X_\alpha^i - w_{\alpha}X_\alpha^j)^2 \Bigg]^{1/2}, \end{equation} and they are optimized by gradient descent.

[3]:
# generate test data
weights_ground_truth = np.array([10, 3, 1, 30, 7.3])
np.random.seed(0)
data_A = np.random.normal(loc=0, scale=1.0, size=(500, 5))  # sample 500 points
data_B = weights_ground_truth[np.newaxis, :] * data_A

# train the DII to recover ground-truth metric
dii = DiffImbalance(
    data_A=data_A,  # matrix of shape (N,D_A)
    data_B=data_B,  # matrix of shape (N,D_B)
    periods_A=None,
    periods_B=None,
    seed=0,
    num_epochs=500,
    batches_per_epoch=1,  # no mini-batches
    l1_strength=0.0,  # no l1 regularization
    point_adapt_lambda=True,
    k_init=1,
    k_final=1,
    params_init=None,  # automatically set to [0.1,0.1,0.1,0.1,0.1]
    optimizer_name="sgd",  # possible choices: "sgd", "adam", "adamw"
    learning_rate=1e-2,
    learning_rate_decay=None,  # possible choices: None, "cos", "exp"
    num_points_rows=None,
)
weights, imbs = dii.train()

print(f"Ground truth weights = {weights_ground_truth}\n")
# scale learnt weights in same range of ground-truth ones (same magnitude of the largest one)
print(
    f"Learnt weights: {np.abs(weights[-1]) / (np.max(np.abs(weights[-1])) / np.max(weights_ground_truth))}"
)
Ground truth weights = [10.   3.   1.  30.   7.3]

Learnt weights: [ 9.99  2.95  0.89 30.    7.31]
[4]:
# plot the DII and the weights during the training
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(imbs)
ax2.plot(weights[:, 0], label="$w_1$")
ax2.plot(weights[:, 1], label="$w_2$")
ax2.plot(weights[:, 2], label="$w_3$")
ax2.plot(weights[:, 3], label="$w_4$")
ax2.plot(weights[:, 4], label="$w_5$")

ax1.set(ylabel="Differentiable Information Imbalance", xlabel="Epoch number")
ax2.set(ylabel="Weights", xlabel="Epoch number")
ax2.legend()
plt.show()
_images/jupyter_example_7_7_0.png

The plots above show the convergence of the DII (left panel) and of the feature weights (right panel) as a function of the epoch number.

In this simple example we employed no decay schedule for the learning rate (argument ‘learning_rate_decay’), although possible options are ‘cos’ (cosine decay) and ‘exp’ (learning rate is halved every 10 epochs). Although these schedules avoid “overshooting” the minimum and can improve the optimization of the DII, we suggest to always perform a first optimization in absence of any learning rate decay schedule, to verify that the number of epoches (argument ‘num_epochs’) is appropriate to ensure convergence.

In particularly noisy data sets, strategies to speed up and improve the DII optimization include the use of mini-batches (argument ‘batches_per_epoch’) coupled with more sophisticated optimizers (e.g. ‘adam’), and the use of a larger neighborhood size for setting the parameter \(\lambda\) adaptively (arguments ‘k_init’ and ‘k_final’). Viable options are setting ‘batches_per_epoch’ such that each mini-batch contains ~100 points, and setting ‘k_init’ and ‘k_final’ such that ~5% of the points are included in each neighborhood. For example, if the original data set contains \(N=500\) points, setting ‘batches_per_epoch’ to 5 results in mini-batches of \(N'=100\) points each, and setting ‘k_init’ and ‘k_final’ to 5 allows selecting 5% of the points in each mini-batch:

[5]:
# train the DII to recover ground-truth metric
dii = DiffImbalance(
    data_A=data_A,  # matrix of shape (N,D_A)
    data_B=data_B,  # matrix of shape (N,D_B)
    periods_A=None,
    periods_B=None,
    seed=0,
    num_epochs=500,
    batches_per_epoch=5,  # no mini-batches
    l1_strength=0.0,  # no l1 regularization
    point_adapt_lambda=True,
    k_init=5,
    k_final=5,
    params_init=None,  # automatically set to [0.1,0.1,0.1,0.1,0.1]
    optimizer_name="adam",  # possible choices: "sgd", "adam", "adamw"
    learning_rate=1e-2,
    learning_rate_decay="cos",  # possible choices: None, "cos", "exp"
    num_points_rows=None,
)
weights, imbs = dii.train()  # the outputs can also be accessed after training
# with dii.params_training and dii.imbs_training

print(f"Ground truth weights = {weights_ground_truth}\n")

# scale learnt weights in same range of ground-truth ones (same magnitude of the largest one)
print(
    f"Learnt weights: {np.abs(weights[-1]) / (np.max(np.abs(weights[-1])) / np.max(weights_ground_truth))}"
)
Ground truth weights = [10.   3.   1.  30.   7.3]

Learnt weights: [10.04  2.86  0.57 30.    7.3 ]
[6]:
# plot the DII and the weights during the training
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(imbs)
ax2.plot(weights[:, 0], label="$w_1$")
ax2.plot(weights[:, 1], label="$w_2$")
ax2.plot(weights[:, 2], label="$w_3$")
ax2.plot(weights[:, 3], label="$w_4$")
ax2.plot(weights[:, 4], label="$w_5$")

ax1.set(ylabel="Differentiable Information Imbalance", xlabel="Epoch number")
ax2.set(ylabel="Weights", xlabel="Epoch number")
ax2.legend()
plt.show()
_images/jupyter_example_7_10_0.png

In this case, the panel on the left depicts the DII computed, at each training epoch, over the last mini-batch employed in that epoch. The DII can be computed on the full data set after convergence using the method ‘return_final_dii’. For details on its arguments refer to the DADApy documentation: https://dadapy.readthedocs.io/en/latest/modules.html.

[7]:
imb_final, _ = dii.return_final_dii(
    compute_error=False, ratio_rows_columns=None, seed=0, discard_close_ind=0
)
print(f"Optimal DII over full data set: {imb_final:2f}")  # can also be accessed
# through dii.imb_final
Optimal DII over full data set: 0.011225

Greedy Feature Search: Forward Search Implementation

Feature selection can be achieved either by optimizing the DII with a L1 regularization, or with a greedy search approach. We show in this section how to run a “forward” greedy search, where we first identify the optimal single feature and successively add features one by one to get the optimal n-tuple feature sets. The process is called “greedy” because it does not explore every possible combination of features, since that would be computationally untenable for high-dimensional datasets.

The parameter n_best controls the greediness of the algorithm. If n_best is set to 1, the optimal n-tuple is identified by adding one feature at a time to the optimal (n-1)-tuple (by only adding features that are not already contained in this (n-1)-tuple). The candidate n-tuples are optimized according to the DII, and the one providing the lowest DII is selected for the next step. If n_best is not 1, the optimal n-tuple is identified by adding one feature at a time to the optimal n_best tuples of size (n-1), resulting in a broader exploration of the search space at the price of an increased runtime.

In the following example, we use the same 5-dimensional Gaussian of the previous tests, with ground-trith weights [10, 3, 1, 30, 7.3]. Hence, the optimal n-tuples for each value of n should be:

  • n=1 -> [3]

  • n=2 -> [0, 3]

  • n=3 -> [0, 3, 4]

  • n=4 -> [0, 1, 3, 4]

  • n=5 -> [0, 1, 2, 3, 4]

The “forward” greedy search is the recommended approach to use when one is looking to select few relevant features (compared to the total possible number of features).

[8]:
# initialize DII object
dii = DiffImbalance(
    data_A,
    data_B,
    periods_A=None,
    periods_B=None,
    seed=0,
    num_epochs=50,
    batches_per_epoch=5,
    l1_strength=0.0,
    point_adapt_lambda=True,
    k_init=5,
    k_final=5,
    lambda_factor=1e-1,
    params_init=None,
    optimizer_name="sgd",
    learning_rate=1e-1,
    learning_rate_decay="cos",
    num_points_rows=None,
)
[9]:
# forward greedy search
(
    feature_sets_fw,
    diis_fw,
    errors_fw,
    weights_fw,
) = dii.forward_greedy_feature_selection(
    n_features_max=5,
    n_best=10,
    compute_error=True,
    ratio_rows_columns=1,
    seed=0,
    discard_close_ind=0,
)

diis_fw_array = np.array(diis_fw)
Feature set = [0], DII = 0.8999369144439697

Feature set = [1], DII = 0.9869815111160278

Feature set = [2], DII = 0.9980540871620178

Feature set = [3], DII = 0.24966953694820404

Feature set = [4], DII = 0.9708154201507568

------------------------------------------------
Best single feature: [3]
        DII: 0.24966953694820404
        Optimal weights: [0.  0.  0.  0.1 0. ]
Selected 5 best candidates for next iteration
------------------------------------------------
Feature set = [0, 3], DII = 0.09266918152570724

Feature set = [1, 3], DII = 0.23958951234817505

Feature set = [2, 3], DII = 0.246477410197258

Feature set = [3, 4], DII = 0.1784021258354187

Feature set = [0, 1], DII = 0.8793059587478638

Feature set = [0, 2], DII = 0.8912973999977112

Feature set = [0, 4], DII = 0.8359236121177673

Feature set = [1, 4], DII = 0.9284022450447083

Feature set = [2, 4], DII = 0.965333878993988

Feature set = [1, 2], DII = 0.9652037620544434


Optimal weights for feature set [0, 3]: [0.04 0.   0.   0.13 0.  ]

------------------------------------------------
Best 2-tuple: [0, 3]
        DII: 0.09266918152570724
        Optimal weights: [0.04 0.   0.   0.13 0.  ]
Selected 10 best candidates for next iteration
------------------------------------------------
Feature set = [0, 1, 3], DII = 0.07907145470380783

Feature set = [0, 2, 3], DII = 0.09184076637029648

Feature set = [0, 3, 4], DII = 0.025932027027010918

Feature set = [1, 3, 4], DII = 0.1665874570608139

Feature set = [2, 3, 4], DII = 0.18107758462429047

Feature set = [1, 2, 3], DII = 0.24252423644065857

Feature set = [0, 1, 4], DII = 0.8256341218948364

Feature set = [0, 2, 4], DII = 0.8317562937736511

Feature set = [0, 1, 2], DII = 0.8771141767501831

Feature set = [1, 2, 4], DII = 0.9303593635559082


Optimal weights for feature set [0, 3, 4]: [0.05 0.   0.   0.16 0.04]

------------------------------------------------
Best 3-tuple: [0, 3, 4]
        DII: 0.025932027027010918
        Optimal weights: [0.05 0.   0.   0.16 0.04]
Selected 10 best candidates for next iteration
------------------------------------------------
Feature set = [0, 1, 3, 4], DII = 0.02279951423406601

Feature set = [0, 2, 3, 4], DII = 0.02594083920121193

Feature set = [0, 1, 2, 3], DII = 0.07834427058696747

Feature set = [1, 2, 3, 4], DII = 0.1662881225347519

Feature set = [0, 1, 2, 4], DII = 0.8106766939163208


Optimal weights for feature set [0, 1, 3, 4]: [0.06 0.02 0.   0.18 0.05]

------------------------------------------------
Best 4-tuple: [0, 1, 3, 4]
        DII: 0.02279951423406601
        Optimal weights: [0.06 0.02 0.   0.18 0.05]
Selected 5 best candidates for next iteration
------------------------------------------------
Feature set = [0, 1, 2, 3, 4], DII = 0.02285277657210827


Optimal weights for feature set [0, 1, 2, 3, 4]: [0.07 0.02 0.   0.21 0.05]

------------------------------------------------
Best 5-tuple: [0, 1, 2, 3, 4]
        DII: 0.02285277657210827
        Optimal weights: [0.07 0.02 0.   0.21 0.05]
Selected 1 best candidates for next iteration
------------------------------------------------
[10]:
print(f"\nSelected feature subsets:\n {feature_sets_fw}\n")
print("Forward DIIs:", diis_fw)
print("Forward Errors:", errors_fw)
print("Forward Weights:", weights_fw)

print("\nOriginal weights:", weights_ground_truth)
print(
    f"\nLearnt weights for the full feature space: {np.abs(weights_fw[-1]) / (np.max(np.abs(weights_fw[-1])) / np.max(weights_ground_truth))}"
)
print(
    f"\nAll weights along the feature selection:\n {np.abs(weights_fw) / (np.max(np.abs(weights_fw)) / np.max(weights_ground_truth))}"
)

Selected feature subsets:
 [[3], [0, 3], [0, 3, 4], [0, 1, 3, 4], [0, 1, 2, 3, 4]]

Forward DIIs: [0.24966953694820404, 0.09266918152570724, 0.025932027027010918, 0.02279951423406601, 0.02285277657210827]
Forward Errors: [0.007513741031289101, 0.005023834761232138, 0.000543437316082418, 0.00034482675255276263, 0.0003450679068919271]
Forward Weights: [array([0. , 0. , 0. , 0.1, 0. ]), array([0.04, 0.  , 0.  , 0.13, 0.  ], dtype=float32), array([0.05, 0.  , 0.  , 0.16, 0.04], dtype=float32), array([0.06, 0.02, 0.  , 0.18, 0.05], dtype=float32), array([0.07, 0.02, 0.  , 0.21, 0.05], dtype=float32)]

Original weights: [10.   3.   1.  30.   7.3]

Learnt weights for the full feature space: [10.17  2.78  0.6  30.    7.39]

All weights along the feature selection:
 [[ 0.    0.    0.   14.6   0.  ]
 [ 6.27  0.    0.   19.68  0.  ]
 [ 7.88  0.    0.   23.32  5.82]
 [ 9.09  2.48  0.   26.84  6.63]
 [10.17  2.78  0.6  30.    7.39]]

Greedy Feature Search: Backward Search Implementation

Here we implement the “backward” greedy search approach, starting with all the features and then successively removing features one by one to get the most optimal n-plet feature sets. As in the “forward” greedy search implementation, the parameter n_best controls the greediness of the algorithm. If n_best is set to 1, the optimal n-tuple is identified by removing one feature at a time to the optimal (n+1)-tuple (by only adding features that are not already contained in this (n+1)-tuple). The candidate n-tuples are optimized according to the DII, and the one providing the lowest DII is selected for the next step. If n_best is not 1, the optimal n-tuple is identified by removing one feature at a time to the n_best optimal tuples of size (n+1).

In the following example, the optimal n-plets for each value of n should be:

  • n=5 -> [0, 1, 2, 3, 4]

  • n=4 -> [0, 1, 3, 4]

  • n=3 -> [0, 3, 4]

  • n=2 -> [0, 3]

  • n=1 -> [3]

The “backward” greedy search is the recommended approach to use when one is looking to select a significant number of relevant features (compared to the total possible number of features), or a case where one just wants to eliminate a few uninformative features.

[11]:
# initialize DII object and train over full feature space
dii = DiffImbalance(
    data_A,
    data_B,
    periods_A=None,
    periods_B=None,
    seed=0,
    num_epochs=50,
    batches_per_epoch=5,
    l1_strength=0.0,
    point_adapt_lambda=True,
    k_init=5,
    k_final=5,
    lambda_factor=1e-1,
    params_init=None,
    optimizer_name="sgd",
    learning_rate=1e-1,
    learning_rate_decay="cos",
    num_points_rows=None,
)
weights, imbs = dii.train()
[12]:
# forward greedy selection
(
    feature_sets_bw,
    diis_bw,
    errors_bw,
    weights_bw,
) = dii.backward_greedy_feature_selection(
    n_features_min=1,
    n_best=10,
    compute_error=True,
    ratio_rows_columns=1,
    seed=0,
    discard_close_ind=0,
)

diis_bw_array = np.array(diis_bw)
------------------------------------------------
All features: [[0, 1, 2, 3, 4]]
        DII: 0.022820621728897095
        Optimal weights: [0.07 0.02 0.   0.21 0.05]
------------------------------------------------
Feature set = [1, 2, 3, 4], DII = 0.16653390228748322

Feature set = [0, 2, 3, 4], DII = 0.02594083920121193

Feature set = [0, 1, 3, 4], DII = 0.022799985483288765

Feature set = [0, 1, 2, 4], DII = 0.8058949112892151

Feature set = [0, 1, 2, 3], DII = 0.0787024274468422

------------------------------------------------
Best 4-tuple: [0, 1, 3, 4]
        DII: 0.022799985483288765
        Optimal weights: [0.06 0.02 0.   0.18 0.05]
Selected 5 best candidates for next iteration
------------------------------------------------
Feature set = [1, 3, 4], DII = 0.1665170043706894

Feature set = [0, 3, 4], DII = 0.0259428508579731

Feature set = [0, 1, 4], DII = 0.8102166056632996

Feature set = [0, 1, 3], DII = 0.07857366651296616

Feature set = [2, 3, 4], DII = 0.18107758462429047

Feature set = [0, 2, 4], DII = 0.8303636908531189

Feature set = [0, 2, 3], DII = 0.09187900274991989

Feature set = [1, 2, 3], DII = 0.23959513008594513

Feature set = [0, 1, 2], DII = 0.8771141767501831

Feature set = [1, 2, 4], DII = 0.9303593635559082

------------------------------------------------
Best 3-tuple: [0, 3, 4]
        DII: 0.0259428508579731
        Optimal weights: [0.05 0.   0.   0.16 0.04]
Selected 10 best candidates for next iteration
------------------------------------------------
Feature set = [3, 4], DII = 0.17846915125846863

Feature set = [0, 4], DII = 0.8343183994293213

Feature set = [0, 3], DII = 0.0921807810664177

Feature set = [1, 3], DII = 0.2390105277299881

Feature set = [0, 1], DII = 0.8793059587478638

Feature set = [2, 3], DII = 0.24573972821235657

Feature set = [0, 2], DII = 0.898990273475647

Feature set = [1, 4], DII = 0.9284022450447083

Feature set = [2, 4], DII = 0.965333878993988

Feature set = [1, 2], DII = 0.9652037620544434

------------------------------------------------
Best 2-tuple: [0, 3]
        DII: 0.0921807810664177
        Optimal weights: [0.04 0.   0.   0.13 0.  ]
Selected 10 best candidates for next iteration
------------------------------------------------
Feature set = [3], DII = 0.24966953694820404

Feature set = [0], DII = 0.8999369144439697

Feature set = [4], DII = 0.9708154201507568

Feature set = [1], DII = 0.9869815111160278

Feature set = [2], DII = 0.9980540871620178

------------------------------------------------
Best 1-tuple: [3]
        DII: 0.24966953694820404
        Optimal weights: [0.  0.  0.  0.1 0. ]
Selected 5 best candidates for next iteration
------------------------------------------------
[13]:
print(f"\nSelected feature subsets:\n {feature_sets_bw}\n")

print("Backward DIIs:", diis_bw)
print("Backward Errors:", errors_bw)
print("Backward Weights:", weights_bw)

print("\nOriginal weights:", weights_ground_truth)
print(
    f"\nLearnt weights for the full feature space: {np.abs(weights_bw[0]) / (np.max(np.abs(weights_bw[0])) / np.max(weights_ground_truth))}"
)
print(
    f"\nAll weights along the feature selection:\n {np.abs(weights_bw) / (np.max(np.abs(weights_bw)) / np.max(weights_ground_truth))}"
)

Selected feature subsets:
 [[0, 1, 2, 3, 4], [0, 1, 3, 4], [0, 3, 4], [0, 3], [3]]

Backward DIIs: [0.022820621728897095, 0.022799985, 0.02594285, 0.09218078, 0.24966954]
Backward Errors: [0.0003439213614910841, 0.00034475446, 0.0005474837, 0.004991478, 0.007513741]
Backward Weights: [Array([0.07, 0.02, 0.  , 0.21, 0.05], dtype=float32), Array([0.06, 0.02, 0.  , 0.18, 0.05], dtype=float32), Array([0.05, 0.  , 0.  , 0.16, 0.04], dtype=float32), Array([0.04, 0.  , 0.  , 0.13, 0.  ], dtype=float32), Array([0. , 0. , 0. , 0.1, 0. ], dtype=float32)]

Original weights: [10.   3.   1.  30.   7.3]

Learnt weights for the full feature space: [10.17  2.78  0.6  30.    7.39]

All weights along the feature selection:
 [[10.17  2.78  0.6  30.    7.39]
 [ 9.09  2.48  0.   26.84  6.63]
 [ 7.88  0.    0.   23.32  5.82]
 [ 6.27  0.    0.   19.68  0.  ]
 [ 0.    0.    0.   14.6   0.  ]]