updating to latest magenta and tf

This commit is contained in:
Yotam Mann 2017-01-13 16:50:56 -05:00
parent 9c4445abf9
commit c94efe36b1
3 changed files with 43 additions and 48 deletions

BIN
server/basic_rnn.mag Normal file

Binary file not shown.

View File

@ -14,23 +14,40 @@
# limitations under the License. # limitations under the License.
# #
import magenta.models.basic_rnn.basic_rnn_generator as basic_rnn_generator
from magenta.music import sequence_generator_bundle import magenta
from magenta.models.melody_rnn import melody_rnn_config_flags
from magenta.models.melody_rnn import melody_rnn_model
from magenta.models.melody_rnn import melody_rnn_sequence_generator
from magenta.protobuf import generator_pb2 from magenta.protobuf import generator_pb2
from magenta.music import midi_io from magenta.protobuf import music_pb2
from magenta.models.shared.melody_rnn_generate import _steps_to_seconds
import os import os
import time
import tempfile import tempfile
import pretty_midi
BUNDLE_NAME = 'basic_rnn'
basic_generator = basic_rnn_generator.create_generator( config = magenta.models.melody_rnn.melody_rnn_model.default_configs[BUNDLE_NAME]
None, bundle_file = magenta.music.read_bundle_file(os.path.abspath(BUNDLE_NAME+'.mag'))
sequence_generator_bundle.read_bundle_file(os.path.abspath('./magenta/basic_rnn.mag')), steps_per_quarter = 4
4)
generator = melody_rnn_sequence_generator.MelodyRnnSequenceGenerator(
model=melody_rnn_model.MelodyRnnModel(config),
details=config.details,
steps_per_quarter=steps_per_quarter,
bundle=bundle_file)
def _steps_to_seconds(steps, qpm):
return steps * 60.0 / qpm / steps_per_quarter
def generate_midi(midi_data, total_seconds=10): def generate_midi(midi_data, total_seconds=10):
primer_sequence = midi_io.midi_to_sequence_proto(midi_data) primer_sequence = magenta.music.midi_io.midi_to_sequence_proto(midi_data)
generate_request = generator_pb2.GenerateSequenceRequest()
# predict the tempo
if len(primer_sequence.notes) > 4: if len(primer_sequence.notes) > 4:
estimated_tempo = midi_data.estimate_tempo() estimated_tempo = midi_data.estimate_tempo()
if estimated_tempo > 240: if estimated_tempo > 240:
@ -40,18 +57,19 @@ def generate_midi(midi_data, total_seconds=10):
else: else:
qpm = 120 qpm = 120
primer_sequence.tempos[0].qpm = qpm primer_sequence.tempos[0].qpm = qpm
generate_request.input_sequence.CopyFrom(primer_sequence)
generate_section = (generate_request.generator_options.generate_sections.add()) generator_options = generator_pb2.GeneratorOptions()
# Set the start time to begin on the next step after the last note ends. # Set the start time to begin on the next step after the last note ends.
notes_by_end_time = sorted(primer_sequence.notes, key=lambda n: n.end_time) last_end_time = (max(n.end_time for n in primer_sequence.notes)
last_end_time = notes_by_end_time[-1].end_time if notes_by_end_time else 0 if primer_sequence.notes else 0)
generate_section.start_time_seconds = last_end_time + _steps_to_seconds( generator_options.generate_sections.add(
1, qpm) start_time=last_end_time + _steps_to_seconds(1, qpm),
generate_section.end_time_seconds = total_seconds end_time=total_seconds)
# generate_response = generator_map[generator_name].generate(generate_request)
generate_response = basic_generator.generate(generate_request) # generate the output sequence
generated_sequence = generator.generate(primer_sequence, generator_options)
output = tempfile.NamedTemporaryFile() output = tempfile.NamedTemporaryFile()
midi_io.sequence_proto_to_midi_file( magenta.music.midi_io.sequence_proto_to_midi_file(generated_sequence, output.name)
generate_response.generated_sequence, output.name)
output.seek(0) output.seek(0)
return output return output

View File

@ -1,28 +1,5 @@
Flask==0.11.1 tensorflow==0.12.1
Jinja2==2.8 magenta==0.1.8
MarkupSafe==0.23 Flask==0.12
Werkzeug==0.11.11
argparse==1.2.1
boto==2.34.0
chardet==2.3.0
click==6.6
colorama==0.3.2
crcmod==1.7
funcsigs==1.0.2
gunicorn==19.6.0 gunicorn==19.6.0
html5lib==0.999 ipython==5.1.0
itsdangerous==0.24
meld3==1.0.0
mido==1.1.17
mock==2.0.0
numpy==1.11.2
pbr==1.10.0
pretty-midi==0.2.6
protobuf==3.0.0
requests==2.4.3
six==1.10.0
supervisor==3.0
urllib3==1.9.1
virtualenv==1.11.6
wheel==0.29.0
wsgiref==0.1.2