diff --git a/server/basic_rnn.mag b/server/basic_rnn.mag new file mode 100644 index 0000000..8d5f39b Binary files /dev/null and b/server/basic_rnn.mag differ diff --git a/server/predict.py b/server/predict.py index c17d685..640fe86 100644 --- a/server/predict.py +++ b/server/predict.py @@ -14,23 +14,40 @@ # limitations under the License. # -import magenta.models.basic_rnn.basic_rnn_generator as basic_rnn_generator -from magenta.music import sequence_generator_bundle + +import magenta + +from magenta.models.melody_rnn import melody_rnn_config_flags +from magenta.models.melody_rnn import melody_rnn_model +from magenta.models.melody_rnn import melody_rnn_sequence_generator from magenta.protobuf import generator_pb2 -from magenta.music import midi_io -from magenta.models.shared.melody_rnn_generate import _steps_to_seconds +from magenta.protobuf import music_pb2 + + import os +import time import tempfile +import pretty_midi +BUNDLE_NAME = 'basic_rnn' -basic_generator = basic_rnn_generator.create_generator( - None, - sequence_generator_bundle.read_bundle_file(os.path.abspath('./magenta/basic_rnn.mag')), - 4) +config = magenta.models.melody_rnn.melody_rnn_model.default_configs[BUNDLE_NAME] +bundle_file = magenta.music.read_bundle_file(os.path.abspath(BUNDLE_NAME+'.mag')) +steps_per_quarter = 4 + +generator = melody_rnn_sequence_generator.MelodyRnnSequenceGenerator( + model=melody_rnn_model.MelodyRnnModel(config), + details=config.details, + steps_per_quarter=steps_per_quarter, + bundle=bundle_file) + +def _steps_to_seconds(steps, qpm): + return steps * 60.0 / qpm / steps_per_quarter def generate_midi(midi_data, total_seconds=10): - primer_sequence = midi_io.midi_to_sequence_proto(midi_data) - generate_request = generator_pb2.GenerateSequenceRequest() + primer_sequence = magenta.music.midi_io.midi_to_sequence_proto(midi_data) + + # predict the tempo if len(primer_sequence.notes) > 4: estimated_tempo = midi_data.estimate_tempo() if estimated_tempo > 240: @@ -40,18 +57,19 @@ def generate_midi(midi_data, total_seconds=10): else: qpm = 120 primer_sequence.tempos[0].qpm = qpm - generate_request.input_sequence.CopyFrom(primer_sequence) - generate_section = (generate_request.generator_options.generate_sections.add()) + + generator_options = generator_pb2.GeneratorOptions() # Set the start time to begin on the next step after the last note ends. - notes_by_end_time = sorted(primer_sequence.notes, key=lambda n: n.end_time) - last_end_time = notes_by_end_time[-1].end_time if notes_by_end_time else 0 - generate_section.start_time_seconds = last_end_time + _steps_to_seconds( - 1, qpm) - generate_section.end_time_seconds = total_seconds - # generate_response = generator_map[generator_name].generate(generate_request) - generate_response = basic_generator.generate(generate_request) + last_end_time = (max(n.end_time for n in primer_sequence.notes) + if primer_sequence.notes else 0) + generator_options.generate_sections.add( + start_time=last_end_time + _steps_to_seconds(1, qpm), + end_time=total_seconds) + + # generate the output sequence + generated_sequence = generator.generate(primer_sequence, generator_options) + output = tempfile.NamedTemporaryFile() - midi_io.sequence_proto_to_midi_file( - generate_response.generated_sequence, output.name) + magenta.music.midi_io.sequence_proto_to_midi_file(generated_sequence, output.name) output.seek(0) return output diff --git a/server/requirements.txt b/server/requirements.txt index 427b8f7..27b1c93 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -1,28 +1,5 @@ -Flask==0.11.1 -Jinja2==2.8 -MarkupSafe==0.23 -Werkzeug==0.11.11 -argparse==1.2.1 -boto==2.34.0 -chardet==2.3.0 -click==6.6 -colorama==0.3.2 -crcmod==1.7 -funcsigs==1.0.2 +tensorflow==0.12.1 +magenta==0.1.8 +Flask==0.12 gunicorn==19.6.0 -html5lib==0.999 -itsdangerous==0.24 -meld3==1.0.0 -mido==1.1.17 -mock==2.0.0 -numpy==1.11.2 -pbr==1.10.0 -pretty-midi==0.2.6 -protobuf==3.0.0 -requests==2.4.3 -six==1.10.0 -supervisor==3.0 -urllib3==1.9.1 -virtualenv==1.11.6 -wheel==0.29.0 -wsgiref==0.1.2 \ No newline at end of file +ipython==5.1.0 \ No newline at end of file