aiexperiments-ai-duet/server/magenta/models/shared/melody_rnn_sequence_generator.py
Yotam Mann ff837cec16 server
2016-11-11 13:53:51 -05:00

188 lines
7.7 KiB
Python

# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared Melody RNN generation code as a SequenceGenerator interface."""
import random
# internal imports
from six.moves import range # pylint: disable=redefined-builtin
import tensorflow as tf
from magenta.music import constants
from magenta.music import melodies_lib
from magenta.music import sequence_generator
from magenta.music import sequences_lib
from magenta.protobuf import generator_pb2
class MelodyRnnSequenceGenerator(sequence_generator.BaseSequenceGenerator):
"""Shared Melody RNN generation code as a SequenceGenerator interface."""
def __init__(self, details, checkpoint, bundle, melody_encoder_decoder,
build_graph, steps_per_quarter, hparams):
"""Creates a MelodyRnnSequenceGenerator.
Args:
details: A generator_pb2.GeneratorDetails for this generator.
checkpoint: Where to search for the most recent model checkpoint.
bundle: A generator_pb2.GeneratorBundle object that includes both the
model checkpoint and metagraph.
melody_encoder_decoder: A melodies_lib.MelodyEncoderDecoder object
specific to your model.
build_graph: A function that when called, returns the tf.Graph object for
your model. The function will be passed the parameters:
(mode, hparams_string, input_size, num_classes, sequence_example_file)
For an example usage, see models/basic_rnn/basic_rnn_graph.py.
steps_per_quarter: What precision to use when quantizing the melody. How
many steps per quarter note.
hparams: A dict of hparams.
"""
super(MelodyRnnSequenceGenerator, self).__init__(
details, checkpoint, bundle)
self._melody_encoder_decoder = melody_encoder_decoder
self._build_graph = build_graph
self._session = None
self._steps_per_quarter = steps_per_quarter
# Start with some defaults
self._hparams = {
'temperature': 1.0,
}
# Update with whatever was supplied.
self._hparams.update(hparams)
self._hparams['dropout_keep_prob'] = 1.0
self._hparams['batch_size'] = 1
def _initialize_with_checkpoint(self, checkpoint_file):
graph = self._build_graph('generate',
repr(self._hparams),
self._melody_encoder_decoder)
with graph.as_default():
saver = tf.train.Saver()
self._session = tf.Session()
tf.logging.info('Checkpoint used: %s', checkpoint_file)
saver.restore(self._session, checkpoint_file)
def _initialize_with_checkpoint_and_metagraph(self, checkpoint_filename,
metagraph_filename):
self._session = tf.Session()
new_saver = tf.train.import_meta_graph(metagraph_filename)
new_saver.restore(self._session, checkpoint_filename)
def _write_checkpoint_with_metagraph(self, checkpoint_filename):
with self._session.graph.as_default():
saver = tf.train.Saver(sharded=False)
saver.save(self._session, checkpoint_filename, meta_graph_suffix='meta',
write_meta_graph=True)
def _close(self):
self._session.close()
self._session = None
def _seconds_to_steps(self, seconds, qpm):
"""Converts seconds to steps.
Uses the generator's steps_per_quarter setting and the specified qpm.
Args:
seconds: number of seconds.
qpm: current qpm.
Returns:
Number of steps the seconds represent.
"""
return int(seconds * (qpm / 60.0) * self._steps_per_quarter)
def _generate(self, generate_sequence_request):
if len(generate_sequence_request.generator_options.generate_sections) != 1:
raise sequence_generator.SequenceGeneratorException(
'This model supports only 1 generate_sections message, but got %s' %
(len(generate_sequence_request.generator_options.generate_sections)))
generate_section = (
generate_sequence_request.generator_options.generate_sections[0])
primer_sequence = generate_sequence_request.input_sequence
notes_by_end_time = sorted(primer_sequence.notes, key=lambda n: n.end_time)
last_end_time = notes_by_end_time[-1].end_time if notes_by_end_time else 0
if last_end_time > generate_section.start_time_seconds:
raise sequence_generator.SequenceGeneratorException(
'Got GenerateSection request for section that is before the end of '
'the NoteSequence. This model can only extend sequences. '
'Requested start time: %s, Final note end time: %s' %
(generate_section.start_time_seconds, notes_by_end_time[-1].end_time))
# Quantize the priming sequence.
quantized_sequence = sequences_lib.QuantizedSequence()
quantized_sequence.from_note_sequence(
primer_sequence, self._steps_per_quarter)
# Setting gap_bars to infinite ensures that the entire input will be used.
extracted_melodies, _ = melodies_lib.extract_melodies(
quantized_sequence, min_bars=0, min_unique_pitches=1,
gap_bars=float('inf'), ignore_polyphonic_notes=True)
assert len(extracted_melodies) <= 1
qpm = (primer_sequence.tempos[0].qpm if primer_sequence
and primer_sequence.tempos
else constants.DEFAULT_QUARTERS_PER_MINUTE)
start_step = self._seconds_to_steps(
generate_section.start_time_seconds, qpm)
end_step = self._seconds_to_steps(generate_section.end_time_seconds, qpm)
if extracted_melodies and extracted_melodies[0]:
melody = extracted_melodies[0]
else:
tf.logging.warn('No melodies were extracted from the priming sequence. '
'Melodies will be generated from scratch.')
melody = melodies_lib.MonophonicMelody()
melody.from_event_list([
random.randint(self._melody_encoder_decoder.min_note,
self._melody_encoder_decoder.max_note)])
start_step += 1
transpose_amount = melody.squash(
self._melody_encoder_decoder.min_note,
self._melody_encoder_decoder.max_note,
self._melody_encoder_decoder.transpose_to_key)
# Ensure that the melody extends up to the step we want to start generating.
melody.set_length(start_step)
inputs = self._session.graph.get_collection('inputs')[0]
initial_state = self._session.graph.get_collection('initial_state')[0]
final_state = self._session.graph.get_collection('final_state')[0]
softmax = self._session.graph.get_collection('softmax')[0]
final_state_ = None
for i in range(end_step - len(melody)):
if i == 0:
inputs_ = self._melody_encoder_decoder.get_inputs_batch(
[melody], full_length=True)
initial_state_ = self._session.run(initial_state)
else:
inputs_ = self._melody_encoder_decoder.get_inputs_batch([melody])
initial_state_ = final_state_
feed_dict = {inputs: inputs_, initial_state: initial_state_}
final_state_, softmax_ = self._session.run(
[final_state, softmax], feed_dict)
self._melody_encoder_decoder.extend_event_sequences([melody], softmax_)
melody.transpose(-transpose_amount)
generate_response = generator_pb2.GenerateSequenceResponse()
generate_response.generated_sequence.CopyFrom(melody.to_sequence(qpm=qpm))
return generate_response