aiexperiments-ai-duet/server/third_party/magenta/models/shared/melody_rnn_generate.py

260 lines
9.5 KiB
Python
Raw Normal View History

2016-11-11 18:53:51 +00:00
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generate melodies from a trained checkpoint of a melody RNN model.
Uses flags to define operation.
"""
import ast
import os
import time
# internal imports
from six.moves import range # pylint: disable=redefined-builtin
import tensorflow as tf
from magenta.music import constants
from magenta.music import melodies_lib
from magenta.music import midi_io
from magenta.music import sequence_generator
from magenta.music import sequence_generator_bundle
from magenta.protobuf import generator_pb2
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string(
'run_dir', None,
'Path to the directory where the latest checkpoint will be loaded from.')
tf.app.flags.DEFINE_string(
'checkpoint_file', None,
'Path to the checkpoint file. run_dir will take priority over this flag.')
tf.app.flags.DEFINE_string(
'bundle_file', None,
'Path to the bundle file. If specified, this will take priority over '
'run_dir and checkpoint_file, unless save_generator_bundle is True, in '
'which case both this flag and either run_dir or checkpoint_file are '
'required')
tf.app.flags.DEFINE_boolean(
'save_generator_bundle', False,
'If true, instead of generating a sequence, will save this generator as a '
'bundle file in the location specified by the bundle_file flag')
tf.app.flags.DEFINE_string(
'hparams', '{}',
'String representation of a Python dictionary containing hyperparameter '
'to value mapping. This mapping is merged with the default '
'hyperparameters.')
tf.app.flags.DEFINE_string(
'output_dir', '/tmp/melody_rnn/generated',
'The directory where MIDI files will be saved to.')
tf.app.flags.DEFINE_integer(
'num_outputs', 10,
'The number of melodies to generate. One MIDI file will be created for '
'each.')
tf.app.flags.DEFINE_integer(
'num_steps', 128,
'The total number of steps the generated melodies should be, priming '
'melody length + generated steps. Each step is a 16th of a bar.')
tf.app.flags.DEFINE_string(
'primer_melody', '',
'A string representation of a Python list of melodies_lib.MonophonicMelody '
'event values. For example: "[60, -2, 60, -2, 67, -2, 67, -2]". If '
'specified, this melody will be used as the priming melody. If a priming '
'melody is not specified, melodies will be generated from scratch.')
tf.app.flags.DEFINE_string(
'primer_midi', '',
'The path to a MIDI file containing a melody that will be used as a '
'priming melody. If a primer melody is not specified, melodies will be '
'generated from scratch.')
tf.app.flags.DEFINE_float(
'qpm', None,
'The quarters per minute to play generated output at. If a primer MIDI is '
'given, the qpm from that will override this flag. If qpm is None, qpm '
'will default to 120.')
tf.app.flags.DEFINE_float(
'temperature', 1.0,
'The randomness of the generated melodies. 1.0 uses the unaltered softmax '
'probabilities, greater than 1.0 makes melodies more random, less than '
'1.0 makes melodies less random.')
tf.app.flags.DEFINE_integer(
'steps_per_quarter', 4, 'What precision to use when quantizing the melody.')
tf.app.flags.DEFINE_string(
'log', 'INFO',
'The threshold for what messages will be logged DEBUG, INFO, WARN, ERROR, '
'or FATAL.')
def get_hparams():
"""Get the hparams dictionary to be used by the model."""
hparams = ast.literal_eval(FLAGS.hparams if FLAGS.hparams else '{}')
hparams['temperature'] = FLAGS.temperature
return hparams
def get_checkpoint():
"""Get the training dir or checkpoint path to be used by the model."""
if ((FLAGS.run_dir or FLAGS.checkpoint_file) and
FLAGS.bundle_file and not should_save_generator_bundle()):
raise sequence_generator.SequenceGeneratorException(
'Cannot specify both bundle_file and run_dir or checkpoint_file')
if FLAGS.run_dir:
train_dir = os.path.join(os.path.expanduser(FLAGS.run_dir), 'train')
return train_dir
elif FLAGS.checkpoint_file:
return os.path.expanduser(FLAGS.checkpoint_file)
else:
return None
def get_bundle_file():
"""Get the path to the bundle file with both a checkpoint and metagraph."""
if FLAGS.bundle_file is None:
return None
else:
return os.path.expanduser(FLAGS.bundle_file)
def get_bundle():
"""Returns a generator_pb2.GeneratorBundle object based read from bundle_file.
Returns:
Either a generator_pb2.GeneratorBundle or None if the bundle_file flag is
not set or the save_generator_bundle flag is set.
"""
if should_save_generator_bundle():
return None
bundle_file = get_bundle_file()
if bundle_file is None:
return None
return sequence_generator_bundle.read_bundle_file(bundle_file)
def should_save_generator_bundle():
"""Returns whether the generator should save a bundle.
If true, the generator should save its checkpoint and metagraph into a bundle
file, specified by get_bundle_file, instead of generating a sequence.
Returns:
Whether the generator should save a bundle.
"""
return FLAGS.save_generator_bundle
def get_steps_per_quarter():
"""Get the number of steps per quarter note."""
return FLAGS.steps_per_quarter
def _steps_to_seconds(steps, qpm):
"""Converts steps to seconds.
Uses the current flag value for steps_per_quarter.
Args:
steps: number of steps.
qpm: current qpm.
Returns:
Number of seconds the steps represent.
"""
return steps * 60.0 / qpm / get_steps_per_quarter()
def setup_logs():
"""Sets log level to the one specified in the flags."""
tf.logging.set_verbosity(FLAGS.log)
def run_with_flags(melody_rnn_sequence_generator):
"""Generates melodies and saves them as MIDI files.
Uses the options specified by the flags defined in this module. Intended to be
called from the main function of one of the melody generator modules.
Args:
melody_rnn_sequence_generator: A MelodyRnnSequenceGenerator object specific
to your model.
"""
if not FLAGS.output_dir:
tf.logging.fatal('--output_dir required')
return
FLAGS.output_dir = os.path.expanduser(FLAGS.output_dir)
if FLAGS.primer_midi:
FLAGS.primer_midi = os.path.expanduser(FLAGS.primer_midi)
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
primer_sequence = None
qpm = FLAGS.qpm if FLAGS.qpm else constants.DEFAULT_QUARTERS_PER_MINUTE
if FLAGS.primer_melody:
primer_melody = melodies_lib.MonophonicMelody()
primer_melody.from_event_list(ast.literal_eval(FLAGS.primer_melody))
primer_sequence = primer_melody.to_sequence(qpm=qpm)
elif FLAGS.primer_midi:
primer_sequence = midi_io.midi_file_to_sequence_proto(FLAGS.primer_midi)
if primer_sequence.tempos and primer_sequence.tempos[0].qpm:
qpm = primer_sequence.tempos[0].qpm
# Derive the total number of seconds to generate based on the QPM of the
# priming sequence and the num_steps flag.
total_seconds = _steps_to_seconds(FLAGS.num_steps, qpm)
# Specify start/stop time for generation based on starting generation at the
# end of the priming sequence and continuing until the sequence is num_steps
# long.
generate_request = generator_pb2.GenerateSequenceRequest()
if primer_sequence:
generate_request.input_sequence.CopyFrom(primer_sequence)
generate_section = (
generate_request.generator_options.generate_sections.add())
# Set the start time to begin on the next step after the last note ends.
notes_by_end_time = sorted(primer_sequence.notes, key=lambda n: n.end_time)
last_end_time = notes_by_end_time[-1].end_time if notes_by_end_time else 0
generate_section.start_time_seconds = last_end_time + _steps_to_seconds(
1, qpm)
generate_section.end_time_seconds = total_seconds
if generate_section.start_time_seconds >= generate_section.end_time_seconds:
tf.logging.fatal(
'Priming sequence is longer than the total number of steps '
'requested: Priming sequence length: %s, Generation length '
'requested: %s',
generate_section.start_time_seconds, total_seconds)
return
else:
generate_section = (
generate_request.generator_options.generate_sections.add())
generate_section.start_time_seconds = 0
generate_section.end_time_seconds = total_seconds
generate_request.input_sequence.tempos.add().qpm = qpm
tf.logging.debug('generate_request: %s', generate_request)
# Make the generate request num_outputs times and save the output as midi
# files.
date_and_time = time.strftime('%Y-%m-%d_%H%M%S')
digits = len(str(FLAGS.num_outputs))
for i in range(FLAGS.num_outputs):
generate_response = melody_rnn_sequence_generator.generate(
generate_request)
midi_filename = '%s_%s.mid' % (date_and_time, str(i + 1).zfill(digits))
midi_path = os.path.join(FLAGS.output_dir, midi_filename)
midi_io.sequence_proto_to_midi_file(
generate_response.generated_sequence, midi_path)
tf.logging.info('Wrote %d MIDI files to %s',
FLAGS.num_outputs, FLAGS.output_dir)