Source code for holoviews.plotting.mpl

from __future__ import absolute_import, division, unicode_literals

import os

from matplotlib import rc_params_from_file
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
from matplotlib.cm import register_cmap
from param import concrete_descendents

from ...core import Layout, Collator, GridMatrix, config
from ...core.options import Cycle, Palette, Options
from ...core.overlay import NdOverlay, Overlay
from ...core.util import LooseVersion, pd
from ...element import * # noqa (API import)
from ..plot import PlotSelector
from ..util import fire_colors
from .annotation import * # noqa (API import)
from .chart import * # noqa (API import)
from .chart3d import * # noqa (API import)
from .element import ElementPlot
from .geometry import * # noqa (API import)
from .graphs import * # noqa (API import)
from .heatmap import * # noqa (API import)
from .hex_tiles import * # noqa (API import)
from .path import * # noqa (API import)
from .plot import * # noqa (API import)
from .raster import * # noqa (API import)
from .sankey import * # noqa (API import)
from .stats import * # noqa (API import)
from .tabular import * # noqa (API import)

from .renderer import MPLRenderer


mpl_ge_150 = LooseVersion(mpl.__version__) >= '1.5.0'

if pd:
    try:
        from pandas.plotting import register_matplotlib_converters
        register_matplotlib_converters()
    except ImportError:
        from pandas.tseries import converter
        converter.register()


[docs]def set_style(key): """ Select a style by name, e.g. set_style('default'). To revert to the previous style use the key 'unset' or False. """ if key is None: return elif not key or key in ['unset', 'backup']: if 'backup' in styles: plt.rcParams.update(styles['backup']) else: raise Exception('No style backed up to restore') elif key not in styles: raise KeyError('%r not in available styles.') else: path = os.path.join(os.path.dirname(__file__), styles[key]) new_style = rc_params_from_file(path, use_default_template=False) styles['backup'] = dict(plt.rcParams) plt.rcParams.update(new_style)
# Define matplotlib based style cycles and Palettes def get_color_cycle(): if mpl_ge_150: cyl = mpl.rcParams['axes.prop_cycle'] # matplotlib 1.5 verifies that axes.prop_cycle *is* a cycler # but no guarantee that there's a `color` key. # so users could have a custom rcParams w/ no color... try: return [x['color'] for x in cyl] except KeyError: pass # just return axes.color style below return mpl.rcParams['axes.color_cycle'] styles = {'default': './default.mplstyle', 'default>1.5': './default1.5.mplstyle'} # Define Palettes and cycles from matplotlib colormaps Palette.colormaps.update({cm: plt.get_cmap(cm) for cm in plt.cm.datad if not ('spectral' in cm or 'Vega' in cm)}) listed_cmaps = [cm for cm in Palette.colormaps.values() if isinstance(cm, ListedColormap)] Cycle.default_cycles.update({cm.name: list(cm.colors) for cm in listed_cmaps}) style_aliases = {'edgecolor': ['ec', 'ecolor'], 'facecolor': ['fc'], 'linewidth': ['lw'], 'edgecolors': ['ec', 'edgecolor'], 'size': ['s'], 'color': ['c'], 'markeredgecolor': ['mec'], 'markeredgewidth': ['mew'], 'markerfacecolor': ['mfc'], 'markersize': ['ms']} Store.renderers['matplotlib'] = MPLRenderer.instance() if len(Store.renderers) == 1: Store.set_current_backend('matplotlib') # Defines a wrapper around GridPlot and RasterGridPlot # switching to RasterGridPlot if the plot only contains # Raster Elements BasicGridPlot = GridPlot def grid_selector(grid): raster_fn = lambda x: True if isinstance(x, Raster) else False all_raster = all(grid.traverse(raster_fn, [Element])) return 'RasterGridPlot' if all_raster else 'GridPlot' GridPlot = PlotSelector(grid_selector, plot_classes=[('GridPlot', BasicGridPlot), ('RasterGridPlot', RasterGridPlot)]) # Register default Elements Store.register({Curve: CurvePlot, Scatter: PointPlot, Bars: BarPlot, Histogram: HistogramPlot, Points: PointPlot, VectorField: VectorFieldPlot, ErrorBars: ErrorPlot, Spread: SpreadPlot, Spikes: SpikesPlot, BoxWhisker: BoxPlot, Area: AreaPlot, # General plots GridSpace: GridPlot, GridMatrix: GridPlot, NdLayout: LayoutPlot, Layout: LayoutPlot, AdjointLayout: AdjointLayoutPlot, # Element plots NdOverlay: OverlayPlot, Overlay: OverlayPlot, # Chart 3D Surface: SurfacePlot, TriSurface: TriSurfacePlot, Trisurface: TriSurfacePlot, # Alias, remove in 2.0 Scatter3D: Scatter3DPlot, Path3D: Path3DPlot, # Tabular plots ItemTable: TablePlot, Table: TablePlot, Collator: TablePlot, # Raster plots QuadMesh: QuadMeshPlot, Raster: RasterPlot, HeatMap: PlotSelector(HeatMapPlot.is_radial, {True: RadialHeatMapPlot, False: HeatMapPlot}, True), Image: RasterPlot, RGB: RGBPlot, HSV: RGBPlot, # Graph Elements Graph: GraphPlot, TriMesh: TriMeshPlot, Chord: ChordPlot, Nodes: PointPlot, EdgePaths: PathPlot, Sankey: SankeyPlot, # Annotation plots VLine: VLinePlot, HLine: HLinePlot, VSpan: VSpanPlot, HSpan: HSpanPlot, Slope: SlopePlot, Arrow: ArrowPlot, Spline: SplinePlot, Text: TextPlot, Labels: LabelsPlot, # Path plots Contours: ContourPlot, Path: PathPlot, Box: PathPlot, Bounds: PathPlot, Ellipse: PathPlot, Polygons: PolygonPlot, # Geometry plots Rectangles: RectanglesPlot, Segments: SegmentPlot, # Statistics elements Distribution: DistributionPlot, Bivariate: BivariatePlot, Violin: ViolinPlot, HexTiles: HexTilesPlot}, 'matplotlib', style_aliases=style_aliases) MPLPlot.sideplots.update({Histogram: SideHistogramPlot, Area: SideAreaPlot, GridSpace: GridPlot, Spikes: SideSpikesPlot, BoxWhisker: SideBoxPlot}) if config.no_padding: for plot in concrete_descendents(ElementPlot).values(): plot.padding = 0 # Raster types, Path types and VectorField should have frames for framedcls in [VectorFieldPlot, ContourPlot, PathPlot, RasterPlot, QuadMeshPlot, HeatMapPlot, PolygonPlot]: framedcls.show_frame = True fire_cmap = LinearSegmentedColormap.from_list("fire", fire_colors, N=len(fire_colors)) fire_r_cmap = LinearSegmentedColormap.from_list("fire_r", list(reversed(fire_colors)), N=len(fire_colors)) register_cmap("fire", cmap=fire_cmap) register_cmap("fire_r", cmap=fire_r_cmap) options = Store.options(backend='matplotlib') dflt_cmap = 'fire' # Default option definitions # Note: *No*short aliases here! e.g use 'facecolor' instead of 'fc' # Charts options.Curve = Options('style', color=Cycle(), linewidth=2) options.Scatter = Options('style', color=Cycle(), marker='o', cmap=dflt_cmap) options.Points = Options('plot', show_frame=True) options.ErrorBars = Options('style', edgecolor='k') options.Spread = Options('style', facecolor=Cycle(), alpha=0.6, edgecolor='k', linewidth=0.5) options.Bars = Options('style', edgecolor='k', color=Cycle()) options.Histogram = Options('style', edgecolor='k', facecolor=Cycle()) options.Points = Options('style', color=Cycle(), marker='o', cmap=dflt_cmap) options.Scatter3D = Options('style', c=Cycle(), marker='o') options.Scatter3D = Options('plot', fig_size=150) options.Path3D = Options('plot', fig_size=150) options.Surface = Options('plot', fig_size=150) options.Surface = Options('style', cmap='fire') options.Spikes = Options('style', color='black', cmap='fire') options.Area = Options('style', facecolor=Cycle(), edgecolor='black') options.BoxWhisker = Options('style', boxprops=dict(color='k', linewidth=1.5), whiskerprops=dict(color='k', linewidth=1.5)) # Geometries options.Rectangles = Options('style', edgecolor='black') # Rasters options.Image = Options('style', cmap=dflt_cmap, interpolation='nearest') options.Raster = Options('style', cmap=dflt_cmap, interpolation='nearest') options.QuadMesh = Options('style', cmap=dflt_cmap) options.HeatMap = Options('style', cmap='RdYlBu_r', edgecolors='white', annular_edgecolors='white', annular_linewidth=0.5, xmarks_edgecolor='white', xmarks_linewidth=3, ymarks_edgecolor='white', ymarks_linewidth=3, linewidths=0) options.HeatMap = Options('plot', show_values=True) options.RGB = Options('style', interpolation='nearest') # Composites options.Layout = Options('plot', sublabel_format='{Alpha}') options.GridMatrix = Options('plot', fig_size=160, shared_xaxis=True, shared_yaxis=True, xaxis=None, yaxis=None) # Annotations options.VLine = Options('style', color=Cycle()) options.HLine = Options('style', color=Cycle()) options.Slope = Options('style', color=Cycle()) options.VSpan = Options('style', alpha=0.5, facecolor=Cycle()) options.HSpan = Options('style', alpha=0.5, facecolor=Cycle()) options.Spline = Options('style', edgecolor=Cycle()) options.Arrow = Options('style', color='k', linewidth=2, fontsize=13) # Paths options.Contours = Options('style', color=Cycle(), cmap='viridis') options.Contours = Options('plot', show_legend=True) options.Path = Options('style', color=Cycle(), cmap='viridis') options.Polygons = Options('style', facecolor=Cycle(), edgecolor='black', cmap='viridis') options.Box = Options('style', color='black') options.Bounds = Options('style', color='black') options.Ellipse = Options('style', color='black') # Interface options.TimeSeries = Options('style', color=Cycle()) # Graphs options.Graph = Options('style', node_edgecolors='black', node_facecolors=Cycle(), edge_color='black', node_size=15) options.TriMesh = Options('style', node_edgecolors='black', node_facecolors='white', edge_color='black', node_size=5, edge_linewidth=1) options.Chord = Options('style', node_edgecolors='black', node_facecolors=Cycle(), edge_color='black', node_size=10, edge_linewidth=0.5) options.Chord = Options('plot', xaxis=None, yaxis=None) options.Nodes = Options('style', edgecolors='black', facecolors=Cycle(), marker='o', s=20**2) options.EdgePaths = Options('style', color='black') options.Sankey = Options('plot', xaxis=None, yaxis=None, fig_size=400, aspect=1.6, show_frame=False) options.Sankey = Options('style', edge_color='grey', node_edgecolors='black', edge_alpha=0.6, node_size=6) # Statistics options.Distribution = Options('style', facecolor=Cycle(), edgecolor='black', alpha=0.5) options.Violin = Options('style', facecolors=Cycle(), showextrema=False, alpha=0.7)