from __future__ import division
import copy
import re
import numpy as np
from plotly import colors
from ...core.util import isfinite, max_range
from ..util import color_intervals, process_cmap
# Constants
# ---------
# Trace types that are individually positioned with their own domain.
# These are traces that don't overlay on top of each other in a shared subplot,
# so they are positioned individually.  All other trace types are associated
# with a layout subplot type (xaxis/yaxis, polar, scene etc.)
#
# Each of these trace types has a `domain` property with `x`/`y` properties
_domain_trace_types = {'parcoords', 'pie', 'table', 'sankey', 'parcats'}
# Subplot types that are each individually positioned with a domain
#
# Each of these subplot types has a `domain` property with `x`/`y` properties.
# Note that this set does not contain `xaxis`/`yaxis` because these behave a
# little differently.
_subplot_types = {'scene', 'geo', 'polar', 'ternary', 'mapbox'}
# For most subplot types, a trace is associated with a particular subplot
# using a trace property with a name that matches the subplot type. For
# example, a `scatter3d.scene` property set to `'scene2'` associates a
# scatter3d trace with the second `scene` subplot in the figure.
#
# There are a few subplot types that don't follow this pattern, and instead
# the trace property is just named `subplot`.  For example setting
# the `scatterpolar.subplot` property to `polar3` associates the scatterpolar
# trace with the third polar subplot in the figure
_subplot_prop_named_subplot = {'polar', 'ternary', 'mapbox'}
# Mapping from trace type to subplot type(s).
_trace_to_subplot = {
    # xaxis/yaxis
    'bar':                  ['xaxis', 'yaxis'],
    'box':                  ['xaxis', 'yaxis'],
    'candlestick':          ['xaxis', 'yaxis'],
    'carpet':               ['xaxis', 'yaxis'],
    'contour':              ['xaxis', 'yaxis'],
    'contourcarpet':        ['xaxis', 'yaxis'],
    'heatmap':              ['xaxis', 'yaxis'],
    'heatmapgl':            ['xaxis', 'yaxis'],
    'histogram':            ['xaxis', 'yaxis'],
    'histogram2d':          ['xaxis', 'yaxis'],
    'histogram2dcontour':   ['xaxis', 'yaxis'],
    'ohlc':                 ['xaxis', 'yaxis'],
    'pointcloud':           ['xaxis', 'yaxis'],
    'scatter':              ['xaxis', 'yaxis'],
    'scattercarpet':        ['xaxis', 'yaxis'],
    'scattergl':            ['xaxis', 'yaxis'],
    'violin':               ['xaxis', 'yaxis'],
    # scene
    'cone':         ['scene'],
    'mesh3d':       ['scene'],
    'scatter3d':    ['scene'],
    'streamtube':   ['scene'],
    'surface':      ['scene'],
    # geo
    'choropleth': ['geo'],
    'scattergeo': ['geo'],
    # polar
    'barpolar':         ['polar'],
    'scatterpolar':     ['polar'],
    'scatterpolargl':   ['polar'],
    # ternary
    'scatterternary': ['ternary'],
    # mapbox
    'scattermapbox': ['mapbox']
}
# trace types that support legends
legend_trace_types = {
    'scatter',
    'bar',
    'box',
    'histogram',
    'histogram2dcontour',
    'contour',
    'scatterternary',
    'violin',
    'waterfall',
    'pie',
    'scatter3d',
    'scattergeo',
    'scattergl',
    'splom',
    'pointcloud',
    'scattermapbox',
    'scattercarpet',
    'contourcarpet',
    'ohlc',
    'candlestick',
    'scatterpolar',
    'scatterpolargl',
    'barpolar',
    'area',
}
# Aliases - map common style options to more common names
STYLE_ALIASES = {'alpha': 'opacity',
                 'cell_height': 'height', 'marker': 'symbol'}
# Regular expression to extract any trailing digits from a subplot-style
# string.
_subplot_re = re.compile('\D*(\d+)')
def _get_subplot_number(subplot_val):
    """
    Extract the subplot number from a subplot value string.
    'x3' -> 3
    'polar2' -> 2
    'scene' -> 1
    'y' -> 1
    Note: the absence of a subplot number (e.g. 'y') is treated by plotly as
    a subplot number of 1
    Parameters
    ----------
    subplot_val: str
        Subplot string value (e.g. 'scene4')
    Returns
    -------
    int
    """
    match = _subplot_re.match(subplot_val)
    if match:
        subplot_number = int(match.group(1))
    else:
        subplot_number = 1
    return subplot_number
def _get_subplot_val_prefix(subplot_type):
    """
    Get the subplot value prefix for a subplot type. For most subplot types
    this is equal to the subplot type string itself. For example, a
    `scatter3d.scene` value of `scene2` is used to associate the scatter3d
    trace with the `layout.scene2` subplot.
    However, the `xaxis`/`yaxis` subplot types are exceptions to this pattern.
    For example, a `scatter.xaxis` value of `x2` is used to associate the
    scatter trace with the `layout.xaxis2` subplot.
    Parameters
    ----------
    subplot_type: str
        Subplot string value (e.g. 'scene4')
    Returns
    -------
    str
    """
    if subplot_type == 'xaxis':
        subplot_val_prefix = 'x'
    elif subplot_type == 'yaxis':
        subplot_val_prefix = 'y'
    else:
        subplot_val_prefix = subplot_type
    return subplot_val_prefix
def _get_subplot_prop_name(subplot_type):
    """
    Get the name of the trace property used to associate a trace with a
    particular subplot type.  For most subplot types this is equal to the
    subplot type string. For example, the `scatter3d.scene` property is used
    to associate a `scatter3d` trace with a particular `scene` subplot.
    However, for some subplot types the trace property is not named after the
    subplot type.  For example, the `scatterpolar.subplot` property is used
    to associate a `scatterpolar` trace with a particular `polar` subplot.
    Parameters
    ----------
    subplot_type: str
        Subplot string value (e.g. 'scene4')
    Returns
    -------
    str
    """
    if subplot_type in _subplot_prop_named_subplot:
        subplot_prop_name = 'subplot'
    else:
        subplot_prop_name = subplot_type
    return subplot_prop_name
def _normalize_subplot_ids(fig):
    """
    Make sure a layout subplot property is initialized for every subplot that
    is referenced by a trace in the figure.
    For example, if a figure contains a `scatterpolar` trace with the `subplot`
    property set to `polar3`, this function will make sure the figure's layout
    has a `polar3` property, and will initialize it to an empty dict if it
    does not
    Note: This function mutates the input figure dict
    Parameters
    ----------
    fig: dict
        A plotly figure dict
    """
    layout = fig.setdefault('layout', {})
    for trace in fig.get('data', None):
        trace_type = trace.get('type', 'scatter')
        subplot_types = _trace_to_subplot.get(trace_type, [])
        for subplot_type in subplot_types:
            subplot_prop_name = _get_subplot_prop_name(subplot_type)
            subplot_val_prefix = _get_subplot_val_prefix(subplot_type)
            subplot_val = trace.get(subplot_prop_name, subplot_val_prefix)
            # extract trailing number (if any)
            subplot_number = _get_subplot_number(subplot_val)
            if subplot_number > 1:
                layout_prop_name = subplot_type + str(subplot_number)
            else:
                layout_prop_name = subplot_type
            if layout_prop_name not in layout:
                layout[layout_prop_name] = {}
def _get_max_subplot_ids(fig):
    """
    Given an input figure, return a dict containing the max subplot number
    for each subplot type in the figure
    Parameters
    ----------
    fig: dict
        A plotly figure dict
    Returns
    -------
    dict
        A dict from subplot type strings to integers indicating the largest
        subplot number in the figure of that subplot type
    """
    max_subplot_ids = {subplot_type: 0
                       for subplot_type in _subplot_types}
    max_subplot_ids['xaxis'] = 0
    max_subplot_ids['yaxis'] = 0
    # Check traces
    for trace in fig.get('data', []):
        trace_type = trace.get('type', 'scatter')
        subplot_types = _trace_to_subplot.get(trace_type, [])
        for subplot_type in subplot_types:
            subplot_prop_name = _get_subplot_prop_name(subplot_type)
            subplot_val_prefix = _get_subplot_val_prefix(subplot_type)
            subplot_val = trace.get(subplot_prop_name, subplot_val_prefix)
            # extract trailing number (if any)
            subplot_number = _get_subplot_number(subplot_val)
            max_subplot_ids[subplot_type] = max(
                max_subplot_ids[subplot_type], subplot_number)
    # check annotations/shapes/images
    layout = fig.get('layout', {})
    for layout_prop in ['annotations', 'shapes', 'images']:
        for obj in layout.get(layout_prop, []):
            xref = obj.get('xref', 'x')
            if xref != 'paper':
                xref_number = _get_subplot_number(xref)
                max_subplot_ids['xaxis'] = max(max_subplot_ids['xaxis'], xref_number)
            yref = obj.get('yref', 'y')
            if yref != 'paper':
                yref_number = _get_subplot_number(yref)
                max_subplot_ids['yaxis'] = max(max_subplot_ids['yaxis'], yref_number)
    return max_subplot_ids
def _offset_subplot_ids(fig, offsets):
    """
    Apply offsets to the subplot id numbers in a figure.
    Note: This function mutates the input figure dict
    Note: This function assumes that the normalize_subplot_ids function has
    already been run on the figure, so that all layout subplot properties in
    use are explicitly present in the figure's layout.
    Parameters
    ----------
    fig: dict
        A plotly figure dict
    offsets: dict
        A dict from subplot types to the offset to be applied for each subplot
        type.  This dict matches the form of the dict returned by
        get_max_subplot_ids
    """
    # Offset traces
    for trace in fig.get('data', None):
        trace_type = trace.get('type', 'scatter')
        subplot_types = _trace_to_subplot.get(trace_type, [])
        for subplot_type in subplot_types:
            subplot_prop_name = _get_subplot_prop_name(subplot_type)
            # Compute subplot value prefix
            subplot_val_prefix = _get_subplot_val_prefix(subplot_type)
            subplot_val = trace.get(subplot_prop_name, subplot_val_prefix)
            subplot_number = _get_subplot_number(subplot_val)
            offset_subplot_number = (
                    subplot_number + offsets.get(subplot_type, 0))
            if offset_subplot_number > 1:
                trace[subplot_prop_name] = (
                        subplot_val_prefix + str(offset_subplot_number))
            else:
                trace[subplot_prop_name] = subplot_val_prefix
    # layout subplots
    layout = fig.setdefault('layout', {})
    new_subplots = {}
    for subplot_type in offsets:
        offset = offsets[subplot_type]
        if offset < 1:
            continue
        for layout_prop in list(layout.keys()):
            if layout_prop.startswith(subplot_type):
                subplot_number = _get_subplot_number(layout_prop)
                new_subplot_number = subplot_number + offset
                new_layout_prop = subplot_type + str(new_subplot_number)
                new_subplots[new_layout_prop] = layout.pop(layout_prop)
    layout.update(new_subplots)
    # xaxis/yaxis anchors
    x_offset = offsets.get('xaxis', 0)
    y_offset = offsets.get('yaxis', 0)
    for layout_prop in list(layout.keys()):
        if layout_prop.startswith('xaxis'):
            xaxis = layout[layout_prop]
            anchor = xaxis.get('anchor', 'y')
            anchor_number = _get_subplot_number(anchor) + y_offset
            if anchor_number > 1:
                xaxis['anchor'] = 'y' + str(anchor_number)
            else:
                xaxis['anchor'] = 'y'
        elif layout_prop.startswith('yaxis'):
            yaxis = layout[layout_prop]
            anchor = yaxis.get('anchor', 'x')
            anchor_number = _get_subplot_number(anchor) + x_offset
            if anchor_number > 1:
                yaxis['anchor'] = 'x' + str(anchor_number)
            else:
                yaxis['anchor'] = 'x'
    # Axis matches references
    for layout_prop in list(layout.keys()):
        if layout_prop[1:5] == 'axis':
            axis = layout[layout_prop]
            matches_val = axis.get('matches', None)
            if matches_val:
                if matches_val[0] == 'x':
                    matches_number = _get_subplot_number(matches_val) + x_offset
                elif matches_val[0] == 'y':
                    matches_number = _get_subplot_number(matches_val) + y_offset
                else:
                    continue
                suffix = str(matches_number) if matches_number > 1 else ""
                axis['matches'] = matches_val[0] + suffix
    # annotations/shapes/images
    for layout_prop in ['annotations', 'shapes', 'images']:
        for obj in layout.get(layout_prop, []):
            if x_offset:
                xref = obj.get('xref', 'x')
                if xref != 'paper':
                    xref_number = _get_subplot_number(xref)
                    obj['xref'] = 'x' + str(xref_number + x_offset)
            if y_offset:
                yref = obj.get('yref', 'y')
                if yref != 'paper':
                    yref_number = _get_subplot_number(yref)
                    obj['yref'] = 'y' + str(yref_number + y_offset)
def _scale_translate(fig, scale_x, scale_y, translate_x, translate_y):
    """
    Scale a figure and translate it to sub-region of the original
    figure canvas.
    Note: If the input figure has a title, this title is converted into an
    annotation and scaled along with the rest of the figure.
    Note: This function mutates the input fig dict
    Note: This function assumes that the normalize_subplot_ids function has
    already been run on the figure, so that all layout subplot properties in
    use are explicitly present in the figure's layout.
    Parameters
    ----------
    fig: dict
        A plotly figure dict
    scale_x: float
        Factor by which to scale the figure in the x-direction. This will
        typically be a value < 1.  E.g. a value of 0.5 will cause the
        resulting figure to be half as wide as the original.
    scale_y: float
        Factor by which to scale the figure in the y-direction. This will
        typically be a value < 1
    translate_x: float
        Factor by which to translate the scaled figure in the x-direction in
        normalized coordinates.
    translate_y: float
        Factor by which to translate the scaled figure in the x-direction in
        normalized coordinates.
    """
    data = fig.setdefault('data', [])
    layout = fig.setdefault('layout', {})
    def scale_translate_x(x):
        return [min(x[0] * scale_x + translate_x, 1),
                min(x[1] * scale_x + translate_x, 1)]
    def scale_translate_y(y):
        return [min(y[0] * scale_y + translate_y, 1),
                min(y[1] * scale_y + translate_y, 1)]
    def perform_scale_translate(obj):
        domain = obj.setdefault('domain', {})
        x = domain.get('x', [0, 1])
        y = domain.get('y', [0, 1])
        domain['x'] = scale_translate_x(x)
        domain['y'] = scale_translate_y(y)
    # Scale/translate traces
    for trace in data:
        trace_type = trace.get('type', 'scatter')
        if trace_type in _domain_trace_types:
            perform_scale_translate(trace)
    # Scale/translate subplot containers
    for prop in layout:
        for subplot_type in _subplot_types:
            if prop.startswith(subplot_type):
                perform_scale_translate(layout[prop])
    for prop in layout:
        if prop.startswith('xaxis'):
            xaxis = layout[prop]
            x_domain = xaxis.get('domain', [0, 1])
            xaxis['domain'] = scale_translate_x(x_domain)
        elif prop.startswith('yaxis'):
            yaxis = layout[prop]
            y_domain = yaxis.get('domain', [0, 1])
            yaxis['domain'] = scale_translate_y(y_domain)
    # convert title to annotation
    # This way the annotation will be scaled with the reset of the figure
    annotations = layout.get('annotations', [])
    title = layout.pop('title', None)
    if title:
        titlefont = layout.pop('titlefont', {})
        title_fontsize = titlefont.get('size', 17)
        min_fontsize = 12
        titlefont['size'] = round(min_fontsize +
                                  (title_fontsize - min_fontsize) * scale_x)
        annotations.append({
            'text': title,
            'showarrow': False,
            'xref': 'paper',
            'yref': 'paper',
            'x': 0.5,
            'y': 1.01,
            'xanchor': 'center',
            'yanchor': 'bottom',
            'font': titlefont
        })
        layout['annotations'] = annotations
    # annotations
    for obj in layout.get('annotations', []):
        if obj.get('xref', None) == 'paper':
            obj['x'] = obj.get('x', 0.5) * scale_x + translate_x
        if obj.get('yref', None) == 'paper':
            obj['y'] = obj.get('y', 0.5) * scale_y + translate_y
    # shapes
    for obj in layout.get('shapes', []):
        if obj.get('xref', None) == 'paper':
            obj['x0'] = obj.get('x0', 0.25) * scale_x + translate_x
            obj['x1'] = obj.get('x1', 0.75) * scale_x + translate_x
        if obj.get('yref', None) == 'paper':
            obj['y0'] = obj.get('y0', 0.25) * scale_y + translate_y
            obj['y1'] = obj.get('y1', 0.75) * scale_y + translate_y
    # images
    for obj in layout.get('images', []):
        if obj.get('xref', None) == 'paper':
            obj['x'] = obj.get('x', 0.5) * scale_x + translate_x
            obj['sizex'] = obj.get('sizex', 0) * scale_x
        if obj.get('yref', None) == 'paper':
            obj['y'] = obj.get('y', 0.5) * scale_y + translate_y
            obj['sizey'] = obj.get('sizey', 0) * scale_y
def _merge_layout_objs(obj, subobj):
    """
    Merge layout objects recursively
    Note: This function mutates the input obj dict, but it does not mutate
    the subobj dict
    Parameters
    ----------
    obj: dict
        dict into which the sub-figure dict will be merged
    subobj: dict
        dict that sill be copied and merged into `obj`
    """
    for prop, val in subobj.items():
        if isinstance(val, dict) and prop in obj:
            # recursion
            _merge_layout_objs(obj[prop], val)
        elif (isinstance(val, list) and
              obj.get(prop, None) and
              isinstance(obj[prop][0], dict)):
            # append
            obj[prop].extend(val)
        else:
            # init/overwrite
            obj[prop] = copy.deepcopy(val)
def _compute_subplot_domains(widths, spacing):
    """
    Compute normalized domain tuples for a list of widths and a subplot
    spacing value
    Parameters
    ----------
    widths: list of float
        List of the desired withs of each subplot. The length of this list
        is also the specification of the number of desired subplots
    spacing: float
        Spacing between subplots in normalized coordinates
    Returns
    -------
    list of tuple of float
    """
    # normalize widths
    widths_sum = float(sum(widths))
    total_spacing = (len(widths) - 1) * spacing
    total_width = widths_sum + total_spacing
    relative_spacing = spacing / (widths_sum + total_spacing)
    relative_widths = [(w / total_width) for w in widths]
    domains = []
    for c in range(len(widths)):
        domain_start = c * relative_spacing + sum(relative_widths[:c])
        domain_stop = min(1, domain_start + relative_widths[c])
        domains.append((domain_start, domain_stop))
    return domains
[docs]def get_colorscale(cmap, levels=None, cmin=None, cmax=None):
    """Converts a cmap spec to a plotly colorscale
    Args:
        cmap: A recognized colormap by name or list of colors
        levels: A list or integer declaring the color-levels
        cmin: The lower bound of the color range
        cmax: The upper bound of the color range
    Returns:
        A valid plotly colorscale
    """
    ncolors = levels if isinstance(levels, int) else None
    if isinstance(levels, list):
        ncolors = len(levels) - 1
        if isinstance(cmap, list) and len(cmap) != ncolors:
            raise ValueError('The number of colors in the colormap '
                             'must match the intervals defined in the '
                             'color_levels, expected %d colors found %d.'
                             % (ncolors, len(cmap)))
    try:
        palette = process_cmap(cmap, ncolors)
    except Exception as e:
        colorscale = colors.PLOTLY_SCALES.get(cmap)
        if colorscale is None:
            raise e
        return colorscale
    if isinstance(levels, int):
        colorscale = []
        scale = np.linspace(0, 1, levels+1)
        for i in range(levels+1):
            if i == 0:
                colorscale.append((scale[0], palette[i]))
            elif i == levels:
                colorscale.append((scale[-1], palette[-1]))
            else:
                colorscale.append((scale[i], palette[i-1]))
                colorscale.append((scale[i], palette[i]))
        return colorscale
    elif isinstance(levels, list):
        palette, (cmin, cmax) = color_intervals(
            palette, levels, clip=(cmin, cmax))
    return colors.make_colorscale(palette)