# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Abstract class for sequence generators.

Provides a uniform interface for interacting with generators for any model.
"""

import abc
import os
import tempfile

# internal imports

import tensorflow as tf

from magenta.protobuf import generator_pb2


class SequenceGeneratorException(Exception):
  """Generic exception for sequence generation errors."""
  pass


class BaseSequenceGenerator(object):
  """Abstract class for generators."""

  __metaclass__ = abc.ABCMeta

  def __init__(self, details, checkpoint, bundle):
    """Constructs a BaseSequenceGenerator.

    Args:
      details: A generator_pb2.GeneratorDetails for this generator.
      checkpoint: Where to look for the most recent model checkpoint. Either a
          directory to be used with tf.train.latest_checkpoint or the path to a
          single checkpoint file. Or None if a bundle should be used.
      bundle: A generator_pb2.GeneratorBundle object that contains both a
          checkpoint and a metagraph. Or None if a checkpoint should be used.

    Raises:
      SequenceGeneratorException: if neither checkpoint nor bundle is set.
    """
    self._details = details
    self._checkpoint = checkpoint
    self._bundle = bundle

    if self._checkpoint is None and self._bundle is None:
      raise SequenceGeneratorException(
          'Either checkpoint or bundle must be set')
    if self._checkpoint is not None and self._bundle is not None:
      raise SequenceGeneratorException(
          'Checkpoint and bundle cannot both be set')

    if self._bundle:
      if self._bundle.generator_details.id != self._details.id:
        raise SequenceGeneratorException(
            'Generator id in bundle (%s) does not match this generator\'s id '
            '(%s)' % (self._bundle.generator_details.id, self._details.id))

    self._initialized = False

  @property
  def details(self):
    """Returns a GeneratorDetails description of this generator."""
    return self._details

  @property
  def bundle_details(self):
    """Returns the BundleDetails or None if checkpoint was used."""
    if self._bundle is None:
      return None
    return self._bundle.bundle_details

  @abc.abstractmethod
  def _initialize_with_checkpoint(self, checkpoint_file):
    """Implementation for building the TF graph given a checkpoint file.

    Args:
      checkpoint_file: The path to the checkpoint file that should be used.
    """
    pass

  @abc.abstractmethod
  def _initialize_with_checkpoint_and_metagraph(self, checkpoint_file,
                                                metagraph_file):
    """Implementation for building the TF graph with a checkpoint and metagraph.

    The implementation should not expect the checkpoint_file and metagraph_file
    to be available after the method returns.

    Args:
      checkpoint_file: The path to the checkpoint file that should be used.
      metagraph_file: The path to the metagraph file that should be used.
    """
    pass

  @abc.abstractmethod
  def _close(self):
    """Implementation for closing the TF session."""
    pass

  @abc.abstractmethod
  def _generate(self, generate_sequence_request):
    """Implementation for sequence generation based on request.

    The implementation can assume that _initialize has been called before this
    method is called.

    Args:
      generate_sequence_request: The request for generating a sequence

    Returns:
      A GenerateSequenceResponse proto.
    """
    pass

  @abc.abstractmethod
  def _write_checkpoint_with_metagraph(self, checkpoint_filename):
    """Implementation for writing the checkpoint and metagraph.

    Saver should be initialized with sharded=False, and save should be called
    with: meta_graph_suffix='meta', write_meta_graph=True.

    Args:
      checkpoint_filename: Path to the checkpoint file. Should be passed as the
          save_path argument to Saver.save.
    """
    pass

  def initialize(self):
    """Builds the TF graph and loads the checkpoint.

    If the graph has already been initialized, this is a no-op.

    Raises:
      SequenceGeneratorException: If the checkpoint cannot be found.
    """
    if self._initialized:
      return

    # Either self._checkpoint or self._bundle should be set.
    # This is enforced by the constructor.
    if self._checkpoint is not None:
      if not tf.gfile.Exists(self._checkpoint):
        raise SequenceGeneratorException(
            'Checkpoint path does not exist: %s' % (self._checkpoint))
      checkpoint_file = self._checkpoint
      # If this is a directory, try to determine the latest checkpoint in it.
      if tf.gfile.IsDirectory(checkpoint_file):
        checkpoint_file = tf.train.latest_checkpoint(checkpoint_file)
      if checkpoint_file is None:
        raise SequenceGeneratorException(
            'No checkpoint file found in directory: %s' % self._checkpoint)
      if (not tf.gfile.Exists(checkpoint_file) or
          tf.gfile.IsDirectory(checkpoint_file)):
        raise SequenceGeneratorException(
            'Checkpoint path is not a file: %s (supplied path: %s)' % (
                checkpoint_file, self._checkpoint))
      self._initialize_with_checkpoint(checkpoint_file)
    else:
      # Write checkpoint and metagraph files to a temp dir.
      tempdir = None
      try:
        tempdir = tempfile.mkdtemp()
        checkpoint_filename = os.path.join(tempdir, 'model.ckpt')
        with tf.gfile.Open(checkpoint_filename, 'wb') as f:
          # For now, we support only 1 checkpoint file.
          # If needed, we can later change this to support sharded checkpoints.
          f.write(self._bundle.checkpoint_file[0])
        metagraph_filename = os.path.join(tempdir, 'model.ckpt.meta')
        with tf.gfile.Open(metagraph_filename, 'wb') as f:
          f.write(self._bundle.metagraph_file)

        self._initialize_with_checkpoint_and_metagraph(
            checkpoint_filename, metagraph_filename)
      finally:
        # Clean up the temp dir.
        if tempdir is not None:
          tf.gfile.DeleteRecursively(tempdir)
    self._initialized = True

  def close(self):
    """Closes the TF session.

    If the session was already closed, this is a no-op.
    """
    if self._initialized:
      self._close()
      self._initialized = False

  def __enter__(self):
    """When used as a context manager, initializes the TF session."""
    self.initialize()
    return self

  def __exit__(self, *args):
    """When used as a context manager, closes the TF session."""
    self.close()

  def generate(self, generate_sequence_request):
    """Generates a sequence from the model based on the request.

    Also initializes the TF graph if not yet initialized.

    Args:
      generate_sequence_request: The request for generating a sequence

    Returns:
      A GenerateSequenceResponse proto.
    """
    self.initialize()
    return self._generate(generate_sequence_request)

  def create_bundle_file(self, bundle_file, description=None):
    """Writes a generator_pb2.GeneratorBundle file in the specified location.

    Saves the checkpoint, metagraph, and generator id in one file.

    Args:
      bundle_file: Location to write the bundle file.
      description: A short, human-readable text description of the bundle (e.g.,
          training data, hyper parameters, etc.).

    Raises:
      SequenceGeneratorException: if there is an error creating the bundle file.
    """
    if not bundle_file:
      raise SequenceGeneratorException('Bundle file location not specified.')

    self.initialize()

    tempdir = None
    try:
      tempdir = tempfile.mkdtemp()
      checkpoint_filename = os.path.join(tempdir, 'model.ckpt')

      self._write_checkpoint_with_metagraph(checkpoint_filename)

      if not os.path.isfile(checkpoint_filename):
        raise SequenceGeneratorException(
            'Could not read checkpoint file: %s' % (checkpoint_filename))
      metagraph_filename = checkpoint_filename + '.meta'
      if not os.path.isfile(metagraph_filename):
        raise SequenceGeneratorException(
            'Could not read metagraph file: %s' % (metagraph_filename))

      bundle = generator_pb2.GeneratorBundle()
      bundle.generator_details.CopyFrom(self.details)
      if description is not None:
        bundle.bundle_details.description = description
      with tf.gfile.Open(checkpoint_filename, 'rb') as f:
        bundle.checkpoint_file.append(f.read())
      with tf.gfile.Open(metagraph_filename, 'rb') as f:
        bundle.metagraph_file = f.read()

      with tf.gfile.Open(bundle_file, 'wb') as f:
        f.write(bundle.SerializeToString())
    finally:
      if tempdir is not None:
        tf.gfile.DeleteRecursively(tempdir)