188 lines
7.7 KiB
Python
188 lines
7.7 KiB
Python
|
# Copyright 2016 Google Inc. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
"""Shared Melody RNN generation code as a SequenceGenerator interface."""
|
||
|
|
||
|
import random
|
||
|
|
||
|
# internal imports
|
||
|
from six.moves import range # pylint: disable=redefined-builtin
|
||
|
import tensorflow as tf
|
||
|
|
||
|
from magenta.music import constants
|
||
|
from magenta.music import melodies_lib
|
||
|
from magenta.music import sequence_generator
|
||
|
from magenta.music import sequences_lib
|
||
|
from magenta.protobuf import generator_pb2
|
||
|
|
||
|
|
||
|
class MelodyRnnSequenceGenerator(sequence_generator.BaseSequenceGenerator):
|
||
|
"""Shared Melody RNN generation code as a SequenceGenerator interface."""
|
||
|
|
||
|
def __init__(self, details, checkpoint, bundle, melody_encoder_decoder,
|
||
|
build_graph, steps_per_quarter, hparams):
|
||
|
"""Creates a MelodyRnnSequenceGenerator.
|
||
|
|
||
|
Args:
|
||
|
details: A generator_pb2.GeneratorDetails for this generator.
|
||
|
checkpoint: Where to search for the most recent model checkpoint.
|
||
|
bundle: A generator_pb2.GeneratorBundle object that includes both the
|
||
|
model checkpoint and metagraph.
|
||
|
melody_encoder_decoder: A melodies_lib.MelodyEncoderDecoder object
|
||
|
specific to your model.
|
||
|
build_graph: A function that when called, returns the tf.Graph object for
|
||
|
your model. The function will be passed the parameters:
|
||
|
(mode, hparams_string, input_size, num_classes, sequence_example_file)
|
||
|
For an example usage, see models/basic_rnn/basic_rnn_graph.py.
|
||
|
steps_per_quarter: What precision to use when quantizing the melody. How
|
||
|
many steps per quarter note.
|
||
|
hparams: A dict of hparams.
|
||
|
"""
|
||
|
super(MelodyRnnSequenceGenerator, self).__init__(
|
||
|
details, checkpoint, bundle)
|
||
|
self._melody_encoder_decoder = melody_encoder_decoder
|
||
|
self._build_graph = build_graph
|
||
|
self._session = None
|
||
|
self._steps_per_quarter = steps_per_quarter
|
||
|
|
||
|
# Start with some defaults
|
||
|
self._hparams = {
|
||
|
'temperature': 1.0,
|
||
|
}
|
||
|
# Update with whatever was supplied.
|
||
|
self._hparams.update(hparams)
|
||
|
self._hparams['dropout_keep_prob'] = 1.0
|
||
|
self._hparams['batch_size'] = 1
|
||
|
|
||
|
def _initialize_with_checkpoint(self, checkpoint_file):
|
||
|
graph = self._build_graph('generate',
|
||
|
repr(self._hparams),
|
||
|
self._melody_encoder_decoder)
|
||
|
with graph.as_default():
|
||
|
saver = tf.train.Saver()
|
||
|
self._session = tf.Session()
|
||
|
tf.logging.info('Checkpoint used: %s', checkpoint_file)
|
||
|
saver.restore(self._session, checkpoint_file)
|
||
|
|
||
|
def _initialize_with_checkpoint_and_metagraph(self, checkpoint_filename,
|
||
|
metagraph_filename):
|
||
|
self._session = tf.Session()
|
||
|
new_saver = tf.train.import_meta_graph(metagraph_filename)
|
||
|
new_saver.restore(self._session, checkpoint_filename)
|
||
|
|
||
|
def _write_checkpoint_with_metagraph(self, checkpoint_filename):
|
||
|
with self._session.graph.as_default():
|
||
|
saver = tf.train.Saver(sharded=False)
|
||
|
saver.save(self._session, checkpoint_filename, meta_graph_suffix='meta',
|
||
|
write_meta_graph=True)
|
||
|
|
||
|
def _close(self):
|
||
|
self._session.close()
|
||
|
self._session = None
|
||
|
|
||
|
def _seconds_to_steps(self, seconds, qpm):
|
||
|
"""Converts seconds to steps.
|
||
|
|
||
|
Uses the generator's steps_per_quarter setting and the specified qpm.
|
||
|
|
||
|
Args:
|
||
|
seconds: number of seconds.
|
||
|
qpm: current qpm.
|
||
|
|
||
|
Returns:
|
||
|
Number of steps the seconds represent.
|
||
|
"""
|
||
|
|
||
|
return int(seconds * (qpm / 60.0) * self._steps_per_quarter)
|
||
|
|
||
|
def _generate(self, generate_sequence_request):
|
||
|
if len(generate_sequence_request.generator_options.generate_sections) != 1:
|
||
|
raise sequence_generator.SequenceGeneratorException(
|
||
|
'This model supports only 1 generate_sections message, but got %s' %
|
||
|
(len(generate_sequence_request.generator_options.generate_sections)))
|
||
|
|
||
|
generate_section = (
|
||
|
generate_sequence_request.generator_options.generate_sections[0])
|
||
|
primer_sequence = generate_sequence_request.input_sequence
|
||
|
|
||
|
notes_by_end_time = sorted(primer_sequence.notes, key=lambda n: n.end_time)
|
||
|
last_end_time = notes_by_end_time[-1].end_time if notes_by_end_time else 0
|
||
|
if last_end_time > generate_section.start_time_seconds:
|
||
|
raise sequence_generator.SequenceGeneratorException(
|
||
|
'Got GenerateSection request for section that is before the end of '
|
||
|
'the NoteSequence. This model can only extend sequences. '
|
||
|
'Requested start time: %s, Final note end time: %s' %
|
||
|
(generate_section.start_time_seconds, notes_by_end_time[-1].end_time))
|
||
|
|
||
|
# Quantize the priming sequence.
|
||
|
quantized_sequence = sequences_lib.QuantizedSequence()
|
||
|
quantized_sequence.from_note_sequence(
|
||
|
primer_sequence, self._steps_per_quarter)
|
||
|
# Setting gap_bars to infinite ensures that the entire input will be used.
|
||
|
extracted_melodies, _ = melodies_lib.extract_melodies(
|
||
|
quantized_sequence, min_bars=0, min_unique_pitches=1,
|
||
|
gap_bars=float('inf'), ignore_polyphonic_notes=True)
|
||
|
assert len(extracted_melodies) <= 1
|
||
|
|
||
|
qpm = (primer_sequence.tempos[0].qpm if primer_sequence
|
||
|
and primer_sequence.tempos
|
||
|
else constants.DEFAULT_QUARTERS_PER_MINUTE)
|
||
|
start_step = self._seconds_to_steps(
|
||
|
generate_section.start_time_seconds, qpm)
|
||
|
end_step = self._seconds_to_steps(generate_section.end_time_seconds, qpm)
|
||
|
|
||
|
if extracted_melodies and extracted_melodies[0]:
|
||
|
melody = extracted_melodies[0]
|
||
|
else:
|
||
|
tf.logging.warn('No melodies were extracted from the priming sequence. '
|
||
|
'Melodies will be generated from scratch.')
|
||
|
melody = melodies_lib.MonophonicMelody()
|
||
|
melody.from_event_list([
|
||
|
random.randint(self._melody_encoder_decoder.min_note,
|
||
|
self._melody_encoder_decoder.max_note)])
|
||
|
start_step += 1
|
||
|
|
||
|
transpose_amount = melody.squash(
|
||
|
self._melody_encoder_decoder.min_note,
|
||
|
self._melody_encoder_decoder.max_note,
|
||
|
self._melody_encoder_decoder.transpose_to_key)
|
||
|
|
||
|
# Ensure that the melody extends up to the step we want to start generating.
|
||
|
melody.set_length(start_step)
|
||
|
|
||
|
inputs = self._session.graph.get_collection('inputs')[0]
|
||
|
initial_state = self._session.graph.get_collection('initial_state')[0]
|
||
|
final_state = self._session.graph.get_collection('final_state')[0]
|
||
|
softmax = self._session.graph.get_collection('softmax')[0]
|
||
|
|
||
|
final_state_ = None
|
||
|
for i in range(end_step - len(melody)):
|
||
|
if i == 0:
|
||
|
inputs_ = self._melody_encoder_decoder.get_inputs_batch(
|
||
|
[melody], full_length=True)
|
||
|
initial_state_ = self._session.run(initial_state)
|
||
|
else:
|
||
|
inputs_ = self._melody_encoder_decoder.get_inputs_batch([melody])
|
||
|
initial_state_ = final_state_
|
||
|
|
||
|
feed_dict = {inputs: inputs_, initial_state: initial_state_}
|
||
|
final_state_, softmax_ = self._session.run(
|
||
|
[final_state, softmax], feed_dict)
|
||
|
self._melody_encoder_decoder.extend_event_sequences([melody], softmax_)
|
||
|
|
||
|
melody.transpose(-transpose_amount)
|
||
|
|
||
|
generate_response = generator_pb2.GenerateSequenceResponse()
|
||
|
generate_response.generated_sequence.CopyFrom(melody.to_sequence(qpm=qpm))
|
||
|
return generate_response
|