Source code for diff_imbalance

# Copyright 2021-2025 The DADApy Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""
The *diff_imbalance* module contains the *DiffImbalance* class, implemented with JAX.

The only method supposed to be called by the user is 'train', which carries out the automatic optimization ot the
Differential Information as a function of the weights of the features in the first distance space.
The code can be runned on gpu using the command
    jax.config.update('jax_platform_name', 'gpu') # set 'cpu' or 'gpu'
"""

import warnings
from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.training import train_state
from tqdm.auto import tqdm

# OPTIMIZABLE DISTANCE FUNCTIONS
# (here new functions may be added for purposes beyond feature selection)
# ----------------------------------------------------------------------------------------------


# for feature selection
@partial(jax.jit, static_argnames="params_groups")
def _compute_dist2_matrix_scaling(
    params, batch_rows, batch_columns, periods=None, params_groups=None
):
    """Computes the (squared) Euclidean distance matrix between points in 'batch_rows' and points in 'batch_columns'.

    The features of the points are scaled by the weights in 'params', such that the distance between
    point i in batch_rows and point j in batch_columns is computed as
        dist2_matrix[i,j] = ((batch_rows[i,:] - batch_columns[j,:])**2).sum(axis=-1)

    Args:
        params (jnp.array(float)): array of shape (n_params,). If parmas_groups is None, n_params == n_features.
        batch_rows (jnp.array(float)): matrix of shape (n_points_rows, n_features).
        batch_columns (jnp.array(float)): matrix of shape (n_points_columns, n_features).
        periods (jnp.array(float)): array of shape (n_features,) for computing distances between periodic
            features by applying PBCs. If only a subset of features is periodic, the entries of 'periods' for the
            nonperiodic features should be set to zero. Default is None, for which PBCs are not applied.
        params_groups (jnp.array(int)): array of shape (n_params,) containing at position i the number of features
            that share the same weight params[i], using the same order of the columns in batch_rows and batch_columns.
            If params_groups is None, no weight sharing is enforced.
    Returns:
        dist2_matrix (jnp.array(float)): array of shape (n_points_rows, n_features) containing the square Euclidean
            distances between all points in 'batch_rows' and all points in 'batch_columns'.
    """
    diffs = batch_rows[:, jnp.newaxis, :] - batch_columns[jnp.newaxis, :, :]
    if periods is not None:
        periodic_mask = periods > 0  # only shift periodic features
        periodic_shifts = (
            jnp.round(diffs / jnp.where(periodic_mask, periods, 1.0)) * periods
        )
        diffs -= jnp.where(periodic_mask, periodic_shifts, 0.0)

    params_repeated = +params
    if params_groups is not None:
        params_repeated = jnp.repeat(params, np.array(params_groups))

    diffs *= params_repeated[jnp.newaxis, jnp.newaxis, :]
    dist2_matrix = jnp.sum(diffs * diffs, axis=-1)
    return dist2_matrix


# CLASS TO OPTIMIZE THE DIFFERENTIAL INFORMATION IMBALANCE
# ----------------------------------------------------------------------------------------------


[docs] class DiffImbalance: """Carries out the optimization of the DII(A(w)->B) with respect to the weights in the first distance space. The class 'DiffImbalance' supports two schemes for setting the smoothing parameter lambda, which tunes the size of neighborhoods in space A. In both schemes lambda can be epoch-dependent, i.e. decreased during the training according to a cosine decay between 'init' and 'final' values. The schemes are: 1. Adaptive: lambda is equal for all the points and is set to a fraction (given by lambda_factor, default is 1/10) of the *average* square distance of k-th neighbors. Example: point_adapt_lambda: False k_init: 10 k_final: 1 lambda_factor=1/10 2. Point-adaptive: lambda is different for each point. For point i, it is set to a fraction of the square distance between i and its k-th neighbor. Example: point_adapt_lambda: True k_init: 10 k_final: 1 lambda_factor=1/10 As a rule of thumb, we suggest to set k_init and k_final to ~5% of the points in the data set, if mini- batches are not employed, or to ~5% of the points within each mini-batch, if they are employed. Attributes: data_A (np.array(float), jnp.array(float)): feature space A, matrix of shape (n_points, n_features_A). data_B (np.array(float), jnp.array(float)): feature space B, matrix of shape (n_points, n_features_B). distances_B (np.array(float), jnp.array(float)): distance matrix in space B, of shape (n_points, n_points). Default is None, for which distances are computed from the features in data_B. periods_A (np.array(float), jnp.array(float)): array of shape (n_features_A,), periods of features A. Default is None, which means that features A are treated as nonperiodic. If not all features are periodic, the entries of the nonperiodic ones should be set to 0. periods_B (np.array(float), jnp.array(float)): array of shape (n_features_B,), periods of features B. Default is None, which means that features B are treated as nonperiodic. If not all features are periodic, the entries of the nonperiodic ones should be set to 0. num_epochs (int): number of training epochs. Default is 200. batches_per_epoch (int): number of minibatches; must be a divisor of n_points. Each weight update is carried out by computing the DII gradient over n_points / batches_per_epoch points. Default is 1, which means that the gradient is computed over all the available points (batch GD). seed (int): seed of JAX random generator, default is 0. Different seeds determine different mini-batch partitions. l1_strength (float): strength of the L1 regularization (LASSO) term. Default is 0. point_adapt_lambda (bool): whether to use a global smoothing parameter lambda for the c_ij coefficients in the DII (if False), or a different parameter for each point (if True). Default is True. k_init (int): initial rank of neighbors used to set lambda. Ranks are defined starting from 1. If batches_per_epoch > 1, neighbors are recomputed within each mini-batch. Default is 1. k_final (int): final rank of neighbors used to set lambda. If batches_per_epoch > 1, neighbors are recomputed within each mini-batch. Default is 1. lambda_factor (float): factor defining the scale of lambda. Default is 0.1. params_init (np.array(float), jnp.array(float)): array of shape (n_params,) containing the initial values of the scaling weights to be optimized. If params_groups is set to None, each feature is scaled by an independent optimization parameter, so n_params == n_features_A. If params_init is None, the initial scaling parameters are set to [0.1, 0.1, ..., 0.1]. params_groups (np.array(int), jnp.array(int)): array of shape (n_params,) containing at position i the number of features that share the same weight in params_init[i], using the same order of the columns in data_A. If params_groups = [3, 2, 4], for example, the first 3 features in space A will share a common weight, the following 2 features will share a second common weight, and the last 4 features will also be scaled by a common optimization parameter. params_groups should satisfy the constraint sum(params_groups) == n_features_A. If params_groups is None, no weight sharing is enforced. optimizer_name (str): name of the optimizer, calling the Optax library. Possible choices are 'sgd' (default), 'adam' and 'adamw'. See https://optax.readthedocs.io/en/latest/api/optimizers.html for additional details. learning_rate (float): value of the learning rate. Default is 1e-2. learning_rate_decay (str): schedule to damp the learning rate to zero starting from the value provided with the attribute learning_rate. The available schedules are: cosine decay ("cos"), exponential decay ("exp"; the initial learning rate is halved every 10 steps), or constant learning rate (None). Default is None (constant learning rate). num_points_rows (int): number of points sampled from the rows of rank and distance matrices. In case of large datasets, choosing num_points_rows < n_points can significantly speed up the training. The default is None, for which num_points_rows == n_points. """ def __init__( self, data_A, data_B, distances_B=None, periods_A=None, periods_B=None, num_epochs=200, batches_per_epoch=1, seed=0, l1_strength=0.0, point_adapt_lambda=True, k_init=1, k_final=1, lambda_factor=0.1, params_init=None, params_groups=None, optimizer_name="sgd", learning_rate=1e-2, learning_rate_decay=None, num_points_rows=None, ): """Initialise the DiffImbalance class.""" self.nfeatures_A = data_A.shape[1] if distances_B is None: # space B provided as features self.nfeatures_B = data_B.shape[1] assert data_A.shape[0] == data_B.shape[0], ( f"Space A has {data_A.shape[0]} samples " + f"while space B has {data_B.shape[0]} samples." ) else: # space B provided as distances if data_B is not None: warnings.warn( f"Argument distances_B is not None; data_B will be ignored." ) # self.distances_B = jnp.array(distances_B) assert ( distances_B.shape[0] == distances_B.shape[1] ), f"Argument distances_B should be a square matrix, while it has shape {distances_B.shape}" assert data_A.shape[0] == distances_B.shape[0], ( f"Number of points in data_A ({data_A.shape[0]}) and distances_B ({distances_B.shape[0]})" + f" do not match." ) self.nparams = self.nfeatures_A if params_groups is None else len(params_groups) # initialize jax random generator self.key = jax.random.PRNGKey(seed) self.key, subkey = jax.random.split(self.key, num=2) # initialize spaces A and B self.data_A = data_A self.data_B = data_B self.distances_B = distances_B # option to speed up DII calculation by decimating rows (rectangular distance matrices) if num_points_rows is not None: assert num_points_rows < self.data_A.shape[0], ( f"num_points_rows ({num_points_rows}) should be smaller than the number " + f"of points in the data set ({self.data_A.shape[0]}) or set to None." ) # decimate rows but not columns, and keep same indices in upper left square matrix indices_rows = jax.random.choice( subkey, jnp.arange(data_A.shape[0]), shape=(num_points_rows,), replace=False, ) indices_columns = jnp.delete(jnp.arange(data_A.shape[0]), indices_rows) indices_columns = jnp.concatenate((indices_rows, indices_columns)) else: indices_rows = jnp.arange(data_A.shape[0]) indices_columns = +indices_rows self.data_A_rows = data_A[indices_rows] self.data_A_columns = data_A[indices_columns] if self.distances_B is None: # space B provided as features self.data_B_rows = data_B[indices_rows] self.data_B_columns = data_B[indices_columns] else: # space B provided as distances self.distances_B_subsampled = self.distances_B[indices_rows][ :, indices_columns ] self.nrows = self.data_A_rows.shape[0] self.ncolumns = self.data_A_columns.shape[0] self.periods_A = ( jnp.ones(self.nfeatures_A) * jnp.array(periods_A) if periods_A is not None else periods_A ) if self.distances_B is None: # space B provided as features self.periods_B = ( jnp.ones(self.nfeatures_B) * jnp.array(periods_B) if periods_B is not None else periods_B ) self.num_epochs = num_epochs self.batches_per_epoch = batches_per_epoch self.l1_strength = l1_strength self.point_adapt_lambda = point_adapt_lambda self.k_init = k_init self.k_final = k_final self.lambda_factor = lambda_factor if params_init is not None: self.params_init = jnp.array(params_init, dtype=float) else: self.params_init = 0.1 * jnp.ones(self.nparams) self.params_groups = params_groups if params_groups is not None: self.params_groups = tuple(params_groups) self.params_final = None self.params_training = None self.imb_final = None self.imbs_training = None self.error_final = None self.optimizer_name = optimizer_name self.learning_rate = learning_rate self.learning_rate_decay = learning_rate_decay self.num_points_rows = num_points_rows self.mask = None self.state = None self._distance_A = _compute_dist2_matrix_scaling # TODO: assign other functions if other distances d_A are chosen # generic checks and warnings assert self.nrows >= batches_per_epoch, ( f"Cannot extract {batches_per_epoch} minibatches " + f"from {self.nrows} samples." ) assert ( self.k_init is not None and self.k_final is not None ), f"Provide values of 'k_init' and 'k_final' to compute lambda adaptively." if self.k_init > 100: warnings.warn( f"For efficiency reasons the maximum value for 'k_init' is 100, while you set it to {self.k_init}.\n" + f"The run will continue with 'k_init = 100'" ) self.k_init = 100 assert ( self.k_init >= self.k_final and self.k_final > 0 ), f"'k_init' and 'k_final' must satisfy: k_init >= k_final >= 1." assert isinstance(k_init, int) and isinstance( k_final, int ), f"'k_init' and 'k_final' must be positive integers." if self.params_groups is not None: n_vars = np.sum(self.params_groups) assert n_vars == self.nfeatures_A, ( f"Number of elements in 'params_groups' ({n_vars}) does not match the number " + f"of features in space A ({self.nfeatures_A})." ) assert self.params_init.shape[0] == self.nparams, ( f"With your inputs ('data_A' and 'params_groups'), 'params_init' should contain {self.nparams} weights, " + f"while it contains {self.params_init.shape[0]} weights." ) # create jitted functions self._create_functions() # pre-compute ranks B to speed up training if self.distances_B is None: # input B provided as features self.ranks_B = self._compute_rank_matrix( batch_rows=self.data_B_rows, batch_columns=self.data_B_columns, periods=self.periods_B, ) else: # input B provided as features self.ranks_B = self.distances_B_subsampled.argsort(axis=1).argsort(axis=1) # set method to compute lambda (adaptive or point-adaptive) if point_adapt_lambda: self.lambda_method = self._compute_point_adapt_lambdas else: self.lambda_method = self._compute_adapt_lambda def _create_functions(self): def _compute_rank_matrix(batch_rows, batch_columns, periods): """Computes the matrix of ranks for the target space B. Args: batch_rows (jnp.array(float)): matrix of shape (n_points_rows, n_features_B), containing points labelling the rank matrix rows. batch_columns (jnp.array(float)): matrix of shape (n_points_columns, n_features_B), containing points labelling the rank matrix columns. periods (jnp.array(float)): array of shape (n_features_B,), containing the periods of features in space B. PBCs are not applied for feature i if periods[i] == 0, or if periods == None. Returns: rank_matrix (jnp.array(float)): matrix of shape (n_points_rows, n_points_columns), defining the target distance ranks in space B. Ranks start from 1, and are 0 only for a point with respect to itself (when a point appears both in batch_rows and batch_columns). """ diffs = batch_rows[:, jnp.newaxis, :] - batch_columns[jnp.newaxis, :, :] if periods is not None: periodic_mask = periods > 0 # only shift periodic features periodic_shifts = ( jnp.round(diffs / jnp.where(periodic_mask, periods, 1.0)) * periods ) diffs -= jnp.where(periodic_mask, periodic_shifts, 0.0) dist2_matrix = jnp.sum(diffs * diffs, axis=-1) rank_matrix = dist2_matrix.argsort(axis=1).argsort(axis=1) return rank_matrix def _cosine_decay_func(start_value, final_value, step): """Implements a cosine decay during the training. The arguments start_value and final_value can be values of the learning rate or of the neighbor order used to compute lambda in the adaptive and point-adaptive schemes. Args: start_value (float): initial value. final_value (float): final value. step (int): number of current gradient descent step. Returns: cosine (float): value of the cosine interpolating start_value at step 0 and final_value at the last training step. """ x = jnp.pi / (self.num_epochs * self.batches_per_epoch) * step cosine = (start_value - final_value) * (jnp.cos(x) + 1) / 2 + final_value return cosine def _compute_point_adapt_lambdas(dist2_matrix, step=None, k=None): """Computes lambda parameters with the point-adaptive scheme, according to the current value of k. Args: dist2_matrix (jnp.array(float)): matrix of shape (n_points_rows, n_points_columns), containing the squared distances in space A at the current training step. step (int): number of current gradient descent step, from which the current value of 'k' to compute lambda adaptively is obtained. k (int): neighbor order to set lambda adaptively (alternative to step). Returns: current_lambdas (jnp.array(float)): array of shape (n_points_rows,), containing a value of lambda computed adaptively for each point, as the fraction ('lambda_factor', default: 1/10) of the squared distance of the neighbor of order k. """ if step is not None: current_k = jnp.rint( _cosine_decay_func( start_value=self.k_init, final_value=self.k_final, step=step ) ).astype(int) elif k is not None: current_k = k # take the k_max_allowed smallest distances with negative sign k_max_allowed = 100 if dist2_matrix.shape[1] < k_max_allowed: k_max_allowed = dist2_matrix.shape[1] smallest_dist2, _ = jax.lax.top_k(-dist2_matrix, k_max_allowed) current_lambdas = -smallest_dist2[:, current_k - 1] * self.lambda_factor # DON'T DELETE: Adaptive scheme of cython code # diffs_dists_2nd_1st = -smallest_dist2[:, 1] + smallest_dist2[:, 0] # current_lambdas = 0.5*(diffs_dists_2nd_1st.min() + diffs_dists_2nd_1st.mean()) return current_lambdas def _compute_adapt_lambda(dist2_matrix, step=None, k=None): """Computes smoothing parameter lambda with adaptive scheme. Args: dist2_matrix (jnp.array(float)): matrix of shape (n_points_rows, n_points_columns), containing the squared distances in space A at the current training step. step (int): number of current gradient descent step, from which the current value of 'k' to compute lambda adaptively is obtained. k (int): neighbor order to set lambda adaptively (alternative to step). Returns: current_lambda (jnp.array(float)): array of shape (n_points_rows,), containing the same value of lambda for all points, computed as the fraction (1/10) of the *average* squared distance of the neighbor of order k. """ current_lambda = _compute_point_adapt_lambdas( dist2_matrix, step, k ).mean() * jnp.ones(dist2_matrix.shape[0]) return current_lambda def _compute_training_diff_imbalance( params, batch_A_rows, batch_A_columns, batch_B_ranks, step ): """Computes the Differentiable Information Imbalance (DII) at the current step of the training. Args: params (jnp.array(float)): array of shape (n_features_A,) of the current feature weights. batch_A_rows (jnp.array(float)): matrix of shape (n_points_rows, n_features_A), containing points labelling the distance matrix rows. batch_A_columns (jnp.array(float)): matrix of shape (n_points_columns, n_features_A), containing points labelling the distance matrix columns. batch_B_ranks (jnp.array(float)): matrix of shape (n_points_rows, n_points_columns), containing the pre-computed target ranks in space B. step (int): number of current gradient descent step. Returns: diff_imbalance (float): current value of the DII. """ dist2_matrix_A = self._distance_A( # compute distance matrix A params=params, batch_rows=batch_A_rows, batch_columns=batch_A_columns, periods=self.periods_A, params_groups=self.params_groups, ) N = dist2_matrix_A.shape[0] max_rank = dist2_matrix_A.shape[1] - 1 # set distance of a point with itself to large number dist2_matrix_A = dist2_matrix_A.at[jnp.arange(N), jnp.arange(N)].set( jnp.max(dist2_matrix_A) + 1e6 ) lambdas = self.lambda_method( dist2_matrix=dist2_matrix_A, step=step ) # compute lambda values c_matrix = ( jax.nn.softmax( # N.B. diagonal elements already numerically zero -dist2_matrix_A / lambdas[ :, jnp.newaxis ], # jax.lax.stop_gradient(lambdas[:, jnp.newaxis]) axis=1, ) ) # DON'T DELETE: Alternative definition of c_ij coefficients (sigmoid instead of softmax) # c_matrix = jax.nn.sigmoid( # (lambdas[:, jnp.newaxis] - dist2_matrix_A)/(self.lambda_factor * lambdas[:, jnp.newaxis]) # ) # compute DII conditional_ranks = jnp.sum(batch_B_ranks * c_matrix, axis=1) diff_imbalance = 2.0 / (max_rank + 1) * jnp.sum(conditional_ranks) / N # DON'T DELETE: analytical gradient of the DII (without differentiating lambda) # diffs_squared = ((batch_A_rows[:,jnp.newaxis,:] - batch_A_columns[jnp.newaxis,:,:]) # *(batch_A_rows[:,jnp.newaxis,:] - batch_A_columns[jnp.newaxis,:,:])) # shape (nrows, ncols, D) # second_term = (c_matrix[:,:,jnp.newaxis] * diffs_squared).sum(axis=1, keepdims=True) # grad_imbalance = ( # 4.0 * params / (N * (self.max_rank + 1)) # * jnp.sum((batch_B_ranks * c_matrix)[:,:,jnp.newaxis] / lambdas[:,jnp.newaxis,jnp.newaxis] # * (-diffs_squared + second_term), axis=(0,1)) # ) return diff_imbalance def _compute_final_diff_imbalance_and_error( params, batch_A_rows, batch_A_columns, batch_B_ranks, k ): """Computes the Differentiable Information Imbalance (DII) and its error. Args: params (jnp.array(float)): array of shape (n_features_A,) of the current feature weights. batch_A_rows (jnp.array(float)): matrix of shape (n_points_rows, n_features_A), containing points labelling the distance matrix rows. batch_A_columns (jnp.array(float)): matrix of shape (n_points_columns, n_features_A), containing points labelling the distance matrix columns. batch_B_ranks (jnp.array(float)): matrix of shape (n_points_rows, n_points_columns), containing the pre-computed target ranks in space B. k (int): neighbor order to set lambda adaptively. Returns: diff_imbalance (float): value of the DII. error_imbalance (float): error associated to the DII. """ dist2_matrix_A = self._distance_A( # compute distance matrix A params=params, batch_rows=batch_A_rows, batch_columns=batch_A_columns, periods=self.periods_A, params_groups=self.params_groups, ) N = dist2_matrix_A.shape[0] max_rank = dist2_matrix_A.shape[1] lambdas = self.lambda_method( dist2_matrix=dist2_matrix_A, k=k ) # compute lambda values c_matrix = jax.nn.softmax( -dist2_matrix_A / lambdas[ :, jnp.newaxis ], # jax.lax.stop_gradient(lambdas[:,jnp.newaxis]) axis=1, ) # DON'T DELETE: compute standard Information Imbalance # batch_A_ranks = dist2_matrix_A.argsort(axis=1).argsort(axis=1) + 1 # mask_A = (batch_A_ranks <= k) # conditional_ranks = jnp.where(mask_A, 1.0, 0.0) * batch_B_ranks # conditional_ranks = conditional_ranks.sum(axis=-1) / k # values_average = 2.0 / (max_rank + 1) * conditional_ranks # compute DII and error values_average = ( 2.0 / (max_rank + 1) * jnp.sum(batch_B_ranks * c_matrix, axis=1) ) diff_imbalance = jnp.mean(values_average) error_imbalance = jnp.std(values_average, ddof=1) / jnp.sqrt(N) return diff_imbalance, error_imbalance def _compute_final_diff_imbalance( params, batch_A_rows, batch_A_columns, batch_B_ranks, k ): """Computes the Differentiable Information Imbalance (DII) without providing the error. Args: params (jnp.array(float)): array of shape (n_features_A,) of the current feature weights. batch_A_rows (jnp.array(float)): matrix of shape (n_points_rows, n_features_A), containing the points labelling the distance matrix rows. batch_A_columns (jnp.array(float)): matrix of shape (n_points_columns, n_features_A), containing the points labelling the distance matrix columns. batch_B_ranks (jnp.array(float)): matrix of shape (n_points_rows, n_points_columns), containing the pre-computed target ranks in space B. k (int): neighbor order to set lambda adaptively. Returns: diff_imbalance (float): value of the DII. """ dist2_matrix_A = self._distance_A( # compute distance matrix A params=params, batch_rows=batch_A_rows, batch_columns=batch_A_columns, periods=self.periods_A, params_groups=self.params_groups, ) N = dist2_matrix_A.shape[0] max_rank = dist2_matrix_A.shape[1] - 1 # set distance of a point with itself to large number dist2_matrix_A = dist2_matrix_A.at[jnp.arange(N), jnp.arange(N)].set( jnp.max(dist2_matrix_A) + 1e6 ) # apply mask to column indices around the row index if self.mask is not None: dist2_matrix_A = dist2_matrix_A[self.mask].reshape( (dist2_matrix_A.shape[0], -1) ) max_rank = dist2_matrix_A.shape[1] lambdas = self.lambda_method( dist2_matrix=dist2_matrix_A, k=k ) # compute lambda values c_matrix = jax.nn.softmax( # N.B. diagonal elements already numerically zero if mask is None -dist2_matrix_A / lambdas[:, jnp.newaxis], axis=1, ) # compute DII conditional_ranks = jnp.sum(batch_B_ranks * c_matrix, axis=1) diff_imbalance = 2.0 / (max_rank + 1) * jnp.sum(conditional_ranks) / N return diff_imbalance def _train_step(state, batch_A_rows, batch_A_columns, batch_B_ranks): """Performs a single gradient descent step in the optimization of the DII. Args: state (flax.training.train_state.TrainState object): current training state. batch_A_rows (jnp.array(float)): matrix of shape (n_points_rows, n_features_A), containing the points labelling the distance matrix rows. batch_A_columns (jnp.array(float)): matrix of shape (n_points_columns, n_features_A), containing the points labelling the distance matrix columns. batch_B_ranks (jnp.array(float)): matrix of shape (n_points_rows, n_points_columns), containing the pre-computed target ranks in space B. Returns: state_new (flax.training.train_state.TrainState object): new training state after optimizer step imb (flat): new value of the DII after optimizer step. """ loss_fn = lambda params: _compute_training_diff_imbalance( params=params, batch_A_rows=batch_A_rows, batch_A_columns=batch_A_columns, batch_B_ranks=batch_B_ranks, step=state.step, ) # Get loss and gradient imb, grads = jax.value_and_grad(loss_fn)(state.params) # Update parameters state = state.apply_gradients(grads=grads) norm_init = jnp.sqrt((self.params_init**2).sum()) norm_now = jnp.sqrt((state.params**2).sum()) # Scale weight vector to original norm state = state.replace(params=norm_init / norm_now * state.params) # Apply L1 penalty if self.l1_strength != 0: current_lr = self.lr_schedule(state.step) # (GD clipping, B. Carpenter et al, 2008) state = state.replace( params=jnp.where(state.params > 0, 1.0, 0.0) * jnp.maximum(0, state.params - current_lr * self.l1_strength) + jnp.where(state.params < 0, 1.0, 0.0) * jnp.minimum(0, state.params + current_lr * self.l1_strength) ) # DON'T DELETE: Soft version of GD clipping # candidate_params = ( # state.params # - jnp.sign(state.params) * current_lr * self.l1_strength # ) # state = state.replace( # params=state.params # * (1.0 - jnp.where(state.params * candidate_params < 0, 1.0, 0.0)) # ) # Scale weight vector to original norm norm_now = jnp.sqrt((state.params**2).sum()) state = state.replace(params=norm_init / norm_now * state.params) return state, imb # jit compilation of functions self._compute_rank_matrix = jax.jit(_compute_rank_matrix) self._cosine_decay_func = jax.jit(_cosine_decay_func) self._compute_point_adapt_lambdas = jax.jit(_compute_point_adapt_lambdas) self._compute_adapt_lambda = jax.jit(_compute_adapt_lambda) self._compute_training_diff_imbalance = jax.jit( _compute_training_diff_imbalance ) self._compute_final_diff_imbalance_and_error = jax.jit( _compute_final_diff_imbalance_and_error ) self._compute_final_diff_imbalance = jax.jit(_compute_final_diff_imbalance) self._train_step = jax.jit(_train_step) def _return_mask(self, npoints, discard_close_ind): """Returns a square boolean mask with False on the diagonals, and True elsewhere. Args: npoints (int): number of rows and columns of the mask matrix. discard_close_ind (int): defines the diagonals filled with False, with offset between -discard_close_ind (below the main diagonal) and +discard_close_ind (above the main diagonal). Returns: mask (jnp.array(float)): square boolean matrix of shape (npoints, npoints). """ mask = jnp.abs( jnp.arange(npoints)[:, jnp.newaxis] - jnp.arange(npoints)[jnp.newaxis, :] ) mask = (mask > discard_close_ind).astype(jnp.bool) # more columns than necessary discarded for starting and final rows, for shape compatibility first_rows = jnp.concatenate( ( jnp.zeros(2 * discard_close_ind + 1), jnp.ones(npoints - 2 * discard_close_ind - 1), ), dtype=bool, ) last_rows = jnp.concatenate( ( jnp.ones(npoints - 2 * discard_close_ind - 1), jnp.zeros(2 * discard_close_ind + 1), ), dtype=bool, ) mask = mask.at[:discard_close_ind].set(first_rows) mask = mask.at[-discard_close_ind:].set(last_rows) return mask def _return_nn_indices(self, discard_close_ind=0): """ Returns indices of the nearest neighbor of each point. Args: discard_close_ind (int): given any point i, defines the "close" points (following the labelling order along axis=0 of data_A and data_B) that are known to be significantly correlated with i. For example, this may occur when the data set is a time series, and axis=0 is the time dimension. For each point i, distances between i and points within the time window [i-discard_close_ind, i+discard_close_ind] are discarded. Default is 0, for which no distances between "time-correlated" points are discarded. Returns: nn_indices (np.array(float)): array of the nearest neighbors indices: nn_indices[i] is the index of the column with value 1 in the rank matrix. """ rank_matrix = self._compute_rank_matrix( batch_rows=self.data_A_rows, batch_columns=self.data_A_columns, periods=self.periods_A, ) npoints = rank_matrix.shape[0] # discard diagonal elements rank_matrix = rank_matrix.at[jnp.arange(npoints), jnp.arange(npoints)].set( npoints + 1 ) # construct and apply mask to discard distances between "close" points if discard_close_ind > 0: mask = self._return_mask( npoints=rank_matrix.shape[0], discard_close_ind=discard_close_ind ) rank_matrix = rank_matrix[mask].reshape((rank_matrix.shape[0], -1)) rank_matrix = rank_matrix.argsort(axis=1).argsort(axis=1) + 1 nn_indices = jnp.argmin(rank_matrix, axis=1) return nn_indices def _train_epoch(self, key): """Performs the training for a single epoch. Args: key (jax.random.PRNGKey): key for the JAX pseudo-random number generator (PRNG). Returns: params (jnp.array(float)): array of shape (n_features_A,) containing the weights at the last step of the current training epoch. The single mini-batch updates are not returned. imb (float): value of the DII at the last step of the current training epoch. """ # ----------------------------MINI-BATCH GD---------------------------- if self.batches_per_epoch > 1: all_batch_indices = jnp.split( jax.random.permutation(key, self.nrows), self.batches_per_epoch ) # mini-batch GD (subsample both rows and columns) for batch_indices in all_batch_indices: self.state, imb = self._train_step( self.state, self.data_A_rows[batch_indices], self.data_A_columns[batch_indices], self.ranks_B[batch_indices][:, batch_indices] .argsort(axis=1) .argsort(axis=1), ) # DON'T DELETE: Alternative method for mini-batch GD (only subsample rows) # for i_batch, batch_indices in enumerate(all_batch_indices): # ordered_column_indices = np.ravel( # np.delete(all_batch_indices, i_batch, axis=0) # ) # ordered_column_indices = np.append( # batch_indices, ordered_column_indices # ) # self.state, imb = self._train_step( # self.state, # self.data_A_rows[batch_indices], # self.data_A_columns[ordered_column_indices], # self.ranks_B[batch_indices][:, ordered_column_indices], # ) # -----------------------------BATCH GD---------------------------- else: self.state, imb = self._train_step( self.state, self.data_A_rows, self.data_A_columns, self.ranks_B, ) assert not jnp.isnan(self.state.params).any(), ( "All weights were set to zero during the optimization. " + "Reduce the value of l1_strength." ) return self.state.params, imb def _init_optimizer(self): """Initializes the optimizer and the training state using the Optax library. The function uses the attribute optimizer_name of the DiffImbalance object, which can be set to one of the following options: "sgd", "adam", "adamw". For more information on these optimizers, see https://optax.readthedocs.io/en/latest/api/optimizers.html. """ if self.optimizer_name.lower() == "adam": opt_class = optax.adam elif self.optimizer_name.lower() == "adamw": opt_class = optax.adamw elif self.optimizer_name.lower() == "sgd": opt_class = optax.sgd else: raise ValueError( f'Unknown optimizer "{self.optimizer_name.lower()}". Choose among "sgd", "adam" and "adamw".' ) # set the learning rate schedule (cosine decay, exp decay or constant) if self.learning_rate_decay == "cos": self.lr_schedule = optax.cosine_decay_schedule( init_value=self.learning_rate, decay_steps=self.num_epochs * self.batches_per_epoch, ) elif self.learning_rate_decay == "exp": self.lr_schedule = optax.exponential_decay( init_value=self.learning_rate, transition_steps=10, decay_rate=0.5, ) elif self.learning_rate_decay is None: self.lr_schedule = optax.constant_schedule(value=self.learning_rate) else: raise ValueError( f'Unknown learning rate decay schedule "{self.learning_rate_decay}". Choose among None, "cos" and "exp".' ) optimizer = opt_class(self.lr_schedule) # Initialize training state self.state = train_state.TrainState.create( apply_fn=self._distance_A, params=self.params_init if self.state is None else self.state.params, tx=optimizer, )
[docs] def train(self, bar_label=None): """Performs the full training of the DII, using the input attributes of the DiffImbalance object. Notice that when mini-batches are employed, for efficiency reasons the DII is *not* recomputed over the full data set at each training epoch. To access the value of the DII over the full data set, use after training the method 'return_final_dii'. Args: bar_label (str): label on the tqdm training bar, useful when several trains are performed. Returns: params_training (np.array(float)): matrix of shape (num_epochs+1, n_features_A) containing the feature weights during the training, starting from their initialization. Also accessible as attribute of the CausalGraph object. imbs_training (np.array(float)): array of shape (num_epochs+1,) containing the DII during the training. Element imbs_training[i] is the DII computed over the last mini-batch used in training epoch i. The same output is accessible as attribute of the CausalGraph object. """ # Initialize optimizer self._init_optimizer() # Construct output arrays and initialize them using inital weights params_training = jnp.empty(shape=(self.num_epochs + 1, self.nparams)) imbs_training = jnp.empty(shape=(self.num_epochs + 1,)) batch_indices = jnp.arange(self.nrows // self.batches_per_epoch) imb_start = self._compute_training_diff_imbalance( params=self.params_init, batch_A_rows=self.data_A_rows[batch_indices], batch_A_columns=self.data_A_columns[batch_indices], batch_B_ranks=self.ranks_B[batch_indices][:, batch_indices] .argsort(axis=1) .argsort(axis=1), step=0, ) # DON'T DELETE: Alternative method for mini-batching (only sample rows) # imb_start = self._compute_training_diff_imbalance( # params=self.params_init, # batch_A_rows=self.data_A_rows[batch_indices], # batch_A_columns=self.data_A_columns, # batch_B_ranks=self.ranks_B[batch_indices], # step=0, # ) params_training = params_training.at[0].set(jnp.abs(self.params_init)) imbs_training = imbs_training.at[0].set(imb_start) # Train over different epochs desc = "Training" if bar_label is not None: desc += f" ({bar_label})" for epoch_idx in tqdm(range(1, self.num_epochs + 1), desc=desc): self.key, subkey = jax.random.split(self.key, num=2) params_now, imb_now = self._train_epoch(subkey) params_training = params_training.at[epoch_idx].set(jnp.abs(params_now)) imbs_training = imbs_training.at[epoch_idx].set(imb_now) self.params_final = params_training[-1] self.params_training = params_training self.imbs_training = imbs_training return np.array(params_training), np.array(imbs_training)
[docs] def return_final_dii( self, compute_error=True, ratio_rows_columns=1, seed=0, discard_close_ind=0 ): """Returns final DII computed over the full data set using the optimal weights. If the training was carried out with mini-batches of small size, this method allows computing a better estimate of the DII than the final DII value produced by 'train'. When 'compute_error=False' and 'discard_close_ind=0', the final DII produced by 'train' is the same computed by 'return_final_dii' if the training was performed without mini-batches (batches_per_epoch=1) and without row subsampling ('num_points_rows=None'). The value of k for computing the smoothing parameter lambda is set in order to keep the same ratio k/N used in the training phase (if batches_per_epoch > 1, N is the size of mini-batches used during the training). Args: compute_error (bool): whether to compute the final DII and its error by sampling different points along rows and columns of the distance matrix. If False, the final DII is computed using the same points along rows and columns, which does not allow for an error estimation. Default is True. ratio_rows_columns (float): only read when compute_error is True; defines the ratio between the number of points along rows (nrows) and along columns (ncolumns) of distance and rank matrices, in two groups randomly sampled. In general, nrows and ncolumns are determined by solving the equations nrows / ncolumns = ratio_rows_columns, nrows + ncolumns = n_total_points. Default is 1, which means that both groups have n_points / 2 elements. discard_close_ind (int): given any point i, defines the "close" points (following the labelling order along axis=0 of data_A and data_B) that are known to be significantly correlated with i. For example, this may occur when the data set is a time series, and axis=0 is the time dimension. If compute_error is True, "time-correlated" points are excluded by subsampling the data along axis=0 with stride discard_close_ind + 1. If compute_error is False, distances between each point i and points within the time window [i-discard_close_ind, i+discard_close_ind] are discarded. Default is 0, for which no distances between points close in time are discarded. seed (int): seed of JAX random generator, default is 0. Returns: imb_final (float): final DII, also accessible as attribute of the CausalGraph object. error_final (float): error associated to final DII, also accessible as attribute of the CausalGraph object. If compute_error is False, error_final is set to None. """ assert self.params_final is not None, "First call the train() method!" if compute_error is True and ratio_rows_columns is None: raise ValueError( "Option 'compute_error==True' requires a value for the argument 'ratio_rows_columns'." ) elif compute_error is False and ratio_rows_columns is not None: warnings.warn( f"You set 'compute_error' to False; argument 'ratio_rows_columns' will be ignored.\n" + f"To suppress this warning set it to None." ) # case 1: compute final DII and its error, using different points for rows and columns if compute_error == True: # subsample data to remove neighbor correlations, with stride discard_close_ind+1 data_A = self.data_A data_B = self.data_B distances_B = self.distances_B if discard_close_ind != 0: subsamples = jnp.arange( 0, self.data_A.shape[0], discard_close_ind + 1, dtype=int ) data_A = data_A[subsamples] if self.distances_B is None: data_B = data_B[subsamples] else: distances_B = distances_B[subsamples][:, subsamples] # Split points in two groups, labelling rows and columns. The number of rows 'nrows' # comes from equations nrows / ncols = ratio_rows_columns and nrows + ncols = npoints. nrows = int(ratio_rows_columns / (ratio_rows_columns + 1) * data_A.shape[0]) self.key = jax.random.PRNGKey(seed) # initialize jax random generator self.key, subkey = jax.random.split(self.key, num=2) indices_rows = jax.random.choice( subkey, jnp.arange(data_A.shape[0]), shape=(nrows,), replace=False ) indices_columns = jnp.delete(jnp.arange(data_A.shape[0]), indices_rows) # compute final DII and its error if self.distances_B is None: # space B provided as features ranks_B = ( self._compute_rank_matrix( batch_rows=data_B[indices_rows], batch_columns=data_B[indices_columns], periods=self.periods_B, ) + 1 ) else: # space B provided as distances ranks_B = ( (distances_B[indices_rows][:, indices_columns]) .argsort(axis=1) .argsort(axis=1) ) + 1 # set k to keep same ration k/N used during DII training k = int( jnp.ceil( self.k_final * self.batches_per_epoch / (discard_close_ind + 1) ) ) imb_final, error_final = self._compute_final_diff_imbalance_and_error( params=self.params_final, batch_A_rows=data_A[indices_rows], batch_A_columns=data_A[indices_columns], batch_B_ranks=ranks_B, k=k, ) # case 2: compute final DII only (square distance matrices) elif compute_error == False: # construct mask to discard distances d[i, i-discard_close_ind:i+discard_close_ind+1], for each i mask = None self.mask = None npoints = self.data_A.shape[0] if discard_close_ind != 0: mask = self._return_mask( npoints=npoints, discard_close_ind=discard_close_ind ) self.mask = mask # compute final DII if self.distances_B is None: # space B provided as features ranks_B = self._compute_rank_matrix( batch_rows=self.data_B, batch_columns=self.data_B, periods=self.periods_B, ) else: # space B provided as distances ranks_B = self.distances_B.argsort(axis=1).argsort(axis=1) if mask is not None: ranks_B = ranks_B[mask].reshape((ranks_B.shape[0], -1)) ranks_B = ranks_B.argsort(axis=1).argsort(axis=1) + 1 # set k to keep same ratio k/N used during DII training k = int( jnp.ceil( self.k_final * self.batches_per_epoch * (1 - 2 * discard_close_ind / self.ncolumns) ) ) imb_final = self._compute_final_diff_imbalance( params=self.params_final, batch_A_rows=self.data_A, batch_A_columns=self.data_A, batch_B_ranks=ranks_B, k=k, ) error_final = None self.imb_final = imb_final self.error_final = error_final return imb_final, error_final
[docs] def forward_greedy_feature_selection( self, n_features_max=None, n_best=10, compute_error=False, ratio_rows_columns=1, seed=0, discard_close_ind=0, ): """Performs forward greedy feature selection using the Differentiable Information Imbalance. Starting with all individual features, the algorithm evaluates which single feature has the lowest DII. Then it combines the best n_best single features with each of the remaining features to find the best 2-feature combination. This process continues until n_features_max features are selected or all features are included. For each candidate feature set, the weights are optimized specifically for that subset. When mini-batches are used, the same random seed ensures consistent mini-batch sequences, and the same split of points along rows and columns of distance matrices if compute_error is True. Args: n_features_max (int): maximum number of features to select. If None, will select up to all features. n_best (int): number of best feature tuples to consider at each iteration. Default is 10. compute_error (bool): whether to compute error estimates for the DII. Default is False. ratio_rows_columns (float): ratio between the number of points along rows and columns when computing the DII. Only used when compute_error is True. Default is 1. seed (int): seed for random number generation. Default is 0. discard_close_ind (int): index to discard close points when computing the DII. Default is 0. Returns: best_feature_sets (list): list of lists, where each sublist contains the indices of the selected features at each iteration. best_diis (list): list of DII values corresponding to each set of selected features. best_errors (list): list of error estimates for each DII value. Only meaningful if compute_error is True. best_weights_list (list): list of arrays containing the optimal weights for each set of selected features. """ if self.l1_strength != 0.0: warnings.warn(f"The greedy search will run with l1 strength equal to 0.") assert ( self.params_groups is None ), f"This method is not yet compatible with option 'params_groups'." n_features = self.nfeatures_A if n_features_max is None: n_features_max = n_features # Initialize lists to store results best_feature_sets = [] best_diis = [] best_errors = [] best_weights_list = [] ############################ First evaluate all single features ############################ single_feature_diis = [] single_feature_errors = [] for feature in range(n_features): # Create mask for this single feature mask = jnp.zeros(n_features, dtype=bool) mask = mask.at[feature].set(True) # Initialize weights for training (only this feature is active) # Use the corresponding value from self.params_init for this feature params_init = jnp.where(mask, self.params_init, 0.0) # Create a copy of the current object for training dii_copy = DiffImbalance( data_A=self.data_A, data_B=self.data_B, distances_B=self.distances_B, periods_A=self.periods_A, periods_B=self.periods_B, seed=seed, num_epochs=self.num_epochs, batches_per_epoch=self.batches_per_epoch, l1_strength=0.0, point_adapt_lambda=self.point_adapt_lambda, k_init=self.k_init, k_final=self.k_final, lambda_factor=self.lambda_factor, params_init=params_init, optimizer_name=self.optimizer_name, learning_rate=self.learning_rate, learning_rate_decay=self.learning_rate_decay, num_points_rows=self.num_points_rows, ) # Set initial parameters and train _, _ = dii_copy.train() # Compute DII on the full dataset if compute_error: dii_copy.return_final_dii( compute_error=True, ratio_rows_columns=ratio_rows_columns, seed=seed, discard_close_ind=discard_close_ind, ) single_feature_diis.append(float(dii_copy.imb_final)) single_feature_errors.append(float(dii_copy.error_final)) else: dii_copy.return_final_dii( compute_error=False, ratio_rows_columns=None, seed=seed, discard_close_ind=discard_close_ind, ) single_feature_diis.append(float(dii_copy.imb_final)) single_feature_errors.append(None) print(f"Feature set = [{feature}], DII = {dii_copy.imb_final}\n") # Convert to numpy arrays for easier manipulation single_feature_diis = np.array(single_feature_diis) # Select the best n_best single features n_best_actual = min(n_best, n_features) selected_indices = np.argsort(single_feature_diis)[:n_best_actual] # Convert indices to lists for consistent processing selected_features = [[idx] for idx in selected_indices] # Add the best single feature to results best_feature = selected_features[0] best_feature_sets.append(best_feature) best_diis.append(single_feature_diis[selected_indices[0]]) # Store the optimal weights for the best single feature best_weights = np.zeros(n_features) best_weights[best_feature[0]] = self.params_init[ best_feature[0] ] # Inherit from parent class # Add to weights list best_weights_list.append(best_weights) if compute_error: best_errors.append(single_feature_errors[selected_indices[0]]) else: best_errors.append(None) # Print the best single feature information print("------------------------------------------------") print(f"Best single feature: [{best_feature[0]}]") print(f"\tDII: {single_feature_diis[selected_indices[0]]}") print(f"\tOptimal weights: {best_weights}") print(f"Selected {n_best_actual} best candidates for next iteration") print("------------------------------------------------") # Get all features as a list all_features = list(range(n_features)) ############################ Greedy loop over n-tuples (n>1) ############################ while len(best_feature_sets[-1]) < min(n_features_max, n_features): candidate_features = [] candidate_diis = [] candidate_errors = [] # Generate candidate feature sets by combining selected features with remaining features for selected_set in selected_features: for feature in all_features: if feature not in selected_set: # Create a new candidate set by adding this feature candidate_set = selected_set + [feature] candidate_set.sort() # Sort for consistent comparison # Skip if this set has already been evaluated if candidate_set in candidate_features: continue candidate_features.append(candidate_set) # Create mask for this candidate set mask = jnp.zeros(n_features, dtype=bool) mask = mask.at[jnp.array(candidate_set)].set(True) # Initialize weights for training: inherit from parent class params_init = jnp.where(mask, self.params_init, 0.0) # Create a copy of the current object for training dii_copy = DiffImbalance( data_A=self.data_A, data_B=self.data_B, distances_B=self.distances_B, periods_A=self.periods_A, periods_B=self.periods_B, seed=seed + len(candidate_features), # Ensure reproducibility num_epochs=self.num_epochs, batches_per_epoch=self.batches_per_epoch, l1_strength=0.0, point_adapt_lambda=self.point_adapt_lambda, k_init=self.k_init, k_final=self.k_final, lambda_factor=self.lambda_factor, params_init=params_init, optimizer_name=self.optimizer_name, learning_rate=self.learning_rate, learning_rate_decay=self.learning_rate_decay, num_points_rows=self.num_points_rows, ) # Set initial parameters and train _, _ = dii_copy.train() # Compute DII on the full dataset if compute_error: dii_copy.return_final_dii( compute_error=True, ratio_rows_columns=ratio_rows_columns, seed=seed, discard_close_ind=discard_close_ind, ) candidate_diis.append(float(dii_copy.imb_final)) candidate_errors.append(float(dii_copy.error_final)) else: dii_copy.return_final_dii( compute_error=False, ratio_rows_columns=None, seed=seed, discard_close_ind=discard_close_ind, ) candidate_diis.append(float(dii_copy.imb_final)) candidate_errors.append(None) print( f"Feature set = {candidate_set}, DII = {dii_copy.imb_final}\n" ) # Convert to numpy arrays for easier manipulation candidate_diis = np.array(candidate_diis) if not candidate_features: # No more features to add break # Select the best n_best candidates for the next iteration n_best_actual = min(n_best, len(candidate_features)) best_indices = np.argsort(candidate_diis)[:n_best_actual] selected_features = [candidate_features[i] for i in best_indices] # Print the best feature set information best_idx = best_indices[0] # Add the best new set to results best_feature_sets.append(candidate_features[best_idx]) best_diis.append(candidate_diis[best_idx]) if compute_error: candidate_errors = np.array(candidate_errors) best_errors.append(candidate_errors[best_idx]) else: best_errors.append(None) # Create a copy of DiffImbalance to get the optimal weights for the best feature set # (not saved before to avoid memory problems for large data sets) mask = jnp.zeros(n_features, dtype=bool) mask = mask.at[jnp.array(candidate_features[best_idx])].set(True) params_init = jnp.where(mask, self.params_init, 0.0) dii_copy = DiffImbalance( data_A=self.data_A, data_B=self.data_B, distances_B=self.distances_B, periods_A=self.periods_A, periods_B=self.periods_B, seed=seed, num_epochs=self.num_epochs, batches_per_epoch=self.batches_per_epoch, l1_strength=0.0, point_adapt_lambda=self.point_adapt_lambda, k_init=self.k_init, k_final=self.k_final, lambda_factor=self.lambda_factor, params_init=params_init, optimizer_name=self.optimizer_name, learning_rate=self.learning_rate, learning_rate_decay=self.learning_rate_decay, num_points_rows=self.num_points_rows, ) # Set initial parameters and train _, _ = dii_copy.train() # Print and store optimal weights print( f"\nOptimal weights for feature set {candidate_features[best_idx]}: {dii_copy.params_final}\n" ) # Save optimal weights best_weights = np.array(dii_copy.params_final) best_weights_list.append(best_weights) # Print the best n-tuple information print("------------------------------------------------") print( f"Best {len(best_feature_sets[-1])}-tuple: {candidate_features[best_idx]}" ) print(f"\tDII: {candidate_diis[best_idx]}") print(f"\tOptimal weights: {best_weights}") print(f"Selected {n_best_actual} best candidates for next iteration") print("------------------------------------------------") # Stop if we've reached the maximum number of features if len(best_feature_sets[-1]) == n_features: break return best_feature_sets, best_diis, best_errors, best_weights_list
[docs] def backward_greedy_feature_selection( self, n_features_min=1, n_best=10, compute_error=False, ratio_rows_columns=1, seed=0, discard_close_ind=0, ): """Performs backward greedy feature selection using the Differentiable Information Imbalance. Starting with all features, the algorithm progressively removes the least informative features one at a time, until either no features are left or n_features_min is reached. For each iteration, the algorithm selects the n_best feature sets with the lowest DII values for consideration in the next round. The method should be called after calling the train() method, which performs the first optimization. For each candidate feature set, the weights are optimized specifically for that subset. When mini-batches are used, the same random seed ensures consistent mini-batch sequences, and the same split of points along rows and columns of distance matrices if compute_error is True. Args: n_features_min (int): minimum number of features to select. Default is 1. n_best (int): number of best feature tuples to consider at each iteration. Default is 10. compute_error (bool): whether to compute error estimates for the DII. Default is False. ratio_rows_columns (float): ratio between the number of points along rows and columns when computing the DII. Only used when compute_error is True. Default is 1. seed (int): seed for random number generation. Default is 0. discard_close_ind (int): index to discard close points when computing the DII. Default is 0. Returns: feature_sets (list): list of lists, where each sublist contains the indices of the selected features at each iteration. diis (list): list of DII values corresponding to each set of selected features. errors (list): list of error estimates for each DII value. Only meaningful if compute_error is True. best_weights_list (list): list of arrays containing the optimal weights for each set of selected features. """ if self.l1_strength != 0.0: warnings.warn(f"The greedy search will run with l1 strength equal to 0.") assert ( self.params_groups is None ), f"This method is not yet compatible with option 'params_groups'." assert self.params_final is not None, "First call the train() method!" n_features = self.nfeatures_A # Initialize lists to store results feature_sets = [] diis = [] errors = [] best_weights_list = [] # Start with all features and use the original trained weights current_features = [list(range(n_features))] ############################ First evaluate all features together ############################ if compute_error: self.return_final_dii( compute_error=True, ratio_rows_columns=ratio_rows_columns, seed=seed, discard_close_ind=discard_close_ind, ) diis.append(float(self.imb_final)) errors.append(float(self.error_final)) else: self.return_final_dii( compute_error=False, ratio_rows_columns=None, # Set to None when compute_error is False seed=seed, discard_close_ind=discard_close_ind, ) diis.append(float(self.imb_final)) errors.append(None) # Print all-feature information print("------------------------------------------------") print(f"All features: {current_features}") print(f"\tDII: {self.imb_final}") print(f"\tOptimal weights: {self.params_final}") print("------------------------------------------------") feature_sets.append(current_features[0].copy()) best_weights_list.append(self.params_final) ############################ Greedy loop over n-tuples (n<D) ############################ while feature_sets[-1] and len(feature_sets[-1]) > n_features_min: candidate_diis = [] candidate_errors = [] candidate_features = [] # Generate candidates by removing one feature from each of the current best feature sets for selected_set in current_features: if len(selected_set) <= n_features_min: # Skip sets that are already at minimum size continue for i, feature in enumerate(selected_set): # Create candidate feature set by removing this feature candidate_set = selected_set.copy() candidate_set.pop(i) # Sort the candidate set for consistent comparison candidate_set.sort() # Skip if this set has already been evaluated if candidate_set in candidate_features: continue candidate_features.append(candidate_set) # Create mask for this candidate set mask = jnp.zeros(n_features, dtype=bool) mask = mask.at[jnp.array(candidate_set)].set(True) # Initialize weights for training: inherit from parent class params_init = jnp.where(mask, self.params_init, 0.0) # Reset the random seed for consistent mini-batch sequence training_seed = seed + len(candidate_features) # Create a copy of the current object for training dii_copy = DiffImbalance( data_A=self.data_A, data_B=self.data_B, distances_B=self.distances_B, periods_A=self.periods_A, periods_B=self.periods_B, seed=training_seed, num_epochs=self.num_epochs, batches_per_epoch=self.batches_per_epoch, l1_strength=0.0, point_adapt_lambda=self.point_adapt_lambda, k_init=self.k_init, k_final=self.k_final, lambda_factor=self.lambda_factor, params_init=params_init, params_groups=None, optimizer_name=self.optimizer_name, learning_rate=self.learning_rate, learning_rate_decay=self.learning_rate_decay, num_points_rows=self.num_points_rows, ) # Set initial parameters and train _, _ = dii_copy.train() # Store the trained weights trained_weights = dii_copy.params_final # Use return_final_dii to compute DII on the full dataset dii_copy.params_final = trained_weights if compute_error: dii_copy.return_final_dii( compute_error=True, ratio_rows_columns=ratio_rows_columns, seed=seed, discard_close_ind=discard_close_ind, ) candidate_diis.append(dii_copy.imb_final) candidate_errors.append(dii_copy.error_final) else: dii_copy.return_final_dii( compute_error=False, ratio_rows_columns=None, # Set to None when compute_error is False seed=seed, discard_close_ind=discard_close_ind, ) candidate_diis.append(dii_copy.imb_final) candidate_errors.append(None) print( f"Feature set = {candidate_set}, DII = {dii_copy.imb_final}\n" ) # Make sure we have candidates before proceeding if not candidate_features: print("No more candidates to evaluate, exiting backward search") break # Convert to numpy arrays for easier manipulation candidate_diis = np.array(candidate_diis) # Select the best n_best candidates n_best_actual = min(n_best, len(candidate_features)) best_indices = np.argsort(candidate_diis)[:n_best_actual] # Update current features for the next iteration current_features = [candidate_features[i] for i in best_indices] # Select the best candidate (lowest DII) best_idx = best_indices[0] best_feature_set = candidate_features[best_idx] # Create a copy of DiffImbalance to get the optimal weights for the best feature set # (not saved before to avoid memory problems for large data sets) mask = jnp.zeros(n_features, dtype=bool) mask = mask.at[jnp.array(best_feature_set)].set(True) params_init = jnp.where(mask, self.params_init, 0.0) dii_copy = DiffImbalance( data_A=self.data_A, data_B=self.data_B, distances_B=self.distances_B, periods_A=self.periods_A, periods_B=self.periods_B, seed=seed, num_epochs=self.num_epochs, batches_per_epoch=self.batches_per_epoch, l1_strength=0.0, point_adapt_lambda=self.point_adapt_lambda, k_init=self.k_init, k_final=self.k_final, lambda_factor=self.lambda_factor, params_init=params_init, params_groups=None, optimizer_name=self.optimizer_name, learning_rate=self.learning_rate, learning_rate_decay=self.learning_rate_decay, num_points_rows=self.num_points_rows, ) # Set initial parameters and train _, _ = dii_copy.train() # Save optimal weights best_weights = dii_copy.params_final best_weights_list.append(best_weights) # Store results feature_sets.append(best_feature_set.copy()) diis.append(candidate_diis[best_idx]) if compute_error: candidate_errors = np.array(candidate_errors) errors.append(candidate_errors[best_idx]) else: errors.append(None) # Print the best n-tuple information print("------------------------------------------------") print(f"Best {len(best_feature_set)}-tuple: {candidate_features[best_idx]}") print(f"\tDII: {candidate_diis[best_idx]}") print(f"\tOptimal weights: {best_weights}") print(f"Selected {n_best_actual} best candidates for next iteration") print("------------------------------------------------") return feature_sets, diis, errors, best_weights_list