Source code for hic3defdr.util.balancing

import numpy as np
import scipy.sparse as sparse


[docs]def kr_balance(array, tol=1e-6, x0=None, delta=0.1, ddelta=3, fl=1, max_iter=3000): """ Balances a matrix via Knight-Ruiz matrix balancing, using sparse matrix operations. Parameters ---------- array : np.ndarray or scipy.sparse.spmatrix The matrix to balance. Must be 2 dimensional and square. Will be symmetrized using its upper triangular entries. tol : float The error tolerance. x0 : np.ndarray, optional The initial guess for the bias vector. Pass None to use a vector of all 1's. delta, ddelta : float How close/far the balancing vectors can get to/from the positive cone, in terms of a relative measure on the size of the elements. fl : int Determines whether or not internal convergence statistics will be printed to standard output. max_iter : int The maximum number of iterations to perform. Returns ------- sparse.csr_matrix, np.ndarray, np.ndarray The CSR matrix is the balanced matrix. It will be upper triangular if ``array`` was upper triangular, otherwise it will by symmetric. The first np.ndarray is the bias vector. The second np.ndarray is a list of residuals after each iteration. Notes ----- The core logic of this implementation is transcribed (by Dan Gillis) from the original Knight and Ruiz IMA J. Numer. Anal. 2013 paper and differs only slightly from the implementation in Juicer/Juicebox (see comments). It then uses the "sum factor" approach from the Juicer/Juicebox code to scale the bias vectors up to match the scale of the original matrix (see comments). Overall, this function is not nan-safe, but you may pass matrices that contain empty rows (the matrix will be shrunken before balancing, but all outputs will match the shape of the original matrix). This function does not perform any of the logic used in Juicer/Juicebox to filter out sparse rows if the balancing fails to converge. If the max_iter limit is reached, this function will return the current best balanced matrix and bias vector - no exception will be thrown. Callers can compare the length of the list of residuals to max_iter - if they are equal, the algorithm has not actually converged, and the caller may choose to perform further filtering on the matrix and try again. The caller is also free to decide if the matrix is "balanced enough" using any other criteria (e.g., variance of nonzero row sums). """ # check to see if input is upper triangular triu = sparse.tril(array, k=-1).nnz == 0 # convert to CSR and symmetrize array = sparse.triu(sparse.csr_matrix(array)) array += array.transpose() - sparse.diags([array.diagonal()], [0]) # shrink matrix, storing copy of original full matrix nonzero = array.getnnz(1) > 0 full_array = array.copy() array = array[nonzero, :][:, nonzero] # this block was copied from lib5c.algorithms.knight_ruiz.kr_balance() which # was transcribed by Dan Gillis from the Matlab source provided in the # supplement of Knight and Ruiz IMA J. Numer. Anal. 2013 # the only changes made in this block (relative to the lib5c version) are # that the dot products have been rewritten to work with sparse matrices # the Matlab/Python implementation is nearly equivalent to the Java # function compteKRNormVector() at # https://github.com/theaidenlab/Juicebox/blob/ # 5a56089c63957cb15401ea7906ab77e242dfd755/src/juicebox/tools/utils/ # original/NormalizationCalculations.java#L138 # except that # 1) the Matlab/Python implementation has a concept of Delta/ddelta, which # serves as an upper limit analogue to delta # 2) the Java version stops infinite iteration using a not_changing counter, # which terminates the iteration when rho_km1 hasn't changed # significantly in 100 iterations and returns a failing result; the # Python version uses a max_iter counter, which simply sets a maximum # number of iterations to perform but does not return a failing result if # the maximum number of iterations is exceeded; the Matlab implementation # does neither and will instead loop forever dims = array.shape if dims[0] != dims[1] or len(dims) > 2: raise Exception n = dims[0] e = np.ones((n, 1)) res = np.array([]) if x0 is None: x0 = e.copy() g = np.float_(0.9) eta_max = 0.1 eta = eta_max stop_tol = tol * 0.5 # x0 is not used again, so do not need to copy object x = x0 # n x 1 array rt = tol ** 2 v = x0 * array.dot(x) # n x 1 array rk = 1 - v # n x 1 array rho_km1 = np.dot(np.transpose(rk), rk)[0, 0] # scalar value # no need to copy next two objects - are scalar values rout = rho_km1 rold = rout mvp = 0 i = 0 if fl == 1: print('it in. it res') while rout > rt: if max_iter is not None and i > max_iter: break i += 1 k = 0 y = np.ones((n, 1)) innertol = max(eta ** 2 * rout, rt) rho_km2 = None while rho_km1 > innertol: k += 1 if k == 1: z = rk / v p = z.copy() rho_km1 = np.dot(np.transpose(rk), z) else: beta = rho_km1 / rho_km2 p = z + beta * p w = x * array.dot(x * p) + v * p alpha = rho_km1 / np.dot(np.transpose(p), w) ap = alpha * p ynew = y + ap if np.min(ynew) <= delta: if delta == 0: break ind = np.where(ap < 0) gamma = np.min((delta - y[ind]) / ap[ind]) y = y + gamma * ap break if np.max(ynew) >= ddelta: # see above - need to code # also check that (i.e. gamma or ap is scalar) ind = np.where(ynew > ddelta) gamma = min((ddelta - y[ind]) / ap[ind]) y = y + gamma * ap break y = ynew rk = rk - alpha * w rho_km2 = rho_km1 z = rk / v rho_km1 = np.dot(np.transpose(rk), z) x = x * y v = x * array.dot(x) rk = 1 - v rho_km1 = np.dot(np.transpose(rk), rk)[0, 0] # scalar value rout = rho_km1 mvp = mvp + k + 1 rat = rout / rold rold = rout res_norm = np.sqrt(rout) eta_0 = eta eta = g * rat if g * eta_0 ** 2 > 0.1: eta = max(eta, g * eta_0 ** 2) eta = max(min(eta, eta_max), stop_tol / res_norm) if fl == 1: print('{} {} {:.3e}'.format(i, k, res_norm)) res = np.append(res, res_norm) del array # expand bias vector bias = np.zeros_like(nonzero, dtype=float) bias[nonzero] = np.squeeze(x) # compute balanced matrix bias_csr = sparse.diags([bias], [0]) balanced = bias_csr.dot(full_array) balanced = balanced.dot(bias_csr) # compute "sum factor" and recompute balanced matrix # this is equivalent to the Java function getSumFactor() at # https://github.com/theaidenlab/Juicebox/blob/ # 5a56089c63957cb15401ea7906ab77e242dfd755/src/juicebox/tools/utils/ # original/NormalizationCalculations.java#L333 sum_factor = np.sqrt(full_array.sum() / balanced.sum()) bias *= sum_factor bias_csr = sparse.diags([bias], [0]) balanced = bias_csr.dot(full_array) balanced = balanced.dot(bias_csr) del full_array # invert bias vector at nonzero positions bias[bias != 0] = 1 / bias[bias != 0] # if input was upper triangular, return upper triangular balanced matrix if triu: balanced = sparse.triu(balanced) balanced = balanced.tocsr() return balanced, bias, res