# Copyright 2016 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A MIDI interface to the sequence generators. Captures monophonic input MIDI sequences and plays back responses from the sequence generator. """ import ast import functools from sys import stdout import threading import time # internal imports import mido import tensorflow as tf from magenta.models.attention_rnn import attention_rnn_generator from magenta.models.basic_rnn import basic_rnn_generator from magenta.models.lookback_rnn import lookback_rnn_generator from magenta.music import sequence_generator_bundle from magenta.protobuf import generator_pb2 from magenta.protobuf import music_pb2 FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_bool( 'list', False, 'Only list available MIDI ports.') tf.app.flags.DEFINE_string( 'input_port', None, 'The name of the input MIDI port.') tf.app.flags.DEFINE_string( 'output_port', None, 'The name of the output MIDI port.') tf.app.flags.DEFINE_integer( 'start_capture_control_number', 1, 'The control change number to use as a signal to start ' 'capturing. Defaults to modulation wheel.') tf.app.flags.DEFINE_integer( 'start_capture_control_value', 127, 'The control change value to use as a signal to start ' 'capturing. If None, any control change with ' 'start_capture_control_number will start capture.') tf.app.flags.DEFINE_integer( 'stop_capture_control_number', 1, 'The control change number to use as a signal to stop ' 'capturing and generate. Defaults to the modulation ' 'wheel.') tf.app.flags.DEFINE_integer( 'stop_capture_control_value', 0, 'The control change value to use as a signal to stop ' 'capturing and generate. If None, any control change with' 'stop_capture_control_number will stop capture.') # TODO(adarob): Make the qpm adjustable by a control change signal. tf.app.flags.DEFINE_integer( 'qpm', 90, 'The quarters per minute to use for the metronome and generated sequence.') # TODO(adarob): Make the number of bars to generate adjustable. tf.app.flags.DEFINE_integer( 'num_bars_to_generate', 5, 'The number of bars to generate each time.') tf.app.flags.DEFINE_integer( 'metronome_channel', 0, 'The MIDI channel on which to send the metronome click.') tf.app.flags.DEFINE_integer( 'metronome_playback_velocity', 0, 'The velocity of the generated playback metronome ' 'expressed as an integer between 0 and 127.') tf.app.flags.DEFINE_string( 'bundle_file', None, 'The location of the bundle file to use. If specified, generator_name, ' 'checkpoint, and hparams cannot be specified.') tf.app.flags.DEFINE_string( 'generator_name', None, 'The name of the SequenceGenerator being used.') tf.app.flags.DEFINE_string( 'checkpoint', None, 'The training directory with checkpoint files or the path to a single ' 'checkpoint file for the model being used.') tf.app.flags.DEFINE_string( 'hparams', '{}', 'String representation of a Python dictionary containing hyperparameter to ' 'value mappings. This mapping is merged with the default hyperparameters.') # A map from a string generator name to its factory class. _GENERATOR_FACTORY_MAP = { 'attention_rnn': attention_rnn_generator, 'basic_rnn': basic_rnn_generator, 'lookback_rnn': lookback_rnn_generator, } _METRONOME_TICK_DURATION = 0.05 _METRONOME_PITCH = 95 # TODO(hanzorama): Make velocity adjustable by a control change signal. _METRONOME_CAPTURE_VELOCITY = 64 def serialized(func): """Decorator to provide mutual exclusion for method using _lock attribute.""" @functools.wraps(func) def serialized_method(self, *args, **kwargs): lock = getattr(self, '_lock') with lock: return func(self, *args, **kwargs) return serialized_method def stdout_write_and_flush(s): stdout.write(s) stdout.flush() class GeneratorException(Exception): """An exception raised by the Generator class.""" pass class Generator(object): """A class wrapping a SequenceGenerator. Args: generator_name: The name of the generator to wrap. Must be present in _GENERATOR_FACTORY_MAP. num_bars_to_generate: The number of bars to generate on each call. Assumes 4/4 time. hparams: A Python dictionary containing hyperparameter to value mappings to be merged with the default hyperparameters. checkpoint: The training directory with checkpoint files or the path to a single checkpoint file for the model being used. Raises: GeneratorException: If an invalid generator name is given or no training directory is given. """ def __init__( self, generator_name, num_bars_to_generate, hparams, checkpoint=None, bundle_file=None): self._num_bars_to_generate = num_bars_to_generate if not checkpoint and not bundle_file: raise GeneratorException( 'No generator checkpoint or bundle location supplied.') if (checkpoint or generator_name or hparams) and bundle_file: raise GeneratorException( 'Cannot specify both bundle file and checkpoint, generator_name, ' 'or hparams.') bundle = None if bundle_file: bundle = sequence_generator_bundle.read_bundle_file(bundle_file) generator_name = bundle.generator_details.id if generator_name not in _GENERATOR_FACTORY_MAP: raise GeneratorException('Invalid generator name given: %s', generator_name) generator = _GENERATOR_FACTORY_MAP[generator_name].create_generator( checkpoint=checkpoint, bundle=bundle, hparams=hparams) generator.initialize() self._generator = generator def generate_melody(self, input_sequence): """Calls the SequenceGenerator and returns the generated NoteSequence.""" # TODO(fjord): Align generation time on a measure boundary. notes_by_end_time = sorted(input_sequence.notes, key=lambda n: n.end_time) last_end_time = notes_by_end_time[-1].end_time if notes_by_end_time else 0 # Assume 4/4 time signature and a single tempo. qpm = input_sequence.tempos[0].qpm seconds_to_generate = (60.0 / qpm) * 4 * self._num_bars_to_generate request = generator_pb2.GenerateSequenceRequest() request.input_sequence.CopyFrom(input_sequence) section = request.generator_options.generate_sections.add() # Start generating 1 quarter note after the sequence ends. section.start_time_seconds = last_end_time + (60.0 / qpm) section.end_time_seconds = section.start_time_seconds + seconds_to_generate response = self._generator.generate(request) return response.generated_sequence class Metronome(threading.Thread): """A thread implementing a MIDI metronome. Attributes: _outport: The Mido port for sending messages. _qpm: The integer quarters per minute to signal on. _stop_metronome: A boolean specifying whether the metronome should stop. _velocity: The velocity of the metronome's MIDI note_on message. Args: outport: The Mido port for sending messages. qpm: The integer quarters per minute to signal on. velocity: The velocity of the metronome's MIDI note_on message. """ daemon = True def __init__(self, outport, qpm, clock_start_time, velocity): self._outport = outport self._qpm = qpm self._stop_metronome = False self._clock_start_time = clock_start_time self._velocity = velocity super(Metronome, self).__init__() def run(self): """Outputs metronome tone on the qpm interval until stop signal received.""" period = 60. / self._qpm sleep_offset = 0 while not self._stop_metronome: now = time.time() next_tick_time = now + period - ((now - self._clock_start_time) % period) delta = next_tick_time - time.time() if delta > 0: time.sleep(delta + sleep_offset) # The sleep function tends to return a little early or a little late. # Gradually modify an offset based on whether it returned early or late, # but prefer returning a little bit early. # If it returned early, spin until the correct time occurs. tick_late = time.time() - next_tick_time if tick_late > 0: sleep_offset -= .0005 elif tick_late < -.001: sleep_offset += .0005 if tick_late < 0: while time.time() < next_tick_time: pass self._outport.send(mido.Message(type='note_on', note=_METRONOME_PITCH, channel=FLAGS.metronome_channel, velocity=self._velocity)) time.sleep(_METRONOME_TICK_DURATION) self._outport.send(mido.Message(type='note_off', note=_METRONOME_PITCH, channel=FLAGS.metronome_channel)) def stop(self): """Signals for the metronome to stop and joins thread.""" self._stop_metronome = True self.join() class MonoMidiPlayer(threading.Thread): """A thread for playing back a monophonic, sorted NoteSequence via MIDI. Attributes: _outport: The Mido port for sending messages. _sequence: The monohponic, chronologically sorted NoteSequence to play. _stop_playback: A boolean specifying whether the playback should stop. Args: outport: The Mido port for sending messages. sequence: The monohponic, chronologically sorted NoteSequence to play. metronome_velocity: The velocity of the metronome's MIDI note_on message. Raises: ValueError: The NoteSequence contains multiple tempos. """ daemon = True def __init__(self, outport, sequence, metronome_velocity): self._outport = outport self._sequence = sequence self._stop_playback = False if len(sequence.tempos) != 1: raise ValueError('The NoteSequence contains multiple tempos.') self._metronome = Metronome(self._outport, sequence.tempos[0].qpm, time.time(), metronome_velocity) super(MonoMidiPlayer, self).__init__() def run(self): """Plays back the NoteSequence until it ends or stop signal is received. Raises: ValueError: The NoteSequence is not monophonic and chronologically sorted. """ stdout_write_and_flush('Playing sequence...') self._metronome.start() # Wall start time. play_start = time.time() # Time relative to start of NoteSequence. playhead = 0 for note in self._sequence.notes: if self._stop_playback: self._outport.panic() return stdout_write_and_flush('.') if note.start_time < playhead: raise ValueError( 'The NoteSequence is not monophonic and chronologically sorted.') playhead = note.start_time delta = playhead - (time.time() - play_start) if delta > 0: time.sleep(delta) self._outport.send( mido.Message( 'note_on', note=note.pitch, velocity=note.velocity)) if self._stop_playback: self._outport.panic() return if note.end_time < playhead: raise ValueError( 'The NoteSequence is not monophonic and chronologically sorted.') playhead = note.end_time delta = playhead - (time.time() - play_start) if delta > 0: time.sleep(delta) self._outport.send(mido.Message('note_off', note=note.pitch)) self._metronome.stop() stdout_write_and_flush('Done\n') def stop(self): """Signals for the playback and metronome to stop and joins thread.""" self._stop_playback = True self._metronome.stop() self.join() class MonoMidiHub(object): """A MIDI interface for capturing and playing monophonic NoteSequences. Attributes: _inport: The Mido port for receiving messages. _outport: The Mido port for sending messages. _lock: An RLock used for thread-safety. _capture_sequence: The NoteSequence being built from MIDI messages currently being captured or having been captured in the previous session. _control_cvs: A dictionary mapping (,) and (, ) to a condition variable that will be notified when a matching control change messsage is received. _player: A thread for playing back NoteSequences via the MIDI output port. Args: input_midi_port: The string MIDI port name to use for input. output_midi_port: The string MIDI port name to use for output. """ def __init__(self, input_midi_port, output_midi_port): self._inport = mido.open_input(input_midi_port) self._outport = mido.open_output(output_midi_port) # This lock is used by the serialized decorator. self._lock = threading.RLock() self._control_cvs = dict() self._player = None self._capture_start_time = None self._sequence_start_time = None def _timestamp_and_capture_message(self, msg): """Stamps message with current time and passes it to the capture handler.""" msg.time = time.time() self._capture_message(msg) @serialized def _capture_message(self, msg): """Handles a single incoming MIDI message during capture. Used as callback. If the message is a control change, notifies threads waiting on the appropriate condition variable. If the message is a note_on event, ends the previous note (if applicable) and opens a new note in the capture sequence. Also forwards the message to the output MIDI port. Ignores repeated note_on events. If the message is a note_off event matching the current open note in the capture sequence, ends that note and forwards the message to the output MIDI port. Args: msg: The mido.Message MIDI message to handle. """ if msg.type == 'control_change': control_tuples = [(msg.control,), (msg.control, msg.value)] for control_tuple in control_tuples: if control_tuple in self._control_cvs: self._control_cvs[control_tuple].notify_all() return last_note = (self.captured_sequence.notes[-1] if self.captured_sequence.notes else None) if msg.type == 'note_off' or (msg.type == 'note_on' and msg.velocity == 0): if (last_note is None or last_note.pitch != msg.note or last_note.end_time > 0): # This is not the note we're looking for. Drop it. return last_note.end_time = msg.time - self._sequence_start_time self._outport.send(msg) stdout_write_and_flush('.') elif msg.type == 'note_on': if self._sequence_start_time is None: # This is the first note. # Find the sequence start time based on the start of the most recent # quarter note. This ensures that the sequence start time lines up with # a metronome tick. period = 60. / self.captured_sequence.tempos[0].qpm self._sequence_start_time = msg.time - ( (msg.time - self._capture_start_time) % period) elif last_note.end_time == 0: if last_note.pitch == msg.note: # This is just a repeat of the previous message. return # End the previous note. last_note.end_time = msg.time - self._sequence_start_time self._outport.send(mido.Message('note_off', note=last_note.pitch)) self._outport.send(msg) new_note = self.captured_sequence.notes.add() new_note.start_time = msg.time - self._sequence_start_time new_note.pitch = msg.note new_note.velocity = msg.velocity stdout_write_and_flush('.') @serialized def start_capture(self, qpm): """Starts a capture session. Initializes a new capture sequence, sets the capture callback on the input port, and starts the metronome. Args: qpm: The integer quarters per minute to use for the metronome and captured sequence. Raises: RuntimeError: Already in a capture session. """ if self._inport.callback is not None: raise RuntimeError('Already in a capture session.') self.captured_sequence = music_pb2.NoteSequence() self.captured_sequence.tempos.add().qpm = qpm self._sequence_start_time = None self._capture_start_time = time.time() self._inport.callback = self._timestamp_and_capture_message self._metronome = Metronome(self._outport, qpm, self._capture_start_time, _METRONOME_CAPTURE_VELOCITY) self._metronome.start() @serialized def stop_capture(self): """Stops the capture session and returns the captured sequence. Resets the capture callback on the input port, closes the final open note (if applicable), stops the metronome, and returns the captured sequence. Returns: The captured NoteSequence. Raises: RuntimeError: Not in a capture session. """ if self._inport.callback is None: raise RuntimeError('Not in a capture session.') self._inport.callback = None self._metronome.stop() last_note = (self.captured_sequence.notes[-1] if self.captured_sequence.notes else None) if last_note is not None and last_note.end_time == 0: last_note.end_time = time.time() - self._sequence_start_time stdout_write_and_flush('Done\n') return self.captured_sequence @serialized def wait_for_control_signal(self, control_number, control_value=None): """Blocks until a specific control signal arrives. Args: control_number: The integer control change number. control_value: The integer control change value or None if any is acceptable. """ if self._inport.callback is None: # Not in a capture session. for msg in self._inport: if (msg.type == 'control_change' and msg.control == control_number and (control_value is None or msg.value == control_value)): return else: # In a capture session. control_tuple = ((control_number,) if control_value is None else (control_number, control_value)) if control_tuple not in self._control_cvs: self._control_cvs[control_tuple] = threading.Condition(self._lock) self._control_cvs[control_tuple].wait() def start_playback(self, sequence, metronome_velocity): """Plays the monophonic, sorted NoteSequence through the MIDI output port. Stops any previously playing sequences. Args: sequence: The monohponic, chronologically sorted NoteSequence to play. metronome_velocity: The velocity of the metronome's MIDI note_on message. """ self.stop_playback() self._player = MonoMidiPlayer(self._outport, sequence, metronome_velocity) self._player.start() def stop_playback(self): """Stops any active sequence playback.""" if self._player is not None and self._player.is_alive(): self._player.stop() stdout_write_and_flush('Stopped\n') def main(unused_argv): if FLAGS.list: print "Input ports: '" + "', '".join(mido.get_input_names()) + "'" print "Output ports: '" + "', '".join(mido.get_output_names()) + "'" return if FLAGS.input_port is None or FLAGS.output_port is None: print '--inport_port and --output_port must be specified.' return if (FLAGS.start_capture_control_number == FLAGS.stop_capture_control_number and (FLAGS.start_capture_control_value == FLAGS.stop_capture_control_value or FLAGS.start_capture_control_value is None or FLAGS.stop_capture_control_value is None)): print('If using the same number for --start_capture_control_number and ' '--stop_capture_control_number, --start_capture_control_value and ' '--stop_capture_control_value must both be defined and unique.') return if not 0 <= FLAGS.metronome_playback_velocity <= 127: print 'The metronome_playback_velocity must be an integer between 0 and 127' return generator = Generator( FLAGS.generator_name, FLAGS.num_bars_to_generate, ast.literal_eval(FLAGS.hparams if FLAGS.hparams else '{}'), FLAGS.checkpoint, FLAGS.bundle_file) hub = MonoMidiHub(FLAGS.input_port, FLAGS.output_port) stdout_write_and_flush('Waiting for start control signal...\n') while True: hub.wait_for_control_signal(FLAGS.start_capture_control_number, FLAGS.start_capture_control_value) hub.stop_playback() hub.start_capture(FLAGS.qpm) stdout_write_and_flush('Capturing notes until stop control signal...') hub.wait_for_control_signal(FLAGS.stop_capture_control_number, FLAGS.stop_capture_control_value) captured_sequence = hub.stop_capture() stdout_write_and_flush('Generating response...') generated_sequence = generator.generate_melody(captured_sequence) stdout_write_and_flush('Done\n') hub.start_playback(generated_sequence, FLAGS.metronome_playback_velocity) def console_entry_point(): tf.app.run(main) if __name__ == '__main__': console_entry_point()