from __future__ import division
from collections import Counter
from functools import cmp_to_key
from itertools import cycle
import param
import numpy as np
from ..core.dimension import Dimension
from ..core.data import Dataset
from ..core.operation import Operation
from ..core.util import OrderedDict, unique_array, RecursionError, get_param_values
from .graphs import Graph, Nodes, EdgePaths
from .util import quadratic_bezier
class _layout_sankey(Operation):
"""
Computes a Sankey diagram from a Graph element for internal use in
the Sankey element constructor.
Adapted from d3-sankey under BSD-3 license.
"""
bounds = param.NumericTuple(default=(0, 0, 1000, 500))
node_width = param.Number(default=15, doc="""
Width of the nodes.""")
node_padding = param.Integer(default=None, allow_None=True, doc="""
Number of pixels of padding relative to the bounds.""")
iterations = param.Integer(default=32, doc="""
Number of iterations to run the layout algorithm.""")
node_sort = param.Boolean(default=True, doc="""
Sort nodes in ascending breadth.""")
def _process(self, element, key=None):
nodes, edges, graph = self.layout(element, **self.p)
params = get_param_values(element)
return Sankey((element.data, nodes, edges), sankey=graph, **params)
def layout(self, element, **params):
self.p = param.ParamOverrides(self, params)
graph = {'nodes': [], 'links': []}
self.computeNodeLinks(element, graph)
self.computeNodeValues(graph)
self.computeNodeDepths(graph)
self.computeNodeBreadths(graph)
self.computeLinkBreadths(graph)
paths = self.computePaths(graph)
node_data = []
for node in graph['nodes']:
node_data.append((np.mean([node['x0'], node['x1']]),
np.mean([node['y0'], node['y1']]),
node['index'])+tuple(node['values']))
if element.nodes.ndims == 3:
kdims = element.nodes.kdims
elif element.nodes.ndims:
kdims = element.node_type.kdims[:2] + element.nodes.kdims[-1:]
else:
kdims = element.node_type.kdims
nodes = element.node_type(node_data, kdims=kdims, vdims=element.nodes.vdims)
edges = element.edge_type(paths)
return nodes, edges, graph
def computePaths(self, graph):
paths = []
for link in graph['links']:
source, target = link['source'], link['target']
x0, y0 = source['x1'], link['y0']
x1, y1 = target['x0'], link['y1']
start = np.array([(x0, link['width']+y0),
(x0, y0)])
src = (x0, y0)
ctr1 = ((x0+x1)/2., y0)
ctr2 = ((x0+x1)/2., y1)
tgt = (x1, y1)
bottom = quadratic_bezier(src, tgt, ctr1, ctr2)
mid = np.array([(x1, y1),
(x1, y1+link['width'])])
xmid = (x0+x1)/2.
y0 = y0+link['width']
y1 = y1+link['width']
src = (x1, y1)
ctr1 = (xmid, y1)
ctr2 = (xmid, y0)
tgt = (x0, y0)
top = quadratic_bezier(src, tgt, ctr1, ctr2)
spline = np.concatenate([start, bottom, mid, top])
paths.append(spline)
return paths
@classmethod
def weightedSource(cls, link):
return cls.nodeCenter(link['source']) * link['value']
@classmethod
def weightedTarget(cls, link):
return cls.nodeCenter(link['target']) * link['value']
@classmethod
def nodeCenter(cls, node):
return (node['y0'] + node['y1']) / 2
@classmethod
def ascendingBreadth(cls, a, b):
return int(a['y0'] - b['y0'])
@classmethod
def ascendingSourceBreadth(cls, a, b):
return cls.ascendingBreadth(a['source'], b['source']) | a['index'] - b['index']
@classmethod
def ascendingTargetBreadth(cls, a, b):
return cls.ascendingBreadth(a['target'], b['target']) | a['index'] - b['index']
@classmethod
def computeNodeLinks(cls, element, graph):
"""
Populate the sourceLinks and targetLinks for each node.
Also, if the source and target are not objects, assume they are indices.
"""
index = element.nodes.kdims[-1]
node_map = {}
if element.nodes.vdims:
values = zip(*(element.nodes.dimension_values(d)
for d in element.nodes.vdims))
else:
values = cycle([tuple()])
for index, vals in zip(element.nodes.dimension_values(index), values):
node = {'index': index, 'sourceLinks': [], 'targetLinks': [], 'values': vals}
graph['nodes'].append(node)
node_map[index] = node
links = [element.dimension_values(d) for d in element.dimensions()[:3]]
for i, (src, tgt, value) in enumerate(zip(*links)):
source, target = node_map[src], node_map[tgt]
link = dict(index=i, source=source, target=target, value=value)
graph['links'].append(link)
source['sourceLinks'].append(link)
target['targetLinks'].append(link)
@classmethod
def computeNodeValues(cls, graph):
"""
Compute the value (size) of each node by summing the associated links.
"""
for node in graph['nodes']:
source_val = np.sum([l['value'] for l in node['sourceLinks']])
target_val = np.sum([l['value'] for l in node['targetLinks']])
node['value'] = max([source_val, target_val])
def computeNodeDepths(self, graph):
"""
Iteratively assign the depth (x-position) for each node.
Nodes are assigned the maximum depth of incoming neighbors plus one;
nodes with no incoming links are assigned depth zero, while
nodes with no outgoing links are assigned the maximum depth.
"""
nodes = graph['nodes']
depth = 0
while nodes:
next_nodes = []
for node in nodes:
node['depth'] = depth
for link in node['sourceLinks']:
if not any(link['target'] is node for node in next_nodes):
next_nodes.append(link['target'])
nodes = next_nodes
depth += 1
if depth > 10000:
raise RecursionError('Sankey diagrams only support acyclic graphs.')
nodes = graph['nodes']
depth = 0
while nodes:
next_nodes = []
for node in nodes:
node['height'] = depth
for link in node['targetLinks']:
if not any(link['source'] is node for node in next_nodes):
next_nodes.append(link['source'])
nodes = next_nodes
depth += 1
if depth > 10000:
raise RecursionError('Sankey diagrams only support acyclic graphs.')
x0, _, x1, _ = self.p.bounds
dx = self.p.node_width
kx = (x1 - x0 - dx) / (depth - 1)
for node in graph['nodes']:
d = node['depth'] if node['sourceLinks'] else depth - 1
node['x0'] = x0 + max([0, min([depth-1, np.floor(d)]) * kx])
node['x1'] = node['x0'] + dx
def computeNodeBreadths(self, graph):
node_map = OrderedDict()
depths = Counter()
for n in graph['nodes']:
if n['x0'] not in node_map:
node_map[n['x0']] = []
node_map[n['x0']].append(n)
depths[n['depth']] += 1
_, y0, _, y1 = self.p.bounds
py = self.p.node_padding
if py is None:
max_depth = max(depths.values()) - 1 if depths else 1
height = self.p.bounds[3] - self.p.bounds[1]
py = min((height * 0.1) / max_depth, 20) if max_depth else 20
def initializeNodeBreadth():
kys = []
for nodes in node_map.values():
nsum = np.sum([node['value'] for node in nodes])
ky = (y1 - y0 - (len(nodes)-1) * py) / nsum
kys.append(ky)
ky = np.min(kys) if len(kys) else np.nan
for nodes in node_map.values():
for i, node in enumerate(nodes):
node['y0'] = i
node['y1'] = i + node['value'] * ky
for link in graph['links']:
link['width'] = link['value'] * ky
def relaxLeftToRight(alpha):
for nodes in node_map.values():
for node in nodes:
if not node['targetLinks']:
continue
weighted = sum([self.weightedSource(l) for l in node['targetLinks']])
tsum = sum([l['value'] for l in node['targetLinks']])
center = self.nodeCenter(node)
dy = (weighted/tsum - center)*alpha
node['y0'] += dy
node['y1'] += dy
def relaxRightToLeft(alpha):
for nodes in list(node_map.values())[::-1]:
for node in nodes:
if not node['sourceLinks']:
continue
weighted = sum([self.weightedTarget(l) for l in node['sourceLinks']])
tsum = sum([l['value'] for l in node['sourceLinks']])
center = self.nodeCenter(node)
dy = (weighted/tsum - center)*alpha
node['y0'] += dy
node['y1'] += dy
def resolveCollisions():
for nodes in node_map.values():
y = y0
if self.p.node_sort:
nodes.sort(key=cmp_to_key(self.ascendingBreadth))
for node in nodes:
dy = y-node['y0']
if dy > 0:
node['y0'] += dy
node['y1'] += dy
y = node['y1'] + py
dy = y-py-y1
if dy > 0:
node['y0'] -= dy
node['y1'] -= dy
y = node['y0']
for node in nodes[:-1][::-1]:
dy = node['y1'] + py - y;
if dy>0:
node['y0'] -= dy
node['y1'] -= dy
y = node['y0']
initializeNodeBreadth()
resolveCollisions()
alpha = 1
for _ in range(self.p.iterations):
alpha = alpha*0.99
relaxRightToLeft(alpha)
resolveCollisions()
relaxLeftToRight(alpha)
resolveCollisions()
@classmethod
def computeLinkBreadths(cls, graph):
for node in graph['nodes']:
node['sourceLinks'].sort(key=cmp_to_key(cls.ascendingTargetBreadth))
node['targetLinks'].sort(key=cmp_to_key(cls.ascendingSourceBreadth))
for node in graph['nodes']:
y0 = y1 = node['y0']
for link in node['sourceLinks']:
link['y0'] = y0
y0 += link['width']
for link in node['targetLinks']:
link['y1'] = y1
y1 += link['width']
[docs]class Sankey(Graph):
"""
Sankey is an acyclic, directed Graph type that represents the flow
of some quantity between its nodes.
"""
group = param.String(default='Sankey', constant=True)
vdims = param.List(default=[Dimension('Value')])
def __init__(self, data, kdims=None, vdims=None, **params):
if data is None:
data = []
if isinstance(data, tuple):
data = data + (None,)*(3-len(data))
edges, nodes, edgepaths = data
else:
edges, nodes, edgepaths = data, None, None
sankey_graph = params.pop('sankey', None)
compute = not (sankey_graph and isinstance(nodes, Nodes) and isinstance(edgepaths, EdgePaths))
super(Graph, self).__init__(edges, kdims=kdims, vdims=vdims, **params)
if compute:
if nodes is None:
src = self.dimension_values(0, expanded=False)
tgt = self.dimension_values(1, expanded=False)
values = unique_array(np.concatenate([src, tgt]))
nodes = Dataset(values, 'index')
elif not isinstance(nodes, Dataset):
try:
nodes = Dataset(nodes)
except:
nodes = Dataset(nodes, 'index')
if not nodes.kdims:
raise ValueError('Could not determine index in supplied node data. '
'Ensure data has at least one key dimension, '
'which matches the node ids on the edges.')
self._nodes = nodes
nodes, edgepaths, graph = _layout_sankey.instance().layout(self)
self._nodes = nodes
self._edgepaths = edgepaths
self._sankey = graph
else:
if not isinstance(nodes, self.node_type):
raise TypeError("Expected Nodes object in data, found %s."
% type(nodes))
self._nodes = nodes
if not isinstance(edgepaths, self.edge_type):
raise TypeError("Expected EdgePaths object in data, found %s."
% type(edgepaths))
self._edgepaths = edgepaths
self._sankey = sankey_graph
self._validate()
[docs] def clone(self, data=None, shared_data=True, new_type=None, link=True,
*args, **overrides):
if data is None:
overrides['sankey'] = self._sankey
return super(Sankey, self).clone(data, shared_data, new_type, link,
*args, **overrides)