2016-11-11 18:53:51 +00:00
|
|
|
#
|
|
|
|
# Copyright 2016 Google Inc.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
#
|
|
|
|
|
2017-01-13 21:50:56 +00:00
|
|
|
|
|
|
|
import magenta
|
|
|
|
|
|
|
|
from magenta.models.melody_rnn import melody_rnn_config_flags
|
|
|
|
from magenta.models.melody_rnn import melody_rnn_model
|
|
|
|
from magenta.models.melody_rnn import melody_rnn_sequence_generator
|
2016-11-17 04:33:15 +00:00
|
|
|
from magenta.protobuf import generator_pb2
|
2017-01-13 21:50:56 +00:00
|
|
|
from magenta.protobuf import music_pb2
|
2017-04-22 15:49:04 +00:00
|
|
|
from usingMusicNN import predictmood
|
2017-01-13 21:50:56 +00:00
|
|
|
|
2016-11-11 18:53:51 +00:00
|
|
|
import os
|
2017-01-13 21:50:56 +00:00
|
|
|
import time
|
2016-11-11 18:53:51 +00:00
|
|
|
import tempfile
|
2017-01-13 21:50:56 +00:00
|
|
|
import pretty_midi
|
|
|
|
|
2017-02-13 22:29:42 +00:00
|
|
|
BUNDLE_NAME = 'attention_rnn'
|
2016-11-11 18:53:51 +00:00
|
|
|
|
2017-01-13 21:50:56 +00:00
|
|
|
config = magenta.models.melody_rnn.melody_rnn_model.default_configs[BUNDLE_NAME]
|
|
|
|
bundle_file = magenta.music.read_bundle_file(os.path.abspath(BUNDLE_NAME+'.mag'))
|
|
|
|
steps_per_quarter = 4
|
2016-11-11 18:53:51 +00:00
|
|
|
|
2017-01-13 21:50:56 +00:00
|
|
|
generator = melody_rnn_sequence_generator.MelodyRnnSequenceGenerator(
|
|
|
|
model=melody_rnn_model.MelodyRnnModel(config),
|
|
|
|
details=config.details,
|
|
|
|
steps_per_quarter=steps_per_quarter,
|
|
|
|
bundle=bundle_file)
|
|
|
|
|
|
|
|
def _steps_to_seconds(steps, qpm):
|
|
|
|
return steps * 60.0 / qpm / steps_per_quarter
|
2016-11-11 18:53:51 +00:00
|
|
|
|
|
|
|
def generate_midi(midi_data, total_seconds=10):
|
2017-01-13 21:50:56 +00:00
|
|
|
primer_sequence = magenta.music.midi_io.midi_to_sequence_proto(midi_data)
|
|
|
|
|
|
|
|
# predict the tempo
|
2016-11-11 18:53:51 +00:00
|
|
|
if len(primer_sequence.notes) > 4:
|
|
|
|
estimated_tempo = midi_data.estimate_tempo()
|
|
|
|
if estimated_tempo > 240:
|
|
|
|
qpm = estimated_tempo / 2
|
|
|
|
else:
|
|
|
|
qpm = estimated_tempo
|
|
|
|
else:
|
|
|
|
qpm = 120
|
|
|
|
primer_sequence.tempos[0].qpm = qpm
|
2017-01-13 21:50:56 +00:00
|
|
|
|
|
|
|
generator_options = generator_pb2.GeneratorOptions()
|
2016-11-11 18:53:51 +00:00
|
|
|
# Set the start time to begin on the next step after the last note ends.
|
2017-01-13 21:50:56 +00:00
|
|
|
last_end_time = (max(n.end_time for n in primer_sequence.notes)
|
|
|
|
if primer_sequence.notes else 0)
|
|
|
|
generator_options.generate_sections.add(
|
|
|
|
start_time=last_end_time + _steps_to_seconds(1, qpm),
|
|
|
|
end_time=total_seconds)
|
|
|
|
|
|
|
|
# generate the output sequence
|
|
|
|
generated_sequence = generator.generate(primer_sequence, generator_options)
|
2017-05-05 02:58:17 +00:00
|
|
|
|
2016-11-11 18:53:51 +00:00
|
|
|
output = tempfile.NamedTemporaryFile()
|
2017-01-13 21:50:56 +00:00
|
|
|
magenta.music.midi_io.sequence_proto_to_midi_file(generated_sequence, output.name)
|
2016-11-11 18:53:51 +00:00
|
|
|
output.seek(0)
|
2016-11-16 21:58:24 +00:00
|
|
|
return output
|
2017-05-05 02:58:17 +00:00
|
|
|
|