Source code for magni.cs.reconstruction.gamp._algorithm

"""
..
    Copyright (c) 2015-2017, Magni developers.
    All rights reserved.
    See LICENSE.rst for further information.

Module providing the core Generalised Approximate Message Passing (GAMP)
algorithm.

Routine listings
----------------
run(y, A, A_asq=None)
    Run the GAMP reconstruction algorithm.

See Also
--------
magni.cs.reconstruction.gamp._config : Configuration options.
magni.cs.reconstruction.gamp.input_channel : Available input channels.
magni.cs.reconstruction.gamp.output_channel : Available output channels.
magni.cs.reconstruction.gamp.stop_criterion : Available stop critria.

Notes
-----
The default configuration of the GAMP algorithm provides the s-GB AMP algorithm
from [1]_ using an MSE convergence based stop criterion. Both the input
channel, the output channel, and the stop criterion may be changed.

This implementation allows for the use of sum approximations of the squared
system matrix as detailed in [1]_ and [2]_. Furthermore, a simple damping
option is available based on the description in [3]_ (see also [4]_ for more
details on damping in GAMP).

References
----------
.. [1] F. Krzakala, M. Mezard, F. Sausset, Y. Sun, and L. Zdeborova,
   "Probabilistic reconstruction in compressed sensing: algorithms, phase
   diagrams, and threshold achieving matrices", *Journal of Statistical
   Mechanics: Theory and Experiment*, vol. P08009, pp. 1-57, Aug. 2012.
.. [2] S. Rangan, "Generalized Approximate Message Passing for Estimation
   with Random Linear Mixing", arXiv:1010.5141v2, pp. 1-22, Aug. 2012.
.. [3] S. Rangan, P. Schniter, and A. Fletcher. "On the Convergence of
   Approximate Message Passing with Arbitrary Matrices", *in IEEE International
   Symposium on Information Theory (ISIT)*, pp. 236-240, Honolulu, Hawaii, USA,
   Jun. 29 - Jul. 4, 2014.
.. [4] J. Vila, P. Schniter, S. Rangan, F. Krzakala, L. Zdeborova, "Adaptive
   Damping and Mean Removal for the Generalized Approximate Message Passing
   Algorithm", *in IEEE International Conference on Acoustics, Speech, and
   Signal Processing (ICASSP)*, South Brisbane, Queensland, Australia, Apr.
   19-24, 2015, pp. 2021-2025.

"""

from __future__ import division
import copy

import numpy as np

from magni.cs.reconstruction.gamp import config as _conf
from magni.utils.config import Configger as _Configger
from magni.utils.validation import decorate_validation as _decorate_validation
from magni.utils.validation import validate_generic as _generic
from magni.utils.validation import validate_numeric as _numeric
from magni.utils.validation.types import MatrixBase as _MatrixBase
from magni.utils.matrices import norm as _norm
from magni.utils.matrices import (SumApproximationMatrix as
                                  _SumApproximationMatrix)


[docs]def run(y, A, A_asq=None): """ Run the GAMP reconstruction algorithm. Parameters ---------- y : ndarray The m x 1 measurement vector. A : ndarray or magni.utils.matrices.{Matrix, MatrixCollection} The m x n matrix which is the product of the measurement matrix and the dictionary matrix. A_asq : ndarray or magni.utils.matrices.{Matrix, MatrixCollection} or None The m x n matrix which is the entrywise absolute value squared product of the measurement matrix and the dictionary matrix (the default is None, which implies that a sum approximation is used). Returns ------- alpha : ndarray The n x 1 reconstructed coefficient vector. history : dict, optional The dictionary of various measures tracked in the GAMP iterations. See Also -------- magni.cs.reconstruction.gamp._config : Configuration options. magni.cs.reconstruction.gamp.input_channel : Input channels. magni.cs.reconstruction.gamp.output_channel : Output channels. magni.cs.reconstruction.gamp.stop_criterion : Stop criteria. Notes ----- Optionally, the algorithm may be configured to save and return the iteration history along with the reconstruction result. The returned history contains the following: * alpha_bar : Mean coefficient estimates (the reconstruction coefficients). * alpha_tilde : Variance coefficient estimates. * MSE : solution Mean squared error (if the true solution is known). * input_channel_parameters : The state of the input channel. * output_channel_parameters : The state of the output channel. * stop_criterion : The currently used stop criterion. * stop_criterion_value : The value of the stop criterion. * stop_iteration : The iteration at which the algorithm stopped. * stop_reason : The reason for termination of the algorithm. Examples -------- For example, recovering a vector from AWGN noisy measurements using GAMP >>> import numpy as np >>> from magni.cs.reconstruction.gamp import run, config >>> np.random.seed(seed=6028) >>> k, m, n = 10, 200, 400 >>> tau = float(k) / n >>> A = 1 / np.sqrt(m) * np.random.randn(m, n) >>> A_asq = np.abs(A)**2 >>> alpha = np.zeros((n, 1)) >>> alpha[:k] = np.random.normal(scale=2, size=(k, 1)) >>> np_printoptions = np.get_printoptions() >>> np.set_printoptions(suppress=True, threshold=k+2) >>> alpha[:k + 2] array([[ 1.92709461], [ 0.74378508], [-3.2418159 ], [-1.32277347], [ 0.90118 ], [-0.19157262], [ 0.82855712], [ 0.24817994], [-1.43034777], [-0.21232344], [ 0. ], [ 0. ]]) >>> sigma = 0.15 >>> y = A.dot(alpha) + np.random.normal(scale=sigma, size=(A.shape[0], 1)) >>> input_channel_params = {'tau': tau, 'theta_bar': 0, 'theta_tilde': 4, ... 'use_em': False} >>> config['input_channel_parameters'] = input_channel_params >>> output_channel_params = {'sigma_sq': sigma**2, ... 'noise_level_estimation': 'fixed'} >>> config['output_channel_parameters'] = output_channel_params >>> alpha_hat = run(y, A, A_asq) >>> alpha_hat[:k + 2] array([[ 1.93810961], [ 0.6955502 ], [-3.39759349], [-1.35533562], [ 1.10524227], [-0.00594848], [ 0.79274671], [ 0.04895264], [-1.08726071], [-0.00142911], [ 0.00022861], [-0.00004272]]) >>> np.sum(np.abs(alpha - alpha_hat) > sigma * 3) 0 or recover the same vector returning a history comparing the pr. iteration solution to the true vector and printing the A_asq details >>> config['report_A_asq_setup'] = True >>> config['report_history'] = True >>> config['true_solution'] = alpha >>> alpha_hat, history = run(y, A, A_asq) # doctest: +NORMALIZE_WHITESPACE GAMP is using the A_asq: [[ 0.024 ..., 0.002] ..., [ 0. ..., 0.014]] The sum approximation method is: None >>> alpha_hat[:k + 2] array([[ 1.93810961], [ 0.6955502 ], [-3.39759349], [-1.35533562], [ 1.10524227], [-0.00594848], [ 0.79274671], [ 0.04895264], [-1.08726071], [-0.00142911], [ 0.00022861], [-0.00004272]]) >>> np.array(history['MSE']).reshape(-1, 1)[1:11] array([[ 0.04562729], [ 0.01328304], [ 0.00112098], [ 0.00074968], [ 0.00080175], [ 0.00076615], [ 0.00077043]]) or recover the same vector using sample variance AWGN noise level estimation >>> config['report_A_asq_setup'] = False >>> config['report_history'] = False >>> output_channel_params['noise_level_estimation'] = 'sample_variance' >>> config['output_channel_parameters'] = output_channel_params >>> alpha_hat = run(y, A, A_asq) >>> alpha_hat[:k + 2] array([[ 1.94820622], [ 0.72162206], [-3.39978431], [-1.35357001], [ 1.10701779], [-0.00834467], [ 0.79790879], [ 0.08441384], [-1.08946306], [-0.0015894 ], [ 0.00020561], [-0.00003623]]) >>> np.sum(np.abs(alpha - alpha_hat) > sigma * 3) 0 or recover the same vector using median AWGN noise level estimation >>> output_channel_params['noise_level_estimation'] = 'median' >>> config['output_channel_parameters'] = output_channel_params >>> alpha_hat = run(y, A, A_asq) >>> alpha_hat[:k + 2] array([[ 1.93356483], [ 0.65232347], [-3.39440429], [-1.35437724], [ 1.10312573], [-0.0050555 ], [ 0.78743162], [ 0.03616397], [-1.08589927], [-0.00136802], [ 0.00024121], [-0.00004498]]) >>> np.sum(np.abs(alpha - alpha_hat) > sigma * 3) 0 or recover the same vector learning the AWGN noise level using expectation maximization (EM) >>> output_channel_params['noise_level_estimation'] = 'em' >>> config['output_channel_parameters'] = output_channel_params >>> alpha_hat = run(y, A, A_asq) >>> alpha_hat[:k + 2] array([[ 1.94118089], [ 0.71553983], [-3.40076165], [-1.35662005], [ 1.1099417 ], [-0.00688125], [ 0.79442879], [ 0.06258856], [-1.08792606], [-0.00148811], [ 0.00022266], [-0.00003785]]) >>> np.sum(np.abs(alpha - alpha_hat) > sigma * 3) 0 >>> np.set_printoptions(**np_printoptions) """ @_decorate_validation def validate_input(): _numeric('y', ('integer', 'floating', 'complex'), shape=(-1, 1)) _numeric('A', ('integer', 'floating', 'complex'), shape=( y.shape[0], _conf['true_solution'].shape[0] if _conf['true_solution'] is not None else -1)) if isinstance(A_asq, _MatrixBase): # It is not possible to validate the range of an implicit matrix # Thus allow all possible values range_ = '[-inf;inf]' else: range_ = '[0;inf)' _numeric('A_asq', ('integer', 'floating', 'complex'), range_=range_, shape=A.shape, ignore_none=True) @_decorate_validation def validate_output(): # complex128 is two float64 (real and imaginary part) each taking 8*8 # bits. Thus, in total 2*8*8=128 bits. However, we only consider it to # be "64 bit precision" since that is what each part is. bits_pr_nbytes = 4 if np.iscomplexobj(convert(0)) else 8 _numeric('alpha', ('integer', 'floating', 'complex'), shape=(A.shape[1], 1), precision=convert(0).nbytes * bits_pr_nbytes) _generic('history', 'mapping', keys_in=('alpha_bar', 'alpha_tilde', 'MSE', 'input_channel_parameters', 'output_channel_parameters', 'stop_criterion', 'stop_criterion_value', 'stop_iteration', 'stop_reason')) validate_input() # Initialisation init = _get_gamp_initialisation(y, A, A_asq) AH = init['AH'] A_asq = init['A_asq'] AT_asq = init['AT_asq'] o = init['o'] q = init['q'] s = init['s'] r = init['r'] m = init['m'] n = init['n'] alpha_bar = init['alpha_bar'] alpha_tilde = init['alpha_tilde'] alpha_breve = init['alpha_breve'] A_dot_alpha_bar = init['A_dot_alpha_bar'] z_bar = init['z_bar'] damping = init['damping'] sum_approximation_method = init['sum_approximation_method'] A_frob_sq = init['A_frob_sq'] output_channel = init['output_channel'] output_channel_parameters = init['output_channel_parameters'] input_channel = init['input_channel'] input_channel_parameters = init['input_channel_parameters'] stop_criterion = init['stop_criterion'] stop_criterion_name = init['stop_criterion_name'] iterations = init['iterations'] tolerance = init['tolerance'] convert = init['convert'] report_history = init['report_history'] history = init['history'] true_solution = init['true_solution'] # GAMP iterations for it in range(iterations): # Save previous state alpha_bar_prev = alpha_bar # Used in stop criterion # GAMP state updates if sum_approximation_method == 'rangan': # Rangan's scalar variance sum approximation. # Factor side updates v = 1.0 / m * A_frob_sq * alpha_breve o = A_dot_alpha_bar - v * q z_bar, z_tilde = output_channel.compute(locals()) q = (z_bar - o) / v u = 1.0 / m * np.sum((v - z_tilde) / v**2) # Variable side updates (including damping) s_full = 1 / (1.0 / n * A_frob_sq * u) s = (1 - damping) * s_full + damping * s r = alpha_bar + s * AH.dot(q) alpha_bar_full, alpha_tilde = input_channel.compute(locals()) alpha_bar = (1 - damping) * alpha_bar_full + damping * alpha_bar alpha_breve = 1.0 / n * np.sum(alpha_tilde) else: # Either "full" GAMP or Krzakala's sum approximation. # For "full" GAMP, A_asq is the entrywise absolute value squared # matrix. # For Krzakala's sum approximation, A_asq is a # magni.utils.matrices.SumApproximationMatrix that implements the # scaled sum approximation to the full A_asq matrix. # Factor side updates v = A_asq.dot(alpha_tilde) o = A_dot_alpha_bar - v * q z_bar, z_tilde = output_channel.compute(locals()) q = (z_bar - o) / v u = (v - z_tilde) / v**2 # Variable side updates (including damping) s_full = 1 / (AT_asq.dot(u)) s = (1 - damping) * s_full + damping * s r = alpha_bar + s * AH.dot(q) alpha_bar_full, alpha_tilde = input_channel.compute(locals()) alpha_bar = (1 - damping) * alpha_bar_full + damping * alpha_bar # Stop criterion A_dot_alpha_bar = A.dot(alpha_bar) # Used in residual stop criteria stop, stop_criterion_value = stop_criterion.compute(locals()) # History reporting if report_history: history['alpha_bar'].append(alpha_bar) history['alpha_tilde'].append(alpha_tilde) history['input_channel_parameters'].append( _clean_vars(copy.deepcopy(vars(input_channel)))) history['output_channel_parameters'].append( _clean_vars(copy.deepcopy(vars(output_channel)))) history['stop_criterion_value'].append(stop_criterion_value) history['stop_iteration'] = it if true_solution is not None: history['MSE'].append( 1/n * np.linalg.norm(true_solution - alpha_bar)**2) if stop: history['stop_reason'] = stop_criterion_name.upper() break alpha = alpha_bar validate_output() if report_history: return alpha, history else: return alpha
[docs]def _clean_vars(vars_dict): """ Return a clean-up vars dict. All private variables are removed and magni configgers are converted to dictionaries. Parameters ---------- vars_dict : dict The dictionary of variables to clean. Returns ------- cleaned_vars : dict The cleaned dictionary of variables. """ cleaned_vars = dict() for key, val in vars_dict.items(): if not key.startswith('_'): if isinstance(val, _Configger): cleaned_vars[key] = dict(val) else: cleaned_vars[key] = val return cleaned_vars
[docs]def _get_gamp_initialisation(y, A, A_asq): """ Return an initialisation of the GAMP algorithm. Parameters ---------- y : ndarray The m x 1 measurement vector. A : ndarray or magni.utils.matrices.{Matrix, MatrixCollection} The m x n matrix which is the product of the measurement matrix and the dictionary matrix. A_asq : ndarray or magni.utils.matrices.{Matrix, MatrixCollection} or None The m x n matrix which is the entrywise absolute value squared product of the measurement matrix and the dictionary matrix (the default is None, which implies that a sum approximation is used). Returns ------- init : dict The initialisation of the the GAMP algorithm. """ init = dict() param = dict(_conf.items()) # Configured setup init['convert'] = param['precision_float'] init['damping'] = init['convert'](param['damping']) init['report_history'] = param['report_history'] init['tolerance'] = param['tolerance'] init['iterations'] = param['iterations'] init['true_solution'] = param['true_solution'] init['output_channel_parameters'] = param['output_channel_parameters'] init['input_channel_parameters'] = param['input_channel_parameters'] init['stop_criterion_name'] = param['stop_criterion'].__name__ # Arguments based configuration init['AH'] = A.conj().T init['m'], init['n'] = A.shape init['A_frob_sq'] = None init['sum_approximation_method'] = None if A_asq is None: init['sum_approximation_method'], sum_approx_const = tuple( param['sum_approximation_constant'].items())[0] # Krzakala's sum approx init['A_asq'] = _SumApproximationMatrix(sum_approx_const) init['AT_asq'] = _SumApproximationMatrix(sum_approx_const) # Rangan's sum approx if init['sum_approximation_method'] == 'rangan': # _norm generally forms A explicitly to compute the Frobenius norm # which may be infeasible memory-wise init['A_frob_sq'] = _norm(A, ord='fro')**2 else: # Full GAMP init['A_asq'] = A_asq init['AT_asq'] = A_asq.T # Channel and stop criterion configuration channel_init = copy.copy(init) channel_init['A'] = A channel_init['y'] = y init['input_channel'] = param['input_channel'](channel_init) init['output_channel'] = param['output_channel'](channel_init) init['stop_criterion'] = param['stop_criterion'](channel_init) # Configuration of remaining states if param['warm_start'] is not None: init['alpha_bar'] = init['convert'](param['warm_start'][0]) init['alpha_tilde'] = init['convert'](param['warm_start'][1]) else: init['alpha_bar'] = np.zeros((init['n'], 1), dtype=init['convert']) init['alpha_tilde'] = np.ones_like(init['alpha_bar']) init['alpha_breve'] = 1.0 / init['n'] * np.sum(init['alpha_tilde']) init['history'] = {'alpha_bar': [init['alpha_bar']], 'alpha_tilde': [init['alpha_tilde']], 'MSE': [np.nan], 'input_channel_parameters': [ init['input_channel_parameters']], 'output_channel_parameters': [ init['output_channel_parameters']], 'stop_criterion': init['stop_criterion_name'].upper(), 'stop_criterion_value': [np.nan], 'stop_iteration': 0, 'stop_reason': 'MAX_ITERATIONS'} init['A_dot_alpha_bar'] = A.dot(init['alpha_bar']) init['q'] = np.zeros_like(y) init['o'] = y.copy() init['s'] = np.ones_like(init['alpha_bar']) init['r'] = np.zeros_like(init['alpha_bar']) init['z_bar'] = y # Report on A_asq setup if param['report_A_asq_setup']: np_printoptions = np.get_printoptions() np.set_printoptions(precision=3, threshold=5, edgeitems=1) print('GAMP is using the A_asq: {}'.format(init['A_asq'])) print('The sum approximation method is: {}'.format( init['sum_approximation_method'])) np.set_printoptions(**np_printoptions) return init