In [1]:
# Ignore a bunch of deprecation warnings
import sys
sys.path.append('../../..')
sys.path.append('../.')
import warnings
warnings.filterwarnings("ignore")

import copy
import os
import time
from tqdm import tqdm
import math

import ddsp
import ddsp.training

from data_handling.ddspdataset import DDSPDataset
from utils.training_utils import print_hparams, set_seed, save_results, str2bool
from hparams_midiae_interp_cond import hparams as hp
from midiae_interp_cond.get_model import get_model, get_fake_data

import librosa
import matplotlib.pyplot as plt
import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
import pandas as pd

from notebook_utils import *

set_seed(1234)

# Helper Functions
sample_rate = 16000


print('Done!')
Done!
In [2]:
model_path = r'/data/ddsp-experiment/logs/logs/logs_interp_cond_6.12_autoreg_rnn_harm_amp_mdn/2021-06-12-09-19-35/20000'
hp_dict = get_hp(os.path.join(os.path.dirname(model_path), 'train.log'))
for k, v in hp_dict.items():
    setattr(hp, k, v)
In [3]:
model = get_model(hp)
_ = model._build(get_fake_data(hp))
model.load_weights(model_path)
Out[3]:
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f5a6006ab90>
In [4]:
from data_handling.urmp_tfrecord_dataloader import UrmpMidi
from data_handling.get_tfrecord_length import get_tfrecord_length
from midiae_interp_cond.recon_loss import ReconLossHelper
In [5]:
from data_handling.get_dataset import get_batch
In [6]:
from data_handling.instrument_name_utils import INST_ABB_TO_NAME_DICT, INST_NAME_TO_ABB_DICT, INST_ABB_LIST
In [7]:
log_dir = '/data/ddsp-experiment/urmp_single_instrument_recon'
os.makedirs(log_dir, exist_ok=True)
In [8]:
data_dir = r'/data/music_dataset/urmp_dataset/tfrecord_ddsp/batched/solo_instrument'
test_data_loader = UrmpMidi(data_dir, instrument_key='vn', split='test')
evaluation_data = get_batch(test_data_loader, batch_size=1, shuffle=True, repeats=1, drop_remainder=False)
evaluation_data = iter(evaluation_data)
In [9]:
data = next(evaluation_data)
In [10]:
outputs = model(data, training=False)
In [11]:
plot_spec(outputs['midi_audio'][0].numpy(), sample_rate)
In [12]:
plot_spec(outputs['synth_audio'][0].numpy(), sample_rate)