from __future__ import absolute_import, division, unicode_literals
import re
import warnings
import numpy as np
import matplotlib
from matplotlib import units as munits
from matplotlib import ticker
from matplotlib.colors import cnames
from matplotlib.lines import Line2D
from matplotlib.markers import MarkerStyle
from matplotlib.patches import Path, PathPatch
from matplotlib.transforms import Bbox, TransformedBbox, Affine2D
from matplotlib.rcsetup import (
validate_capstyle, validate_fontsize, validate_fonttype, validate_hatch,
validate_joinstyle)
try:
from nc_time_axis import NetCDFTimeConverter, CalendarDateTime
nc_axis_available = True
except:
from matplotlib.dates import DateConverter
NetCDFTimeConverter = DateConverter
nc_axis_available = False
from ...core.util import (
LooseVersion, _getargspec, arraylike_types, basestring,
cftime_types, is_number,)
from ...element import Raster, RGB, Polygons
from ..util import COLOR_ALIASES, RGB_HEX_REGEX
mpl_version = LooseVersion(matplotlib.__version__)
[docs]def is_color(color):
"""
Checks if supplied object is a valid color spec.
"""
if not isinstance(color, basestring):
return False
elif RGB_HEX_REGEX.match(color):
return True
elif color in COLOR_ALIASES:
return True
elif color in cnames:
return True
return False
validators = {
'alpha': lambda x: is_number(x) and (0 <= x <= 1),
'capstyle': validate_capstyle,
'color': is_color,
'fontsize': validate_fontsize,
'fonttype': validate_fonttype,
'hatch': validate_hatch,
'joinstyle': validate_joinstyle,
'marker': lambda x: (x in Line2D.markers or isinstance(x, MarkerStyle)
or isinstance(x, Path) or
(isinstance(x, basestring) and x.startswith('$')
and x.endswith('$'))),
's': lambda x: is_number(x) and (x >= 0)
}
def get_old_rcparams():
deprecated_rcparams = [
'text.latex.unicode',
'examples.directory',
'savefig.frameon', # deprecated in MPL 3.1, to be removed in 3.3
'verbose.level', # deprecated in MPL 3.1, to be removed in 3.3
'verbose.fileo', # deprecated in MPL 3.1, to be removed in 3.3
'datapath', # deprecated in MPL 3.2.1, to be removed in 3.3
'text.latex.preview', # deprecated in MPL 3.3.1
'animation.avconv_args', # deprecated in MPL 3.3.1
'animation.avconv_path', # deprecated in MPL 3.3.1
'animation.html_args', # deprecated in MPL 3.3.1
'keymap.all_axes', # deprecated in MPL 3.3.1
'savefig.jpeg_quality' # deprecated in MPL 3.3.1
]
old_rcparams = {
k: v for k, v in matplotlib.rcParams.items()
if mpl_version < '3.0' or k not in deprecated_rcparams
}
return old_rcparams
def get_validator(style):
for k, v in validators.items():
if style.endswith(k) and (len(style) != 1 or style == k):
return v
[docs]def validate(style, value, vectorized=True):
"""
Validates a style and associated value.
Arguments
---------
style: str
The style to validate (e.g. 'color', 'size' or 'marker')
value:
The style value to validate
vectorized: bool
Whether validator should allow vectorized setting
Returns
-------
valid: boolean or None
If validation is supported returns boolean, otherwise None
"""
validator = get_validator(style)
if validator is None:
return None
if isinstance(value, arraylike_types+(list,)) and vectorized:
return all(validator(v) for v in value)
try:
valid = validator(value)
return False if valid == False else True
except:
return False
[docs]def filter_styles(style, group, other_groups, blacklist=[]):
"""
Filters styles which are specific to a particular artist, e.g.
for a GraphPlot this will filter options specific to the nodes and
edges.
Arguments
---------
style: dict
Dictionary of styles and values
group: str
Group within the styles to filter for
other_groups: list
Other groups to filter out
blacklist: list (optional)
List of options to filter out
Returns
-------
filtered: dict
Filtered dictionary of styles
"""
group = group+'_'
filtered = {}
for k, v in style.items():
if (any(k.startswith(p) for p in other_groups)
or k.startswith(group) or k in blacklist):
continue
filtered[k] = v
for k, v in style.items():
if not k.startswith(group) or k in blacklist:
continue
filtered[k[len(group):]] = v
return filtered
def unpack_adjoints(ratios):
new_ratios = {}
offset = 0
for k, (num, ratios) in sorted(ratios.items()):
unpacked = [[] for _ in range(num)]
for r in ratios:
nr = len(r)
for i in range(num):
unpacked[i].append(r[i] if i < nr else np.nan)
for i, r in enumerate(unpacked):
new_ratios[k+i+offset] = r
offset += num-1
return new_ratios
def normalize_ratios(ratios):
normalized = {}
for i, v in enumerate(zip(*ratios.values())):
arr = np.array(v)
normalized[i] = arr/float(np.nanmax(arr))
return normalized
def compute_ratios(ratios, normalized=True):
unpacked = unpack_adjoints(ratios)
with warnings.catch_warnings():
warnings.filterwarnings('ignore', r'All-NaN (slice|axis) encountered')
if normalized:
unpacked = normalize_ratios(unpacked)
sorted_ratios = sorted(unpacked.items())
return np.nanmax(np.vstack([v for _, v in sorted_ratios]), axis=0)
[docs]def axis_overlap(ax1, ax2):
"""
Tests whether two axes overlap vertically
"""
b1, t1 = ax1.get_position().intervaly
b2, t2 = ax2.get_position().intervaly
return t1 > b2 and b1 < t2
[docs]def resolve_rows(rows):
"""
Recursively iterate over lists of axes merging
them by their vertical overlap leaving a list
of rows.
"""
merged_rows = []
for row in rows:
overlap = False
for mrow in merged_rows:
if any(axis_overlap(ax1, ax2) for ax1 in row
for ax2 in mrow):
mrow += row
overlap = True
break
if not overlap:
merged_rows.append(row)
if rows == merged_rows:
return rows
else:
return resolve_rows(merged_rows)
[docs]def fix_aspect(fig, nrows, ncols, title=None, extra_artists=[],
vspace=0.2, hspace=0.2):
"""
Calculate heights and widths of axes and adjust
the size of the figure to match the aspect.
"""
fig.canvas.draw()
w, h = fig.get_size_inches()
# Compute maximum height and width of each row and columns
rows = resolve_rows([[ax] for ax in fig.axes])
rs, cs = len(rows), max([len(r) for r in rows])
heights = [[] for i in range(cs)]
widths = [[] for i in range(rs)]
for r, row in enumerate(rows):
for c, ax in enumerate(row):
bbox = ax.get_tightbbox(fig.canvas.get_renderer())
heights[c].append(bbox.height)
widths[r].append(bbox.width)
height = (max([sum(c) for c in heights])) + nrows*vspace*fig.dpi
width = (max([sum(r) for r in widths])) + ncols*hspace*fig.dpi
# Compute aspect and set new size (in inches)
aspect = height/width
offset = 0
if title and title.get_text():
offset = title.get_window_extent().height/fig.dpi
fig.set_size_inches(w, (w*aspect)+offset)
# Redraw and adjust title position if defined
fig.canvas.draw()
if title and title.get_text():
extra_artists = [a for a in extra_artists
if a is not title]
bbox = get_tight_bbox(fig, extra_artists)
top = bbox.intervaly[1]
if title and title.get_text():
title.set_y((top/(w*aspect)))
[docs]def get_tight_bbox(fig, bbox_extra_artists=[], pad=None):
"""
Compute a tight bounding box around all the artists in the figure.
"""
renderer = fig.canvas.get_renderer()
bbox_inches = fig.get_tightbbox(renderer)
bbox_artists = bbox_extra_artists[:]
bbox_artists += fig.get_default_bbox_extra_artists()
bbox_filtered = []
for a in bbox_artists:
bbox = a.get_window_extent(renderer)
if isinstance(bbox, tuple):
continue
if a.get_clip_on():
clip_box = a.get_clip_box()
if clip_box is not None:
bbox = Bbox.intersection(bbox, clip_box)
clip_path = a.get_clip_path()
if clip_path is not None and bbox is not None:
clip_path = clip_path.get_fully_transformed_path()
bbox = Bbox.intersection(bbox,
clip_path.get_extents())
if bbox is not None and (bbox.width != 0 or
bbox.height != 0):
bbox_filtered.append(bbox)
if bbox_filtered:
_bbox = Bbox.union(bbox_filtered)
trans = Affine2D().scale(1.0 / fig.dpi)
bbox_extra = TransformedBbox(_bbox, trans)
bbox_inches = Bbox.union([bbox_inches, bbox_extra])
return bbox_inches.padded(pad) if pad else bbox_inches
[docs]def get_raster_array(image):
"""
Return the array data from any Raster or Image type
"""
if isinstance(image, RGB):
rgb = image.rgb
data = np.dstack([np.flipud(rgb.dimension_values(d, flat=False))
for d in rgb.vdims])
else:
data = image.dimension_values(2, flat=False)
if type(image) is Raster:
data = data.T
else:
data = np.flipud(data)
return data
[docs]def ring_coding(array):
"""
Produces matplotlib Path codes for exterior and interior rings
of a polygon geometry.
"""
# The codes will be all "LINETO" commands, except for "MOVETO"s at the
# beginning of each subpath
n = len(array)
codes = np.ones(n, dtype=Path.code_type) * Path.LINETO
codes[0] = Path.MOVETO
codes[-1] = Path.CLOSEPOLY
return codes
[docs]def polygons_to_path_patches(element):
"""
Converts Polygons into list of lists of matplotlib.patches.PathPatch
objects including any specified holes. Each list represents one
(multi-)polygon.
"""
paths = element.split(datatype='array', dimensions=element.kdims)
has_holes = isinstance(element, Polygons) and element.interface.has_holes(element)
holes = element.interface.holes(element) if has_holes else None
mpl_paths = []
for i, path in enumerate(paths):
splits = np.where(np.isnan(path[:, :2].astype('float')).sum(axis=1))[0]
arrays = np.split(path, splits+1) if len(splits) else [path]
subpath = []
for j, array in enumerate(arrays):
if j != (len(arrays)-1):
array = array[:-1]
if (array[0] != array[-1]).any():
array = np.append(array, array[:1], axis=0)
interiors = []
for interior in (holes[i][j] if has_holes else []):
if (interior[0] != interior[-1]).any():
interior = np.append(interior, interior[:1], axis=0)
interiors.append(interior)
vertices = np.concatenate([array]+interiors)
codes = np.concatenate([ring_coding(array)]+
[ring_coding(h) for h in interiors])
subpath.append(PathPatch(Path(vertices, codes)))
mpl_paths.append(subpath)
return mpl_paths
[docs]class CFTimeConverter(NetCDFTimeConverter):
"""
Defines conversions for cftime types by extending nc_time_axis.
"""
[docs] @classmethod
def convert(cls, value, unit, axis):
if not nc_axis_available:
raise ValueError('In order to display cftime types with '
'matplotlib install the nc_time_axis '
'library using pip or from conda-forge '
'using:\n\tconda install -c conda-forge '
'nc_time_axis')
if isinstance(value, cftime_types):
value = CalendarDateTime(value.datetime, value.calendar)
elif isinstance(value, np.ndarray):
value = np.array([CalendarDateTime(v.datetime, v.calendar) for v in value])
return super(CFTimeConverter, cls).convert(value, unit, axis)
for cft in cftime_types:
munits.registry[cft] = CFTimeConverter()