# Copyright 2016 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Pipeline that runs arbitrary pipelines composed in a graph. Some terminology used in the code. dag: Directed acyclic graph. unit: A Pipeline which is run inside DAGPipeline. connection: A key value pair in the DAG dictionary. dependency: The right hand side (value in key value dictionary pair) of a DAG connection. Can be a Pipeline, Input, Key, or dictionary mapping names to one of those. subordinate: Any Input, Pipeline, or Key object that appears in a dependency. shorthand: Invalid things that can be put in the DAG which get converted to valid things before parsing. These things are for convenience. type signature: Something that can be returned from Pipeline's `output_type` or `input_type`. A python class, or dictionary mapping names to classes. """ import itertools # internal imports from magenta.pipelines import pipeline class Output(object): """Represents an output destination for a `DAGPipeline`. Each `Output(name)` instance given to DAGPipeline will be a final output bucket with the same name. If writing output buckets to disk, the names become dataset names. The name can be omitted if connecting `Output()` to a dictionary mapping names to pipelines. """ def __init__(self, name=None): """Create an `Output` with the given name. Args: name: If given, a string name which defines the name of this output. If not given, the names in the dictionary this is connected to will be used as output names. """ self.name = name # `output_type` and `input_type` are set by DAGPipeline. Since `Output` is # not given its type, the type must be infered from what it is connected # to in the DAG. Having `output_type` and `input_type` makes `Output` act # like a `Pipeline` in some cases. self.output_type = None self.input_type = None def __eq__(self, other): return isinstance(other, Output) and other.name == self.name def __hash__(self): return hash(self.name) def __repr__(self): return 'Output(%s)' % self.name class Input(object): """Represents an input source for a `DAGPipeline`. Give an `Input` instance to `DAGPipeline` by connecting `Pipeline` objects to it in the DAG. When `DAGPipeline.transform` is called, the input object will be fed to any Pipeline instances connected to an `Input` given in the DAG. The type given to `Input` will be the `DAGPipeline`'s `input_type`. """ def __init__(self, type_): """Create an `Input` with the given type. Args: type_: The Python class which inputs to `DAGPipeline` should be instances of. `DAGPipeline.input_type` will be this type. """ self.output_type = type_ def __eq__(self, other): return isinstance(other, Input) and other.output_type == self.output_type def __hash__(self): return hash(self.output_type) def __repr__(self): return 'Input(%s)' % self.output_type def _all_are_type(elements, target_type): """Checks that all the given elements are the target type. Args: elements: A list of objects. target_type: The Python class which all elemenets need to be an instance of. Returns: True if every object in `elements` is an instance of `target_type`, and False otherwise. """ return all(isinstance(elem, target_type) for elem in elements) class InvalidDAGException(Exception): """Thrown when the DAG dictionary is not well formatted. This can be because a `destination: dependency` pair is not in the form `Pipeline: Pipeline` or `Pipeline: {'name_1': Pipeline, ...}` (Note that Pipeline or Key objects both are allowed in the dependency). It is also thrown when `Input` is given as a destination, or `Output` is given as a dependency. """ pass class DuplicateNameException(Exception): """Thrown when two `Pipeline` instances in the DAG have the same name. Pipeline names will be used as name spaces for the statistics they produce and we don't want any conflicts. """ pass class BadTopologyException(Exception): """Thrown when there is a directed cycle.""" pass class NotConnectedException(Exception): """Thrown when the DAG is disconnected somewhere. Either because a `Pipeline` used in a dependency has nothing feeding into it, or because a `Pipeline` given as a destination does not feed anywhere. """ pass class TypeMismatchException(Exception): """Thrown when type signatures in a connection don't match. In the DAG's `destination: dependency` pairs, type signatures must match. """ pass class BadInputOrOutputException(Exception): """Thrown when `Input` or `Output` are not used in the graph correctly. Specifically when there are no `Input` objects, more than one `Input` with different types, or there is no `Output` object. """ pass class InvalidDictionaryOutput(Exception): """Thrown when `Output` and dictionaries are not used correctly. Specifically when `Output()` is used without a dictionary dependency, or `Output(name)` is used with a `name` and with a dictionary dependency. """ pass class InvalidTransformOutputException(Exception): """Thrown when a Pipeline does not output types matching its `output_type`. """ pass class DAGPipeline(pipeline.Pipeline): """A directed acyclic graph pipeline. This Pipeline can be given an arbitrary graph composed of Pipeline instances and will run all of those pipelines feeding outputs to inputs. See README.md for details. Use DAGPipeline to compose multiple smaller pipelines together. """ def __init__(self, dag, pipeline_name='DAGPipeline'): """Constructs a DAGPipeline. A DAG (direct acyclic graph) is given which fully specifies what the DAGPipeline runs. Args: dag: A dictionary mapping `Pipeline` or `Output` instances to any of `Pipeline`, `Key`, `Input`. `dag` defines a directed acyclic graph. pipeline_name: String name of this Pipeline object. Raises: InvalidDAGException: If each key value pair in the `dag` dictionary is not of the form (Pipeline or Output): (Pipeline, Key, or Input). TypeMismatchException: The type signature of each key and value in `dag` must match, otherwise this will be thrown. DuplicateNameException: If two `Pipeline` instances in `dag` have the same string name. BadInputOrOutputException: If there are no `Output` instaces in `dag` or not exactly one `Input` plus type combination in `dag`. InvalidDictionaryOutput: If `Output()` is not connected to a dictionary, or `Output(name)` is not connected to a Pipeline, Key, or Input instance. NotConnectedException: If a `Pipeline` used in a dependency has nothing feeding into it, or a `Pipeline` used as a destination does not feed anywhere. BadTopologyException: If there there is a directed cycle in `dag`. Exception: Misc. exceptions. """ # Expand DAG shorthand. self.dag = dict(self._expand_dag_shorthands(dag)) # Make sure DAG is valid. # Input types match output types. Nothing depends on outputs. # Things that require input get input. DAG is composed of correct types. for unit, dependency in self.dag.items(): if not isinstance(unit, (pipeline.Pipeline, Output)): raise InvalidDAGException( 'Dependency {%s: %s} is invalid. Left hand side value %s must ' 'either be a Pipeline or Output object' % (unit, dependency, unit)) if isinstance(dependency, dict): if not all([isinstance(name, basestring) for name in dependency]): raise InvalidDAGException( 'Dependency {%s: %s} is invalid. Right hand side keys %s must be ' 'strings' % (unit, dependency, dependency.keys())) values = dependency.values() else: values = [dependency] for v in values: if not (isinstance(v, pipeline.Pipeline) or (isinstance(v, pipeline.Key) and isinstance(v.unit, pipeline.Pipeline)) or isinstance(v, Input)): raise InvalidDAGException( 'Dependency {%s: %s} is invalid. Right hand side value %s must ' 'be either a Pipeline, Key, or Input object' % (unit, dependency, v)) # Check that all input types match output types. if isinstance(unit, Output): # Output objects don't know their types. continue if unit.input_type != self._get_type_signature_for_dependency(dependency): raise TypeMismatchException( 'Invalid dependency {%s: %s}. Required `input_type` of left hand ' 'side is %s. Output type of right hand side is %s.' % (unit, dependency, unit.input_type, self._get_type_signature_for_dependency(dependency))) # Make sure all Pipeline names are unique, so that Statistic objects don't # clash. sorted_unit_names = sorted( [(unit, unit.name) for unit in self.dag], key=lambda t: t[1]) for index, (unit, name) in enumerate(sorted_unit_names[:-1]): if name == sorted_unit_names[index + 1][1]: other_unit = sorted_unit_names[index + 1][0] raise DuplicateNameException( 'Pipelines %s and %s both have name "%s". Each Pipeline must have ' 'a unique name.' % (unit, other_unit, name)) # Find Input and Output objects and make sure they are being used correctly. self.outputs = [unit for unit in self.dag if isinstance(unit, Output)] self.output_names = dict([(output.name, output) for output in self.outputs]) for output in self.outputs: output.input_type = output.output_type = ( self._get_type_signature_for_dependency(self.dag[output])) inputs = set() for deps in self.dag.values(): units = self._get_units(deps) for unit in units: if isinstance(unit, Input): inputs.add(unit) if len(inputs) != 1: if not inputs: raise BadInputOrOutputException( 'No Input object found. Input is the start of the pipeline.') else: raise BadInputOrOutputException( 'Multiple Input objects found. Only one input is supported.') if not self.outputs: raise BadInputOrOutputException( 'No Output objects found. Output is the end of the pipeline.') self.input = inputs.pop() # Compute output_type for self and call super constructor. output_signature = dict([(output.name, output.output_type) for output in self.outputs]) super(DAGPipeline, self).__init__( input_type=self.input.output_type, output_type=output_signature, name=pipeline_name) # Make sure all Pipeline objects have DAG vertices that feed into them, # and feed their output into other DAG vertices. all_subordinates = ( set([dep_unit for unit in self.dag for dep_unit in self._get_units(self.dag[unit])]) .difference(set([self.input]))) all_destinations = set(self.dag.keys()).difference(set(self.outputs)) if all_subordinates != all_destinations: units_with_no_input = all_subordinates.difference(all_destinations) units_with_no_output = all_destinations.difference(all_subordinates) if units_with_no_input: raise NotConnectedException( '%s is given as a dependency in the DAG but has nothing connected ' 'to it. Nothing in the DAG feeds into it.' % units_with_no_input.pop()) else: raise NotConnectedException( '%s is given as a destination in the DAG but does not output ' 'anywhere. It is a deadend.' % units_with_no_output.pop()) # Construct topological ordering to determine the execution order of the # pipelines. # https://en.wikipedia.org/wiki/Topological_sorting#Kahn.27s_algorithm # `graph` maps a pipeline to the pipelines it depends on. Each dict value # is a list with the dependency pipelines in the 0th position, and a count # of forward connections to the key pipeline (how many pipelines use this # pipeline as a dependency). graph = dict([(unit, [self._get_units(self.dag[unit]), 0]) for unit in self.dag]) graph[self.input] = [[], 0] for unit, (forward_connections, _) in graph.items(): for to_unit in forward_connections: graph[to_unit][1] += 1 self.call_list = call_list = [] # Topologically sorted elements go here. nodes = set(self.outputs) while nodes: n = nodes.pop() call_list.append(n) for m in graph[n][0]: graph[m][1] -= 1 if graph[m][1] == 0: nodes.add(m) elif graph[m][1] < 0: raise Exception( 'Congradulations, you found a bug! Please report this issue at ' 'https://github.com/tensorflow/magenta/issues and copy/paste the ' 'following: dag=%s, graph=%s, call_list=%s' % (self.dag, graph, call_list)) # Check for cycles by checking if any edges remain. for unit in graph: if graph[unit][1] != 0: raise BadTopologyException('Dependency loop found on %s' % unit) # Note: this exception should never be raised. Disconnected graphs will be # caught where NotConnectedException is raised. If this exception goes off # there is likely a bug. if set(call_list) != set( list(all_subordinates) + self.outputs + [self.input]): raise BadTopologyException('Not all pipelines feed into an output or ' 'there is a dependency loop.') call_list.reverse() assert call_list[0] == self.input def _expand_dag_shorthands(self, dag): """Expand DAG shorthand. Currently the only shorthand is "direct connection". A direct connection is a connection {a: b} where a.input_type is a dict, b.output_type is a dict, and a.input_type == b.output_type. This is not actually valid, but we can convert it to a valid connection. {a: b} is expanded to {a: {"name_1": b["name_1"], "name_2": b["name_2"], ...}}. {Output(): {"name_1", obj1, "name_2": obj2, ...} is expanded to {Output("name_1"): obj1, Output("name_2"): obj2, ...}. Args: dag: A dictionary encoding the DAG. Yields: Key, value pairs for a new dag dictionary. Raises: InvalidDictionaryOutput: If `Output` is not used correctly. """ for key, val in dag.items(): # Direct connection. if (isinstance(key, pipeline.Pipeline) and isinstance(val, pipeline.Pipeline) and isinstance(key.input_type, dict) and key.input_type == val.output_type): yield key, dict([(name, val[name]) for name in val.output_type]) elif key == Output(): if (isinstance(val, pipeline.Pipeline) and isinstance(val.output_type, dict)): dependency = [(name, val[name]) for name in val.output_type] elif isinstance(val, dict): dependency = val.items() else: raise InvalidDictionaryOutput( 'Output() with no name can only be connected to a dictionary or ' 'a Pipeline whose output_type is a dictionary. Found Output() ' 'connected to %s' % val) for name, subordinate in dependency: yield Output(name), subordinate elif isinstance(key, Output): if isinstance(val, dict): raise InvalidDictionaryOutput( 'Output("%s") which has name "%s" can only be connectd to a ' 'single input, not dictionary %s. Use Output() without name ' 'instead.' % (key.name, key.name, val)) if (isinstance(val, pipeline.Pipeline) and isinstance(val.output_type, dict)): raise InvalidDictionaryOutput( 'Output("%s") which has name "%s" can only be connectd to a ' 'single input, not pipeline %s which has dictionary ' 'output_type %s. Use Output() without name instead.' % (key.name, key.name, val, val.output_type)) yield key, val else: yield key, val def _get_units(self, dependency): """Gets list of units from a dependency.""" dep_list = [] if isinstance(dependency, dict): dep_list.extend(dependency.values()) else: dep_list.append(dependency) return [self._validate_subordinate(sub) for sub in dep_list] def _validate_subordinate(self, subordinate): """Verifies that subordinate is Input, Key, or Pipeline.""" if isinstance(subordinate, pipeline.Pipeline): return subordinate if isinstance(subordinate, pipeline.Key): if not isinstance(subordinate.unit, pipeline.Pipeline): raise InvalidDAGException( 'Key object %s does not have a valid Pipeline' % subordinate) return subordinate.unit if isinstance(subordinate, Input): return subordinate raise InvalidDAGException( 'Looking for Pipeline, Key, or Input object, but got %s' % type(subordinate)) def _get_type_signature_for_dependency(self, dependency): """Gets the type signature of the dependency output.""" if isinstance(dependency, (pipeline.Pipeline, pipeline.Key, Input)): return dependency.output_type return dict([(name, sub_dep.output_type) for name, sub_dep in dependency.items()]) def transform(self, input_object): """Runs the DAG on the given input. All pipelines in the DAG will run. Args: input_object: Any object. The required type depends on implementation. Returns: A dictionary mapping output names to lists of objects. The object types depend on implementation. Each output name corresponds to an output collection. See get_output_names method. """ def stats_accumulator(unit, unit_inputs, cumulative_stats): for single_input in unit_inputs: results_ = unit.transform(single_input) stats = unit.get_stats() cumulative_stats.extend(stats) yield results_ stats = [] results = {self.input: [input_object]} for unit in self.call_list[1:]: # Compute transformation. if isinstance(unit, Output): unit_outputs = self._get_outputs_as_signature(self.dag[unit], results) else: unit_inputs = self._get_inputs_for_unit(unit, results) if not unit_inputs: # If this unit has no inputs don't run it. results[unit] = [] continue unjoined_outputs = list( stats_accumulator(unit, unit_inputs, stats)) unit_outputs = self._join_lists_or_dicts(unjoined_outputs, unit) results[unit] = unit_outputs self._set_stats(stats) return dict([(output.name, results[output]) for output in self.outputs]) def _get_outputs_as_signature(self, dependency, outputs): """Returns a list or dict which matches the type signature of dependency. Args: dependency: Input, Key, Pipeline instance, or dictionary mapping names to those values. outputs: A database of computed unit outputs. A dictionary mapping Pipeline to list of objects. Returns: A list or dictionary of computed unit outputs which matches the type signature of the given dependency. """ def _get_outputs_for_key(unit_or_key, outputs): if isinstance(unit_or_key, pipeline.Key): if not outputs[unit_or_key.unit]: # If there are no outputs, just return nothing. return outputs[unit_or_key.unit] assert isinstance(outputs[unit_or_key.unit], dict) return outputs[unit_or_key.unit][unit_or_key.key] assert isinstance(unit_or_key, (pipeline.Pipeline, Input)) return outputs[unit_or_key] if isinstance(dependency, dict): return dict([(name, _get_outputs_for_key(unit_or_key, outputs)) for name, unit_or_key in dependency.items()]) return _get_outputs_for_key(dependency, outputs) def _get_inputs_for_unit(self, unit, results, list_operation=itertools.product): """Creates valid inputs for the given unit from the outputs in `results`. Args: unit: The `Pipeline` to create inputs for. results: A database of computed unit outputs. A dictionary mapping Pipeline to list of objects. list_operation: A function that maps lists of inputs to a single list of tuples, where each tuple is an input. This is used when `unit` takes a dictionary as input. Each tuple is used as the values for a dictionary input. This can be thought of as taking a sort of transpose of a ragged 2D array. The default is `itertools.product` which takes the cartesian product of the input lists. Returns: If `unit` takes a single input, a list of objects. If `unit` takes a dictionary input, a list of dictionaries each mapping string name to object. """ previous_outputs = self._get_outputs_as_signature(self.dag[unit], results) if isinstance(previous_outputs, dict): names = list(previous_outputs.keys()) lists = [previous_outputs[name] for name in names] stack = list_operation(*lists) return [dict(zip(names, values)) for values in stack] else: return previous_outputs def _join_lists_or_dicts(self, outputs, unit): """Joins many lists or dicts of outputs into a single list or dict. This function also validates that the outputs are correct for the given Pipeline. If `outputs` is a list of lists, the lists are concated and the type of each object must match `unit.output_type`. If `output` is a list of dicts (mapping string names to lists), each key has its lists concated across all the dicts. The keys and types are validated against `unit.output_type`. Args: outputs: A list of lists, or list of dicts which map string names to lists. unit: A Pipeline which every output in `outputs` will be validated against. `unit` must produce the outputs it says it will produce. Returns: If `outputs` is a list of lists, a single list of outputs. If `outputs` is a list of dicts, a single dictionary mapping string names to lists of outputs. Raises: InvalidTransformOutputException: If anything in `outputs` does not match the type signature given by `unit.output_type`. """ if not outputs: return [] if isinstance(unit.output_type, dict): concated = dict([(key, list()) for key in unit.output_type.keys()]) for d in outputs: if not isinstance(d, dict): raise InvalidTransformOutputException( 'Expected dictionary output for %s with output type %s but ' 'instead got type %s' % (unit, unit.output_type, type(d))) if set(d.keys()) != set(unit.output_type.keys()): raise InvalidTransformOutputException( 'Got dictionary output with incorrect keys for %s. Got %s. ' 'Expected %s' % (unit, d.keys(), unit.output_type.keys())) for k, val in d.items(): if not isinstance(val, list): raise InvalidTransformOutputException( 'Output from %s for key %s is not a list.' % (unit, k)) if not _all_are_type(val, unit.output_type[k]): raise InvalidTransformOutputException( 'Some outputs from %s for key %s are not of expected type %s. ' 'Got types %s' % (unit, k, unit.output_type[k], [type(inst) for inst in val])) concated[k] += val else: concated = [] for l in outputs: if not isinstance(l, list): raise InvalidTransformOutputException( 'Expected list output for %s with outpu type %s but instead got ' 'type %s' % (unit, unit.output_type, type(l))) if not _all_are_type(l, unit.output_type): raise InvalidTransformOutputException( 'Some outputs from %s are not of expected type %s. Got types %s' % (unit, unit.output_type, [type(inst) for inst in l])) concated += l return concated