426 lines
15 KiB
Python
426 lines
15 KiB
Python
|
# Copyright 2016 Google Inc. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
"""For running data processing pipelines."""
|
||
|
|
||
|
import abc
|
||
|
import inspect
|
||
|
import os.path
|
||
|
|
||
|
# internal imports
|
||
|
import tensorflow as tf
|
||
|
|
||
|
from magenta.pipelines import statistics
|
||
|
|
||
|
|
||
|
class InvalidTypeSignatureException(Exception):
|
||
|
"""Thrown when `Pipeline.input_type` or `Pipeline.output_type` is not valid.
|
||
|
"""
|
||
|
pass
|
||
|
|
||
|
|
||
|
class InvalidStatisticsException(Exception):
|
||
|
"""Thrown when stats produced by a `Pipeline` are not valid."""
|
||
|
pass
|
||
|
|
||
|
|
||
|
class Key(object):
|
||
|
"""Represents a get operation on a Pipeline type signature.
|
||
|
|
||
|
If a pipeline instance `my_pipeline` has `output_type`
|
||
|
{'key_1': Type1, 'key_2': Type2}, then Key(my_pipeline, 'key_1'),
|
||
|
represents the output type Type1. And likewise Key(my_pipeline, 'key_2')
|
||
|
represents Type2.
|
||
|
|
||
|
Calling __getitem__ on a pipeline will return a Key instance.
|
||
|
So my_pipeline['key_1'] returns Key(my_pipeline, 'key_1'), and so on.
|
||
|
|
||
|
Key objects are used for assembling a directed acyclic graph of Pipeline
|
||
|
instances. See dag_pipeline.py.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, unit, key):
|
||
|
if not isinstance(unit, Pipeline):
|
||
|
raise ValueError('Cannot take key of non Pipeline %s' % unit)
|
||
|
if not isinstance(unit.output_type, dict):
|
||
|
raise KeyError(
|
||
|
'Cannot take key %s of %s because output type %s is not a dictionary'
|
||
|
% (key, unit, unit.output_type))
|
||
|
if key not in unit.output_type:
|
||
|
raise KeyError('Key %s is not valid for %s with output type %s'
|
||
|
% (key, unit, unit.output_type))
|
||
|
self.key = key
|
||
|
self.unit = unit
|
||
|
self.output_type = unit.output_type[key]
|
||
|
|
||
|
def __repr__(self):
|
||
|
return 'Key(%s, %s)' % (self.unit, self.key)
|
||
|
|
||
|
|
||
|
def _guarantee_dict(given, default_name):
|
||
|
if not isinstance(given, dict):
|
||
|
return {default_name: dict}
|
||
|
return given
|
||
|
|
||
|
|
||
|
def _assert_valid_type_signature(type_sig, type_sig_name):
|
||
|
"""Checks that the given type signature is valid.
|
||
|
|
||
|
Valid type signatures are either a single Python class, or a dictionary
|
||
|
mapping string names to Python classes.
|
||
|
|
||
|
Throws a well formatted exception when invalid.
|
||
|
|
||
|
Args:
|
||
|
type_sig: Type signature to validate.
|
||
|
type_sig_name: Variable name of the type signature. This is used in
|
||
|
exception descriptions.
|
||
|
|
||
|
Raises:
|
||
|
InvalidTypeSignatureException: If `type_sig` is not valid.
|
||
|
"""
|
||
|
if isinstance(type_sig, dict):
|
||
|
for k, val in type_sig.items():
|
||
|
if not isinstance(k, basestring):
|
||
|
raise InvalidTypeSignatureException(
|
||
|
'%s key %s must be a string.' % (type_sig_name, k))
|
||
|
if not inspect.isclass(val):
|
||
|
raise InvalidTypeSignatureException(
|
||
|
'%s %s at key %s must be a Python class.' % (type_sig_name, val, k))
|
||
|
else:
|
||
|
if not inspect.isclass(type_sig):
|
||
|
raise InvalidTypeSignatureException(
|
||
|
'%s %s must be a Python class.' % (type_sig_name, type_sig))
|
||
|
|
||
|
|
||
|
class Pipeline(object):
|
||
|
"""An abstract class for data processing pipelines that transform datasets.
|
||
|
|
||
|
A Pipeline can transform one or many inputs to one or many outputs. When there
|
||
|
are many inputs or outputs, each input/output is assigned a string name.
|
||
|
|
||
|
The `transform` method converts a given input or dictionary of inputs to
|
||
|
a list of transformed outputs, or a dictionary mapping names to lists of
|
||
|
transformed outputs for each name.
|
||
|
|
||
|
The `get_stats` method returns any statistics that were collected during the
|
||
|
last call to `transform`. These statistics can give feedback about why any
|
||
|
data was discarded and what the input data is like.
|
||
|
|
||
|
`Pipeline` implementers should call `_set_stats` from within `transform` to
|
||
|
set the statistics that will be returned by the next call to `get_stats`.
|
||
|
"""
|
||
|
|
||
|
__metaclass__ = abc.ABCMeta
|
||
|
|
||
|
def __init__(self, input_type, output_type, name=None):
|
||
|
"""Constructs a `Pipeline` object.
|
||
|
|
||
|
Subclass constructors are expected to call this constructor.
|
||
|
|
||
|
A type signature is a Python class or primative collection containing
|
||
|
classes. Valid type signatures for `Pipeline` inputs and outputs are either
|
||
|
a Python class, or a dictionary mapping string names to classes. An object
|
||
|
matches a type signature if its type equals the type signature
|
||
|
(i.e. type('hello') == str) or, if its a collection, the types in the
|
||
|
collection match (i.e. {'hello': 'world', 'number': 1234} matches type
|
||
|
signature {'hello': str, 'number': int})
|
||
|
|
||
|
`Pipeline` instances have (preferably unique) string names. These names act
|
||
|
as name spaces for the statistics produced by them. The `get_stats` method
|
||
|
will automatically prepend `name` to all of the statistics names before
|
||
|
returning them.
|
||
|
|
||
|
Args:
|
||
|
input_type: The type signature this pipeline expects for its inputs.
|
||
|
output_type: The type signature this pipeline promises its outputs will
|
||
|
have.
|
||
|
name: The string name for this instance. This name is accessible through
|
||
|
the `name` property. Names should be unique across `Pipeline`
|
||
|
instances. If None (default), the string name of the implementing
|
||
|
subclass is used.
|
||
|
"""
|
||
|
# Make sure `input_type` and `output_type` are valid.
|
||
|
if name is None:
|
||
|
# This will get the name of the subclass, not "Pipeline".
|
||
|
self._name = type(self).__name__
|
||
|
else:
|
||
|
assert isinstance(name, basestring)
|
||
|
self._name = name
|
||
|
_assert_valid_type_signature(input_type, 'input_type')
|
||
|
_assert_valid_type_signature(output_type, 'output_type')
|
||
|
self._input_type = input_type
|
||
|
self._output_type = output_type
|
||
|
self._stats = []
|
||
|
|
||
|
def __getitem__(self, key):
|
||
|
return Key(self, key)
|
||
|
|
||
|
@property
|
||
|
def input_type(self):
|
||
|
"""What type or types does this pipeline take as input.
|
||
|
|
||
|
Returns:
|
||
|
A class, or a dictionary mapping names to classes.
|
||
|
"""
|
||
|
return self._input_type
|
||
|
|
||
|
@property
|
||
|
def output_type(self):
|
||
|
"""What type or types does this pipeline output.
|
||
|
|
||
|
Returns:
|
||
|
A class, or a dictionary mapping names to classes.
|
||
|
"""
|
||
|
return self._output_type
|
||
|
|
||
|
@property
|
||
|
def output_type_as_dict(self):
|
||
|
"""Returns a dictionary mapping names to classes.
|
||
|
|
||
|
If `output_type` is a single class, then a default name will be created
|
||
|
for the output and a dictionary containing `output_type` will be returned.
|
||
|
|
||
|
Returns:
|
||
|
Dictionary mapping names to output types.
|
||
|
"""
|
||
|
return _guarantee_dict(self._output_type, 'dataset')
|
||
|
|
||
|
@property
|
||
|
def name(self):
|
||
|
"""The string name of this pipeline."""
|
||
|
return self._name
|
||
|
|
||
|
@abc.abstractmethod
|
||
|
def transform(self, input_object):
|
||
|
"""Runs the pipeline on the given input.
|
||
|
|
||
|
Args:
|
||
|
input_object: An object or dictionary mapping names to objects.
|
||
|
The object types must match `input_type`.
|
||
|
|
||
|
Returns:
|
||
|
If `output_type` is a class, `transform` returns a list of objects
|
||
|
which are all that type. If `output_type` is a dictionary mapping
|
||
|
names to classes, `transform` returns a dictionary mapping those
|
||
|
same names to lists of objects that are the type mapped to each name.
|
||
|
"""
|
||
|
pass
|
||
|
|
||
|
def _set_stats(self, stats):
|
||
|
"""Overwrites the current statistics returned by `get_stats`.
|
||
|
|
||
|
Implementers of Pipeline should call `_set_stats` from within `transform`.
|
||
|
|
||
|
Args:
|
||
|
stats: An iterable of Statistic objects.
|
||
|
|
||
|
Raises:
|
||
|
InvalidStatisticsException: If `stats` is not iterable, or if each
|
||
|
statistic is not a `Statistic` instance.
|
||
|
"""
|
||
|
if not hasattr(stats, '__iter__'):
|
||
|
raise InvalidStatisticsException(
|
||
|
'Expecting iterable, got type %s' % type(stats))
|
||
|
self._stats = [self._prepend_name(stat) for stat in stats]
|
||
|
|
||
|
def _prepend_name(self, stat):
|
||
|
"""Returns a copy of `stat` with `self.name` prepended to `stat.name`."""
|
||
|
if not isinstance(stat, statistics.Statistic):
|
||
|
raise InvalidStatisticsException(
|
||
|
'Expecting Statistic object, got %s' % stat)
|
||
|
stat_copy = stat.copy()
|
||
|
stat_copy.name = self._name + '_' + stat_copy.name
|
||
|
return stat_copy
|
||
|
|
||
|
def get_stats(self):
|
||
|
"""Returns statistics about pipeline runs.
|
||
|
|
||
|
Call `get_stats` after each call to `transform`.
|
||
|
`transform` computes statistics which will be returned here.
|
||
|
|
||
|
Returns:
|
||
|
A list of `Statistic` objects.
|
||
|
"""
|
||
|
return list(self._stats)
|
||
|
|
||
|
|
||
|
def file_iterator(root_dir, extension=None, recurse=True):
|
||
|
"""Generator that iterates over all files in the given directory.
|
||
|
|
||
|
Will recurse into sub-directories if `recurse` is True.
|
||
|
|
||
|
Args:
|
||
|
root_dir: Path to root directory to search for files in.
|
||
|
extension: If given, only files with the given extension are opened.
|
||
|
recurse: If True, subdirectories will be traversed. Otherwise, only files
|
||
|
in `root_dir` are opened.
|
||
|
|
||
|
Yields:
|
||
|
Raw bytes (as a string) of each file opened.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: When extension is an empty string. Leave as None to omit.
|
||
|
"""
|
||
|
if extension is not None:
|
||
|
if not extension:
|
||
|
raise ValueError('File extension cannot be an empty string.')
|
||
|
extension = extension.lower()
|
||
|
if extension[0] != '.':
|
||
|
extension = '.' + extension
|
||
|
dirs = [os.path.join(root_dir, child)
|
||
|
for child in tf.gfile.ListDirectory(root_dir)]
|
||
|
while dirs:
|
||
|
sub = dirs.pop()
|
||
|
if tf.gfile.IsDirectory(sub):
|
||
|
if recurse:
|
||
|
dirs.extend(
|
||
|
[os.path.join(sub, child) for child in tf.gfile.ListDirectory(sub)])
|
||
|
else:
|
||
|
if extension is None or sub.lower().endswith(extension):
|
||
|
with open(sub, 'rb') as f:
|
||
|
yield f.read()
|
||
|
|
||
|
|
||
|
def tf_record_iterator(tfrecord_file, proto):
|
||
|
"""Generator that iterates over protocol buffers in a TFRecord file.
|
||
|
|
||
|
Args:
|
||
|
tfrecord_file: Path to a TFRecord file containing protocol buffers.
|
||
|
proto: A protocol buffer class. This type will be used to deserialize the
|
||
|
protos from the TFRecord file. This will be the output type.
|
||
|
|
||
|
Yields:
|
||
|
Instances of the given `proto` class from the TFRecord file.
|
||
|
"""
|
||
|
for raw_bytes in tf.python_io.tf_record_iterator(tfrecord_file):
|
||
|
yield proto.FromString(raw_bytes)
|
||
|
|
||
|
|
||
|
def run_pipeline_serial(pipeline,
|
||
|
input_iterator,
|
||
|
output_dir,
|
||
|
output_file_base=None):
|
||
|
"""Runs the a pipeline on a data source and writes to a directory.
|
||
|
|
||
|
Run the the pipeline on each input from the iterator one at a time.
|
||
|
A file will be written to `output_dir` for each dataset name specified
|
||
|
by the pipeline. pipeline.transform is called on each input and the
|
||
|
results are aggregated into their correct datasets.
|
||
|
|
||
|
The output type or types given by `pipeline.output_type` must be protocol
|
||
|
buffers or objects that have a SerializeToString method.
|
||
|
|
||
|
Args:
|
||
|
pipeline: A Pipeline instance. `pipeline.output_type` must be a protocol
|
||
|
buffer or a dictionary mapping names to protocol buffers.
|
||
|
input_iterator: Iterates over the input data. Items returned by it are fed
|
||
|
directly into the pipeline's `transform` method.
|
||
|
output_dir: Path to directory where datasets will be written. Each dataset
|
||
|
is a file whose name contains the pipeline's dataset name. If the
|
||
|
directory does not exist, it will be created.
|
||
|
output_file_base: An optional string prefix for all datasets output by this
|
||
|
run. The prefix will also be followed by an underscore.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If any of `pipeline`'s output types do not have a
|
||
|
SerializeToString method.
|
||
|
"""
|
||
|
if isinstance(pipeline.output_type, dict):
|
||
|
for name, type_ in pipeline.output_type.items():
|
||
|
if not hasattr(type_, 'SerializeToString'):
|
||
|
raise ValueError(
|
||
|
'Pipeline output "%s" does not have method SerializeToString. '
|
||
|
'Output type = %s' % (name, pipeline.output_type))
|
||
|
else:
|
||
|
if not hasattr(pipeline.output_type, 'SerializeToString'):
|
||
|
raise ValueError(
|
||
|
'Pipeline output type %s does not have method SerializeToString.'
|
||
|
% pipeline.output_type)
|
||
|
|
||
|
if not tf.gfile.Exists(output_dir):
|
||
|
tf.gfile.MakeDirs(output_dir)
|
||
|
|
||
|
output_names = pipeline.output_type_as_dict.keys()
|
||
|
|
||
|
if output_file_base is None:
|
||
|
output_paths = [os.path.join(output_dir, name + '.tfrecord')
|
||
|
for name in output_names]
|
||
|
else:
|
||
|
output_paths = [os.path.join(output_dir,
|
||
|
'%s_%s.tfrecord' % (output_file_base, name))
|
||
|
for name in output_names]
|
||
|
|
||
|
writers = dict([(name, tf.python_io.TFRecordWriter(path))
|
||
|
for name, path in zip(output_names, output_paths)])
|
||
|
|
||
|
total_inputs = 0
|
||
|
total_outputs = 0
|
||
|
stats = []
|
||
|
for input_ in input_iterator:
|
||
|
total_inputs += 1
|
||
|
for name, outputs in _guarantee_dict(pipeline.transform(input_),
|
||
|
output_names[0]).items():
|
||
|
for output in outputs:
|
||
|
writers[name].write(output.SerializeToString())
|
||
|
total_outputs += 1
|
||
|
stats = statistics.merge_statistics(stats + pipeline.get_stats())
|
||
|
if total_inputs % 500 == 0:
|
||
|
tf.logging.info('Processed %d inputs so far. Produced %d outputs.',
|
||
|
total_inputs, total_outputs)
|
||
|
statistics.log_statistics_list(stats, tf.logging.info)
|
||
|
tf.logging.info('\n\nCompleted.\n')
|
||
|
tf.logging.info('Processed %d inputs total. Produced %d outputs.',
|
||
|
total_inputs, total_outputs)
|
||
|
statistics.log_statistics_list(stats, tf.logging.info)
|
||
|
|
||
|
|
||
|
def load_pipeline(pipeline, input_iterator):
|
||
|
"""Runs a pipeline saving the output into memory.
|
||
|
|
||
|
Use this instead of `run_pipeline_serial` to build a dataset on the fly
|
||
|
without saving it to disk.
|
||
|
|
||
|
Args:
|
||
|
pipeline: A Pipeline instance.
|
||
|
input_iterator: Iterates over the input data. Items returned by it are fed
|
||
|
directly into the pipeline's `transform` method.
|
||
|
|
||
|
Returns:
|
||
|
The aggregated return values of pipeline.transform. Specifically a
|
||
|
dictionary mapping dataset names to lists of objects. Each name acts
|
||
|
as a bucket where outputs are aggregated.
|
||
|
"""
|
||
|
aggregated_outputs = dict(
|
||
|
[(name, []) for name in pipeline.output_type_as_dict])
|
||
|
total_inputs = 0
|
||
|
total_outputs = 0
|
||
|
stats = []
|
||
|
for input_object in input_iterator:
|
||
|
total_inputs += 1
|
||
|
outputs = _guarantee_dict(pipeline.transform(input_object),
|
||
|
aggregated_outputs.keys()[0])
|
||
|
for name, output_list in outputs.items():
|
||
|
aggregated_outputs[name].extend(output_list)
|
||
|
total_outputs += len(output_list)
|
||
|
stats = statistics.merge_statistics(stats + pipeline.get_stats())
|
||
|
if total_inputs % 500 == 0:
|
||
|
tf.logging.info('Processed %d inputs so far. Produced %d outputs.',
|
||
|
total_inputs, total_outputs)
|
||
|
statistics.log_statistics_list(stats, tf.logging.info)
|
||
|
tf.logging.info('\n\nCompleted.\n')
|
||
|
tf.logging.info('Processed %d inputs total. Produced %d outputs.',
|
||
|
total_inputs, total_outputs)
|
||
|
statistics.log_statistics_list(stats, tf.logging.info)
|
||
|
return aggregated_outputs
|