added sys path to third_party folder

Fixes #1
Fixes #2
This commit is contained in:
Yotam Mann 2016-11-16 13:58:24 -08:00
parent 6a573d1b00
commit f33312634a

View File

@ -14,6 +14,8 @@
# limitations under the License.
#
import sys
sys.path.append('./third_party')
import third_party.magenta.models.basic_rnn.basic_rnn_generator as basic_rnn_generator
from third_party.magenta.music import sequence_generator_bundle
from third_party.magenta.protobuf import generator_pb2
@ -49,9 +51,9 @@ def generate_midi(midi_data, total_seconds=10):
1, qpm)
generate_section.end_time_seconds = total_seconds
# generate_response = generator_map[generator_name].generate(generate_request)
generate_response = basic_generator.generate(generate_request.input_sequence, generate_request.generator_options)
generate_response = basic_generator.generate(generate_request)
output = tempfile.NamedTemporaryFile()
midi_io.sequence_proto_to_midi_file(
generate_response, output.name)
generate_response.generated_sequence, output.name)
output.seek(0)
return output