Source code for hamming

import os
from time import time

import numpy as np
from jax import config
from jax import devices as jdevices
from jax import jit, lax
from jax import numpy as jnp
from jax import random
from jax.scipy.special import gammaln
from jax.tree_util import register_pytree_node

# from jax.experimental.host_callback import call

jdevices("cpu")[0]  # to run JAX on CPU
eps = 1e-7  # good old small epsilon
config.update("jax_enable_x64", True)  # enable jnp.float64 dtype


class Hamming:
    def __init__(
        self,
        q=2,  # number of states: 2 for binary spins. In the future we can think of extending this to q>2.
        coordinates=None,  # spins: must be normalized to +-1 to compute distances.
        distances=None,  #
        crossed_distances=0,  # 0 means we have one dataset with N samples and N(N-1)/2 (correlated) distances.
        verbose=True,  #
    ):
        self.q = q
        self.coordinates = coordinates
        self.distances = distances
        self.crossed_distances = crossed_distances
        self.verbose = verbose

        self.r = None
        self.r_idx = None
        self.D_values = None
        self.D_counts = None
        self.D_probs = None
        self.D_mu_emp = None
        self.D_var_emp = None

    def compute_distances(
        self,
        sort=False,
        check_format=True,
    ):
        """
        Computes all to all distances in dataset and stores them
        in the matrix "self.distances" of shape (Ns,Ns), where Ns is the number of samples
        """
        if self.q == 2:
            self.distances = jcompute_distances(
                X1=self.coordinates,
                X2=self.coordinates,
                crossed_distances=self.crossed_distances,
                check_format=check_format,
                sort=sort,
            )

    """TODO: MODIFY HISTOGRAM ROUTINE TO DISCARD THE TRIVIAL ZEROS WHEN CROSSED_DISTANCES = 1"""

    def D_histogram(
        self,
        compute_flag=0,  # 1 to compute histogram (else it is loaded)
        save=False,  # 1 to save computed histogram
        resultsfolder="results/hist/",
        filename="counts.txt",
    ):
        """
        Given the computed distances, this routine computes the histogram (Pemp).
        It defines
        - self.D_values, a vector containing the sampled distances
        - self.D_counts, a vector containing how many times each distance was sampled
        - self.D_probs, self.counts normalized by the total number of counts observed.
        """
        assert self.crossed_distances == 0

        if save:
            os.makedirs(resultsfolder, exist_ok=True)
        _filename = resultsfolder + filename

        if compute_flag:
            self.D_values, self.D_counts = np.unique(self.distances, return_counts=True)

            if self.crossed_distances == 0:
                Nsamples = self.distances.shape[0]
                assert self.D_values[0] == 0  # trivial zeros
                Nzeros = int(
                    Nsamples * (Nsamples + 1) / 2
                )  # trivial zeros, Gauss sum of them
                self.D_counts[0] -= Nzeros
                if self.D_counts[0] == 0:
                    self.D_values = self.D_values[1:]
                    self.D_counts = self.D_counts[1:]
            self.D_probs = self.D_counts / np.sum(self.D_counts)

            if save:
                np.savetxt(
                    fname=_filename,
                    X=np.transpose([self.D_values, self.D_counts]),
                    fmt="%d,%d",
                )
        else:
            f = np.loadtxt(_filename, delimiter=",", dtype=int)
            self.D_values = f[:, 0]
            self.D_counts = f[:, 1]
            self.D_probs = self.D_counts / np.sum(self.D_counts)

    def set_r_quantile(self, alpha, round=True, precision=10):
        """
        Defines

        - self.r as the quantile of order alpha of self.D_probs,
        which can be used for rmax or rmin,
        to discard distances larger than rmax or smaller than rmin, respectively.
        - self.r_idx as the index of self.r in self.D_values
        """
        if round:
            alpha = np.round(alpha, precision)
            self.D_probs = np.round(self.D_probs, precision)

        indices = np.where(np.cumsum(self.D_probs) <= alpha)[0]
        if len(indices) == 0:
            self.r_idx = 0
        else:
            self.r_idx = indices[-1]

        self.r = int(self.D_values[self.r_idx])
        return

    def compute_moments(self):
        """
        computes the empirical mean and variance of H.D_probs
        """
        self.D_mu_emp = np.dot(self.D_probs, self.D_values)
        self.D_var_emp = np.dot(self.D_probs, self.D_values**2) - self.D_mu_emp**2


def check_data_format(X):
    e1, e2 = np.unique(X)
    assert (
        e1 == -1 and e2 == 1
    ), f"spins have to be formatted to -+1, but {np.unique(X)=}"


[docs] def jcompute_distances( X1, X2, crossed_distances, check_format=True, sort=False, ): """This routine works for Ising spins variables defined as +-1 (this is faster than scipy)""" X1 = jnp.array(X1).astype(jnp.int32) X2 = jnp.array(X2).astype(jnp.int32) if check_format: check_data_format(X1) check_data_format(X2) Ns1, N = X1.shape Ns2 = X2.shape[0] # the samples in the other dataset must have also N spins... distances = jnp.zeros(shape=(Ns1, Ns2), dtype=jnp.int32) sample_idx = 0 lower_idx = 0 pytree = { "crossed_distances": crossed_distances, "D": distances, "X1": X1, "X2": X2, "sample_idx": sample_idx, "lower_idx": lower_idx, # to avoid computing distances twice if crossed_distances=1 "Ns1": Ns1, # number of samples in dataset 1 "Ns2": Ns2, # number of samples in dataset 2 "N": N, # number of spins in each sample } pytree = lax.fori_loop( lower=0, upper=Ns1, body_fun=_jcompute_distances, init_val=pytree ) if sort: return np.array(jnp.sort(pytree["D"])) else: return np.array(pytree["D"])
@jit def _jcompute_distances(idx, pytree): """ for each data sample indexed by "sample_idx" (row), computes the distance between it and the rest """ pytree["sample_idx"] = idx pytree = lax.cond( pytree["crossed_distances"], _set_lower_idx_true, _set_lower_idx_false, pytree ) pytree = lax.fori_loop( lower=pytree["lower_idx"], upper=pytree["Ns2"], body_fun=compute_row_distances, init_val=pytree, ) return pytree def _set_lower_idx_true(pytree): """ if we have two datasets, we have Ns1 * Ns2 distances to compute. """ pytree["lower_idx"] = 0 return pytree def _set_lower_idx_false(pytree): """ if we have one dataset, we have Ns(Ns-1)/2 distances to compute (the upper triangular part of "distances") """ pytree["lower_idx"] = pytree["sample_idx"] + 1 return pytree
[docs] @jit def compute_row_distances(_idx, pytree): """ for each data sample indexed by "sample_idx", computes the distance between it and the rest """ pytree["D"] = ( pytree["D"] .at[pytree["sample_idx"], _idx] .set( jnp.int32( ( pytree["N"] - jnp.dot( pytree["X1"][pytree["sample_idx"], :], pytree["X2"][_idx, :] ) ) / 2 ) ) ) return pytree
[docs] class Optimizer: """ Stochastic optimization """ def __init__( self, key=0, d0=0.0, # BID d0_r=0.0, # BID + random perturbation (*** used by compute_Pmodel instead of d0...) d1=0.0, # slope d1_r=0.0, # slope + random pertubation (*** used by compute_Pmodel instead of d1...) delta=0.0, # optimization step size KL=jnp.double(0.0), # KL divergence between Pemp(r) and Pmodel(r) KL_aux=jnp.inf, # auxiliary variable to check when KL decreases remp=None, # vector with empirical Hamming distances Pemp=None, # vector with empirical probabilities Pmodel=None, # vector with model probabilities Nsteps=0, # Number of steps for the optimization accepted=0, # Accepted moves acc_ratio=jnp.double(1.0), # Acceptance ratio save_logKLs_flag=0, # Flag to save logKLs during optimization logKLs=None, # Vector with logKLs during optimization idx=0, # Auxiliary index mod_divisor=0, # Auxiliary variable to export the log KL Nsteps_max=1000, # Total number of saved steps (subsample of the total number of steps) ): self.key = key self.d0 = d0 self.d0_r = d0_r self.d1 = d1 self.d1_r = d1_r self.delta = delta self.KL = KL self.KL_aux = KL_aux self.remp = remp self.Pemp = Pemp self.Pmodel = Pmodel self.Nsteps = Nsteps self.accepted = accepted self.acc_ratio = acc_ratio self.save_logKLs_flag = save_logKLs_flag self.logKLs = logKLs self.idx = idx self.mod_divisor = mod_divisor self.Nsteps_max = Nsteps_max def _tree_flatten(self): children = ( self.key, self.d0, self.d0_r, self.d1, self.d1_r, self.delta, self.KL, self.KL_aux, self.remp, self.Pemp, self.Pmodel, self.Nsteps, self.accepted, self.acc_ratio, self.save_logKLs_flag, self.logKLs, self.idx, self.mod_divisor, ) # arrays / dynamic values aux_data = { "Nsteps_max": self.Nsteps_max, } # static values return (children, aux_data) @classmethod def _tree_unflatten(cls, aux_data, children): return cls(*children, **aux_data)
register_pytree_node(Optimizer, Optimizer._tree_flatten, Optimizer._tree_unflatten) def _compute_Pmodel(idx, Op): """ note that this routine uses Op.d0_r and Op.d1_r to compute Pmodel """ ID = Op.d0_r + Op.d1_r * Op.remp[idx] Op.Pmodel = Op.Pmodel.at[idx].set( jnp.exp( gammaln(ID + jnp.double(1)) - gammaln(Op.remp[idx] + 1) - gammaln(ID - Op.remp[idx] + 1) - ID * jnp.log(jnp.double(2)) ) ) # call(lambda x: print(f'{x}'),Op.Pmodel[idx]) return Op @jit def compute_Pmodel(Op): Op = lax.fori_loop( lower=0, upper=Op.Pmodel.shape[0], body_fun=_compute_Pmodel, init_val=Op ) Op.Pmodel /= jnp.sum(Op.Pmodel) return Op @jit def step(idx, Op): Op.key, subkey = random.split(Op.key, num=2) r = random.uniform(subkey, dtype=jnp.float64) Op.d0_r = Op.d0 * (1 + Op.delta * (r - jnp.double(0.5))) Op.key, subkey = random.split(Op.key, num=2) rr = random.uniform(subkey, dtype=jnp.float64) Op.d1_r = Op.d1 * (1 + Op.delta * (rr - jnp.double(0.5))) Op = compute_Pmodel(Op) Op = compute_KLd(Op) Op = lax.cond(Op.KL <= Op.KL_aux, update_state, do_nothing, Op) logical_condition = jnp.logical_and( Op.save_logKLs_flag, jnp.mod(idx, Op.mod_divisor) == 0 ) Op = lax.cond(logical_condition, save_logKL, do_nothing, Op) return Op @jit def compute_KLd(Op): Op.KL = jnp.sum(Op.Pemp * jnp.log(Op.Pemp / Op.Pmodel)) return Op @jit def update_state(Op): Op.d0 = Op.d0_r Op.d1 = Op.d1_r Op.KL_aux = Op.KL Op.accepted += 1 return Op @jit def do_nothing(Op): return Op @jit def save_logKL(Op): Op.logKLs = Op.logKLs.at[Op.idx].set(jnp.log(Op.KL)) Op.idx += 1 return Op @jit def minimize_KL(Op): Op.logKLs = jnp.empty(shape=(Op.Nsteps_max), dtype=jnp.double) Op.mod_divisor = Op.Nsteps // Op.Nsteps_max Op = lax.fori_loop(lower=0, upper=Op.Nsteps, body_fun=step, init_val=Op) # This is necessary to keep the last *accepted* move Op.d0_r = Op.d0 Op.d1_r = Op.d1 Op = compute_Pmodel(Op) Op = compute_KLd(Op) Op.acc_ratio = jnp.double(Op.accepted) / jnp.double(Op.Nsteps) return Op class BID: def __init__( self, H=Hamming(), # instance of Hamming class defined above Op=None, # instance of Optimizer class defined above alphamin=0.0, alphamax=0.2, seed=1, d0=jnp.double(0), # BID \equiv d(r=0) (see paper) d1=jnp.double(0), # slope of d(r) at r=0 d00=jnp.double(0), # initial value of d0 d10=jnp.double(0), # initial value of d1 delta=5e-4, Nsteps=1e6, optfolder0="results/opt/", load_initial_condition_flag=False, optimization_elapsed_time=None, export_results=1, export_logKLs=0, # To export the curve of logKLs during optimization L=0, # Number of bits / Ising spins ): self.H = H self.Op = Op self.alphamin = alphamin self.alphamax = alphamax self.seed = seed self.d0 = d0 self.d1 = d1 self.d00 = d00 self.d10 = d10 self.delta = delta self.Nsteps = Nsteps self.optimization_elapsed_time = optimization_elapsed_time # in minutes self.export_results = export_results self.export_logKLs = export_logKLs self.L = L self.intrinsic_dim = self.d0 self.key0 = random.PRNGKey(self.seed) self.optfolder0 = optfolder0 self.load_initial_condition_flag = load_initial_condition_flag if np.isclose(alphamin, 0): self.regularize = False else: self.regularize = True def load_initial_condition( self, ): self.set_filepaths() _, self.d00, self.d10, _ = self.load_results() def set_filepaths( self, ): self.optfolder = self.optfolder0 self.optfolder += f"alphamin{self.alphamin:.5f}/" self.optfolder += f"alphamax{self.alphamax:.5f}/" self.optfolder += f"Nsteps{self.Nsteps}/" self.optfolder += f"delta{self.delta:.5f}/" self.optfolder += f"seed{self.seed}/" self.optfile = self.optfolder + "opt.txt" self.valfile = self.optfolder + "model_validation.txt" self.KLfile = self.optfolder + "logKLs_opt.txt" def set_idmin( self, ): if self.regularize is False: self.idmin = 0 self.rmin = self.H.D_values[0] else: self.H.set_r_quantile(self.alphamin) self.rmin = self.H.r self.idmin = self.H.r_idx self.H.r = None self.H.r_idx = None def set_idmax( self, ): self.H.set_r_quantile(self.alphamax) self.rmax = self.H.r self.idmax = self.H.r_idx self.H.r = None self.H.r_idx = None def truncate_hist(self): self.remp = jnp.array( self.H.D_values[self.idmin : self.idmax + 1], dtype=jnp.float64 ) self.Pemp = jnp.array( self.H.D_probs[self.idmin : self.idmax + 1], dtype=jnp.float64 ) self.Pemp /= jnp.sum(self.Pemp) self.Pmodel = jnp.zeros(shape=self.Pemp.shape, dtype=jnp.float64) def test_initial_condition(self, d0, d1): self.Op.d0_r = jnp.double(d0) self.Op.d1_r = jnp.double(d1) self.Op = compute_Pmodel(self.Op) self.Op = compute_KLd(self.Op) return np.log(self.Op.KL) def set_initial_condition(self, d00min=0.05, d00max=0.95, d00step=0.05): # our home-made guess: d00_guess_list = jnp.array([jnp.double(self.Op.remp[-1])]) d10_guess_list = jnp.array([jnp.double(1)]) if self.L != 0: # Inspired by Cristopher Erazo: self.H.compute_moments() _d00_guess_list = self.L * jnp.arange( d00min, d00max + eps, d00step, dtype=jnp.double ) _d10_guess_list = jnp.double(2) - _d00_guess_list / jnp.double( self.H.D_mu_emp ) d00_guess_list = jnp.concatenate((d00_guess_list, _d00_guess_list)) d10_guess_list = np.concatenate((d10_guess_list, _d10_guess_list)) logKLs0 = jnp.empty(shape=(len(d00_guess_list)), dtype=jnp.double) for i in range(len(d00_guess_list)): logKLs0 = logKLs0.at[i].set( self.test_initial_condition( d00_guess_list[i], d10_guess_list[i], ) ) # print(f'{logKLs0=}') i0 = jnp.nanargmin(logKLs0) # print(f'{i0=}') # print(f'{logKLs0[i0]=}') self.d00 = d00_guess_list[i0] # ; print(f'{self.d00=}') self.d10 = d10_guess_list[i0] # ; print(f'{self.d10=}') self.Op.d0 = self.d00 self.Op.d1 = self.d10 self.Op.KL_aux = jnp.exp(logKLs0[i0]) def computeBID( self, ): self.set_idmin() self.set_idmax() self.truncate_hist() self.set_filepaths() if self.export_results: os.makedirs(self.optfolder, exist_ok=True) self.Op = Optimizer( key=self.key0, d0=jnp.double(self.d00), d1=jnp.double(self.d10), delta=jnp.double(self.delta), remp=self.remp, Pemp=self.Pemp, Pmodel=self.Pmodel, Nsteps=self.Nsteps, save_logKLs_flag=self.export_logKLs, ) self.set_initial_condition() if self.H.verbose == 1: print("starting optimization") starting_time = time() self.Op = minimize_KL(self.Op) self.optimization_elapsed_time = (time() - starting_time) / 60.0 if self.H.verbose == 1: print(f"optimization took {self.optimization_elapsed_time:.1f} minutes") print( f"d_0={self.Op.d0:.3f},d_1={self.Op.d1:.3f},logKL={jnp.log(self.Op.KL):.2f}" ) if self.export_results: os.system(f"rm -f {self.optfile}") print( f"{self.rmax:d},{self.Op.d0:.8f},{self.Op.d1:8f},{np.log(self.Op.KL):.8f}", file=open(self.optfile, "a"), ) np.savetxt( fname=self.valfile, X=np.transpose([self.remp, self.Pemp, self.Op.Pmodel]), ) if self.export_logKLs: np.savetxt(fname=self.KLfile, X=self.Op.logKLs) self.d0 = self.Op.d0.item() self.d1 = self.Op.d1.item() self.logKL = jnp.log(self.Op.KL).item() self.Pmodel = np.array(self.Op.Pmodel) self.intrinsic_dim = self.d0 def load_results( self, ): self.set_filepaths() return np.loadtxt(self.optfile, unpack=True, delimiter=",") def load_fit( self, ): self.set_filepaths() return np.loadtxt(self.valfile, unpack=True) def load_logKLs_opt( self, ): self.set_filepaths() return np.loadtxt(self.KLfile)