Source code for taskflow.types.graph

# -*- coding: utf-8 -*-

#    Copyright (C) 2012 Yahoo! Inc. All Rights Reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

import collections
import os

import networkx as nx
from networkx.drawing import nx_pydot
import six


def _common_format(g, edge_notation):
    lines = []
    lines.append("Name: %s" % g.name)
    lines.append("Type: %s" % type(g).__name__)
    lines.append("Frozen: %s" % nx.is_frozen(g))
    lines.append("Density: %0.3f" % nx.density(g))
    lines.append("Nodes: %s" % g.number_of_nodes())
    for n, n_data in g.nodes_iter(data=True):
        if n_data:
            lines.append("  - %s (%s)" % (n, n_data))
        else:
            lines.append("  - %s" % n)
    lines.append("Edges: %s" % g.number_of_edges())
    for (u, v, e_data) in g.edges_iter(data=True):
        if e_data:
            lines.append("  %s %s %s (%s)" % (u, edge_notation, v, e_data))
        else:
            lines.append("  %s %s %s" % (u, edge_notation, v))
    return lines


[docs]class Graph(nx.Graph): """A graph subclass with useful utility functions.""" def __init__(self, data=None, name=''): super(Graph, self).__init__(name=name, data=data) self.frozen = False
[docs] def freeze(self): """Freezes the graph so that no more mutations can occur.""" if not self.frozen: nx.freeze(self) return self
[docs] def export_to_dot(self): """Exports the graph to a dot format (requires pydot library).""" return nx_pydot.to_pydot(self).to_string()
[docs] def pformat(self): """Pretty formats your graph into a string.""" return os.linesep.join(_common_format(self, "<->"))
[docs]class DiGraph(nx.DiGraph): """A directed graph subclass with useful utility functions.""" def __init__(self, data=None, name=''): super(DiGraph, self).__init__(name=name, data=data) self.frozen = False
[docs] def freeze(self): """Freezes the graph so that no more mutations can occur.""" if not self.frozen: nx.freeze(self) return self
[docs] def get_edge_data(self, u, v, default=None): """Returns a *copy* of the edge attribute dictionary between (u, v). NOTE(harlowja): this differs from the networkx get_edge_data() as that function does not return a copy (but returns a reference to the actual edge data). """ try: return dict(self.adj[u][v]) except KeyError: return default
[docs] def topological_sort(self): """Return a list of nodes in this graph in topological sort order.""" return nx.topological_sort(self)
[docs] def pformat(self): """Pretty formats your graph into a string. This pretty formatted string representation includes many useful details about your graph, including; name, type, frozeness, node count, nodes, edge count, edges, graph density and graph cycles (if any). """ lines = _common_format(self, "->") cycles = list(nx.cycles.recursive_simple_cycles(self)) lines.append("Cycles: %s" % len(cycles)) for cycle in cycles: buf = six.StringIO() buf.write("%s" % (cycle[0])) for i in range(1, len(cycle)): buf.write(" --> %s" % (cycle[i])) buf.write(" --> %s" % (cycle[0])) lines.append(" %s" % buf.getvalue()) return os.linesep.join(lines)
[docs] def export_to_dot(self): """Exports the graph to a dot format (requires pydot library).""" return nx_pydot.to_pydot(self).to_string()
[docs] def is_directed_acyclic(self): """Returns if this graph is a DAG or not.""" return nx.is_directed_acyclic_graph(self)
[docs] def no_successors_iter(self): """Returns an iterator for all nodes with no successors.""" for n in self.nodes_iter(): if not len(self.successors(n)): yield n
[docs] def no_predecessors_iter(self): """Returns an iterator for all nodes with no predecessors.""" for n in self.nodes_iter(): if not len(self.predecessors(n)): yield n
[docs] def bfs_predecessors_iter(self, n): """Iterates breadth first over *all* predecessors of a given node. This will go through the nodes predecessors, then the predecessor nodes predecessors and so on until no more predecessors are found. NOTE(harlowja): predecessor cycles (if they exist) will not be iterated over more than once (this prevents infinite iteration). """ visited = set([n]) queue = collections.deque(self.predecessors_iter(n)) while queue: pred = queue.popleft() if pred not in visited: yield pred visited.add(pred) for pred_pred in self.predecessors_iter(pred): if pred_pred not in visited: queue.append(pred_pred)
[docs]class OrderedDiGraph(DiGraph): """A directed graph subclass with useful utility functions. This derivative retains node, edge, insertion and iteration ordering (so that the iteration order matches the insertion order). """ node_dict_factory = collections.OrderedDict adjlist_dict_factory = collections.OrderedDict edge_attr_dict_factory = collections.OrderedDict
[docs]class OrderedGraph(Graph): """A graph subclass with useful utility functions. This derivative retains node, edge, insertion and iteration ordering (so that the iteration order matches the insertion order). """ node_dict_factory = collections.OrderedDict adjlist_dict_factory = collections.OrderedDict edge_attr_dict_factory = collections.OrderedDict
[docs]def merge_graphs(graph, *graphs, **kwargs): """Merges a bunch of graphs into a new graph. If no additional graphs are provided the first graph is returned unmodified otherwise the merged graph is returned. """ tmp_graph = graph allow_overlaps = kwargs.get('allow_overlaps', False) overlap_detector = kwargs.get('overlap_detector') if overlap_detector is not None and not six.callable(overlap_detector): raise ValueError("Overlap detection callback expected to be callable") elif overlap_detector is None: overlap_detector = (lambda to_graph, from_graph: len(to_graph.subgraph(from_graph.nodes_iter()))) for g in graphs: # This should ensure that the nodes to be merged do not already exist # in the graph that is to be merged into. This could be problematic if # there are duplicates. if not allow_overlaps: # Attempt to induce a subgraph using the to be merged graphs nodes # and see if any graph results. overlaps = overlap_detector(graph, g) if overlaps: raise ValueError("Can not merge graph %s into %s since there " "are %s overlapping nodes (and we do not " "support merging nodes)" % (g, graph, overlaps)) graph = nx.algorithms.compose(graph, g) # Keep the first graphs name. if graphs: graph.name = tmp_graph.name return graph