"""
..
Copyright (c) 2016-2017, Magni developers.
All rights reserved.
See LICENSE.rst for further information.
Module providing threshold functions for the Approximate Message Passing (AMP)
algorithm.
Routine listings
----------------
ValidatedThresholdOperator(magni.utils.validation.types.ThresholdOperator)
A base class for validated `magni.cs.reconstruction.amp` threshold operator
SoftThreshold(ValidatedThresholdOperator)
A soft threshold operator.
"""
from __future__ import division
import numpy as np
import scipy.stats
from magni.utils.validation import decorate_validation as _decorate_validation
from magni.utils.validation import validate_generic as _generic
from magni.utils.validation import validate_numeric as _numeric
from magni.utils.validation import validate_once as _validate_once
from magni.utils.validation.types import (
ThresholdOperator as _ThresholdOperator)
[docs]class ValidatedThresholdOperator(_ThresholdOperator):
"""
A base class for validated `magni.cs.reconstruction.amp` threshold operator
Parameters
----------
var : dict
The threshold operator state variables.
"""
def __init__(self, var):
super(ValidatedThresholdOperator, self).__init__(var)
@_decorate_validation
def validate_input():
_generic('var', 'mapping')
validate_input()
@_validate_once
[docs] def compute_deriv_threshold(self, var):
"""
Compute the entrywise derivative threshold.
Parameters
----------
var : dict
The variables used in computing the derivative threshold.
Returns
-------
eta_deriv : ndarray
The computed entrywise derivative threshold.
Notes
-----
This method honors `magni.utils.validation.enable_allow_validate_once`.
"""
@_decorate_validation
def validate_input():
_generic('var', 'mapping')
validate_input()
@_validate_once
[docs] def compute_threshold(self, var):
"""
Compute the entrywise threshold.
Parameters
----------
var : dict
The variables used in computing the threshold.
Returns
-------
eta : ndarray
The computed entrywise threshold.
Notes
-----
This method honors `magni.utils.validation.enable_allow_validate_once`.
"""
@_decorate_validation
def validate_input():
_generic('var', 'mapping')
validate_input()
@_validate_once
[docs] def update_threshold_level(self, var):
"""
Update the threshold level state.
Parameters
----------
var : dict
The variables used in computing the threshold level update.
Notes
-----
This method honors `magni.utils.validation.enable_allow_validate_once`.
"""
@_decorate_validation
def validate_input():
_generic('var', 'mapping')
validate_input()
[docs]class SoftThreshold(ValidatedThresholdOperator):
"""
A soft threshold operator.
This soft threshold operator is based on the description of it and its use
in AMP as given in [1]_ with corrections from [2]_.
Parameters
----------
threshold_level_update_method : {'residual', 'median'}
The method to use for updating the threshold level.
theta : float
The tunable regularisation parameter in the threshold level.
tau_hat_sq : float
The mean squared error of the (approximated) un-thresholded
estimate used to determine the threshold level.
Notes
-----
The above Parameters are the threshold parameters that must be passed
in a `var` dict to the threshold constructor.
References
----------
.. [1] A. Montanari, "Graphical models concepts in compressed sensing" *in
Compressed Sensing: Theory and Applications*, Y. C. Eldar and
G. Kutyniok (Ed.), Cambridge University Press, ch. 9, pp. 394-438, 2012.
.. [2] J. T. Parker, "Approximate Message Passing Algorithms for
Generalized Bilinear Inference", PhD Thesis, Graduate School of The Ohio
State University, 2014
"""
def __init__(self, var):
super(SoftThreshold, self).__init__(var)
@_decorate_validation
def validate_threshold_parameters():
_generic(('var', 'threshold_parameters'), 'mapping')
_generic(('var', 'threshold_parameters',
'threshold_level_update_method'), 'string',
value_in=['residual', 'median'])
_numeric(('var', 'threshold_parameters', 'theta'),
('integer', 'floating'), range_='(0;inf)')
_numeric(('var', 'threshold_parameters', 'tau_hat_sq'),
('integer', 'floating'), range_='[0;inf)')
validate_threshold_parameters()
t_params = var['threshold_parameters']
self.theta = t_params['theta'] # alpha in Eq. (9.44) in [1].
self.tau_hat_sq = var['convert'](t_params['tau_hat_sq'])
self.update_method = t_params['threshold_level_update_method']
self.m = var['y'].shape[0]
self.stdQ1 = var['convert'](scipy.stats.norm.ppf(1 - 0.25))
[docs] def compute_deriv_threshold(self, var):
"""
Compute the entrywise derivative soft threshold.
Parameters
----------
var : dict
The variables used in computing the derivative threshold.
Returns
-------
eta_deriv : ndarray
The computed entrywise derivative threshold.
"""
super(SoftThreshold, self).compute_deriv_threshold(var)
op = var['alpha_bar_prev'] + var['AH_dot_chi']
thres = self.theta * np.sqrt(self.tau_hat_sq)
eta_deriv = var['convert']((op > thres) + (op < -thres))
return eta_deriv
[docs] def compute_threshold(self, var):
"""
Compute the entrywise soft threshold.
Parameters
----------
var : dict
The variables used in computing the threshold.
Returns
-------
eta : ndarray
The computed entrywise threshold.
"""
super(SoftThreshold, self).compute_threshold(var)
op = var['alpha_bar'] + var['AH_dot_chi']
thres = self.theta * np.sqrt(self.tau_hat_sq)
eta = var['convert'](
((op - thres) * (op > thres) + (op + thres) * (op < -thres)))
return eta
[docs] def update_threshold_level(self, var):
"""
Update the threshold level state.
Parameters
----------
var : dict
The variables used in computing the threshold level update.
"""
super(SoftThreshold, self).update_threshold_level(var)
chi = var['chi']
if self.update_method == 'residual':
# Eq. (9.44) in [1]
self.tau_hat_sq = 1.0 / self.m * np.linalg.norm(chi)**2
elif self.update_method == 'median':
# Eq. (9.45) in [1] corrected according to [2]
self.tau_hat_sq = (
1.0 / self.stdQ1 * np.median(np.abs(chi)))**2