```
# -*- coding: utf-8 -*-
# Copyright (C) 2012 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import collections
import os
import networkx as nx
from networkx.drawing import nx_pydot
import six
def _common_format(g, edge_notation):
lines = []
lines.append("Name: %s" % g.name)
lines.append("Type: %s" % type(g).__name__)
lines.append("Frozen: %s" % nx.is_frozen(g))
lines.append("Density: %0.3f" % nx.density(g))
lines.append("Nodes: %s" % g.number_of_nodes())
for n, n_data in g.nodes_iter(data=True):
if n_data:
lines.append(" - %s (%s)" % (n, n_data))
else:
lines.append(" - %s" % n)
lines.append("Edges: %s" % g.number_of_edges())
for (u, v, e_data) in g.edges_iter(data=True):
if e_data:
lines.append(" %s %s %s (%s)" % (u, edge_notation, v, e_data))
else:
lines.append(" %s %s %s" % (u, edge_notation, v))
return lines
[docs]class Graph(nx.Graph):
"""A graph subclass with useful utility functions."""
def __init__(self, data=None, name=''):
super(Graph, self).__init__(name=name, data=data)
self.frozen = False
[docs] def freeze(self):
"""Freezes the graph so that no more mutations can occur."""
if not self.frozen:
nx.freeze(self)
return self
[docs] def export_to_dot(self):
"""Exports the graph to a dot format (requires pydot library)."""
return nx_pydot.to_pydot(self).to_string()
[docs] def pformat(self):
"""Pretty formats your graph into a string."""
return os.linesep.join(_common_format(self, "<->"))
[docs]class DiGraph(nx.DiGraph):
"""A directed graph subclass with useful utility functions."""
def __init__(self, data=None, name=''):
super(DiGraph, self).__init__(name=name, data=data)
self.frozen = False
[docs] def freeze(self):
"""Freezes the graph so that no more mutations can occur."""
if not self.frozen:
nx.freeze(self)
return self
[docs] def get_edge_data(self, u, v, default=None):
"""Returns a *copy* of the edge attribute dictionary between (u, v).
NOTE(harlowja): this differs from the networkx get_edge_data() as that
function does not return a copy (but returns a reference to the actual
edge data).
"""
try:
return dict(self.adj[u][v])
except KeyError:
return default
[docs] def topological_sort(self):
"""Return a list of nodes in this graph in topological sort order."""
return nx.topological_sort(self)
[docs] def pformat(self):
"""Pretty formats your graph into a string.
This pretty formatted string representation includes many useful
details about your graph, including; name, type, frozeness, node count,
nodes, edge count, edges, graph density and graph cycles (if any).
"""
lines = _common_format(self, "->")
cycles = list(nx.cycles.recursive_simple_cycles(self))
lines.append("Cycles: %s" % len(cycles))
for cycle in cycles:
buf = six.StringIO()
buf.write("%s" % (cycle[0]))
for i in range(1, len(cycle)):
buf.write(" --> %s" % (cycle[i]))
buf.write(" --> %s" % (cycle[0]))
lines.append(" %s" % buf.getvalue())
return os.linesep.join(lines)
[docs] def export_to_dot(self):
"""Exports the graph to a dot format (requires pydot library)."""
return nx_pydot.to_pydot(self).to_string()
[docs] def is_directed_acyclic(self):
"""Returns if this graph is a DAG or not."""
return nx.is_directed_acyclic_graph(self)
[docs] def no_successors_iter(self):
"""Returns an iterator for all nodes with no successors."""
for n in self.nodes_iter():
if not len(self.successors(n)):
yield n
[docs] def no_predecessors_iter(self):
"""Returns an iterator for all nodes with no predecessors."""
for n in self.nodes_iter():
if not len(self.predecessors(n)):
yield n
[docs] def bfs_predecessors_iter(self, n):
"""Iterates breadth first over *all* predecessors of a given node.
This will go through the nodes predecessors, then the predecessor nodes
predecessors and so on until no more predecessors are found.
NOTE(harlowja): predecessor cycles (if they exist) will not be iterated
over more than once (this prevents infinite iteration).
"""
visited = set([n])
queue = collections.deque(self.predecessors_iter(n))
while queue:
pred = queue.popleft()
if pred not in visited:
yield pred
visited.add(pred)
for pred_pred in self.predecessors_iter(pred):
if pred_pred not in visited:
queue.append(pred_pred)
[docs]class OrderedDiGraph(DiGraph):
"""A directed graph subclass with useful utility functions.
This derivative retains node, edge, insertion and iteration
ordering (so that the iteration order matches the insertion
order).
"""
node_dict_factory = collections.OrderedDict
adjlist_dict_factory = collections.OrderedDict
edge_attr_dict_factory = collections.OrderedDict
[docs]class OrderedGraph(Graph):
"""A graph subclass with useful utility functions.
This derivative retains node, edge, insertion and iteration
ordering (so that the iteration order matches the insertion
order).
"""
node_dict_factory = collections.OrderedDict
adjlist_dict_factory = collections.OrderedDict
edge_attr_dict_factory = collections.OrderedDict
[docs]def merge_graphs(graph, *graphs, **kwargs):
"""Merges a bunch of graphs into a new graph.
If no additional graphs are provided the first graph is
returned unmodified otherwise the merged graph is returned.
"""
tmp_graph = graph
allow_overlaps = kwargs.get('allow_overlaps', False)
overlap_detector = kwargs.get('overlap_detector')
if overlap_detector is not None and not six.callable(overlap_detector):
raise ValueError("Overlap detection callback expected to be callable")
elif overlap_detector is None:
overlap_detector = (lambda to_graph, from_graph:
len(to_graph.subgraph(from_graph.nodes_iter())))
for g in graphs:
# This should ensure that the nodes to be merged do not already exist
# in the graph that is to be merged into. This could be problematic if
# there are duplicates.
if not allow_overlaps:
# Attempt to induce a subgraph using the to be merged graphs nodes
# and see if any graph results.
overlaps = overlap_detector(graph, g)
if overlaps:
raise ValueError("Can not merge graph %s into %s since there "
"are %s overlapping nodes (and we do not "
"support merging nodes)" % (g, graph,
overlaps))
graph = nx.algorithms.compose(graph, g)
# Keep the first graphs name.
if graphs:
graph.name = tmp_graph.name
return graph
```