Source code for magni.imaging.visualisation

"""
..
    Copyright (c) 2014-2017, Magni developers.
    All rights reserved.
    See LICENSE.rst for further information.

Module providing functionality for visualising images.

The module provides functionality for adjusting the intensity of an image. It
provides a wrapper of the `matplotlib.pyplot.imshow` function that may exploit
the provided functions for adjusting the image intensity. Also it include a
function may be used to display a set of related images using a common
colormapping.

Routine listings
----------------
imshow(X, ax=None, intensity_func=None, intensity_args=(), \*\*kwargs)
    Function that may be used to display an image.
imsubplot(imgs, rows, titles=None, x_labels=None, y_labels=None,
    x_ticklabels=None, y_ticklabels=None, cbar_label=None,
    normalise=True, fixed_clim=None, \*\*kwargs)
    Function that may be used to display a set of related images.
mask_img_from_coords(img, coords)
    Function for masking certain parts of an image based on coordinates.
shift_mean(x_mod, x_org)
    Function for shifting mean intensity of an image based on another image.
stretch_image(img, max_val, min_val=0)
    Function for stretching the intensity of an image.

"""

from __future__ import division

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

from magni.utils import plotting as _plotting
from magni.utils.validation import decorate_validation as _decorate_validation
from magni.utils.validation import validate_generic as _generic
from magni.utils.validation import validate_levels as _levels
from magni.utils.validation import validate_numeric as _numeric


[docs]def imshow(X, ax=None, intensity_func=None, intensity_args=(), show_axis='frame', **kwargs): """ Display an image. Wrap `matplotlib.pyplot.imshow` to display a possibly intensity manipulated version of the image `X`. Parameters ---------- X : ndarray The image to be displayed. ax : matplotlib.axes.Axes, optional The axes on which the image is displayed (the default is None, which implies that the current axes is used). intensity_func : FunctionType, optional The handle to the function used to manipulate the image intensity before the image is displayed (the default is None, which implies that no intensity manipulation is used). intensity_args : list or tuple, optional The arguments that are passed to the `intensity_func` (the default is (), which implies that no arguments are passed). show_axis : {'none', 'top', 'inherit', 'frame'} How the x- and y-axis are display. If 'none', no axis are displayed. If 'top', the x-axis is displayed at the top of the image. If 'inherit', the axis display is inherited from `matplotlib.pyplot.imshow`. If 'frame' only the frame is shown and not the ticks. Returns ------- im_out : matplotlib.image.AxesImage The AxesImage returned by matplotlibs imshow. See Also -------- matplotlib.pyplot.imshow : Matplotlib's imshow function. Examples -------- For example, >>> import numpy as np >>> from magni.imaging.visualisation import imshow >>> X = np.arange(4).reshape(2, 2) >>> add_k = lambda X, k: X + k >>> im_out = imshow(X, intensity_func=add_k, intensity_args=(2,)) """ @_decorate_validation def validate_input(): _numeric('X', ('boolean', 'integer', 'floating'), shape=(-1, -1)) _generic('ax', mpl.axes.Axes, ignore_none=True) _generic('intensity_func', 'function', ignore_none=True) _generic('intensity_args', 'explicit collection') _generic('show_axis', 'string', value_in=('none', 'top', 'inherit', 'frame')) validate_input() _plotting.setup_matplotlib() # Handle ax keyword argument ca = plt.gca() if ax is not None: plt.sca(ax) # Intensity manipulation if intensity_func is not None: im_out = plt.imshow(intensity_func(X, *intensity_args), **kwargs) else: im_out = plt.imshow(X, **kwargs) # Display of axis axes = plt.gca() if show_axis == 'none': axes.axis('off') elif show_axis == 'top': axes.xaxis.tick_top() axes.xaxis.set_label_position('top') elif show_axis == 'frame': xlabels = axes.get_xticklabels() ylabels = axes.get_yticklabels() empty_xlabels = ['']*len(xlabels) empty_ylabels = ['']*len(ylabels) axes.set_xticklabels(empty_xlabels) axes.set_yticklabels(empty_ylabels) axes.tick_params(length=0) plt.sca(ca) return im_out
[docs]def imsubplot(imgs, rows, titles=None, x_labels=None, y_labels=None, x_ticklabels=None, y_ticklabels=None, cbar_label=None, normalise=True, fixed_clim=None, **kwargs): """ Display a set of related images as subplots in a figure. The images `imgs` are shown in a figure with a subplot layout based on the number of `rows`. The `titles`, `x_labels`, `y_labels`, `x_ticklabels`, and `y_ticklabels` are shown in the subplots. If `normalise` is True, all the images will share the same normalised colorbar/colormapping i.e. a particular colour will correspond to the same value across all images. Parameters ---------- imgs : list or tuple The images (as ndarrays) that is to be displayed. rows : int The number of rows to use in the subplot layout. titles : list or tuple The titles (as strings) to use for each of the subplots (the default is None, which implies that no titles are displayed). x_labels : list or tuple The x_labels (as strings) to use for each of the subplots (the default is None, which implies that no x_labels are displayed). y_labels : list or tuple The y_labels (as strings) to use for each of the subplots (the default is None, which implies that no y_labels are displayed). x_ticklabels : list or tuple The x_ticklabels (as strings or lists of strings) to use for the subplots (the default is None, which implies that no x_ticklabels are displayed). y_ticklabels : list or tuple The y_ticklabels (as strings or lists of strings) to use for the subplots (the default is None, which implies that no y_ticklabels are displayed). cbar_label : str The colorbar label to use with a normalised colormapping (the default is None, which implies that no colorbar label is displayed). fixed_clim : list or tuple The colorbar limits as a (min, max) sequence (the default is None, which implies that the colorbar limits are inferred from the data). normalise : bool The flag that indicates whether to use a normalised colormapping. Returns ------- fig : matplotlib.figure.Figure The resulting figure instance. See Also -------- matplotlib.pyplot.subplots : Underlying subplot function. Notes ----- The `x_ticklabels` and `y_ticklabels` may be either a collection of strings or a collections of collections of strings depending on wherther the labels should be shared across all subplots or different labels are to be used for each subplot. Additional kwargs given to the function will be passed to the underlying suplot instantiation function `matplotlib.pyplot.subplots`. If `normalise` is True, the common colorbar is shown below the subplots. The implementation of the normalisation feature is based on the Pylab example: http://matplotlib.org/examples/pylab_examples/multi_image.html. Examples -------- For example, show two images next to each other with a common colormapping: >>> import numpy as np >>> from magni.imaging.visualisation import imsubplot >>> img1 = np.arange(4).reshape(2, 2) >>> img2 = np.ones((4, 4)) >>> fig = imsubplot([img1, img2], 1, titles=['arange', 'ones'], ... x_labels=['x1', 'x2'], y_labels=['y1', 'y2'], cbar_label='Example', ... normalise=True) or show the same images with shared ticklabels and fixed colorbar limits: >>> common_x_ticklabels = ['a', 'b'] >>> common_y_ticklabels = ['c', 'd'] >>> fig = imsubplot([img1, img2], 1, x_ticklabels=common_x_ticklabels, ... y_ticklabels=common_y_ticklabels, fixed_clim=(0, 2), normalise=True) or show the same images with different colormappings and ticklabels: >>> x_ticklabels = [['a', 'b'], ['aa', 'bb']] >>> y_ticklabels = [['c', 'd'], ['cc', 'dd']] >>> fig = imsubplot([img1, img2], 1, x_ticklabels=x_ticklabels, ... y_ticklabels=y_ticklabels, normalise=False) """ @_decorate_validation def validate_input(): _levels('imgs', (_generic(None, 'explicit collection'), _numeric(None, ('boolean', 'integer', 'floating'), shape=(-1, -1)))) _numeric('rows', 'integer', range_='[1;inf)') _levels('titles', (_generic(None, 'explicit collection', len_=len(imgs), ignore_none=True), _generic(None, 'string'))) _levels('x_labels', (_generic(None, 'explicit collection', len_=len(imgs), ignore_none=True), _generic(None, 'string'))) _levels('y_labels', (_generic(None, 'explicit collection', len_=len(imgs), ignore_none=True), _generic(None, 'string'))) _levels('x_ticklabels', (_generic(None, 'explicit collection', ignore_none=True), _generic(None, ('explicit collection', 'string')), _generic(None, 'string'))) _levels('y_ticklabels', (_generic(None, 'explicit collection', ignore_none=True), _generic(None, ('explicit collection', 'string')), _generic(None, 'string'))) _generic('cbar_label', 'string', ignore_none=True) _levels('fixed_clim', (_generic(None, 'explicit collection', len_=2, ignore_none=True), _numeric(None, ('integer', 'floating')))) _numeric('normalise', 'boolean') validate_input() cols = max(1, np.int(np.ceil(len(imgs) / rows))) fig, axes = plt.subplots(rows, cols, squeeze=False, **kwargs) axs = axes.ravel() vmin = 1e40 vmax = -1e40 ims = [] for k, img in enumerate(imgs): # Show image ims.append(imshow(img, ax=axs[k])) # Handle titles, labels, etc. fig_strings = (titles, x_labels, y_labels) handles = (axs[k].set_title, axs[k].set_xlabel, axs[k].set_ylabel) for fig_string, handle in zip(fig_strings, handles): if fig_string is not None: handle(fig_string[k]) _handle_ticklabels(axs[k], k, x_ticklabels, y_ticklabels) # Track minimum and maximum data values vmin = min(vmin, np.amin(img)) # Find common minimum vmax = max(vmax, np.amax(img)) # Find common maximum if normalise: # Connect a Tracker to each image in order to update colormap limits # This way all images share common colormap limits common_norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) ims[0].set_norm(common_norm) for k in range(1, len(ims)): ims[k].set_norm(common_norm) ims[0].callbacksSM.connect('changed', _ImageColourTracker(ims[k])) # Add common colorbar plt.subplots_adjust(bottom=0.2) c_bar_ax = fig.add_axes([0.2, 0.08, 0.6, 0.04]) c_bar = fig.colorbar(ims[0], c_bar_ax, orientation='horizontal') c_bar.solids.set_edgecolor("face") if fixed_clim is not None: # Force colorbar limits c_bar.set_clim(fixed_clim) if cbar_label is not None: # Set colorbar label c_bar.set_label(cbar_label) return fig
[docs]def mask_img_from_coords(img, coords): """ Mask coordinates in an image. The coordinates `coords` in the image `img` are masked such that only the cooordinates are shown. Parameters ---------- img : ndarray The image to mask. coords : ndarray The coordinates arranged into a 2D array, such that each row is a coordinate pair (x, y). Returns ------- masked_img : numpy.ma.MaskedArray The masked image. See Also -------- magni.imaging.measurements : Further description of the coordinate format. Examples -------- For example, display only center pixel in a 3-by-3 image >>> import numpy as np >>> from magni.imaging.visualisation import mask_img_from_coords >>> img = np.arange(9).reshape(3, 3) >>> coords = np.array([[1, 1]]) >>> mask_img_from_coords(img, coords) masked_array(data = [[-- -- --] [-- 4 --] [-- -- --]], mask = [[ True True True] [ True False True] [ True True True]], fill_value = 999999) <BLANKLINE> """ @_decorate_validation def validate_input(): _numeric('img', ('integer', 'floating', 'complex'), shape=(-1, -1)) _numeric('coords', 'integer', shape=(-1, 2), range_='[0;inf)') validate_input() mask = np.ones_like(img, dtype=np.bool_) mask[coords[:, 1], coords[:, 0]] = False masked_img = np.ma.array(img, mask=mask) return masked_img
[docs]def shift_mean(x_mod, x_org): """ Shift the mean value of `x_mod` such that it equals the mean of `x_org`. Parameters ---------- x_org : ndarray The array which hold the "true" mean value. x_mod : ndarray The modified copy of `x_org` which must have its mean value shifted. Returns ------- shifted_x_mod : ndarray A copy of `x_mod` with the same mean value as `x_org`. Examples -------- For example, >>> import numpy as np >>> from magni.imaging.visualisation import shift_mean >>> x_org = np.arange(4).reshape(2, 2) >>> x_mod = np.ones((2, 2)) >>> print('{:.1f}'.format(x_org.mean())) 1.5 >>> print('{:.1f}'.format(x_mod.mean())) 1.0 >>> shifted_x_mod = shift_mean(x_mod, x_org) >>> print('{:.1f}'.format(shifted_x_mod.mean())) 1.5 >>> np.set_printoptions(suppress=True) >>> shifted_x_mod array([[ 1.5, 1.5], [ 1.5, 1.5]]) """ @_decorate_validation def validate_input(): _numeric('x_mod', ('integer', 'floating', 'complex'), shape=(-1, -1)) _numeric('x_org', ('integer', 'floating', 'complex'), shape=x_mod.shape) validate_input() return x_mod + (x_org.mean() - x_mod.mean())
[docs]def stretch_image(img, max_val, min_val=0): """ Stretch image such that pixels values are in [`min_val`, `max_val`]. Parameters ---------- img : ndarray The (float) image that is to be stretched. max_val : int or float The maximum value in the stretched image. min_val : int or float The minimum value in the stretched image. Returns ------- stretched_img : ndarray A stretched copy of the input image. Notes ----- The pixel values in the input image are scaled to lie in the interval [`min_val`, `max_val`] using a linear stretch. Examples -------- For example, stretch an image between 0 and 1 >>> import numpy as np >>> from magni.imaging.visualisation import stretch_image >>> img = np.arange(4, dtype=np.float).reshape(2, 2) >>> stretched_img = stretch_image(img, 1) >>> np.set_printoptions(suppress=True) >>> stretched_img array([[ 0. , 0.33333333], [ 0.66666667, 1. ]]) or stretch the image between -1 and 1 >>> stretched_img = stretch_image(img, 1.0, min_val=-1.0) >>> stretched_img array([[-1. , -0.33333333], [ 0.33333333, 1. ]]) or re-stretch the strecthed image between -3.0 and -1.5 >>> stretched_img = stretch_image(stretched_img, -1.5, min_val=-3.0) >>> stretched_img array([[-3. , -2.5], [-2. , -1.5]]) or re-stretch that image between 1.25 and 8.00 >>> stretched_img = stretch_image(stretched_img, 8.00, min_val=1.25) >>> stretched_img array([[ 1.25, 3.5 ], [ 5.75, 8. ]]) """ @_decorate_validation def validate_input(): _numeric('img', 'floating', shape=(-1, -1)) _numeric('max_val', ('integer', 'floating')) _numeric('min_val', ('integer', 'floating')) if not max_val > min_val: msg = 'max_val ({!r}) must be larger than min_val ({!r})' raise ValueError(msg.format(max_val, min_val)) validate_input() min_ = img.min() max_ = img.max() a = (max_val - min_val) / (max_ - min_) b = -a * min_ + min_val return a * img + b
[docs]def _handle_ticklabels(ax, k, x_ticklabels, y_ticklabels): """ Handle and format ticks and ticklabels for use in imsubplot. The `imsubplot` function creates a figure showing an abitrary number of subplots. The `imsubplot` function allows for custom ticklabels along the x- and y-axes. This function handles the formatting of the ticklabels for a given subplot. Parameters ---------- ax : matplotlib.axes.Axes The matplotlib axes (subplot) to format ticklabels for. k : int The axes index, i.e. the subplot number out of the total number of subplots. x_ticklabels : list or tuple The x_ticklabels (as strings or lists of strings) to use for the subplots (the default is None, which implies that no x_ticklabels are displayed). y_ticklabels : list or tuple The y_ticklabels (as strings or lists of strings) to use for the subplots (the default is None, which implies that no y_ticklabels are displayed). """ if x_ticklabels is not None: if isinstance(x_ticklabels[0], (list, tuple)): x_t = range(len(x_ticklabels[k])) x_tl = x_ticklabels[k] else: # Backwards compatibility # Allow for sharing a single set of x_ticklabels across all images. x_t = range(len(x_ticklabels)) x_tl = x_ticklabels # Set xticks and xticklabels ax.set_xticks(x_t) ax.set_xticklabels(x_tl, rotation=90) if y_ticklabels is not None: if isinstance(y_ticklabels[0], (list, tuple)): y_t = range(len(y_ticklabels[k])) y_tl = y_ticklabels[k] else: # Backwards compatibility # Allow for sharing a single set of y_ticklabels across all images. y_t = range(len(y_ticklabels)) y_tl = y_ticklabels # Set yticks and yticklabels ax.set_yticks(y_t) ax.set_yticklabels(y_tl)
[docs]class _ImageColourTracker(): """ Track a common 'clim' in a set of matplotlib image subplots. Parameters ---------- tracker : matplotlib.image.AxesImage The image instance that must track a given 'clim'. """ def __init__(self, tracker): self.tracker = tracker
[docs] def __call__(self, tracked): self.tracker.set_clim(tracked.get_clim())