from __future__ import absolute_import, division, unicode_literals
import os
from matplotlib import rc_params_from_file
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
from matplotlib.cm import register_cmap
from param import concrete_descendents
from ...core import Layout, Collator, GridMatrix, config
from ...core.options import Cycle, Palette, Options
from ...core.overlay import NdOverlay, Overlay
from ...core.util import LooseVersion, pd
from ...element import * # noqa (API import)
from ..plot import PlotSelector
from ..util import fire_colors
from .annotation import * # noqa (API import)
from .chart import * # noqa (API import)
from .chart3d import * # noqa (API import)
from .element import ElementPlot
from .geometry import * # noqa (API import)
from .graphs import * # noqa (API import)
from .heatmap import * # noqa (API import)
from .hex_tiles import * # noqa (API import)
from .path import * # noqa (API import)
from .plot import * # noqa (API import)
from .raster import * # noqa (API import)
from .sankey import * # noqa (API import)
from .stats import * # noqa (API import)
from .tabular import * # noqa (API import)
from .renderer import MPLRenderer
mpl_ge_150 = LooseVersion(mpl.__version__) >= '1.5.0'
if pd:
try:
from pandas.plotting import register_matplotlib_converters
register_matplotlib_converters()
except ImportError:
from pandas.tseries import converter
converter.register()
[docs]def set_style(key):
"""
Select a style by name, e.g. set_style('default'). To revert to the
previous style use the key 'unset' or False.
"""
if key is None:
return
elif not key or key in ['unset', 'backup']:
if 'backup' in styles:
plt.rcParams.update(styles['backup'])
else:
raise Exception('No style backed up to restore')
elif key not in styles:
raise KeyError('%r not in available styles.')
else:
path = os.path.join(os.path.dirname(__file__), styles[key])
new_style = rc_params_from_file(path, use_default_template=False)
styles['backup'] = dict(plt.rcParams)
plt.rcParams.update(new_style)
# Define matplotlib based style cycles and Palettes
def get_color_cycle():
if mpl_ge_150:
cyl = mpl.rcParams['axes.prop_cycle']
# matplotlib 1.5 verifies that axes.prop_cycle *is* a cycler
# but no guarantee that there's a `color` key.
# so users could have a custom rcParams w/ no color...
try:
return [x['color'] for x in cyl]
except KeyError:
pass # just return axes.color style below
return mpl.rcParams['axes.color_cycle']
styles = {'default': './default.mplstyle',
'default>1.5': './default1.5.mplstyle'}
# Define Palettes and cycles from matplotlib colormaps
Palette.colormaps.update({cm: plt.get_cmap(cm) for cm in plt.cm.datad
if not ('spectral' in cm or 'Vega' in cm)})
listed_cmaps = [cm for cm in Palette.colormaps.values() if isinstance(cm, ListedColormap)]
Cycle.default_cycles.update({cm.name: list(cm.colors) for cm in listed_cmaps})
style_aliases = {'edgecolor': ['ec', 'ecolor'], 'facecolor': ['fc'],
'linewidth': ['lw'], 'edgecolors': ['ec', 'edgecolor'],
'size': ['s'], 'color': ['c'], 'markeredgecolor': ['mec'],
'markeredgewidth': ['mew'], 'markerfacecolor': ['mfc'],
'markersize': ['ms']}
Store.renderers['matplotlib'] = MPLRenderer.instance()
if len(Store.renderers) == 1:
Store.set_current_backend('matplotlib')
# Defines a wrapper around GridPlot and RasterGridPlot
# switching to RasterGridPlot if the plot only contains
# Raster Elements
BasicGridPlot = GridPlot
def grid_selector(grid):
raster_fn = lambda x: True if isinstance(x, Raster) else False
all_raster = all(grid.traverse(raster_fn, [Element]))
return 'RasterGridPlot' if all_raster else 'GridPlot'
GridPlot = PlotSelector(grid_selector,
plot_classes=[('GridPlot', BasicGridPlot),
('RasterGridPlot', RasterGridPlot)])
# Register default Elements
Store.register({Curve: CurvePlot,
Scatter: PointPlot,
Bars: BarPlot,
Histogram: HistogramPlot,
Points: PointPlot,
VectorField: VectorFieldPlot,
ErrorBars: ErrorPlot,
Spread: SpreadPlot,
Spikes: SpikesPlot,
BoxWhisker: BoxPlot,
Area: AreaPlot,
# General plots
GridSpace: GridPlot,
GridMatrix: GridPlot,
NdLayout: LayoutPlot,
Layout: LayoutPlot,
AdjointLayout: AdjointLayoutPlot,
# Element plots
NdOverlay: OverlayPlot,
Overlay: OverlayPlot,
# Chart 3D
Surface: SurfacePlot,
TriSurface: TriSurfacePlot,
Trisurface: TriSurfacePlot, # Alias, remove in 2.0
Scatter3D: Scatter3DPlot,
Path3D: Path3DPlot,
# Tabular plots
ItemTable: TablePlot,
Table: TablePlot,
Collator: TablePlot,
# Raster plots
QuadMesh: QuadMeshPlot,
Raster: RasterPlot,
HeatMap: PlotSelector(HeatMapPlot.is_radial,
{True: RadialHeatMapPlot,
False: HeatMapPlot},
True),
Image: RasterPlot,
RGB: RGBPlot,
HSV: RGBPlot,
# Graph Elements
Graph: GraphPlot,
TriMesh: TriMeshPlot,
Chord: ChordPlot,
Nodes: PointPlot,
EdgePaths: PathPlot,
Sankey: SankeyPlot,
# Annotation plots
VLine: VLinePlot,
HLine: HLinePlot,
VSpan: VSpanPlot,
HSpan: HSpanPlot,
Slope: SlopePlot,
Arrow: ArrowPlot,
Spline: SplinePlot,
Text: TextPlot,
Labels: LabelsPlot,
# Path plots
Contours: ContourPlot,
Path: PathPlot,
Box: PathPlot,
Bounds: PathPlot,
Ellipse: PathPlot,
Polygons: PolygonPlot,
# Geometry plots
Rectangles: RectanglesPlot,
Segments: SegmentPlot,
# Statistics elements
Distribution: DistributionPlot,
Bivariate: BivariatePlot,
Violin: ViolinPlot,
HexTiles: HexTilesPlot},
'matplotlib', style_aliases=style_aliases)
MPLPlot.sideplots.update({Histogram: SideHistogramPlot,
Area: SideAreaPlot,
GridSpace: GridPlot,
Spikes: SideSpikesPlot,
BoxWhisker: SideBoxPlot})
if config.no_padding:
for plot in concrete_descendents(ElementPlot).values():
plot.padding = 0
# Raster types, Path types and VectorField should have frames
for framedcls in [VectorFieldPlot, ContourPlot, PathPlot, RasterPlot,
QuadMeshPlot, HeatMapPlot, PolygonPlot]:
framedcls.show_frame = True
fire_cmap = LinearSegmentedColormap.from_list("fire", fire_colors, N=len(fire_colors))
fire_r_cmap = LinearSegmentedColormap.from_list("fire_r", list(reversed(fire_colors)),
N=len(fire_colors))
register_cmap("fire", cmap=fire_cmap)
register_cmap("fire_r", cmap=fire_r_cmap)
options = Store.options(backend='matplotlib')
dflt_cmap = 'fire'
# Default option definitions
# Note: *No*short aliases here! e.g use 'facecolor' instead of 'fc'
# Charts
options.Curve = Options('style', color=Cycle(), linewidth=2)
options.Scatter = Options('style', color=Cycle(), marker='o', cmap=dflt_cmap)
options.Points = Options('plot', show_frame=True)
options.ErrorBars = Options('style', edgecolor='k')
options.Spread = Options('style', facecolor=Cycle(), alpha=0.6, edgecolor='k', linewidth=0.5)
options.Bars = Options('style', edgecolor='k', color=Cycle())
options.Histogram = Options('style', edgecolor='k', facecolor=Cycle())
options.Points = Options('style', color=Cycle(), marker='o', cmap=dflt_cmap)
options.Scatter3D = Options('style', c=Cycle(), marker='o')
options.Scatter3D = Options('plot', fig_size=150)
options.Path3D = Options('plot', fig_size=150)
options.Surface = Options('plot', fig_size=150)
options.Surface = Options('style', cmap='fire')
options.Spikes = Options('style', color='black', cmap='fire')
options.Area = Options('style', facecolor=Cycle(), edgecolor='black')
options.BoxWhisker = Options('style', boxprops=dict(color='k', linewidth=1.5),
whiskerprops=dict(color='k', linewidth=1.5))
# Geometries
options.Rectangles = Options('style', edgecolor='black')
# Rasters
options.Image = Options('style', cmap=dflt_cmap, interpolation='nearest')
options.Raster = Options('style', cmap=dflt_cmap, interpolation='nearest')
options.QuadMesh = Options('style', cmap=dflt_cmap)
options.HeatMap = Options('style', cmap='RdYlBu_r', edgecolors='white',
annular_edgecolors='white', annular_linewidth=0.5,
xmarks_edgecolor='white', xmarks_linewidth=3,
ymarks_edgecolor='white', ymarks_linewidth=3,
linewidths=0)
options.HeatMap = Options('plot', show_values=True)
options.RGB = Options('style', interpolation='nearest')
# Composites
options.Layout = Options('plot', sublabel_format='{Alpha}')
options.GridMatrix = Options('plot', fig_size=160, shared_xaxis=True,
shared_yaxis=True, xaxis=None, yaxis=None)
# Annotations
options.VLine = Options('style', color=Cycle())
options.HLine = Options('style', color=Cycle())
options.Slope = Options('style', color=Cycle())
options.VSpan = Options('style', alpha=0.5, facecolor=Cycle())
options.HSpan = Options('style', alpha=0.5, facecolor=Cycle())
options.Spline = Options('style', edgecolor=Cycle())
options.Arrow = Options('style', color='k', linewidth=2, fontsize=13)
# Paths
options.Contours = Options('style', color=Cycle(), cmap='viridis')
options.Contours = Options('plot', show_legend=True)
options.Path = Options('style', color=Cycle(), cmap='viridis')
options.Polygons = Options('style', facecolor=Cycle(), edgecolor='black',
cmap='viridis')
options.Box = Options('style', color='black')
options.Bounds = Options('style', color='black')
options.Ellipse = Options('style', color='black')
# Interface
options.TimeSeries = Options('style', color=Cycle())
# Graphs
options.Graph = Options('style', node_edgecolors='black', node_facecolors=Cycle(),
edge_color='black', node_size=15)
options.TriMesh = Options('style', node_edgecolors='black', node_facecolors='white',
edge_color='black', node_size=5, edge_linewidth=1)
options.Chord = Options('style', node_edgecolors='black', node_facecolors=Cycle(),
edge_color='black', node_size=10, edge_linewidth=0.5)
options.Chord = Options('plot', xaxis=None, yaxis=None)
options.Nodes = Options('style', edgecolors='black', facecolors=Cycle(),
marker='o', s=20**2)
options.EdgePaths = Options('style', color='black')
options.Sankey = Options('plot', xaxis=None, yaxis=None, fig_size=400,
aspect=1.6, show_frame=False)
options.Sankey = Options('style', edge_color='grey', node_edgecolors='black',
edge_alpha=0.6, node_size=6)
# Statistics
options.Distribution = Options('style', facecolor=Cycle(), edgecolor='black',
alpha=0.5)
options.Violin = Options('style', facecolors=Cycle(), showextrema=False, alpha=0.7)