Hi everyone,
I’ve been coding a wavenet model from scratch in pytorch, but for some reason, I just can’t get it to properly train. Every epoch in my code seems to have nearly the same loss, and I can’t seem to figure out why. I was hoping someone here would be able to take a look at my code and help me debug the situation.
With regards to the files:
[train.py]: this is where the training logic happens
[wavenet.py]: this is the pytorch model
[dataloader.py]: this is where I load and create the data to be used for training
[utils.py]: I have some generation functions here.
A lot of this code was adapted from sources I found online, the main difference being that I’m padding the input into the dilated convolution to ensure the input and output remain the same dimension.
Would really appreciate some help on this!
GITHUB: GitHub - mahtanir/Wavenet: Wavnet pytorch implementation
My model structure is below in case you don’t want to check out the github repo:
#x is typically one channel where the timestemps depends on the frequency rate. Can be two though. Typically a 1D array
from torch import nn
import torch
import numpy as np
import torch.nn.functional as F
class CasualDilatedConv1D(nn.Module):
def __init__(self, res_channels, out_channels, kernel_size, dilation):
super().__init__()
self.dilation = dilation
self.kernel_size = kernel_size
self.conv1D = nn.Conv1d(res_channels, out_channels, kernel_size, dilation=dilation, bias=False)
self.ignoreOutIndex = (kernel_size - 1)*dilation #i.e don't have to consider right part because of k - 1 padding on either side.
def forward(self, x):
# Apply padding
x = nn.functional.pad(x, ((self.kernel_size - 1)*self.dilation, 0)) #IF we don't need this need to add to input.
#if we do this without padding we lose (k - 1)*dim everytime.
#padding same is (k - 1) / 2 each side. So instead we do (k - 1) on both sides but now 2k-2 vs k - 1 so k - 1 extra. Remove right.
# x = x.double()
# print('pre shape', x.shape)
# print(self.conv1D(x)[..., :-self.ignoreOutIndex].shape)
return self.conv1D(x)
# [..., :-self.ignoreOutIndex] #https://chat.openai.com/c/0598fb53-ddb1-43e9-9572-8fc80498ca28 cause padding = same
#why do we do this? Only if we add padding right
#ALT
# return self.conv1D(x)
class ResBlock(nn.Module): #using the same kernel weights for all
def __init__(self, res_channels, skip_channels, kernel_size, dilation):
super().__init__()
self.dilatedConv1D = CasualDilatedConv1D(res_channels, res_channels, kernel_size, dilation = dilation)
self.resConv1D = nn.Conv1d(res_channels, res_channels, kernel_size=1, dilation=1) #i.e same input output dims (see diagram)
self.skipConv1D = nn.Conv1d(res_channels, skip_channels, kernel_size=1, dilation=1)
self.tanh = nn.Tanh()
self.sigmoid = nn.Sigmoid()
def forward(self, input):
x = self.dilatedConv1D(input)
x_tan = self.tanh(x)
x_sigmoid = self.sigmoid(x)
x = x_tan * x_sigmoid
residual_output = self.resConv1D(x) #shape = n,c,sample
residual_output = residual_output + input
# ALT residual_output = residual_output + input[..., -residual_output.size(2):] due to causality and dilated conv affecting dimensions.
skip_output = self.skipConv1D(x) #this is for the skip connection output to the right in diagram
return residual_output, skip_output
class stackOfResBlocks(nn.Module):
def __init__(self, stack_size, layer_size, res_channels, skip_channels, kernel_size):
super().__init__()
dilations = self.buildDilations(stack_size, layer_size)
self.resBlockArr = []
for stack in dilations:
for dilation in stack:
self.resBlockArr.append(ResBlock(res_channels, skip_channels, kernel_size, dilation))
def buildDilations(self, stack_size, layer_size):
dilations_arr_all = []
for stack in range(stack_size): #stack is not actually a stack of resblocks but rather to 512.Could just do 1 array I feel but good logic
dilation_arr = []
for j in range(layer_size):
dilation = 2**j if 2**j <= 520 else 520 #assuming doesn't go beyond 512 otherwise impose a cap ie 2**layer_size
dilation_arr.append(dilation)
dilations_arr_all.append(dilation_arr)
return np.array(dilations_arr_all)
def forward(self, x):
residual_outputs = []
for resBlock in self.resBlockArr:
x, residual = resBlock(x)
residual_outputs.append(residual)
return x, torch.stack(residual_outputs) #creates new dim at = 0 . so it is #layers, (n), samples, channels
class DenseLayer(nn.Module): #WHAT IS GOING ON HERE!
def __init__(self, res_channels, out_channels):
super().__init__()
self.relu = nn.ReLU()
self.conv1D = nn.Conv1d(res_channels, res_channels, kernel_size=1, dilation=1, bias=False)
self.conv2nD = nn.Conv1d(res_channels, out_channels, kernel_size=1, dilation=1, bias=False)
self.softmax = nn.Softmax(dim=1)
def forward(self, skipConnections): #not sure about channel here
#we have skip connections of (batches, timesteps, channles) potentially channels is the timseteps and timestep is the song notes?
# based on medium article, it is (#layers, samples, channels)
out = torch.sum(skipConnections, dim=0) #sum across the layers --> should be 0?? he put dim=2
out = self.relu(out)
out = self.conv1D(out)
out = self.relu(out)
out = self.conv2nD(out)
return out
# return self.softmax(out) #outs dimensions after torch.sum become samples,channels in which case this would make sense.
class Wavenet(nn.Module):
def __init__(self, res_channels, out_channels, skip_channels, kernel_size, stack_size, layer_size): #stack sie and layer size depends how many we want to stack
super().__init__()
self.stack_size = stack_size
self.layer_size = layer_size
self.kernel_size = kernel_size
self.casualConv1D = CasualDilatedConv1D(256, res_channels, kernel_size, dilation=1) #what are channels here? Represent different features audio wise. Usually just one channel to represent amplitude.
self.resBlockStack = stackOfResBlocks(stack_size, layer_size, res_channels, skip_channels, kernel_size)
self.denseLayer = DenseLayer(skip_channels, out_channels)
def calculateReceptiveField(self):
sum_val = np.sum([(self.kernel_size - 1) * 2**self.layer_size for i in range(self.layer_size)] * self.stack_size)
#would need this if we were not doing padding, i.e see image above. At each step we're removing (kernel - 1) * 2**layer # from previous resblock output.
return sum_val
def forward(self, x):
x = one_hot(x, self.kernel_size)
x = self.casualConv1D(x)
# print('conv1d 2/ dilation post shape: ', x.shape, '\n')
final_res_output, skip_connections = self.resBlockStack(x) #final output is not necessary
skip_output = sum([skip[...,-final_res_output.shape[-1]:] for skip in skip_connections]) #ALT
return self.denseLayer(skip_connections)
# class WavenetClassifier(nn.Module):
# def __init__(self, ):
# super().__init__()
# self.Wavenet = Wavenet(32, 256, 512, 2, 10, 5) #if we want to one hot encode the notes may have to convert this to 256.
# #32 = 24 in image, 512 = 128 in image
def one_hot(x, kernel_size):
x = torch.tensor(np.array(x))
# print('shape here!', x.shape)
# x = nn.functional.pad(x, (kernel_size - 1, kernel_size - 1)) #IF we don't need this need to add to input.
# print('shape here!', x.shape)
one_hot = F.one_hot(x, num_classes=256)
# print(one_hot.shape, one_hot[0])
tf_shape = (1, -1, 256) #so rows actually are the points! I THINK! but then the way the conv channel works is weird... But image also shows like this
py_shape = (1, 256, -1)
one_hot = torch.reshape(one_hot, py_shape)
one_hot = torch.tensor(one_hot, dtype=torch.float32)
# print('one_hot vector input: \n', one_hot, '\n', one_hot.shape, '\n')
return one_hot
And the train file is as follows:
from wavenet import *
# from wavenet_copy import *
import datetime
from dataloader import *
import numpy as np
from mxnet import ndarray
from tqdm import tqdm
from scipy.io.wavfile import write
from IPython.display import Audio
import sounddevice as sd
class Train():
def __init__(self, config) -> None:
self.layer_size = config.layer_size
self.stack_size = config.stack_size #repeat
self.res_channels = config.res_channels
self.skip_channels = config.skip_channels
self.mu = 256
self.batch_size = config.batch_size
self.epochs = config.epochs
self.seq_length = config.seq_length
def save_params(self, model):
torch.save(model.state_dict(), 'models/best2/wavenet.pth')
# model.save_params('models/best_perf/' + datetime.datetime.now())
def train(self):
print('at train')
# Wavenet()
net = Wavenet(self.res_channels, self.mu, self.skip_channels, 2, self.stack_size, self.layer_size)
self.net = net
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
n_steps = self.batch_size
print('loading music...')
fs, data = load_music('data_parametric-2')
minLoss = None
data_generator = data_generation(data, fs, self.seq_length, self.mu, None) #generate training data lazilly.
for i in tqdm(range(self.epochs)):
loss = 0
for j in tqdm(range(self.batch_size), leave=False): #assuming that the batch size is full training set, stochastic gradient descent
# print('epoch: ', i, '\n sample: ', j, '\n')
sample = next(data_generator)
#Forward Pass
x = sample[:-1] #one behind
# x = x.astype('float64')
# print('type x', type(x[0]))
y = sample[-x.shape[0]:] #normal (effectively one forward)
y = one_hot_utils(y)
y_hat = net(x) #converted to one_hot already but in the right format for conv
# print('model_output: ', y_hat, '\n model shape:', y_hat.shape,
# '\n test output: ', y, '\n test output shape', y.shape)
# print('shape check', y_hat.shape, y.shape)
tf_shape = (0, 2, 1) #so rows actually are the points! I THINK! but then the way the conv channel works is weird... But image also shows like this
y_hat = torch.permute(y_hat, tf_shape)
# print(y_hat.shape, y.shape)
# loss_criterion = criterion(y_hat, y)
#if alt
loss_criterion = criterion(y_hat, y) #ALT
loss = loss + loss_criterion.item()
# print('loss criterion: ', loss_criterion, '\nloss: ', loss)
loss_criterion.backward() #predicts loss across each step for all categories. Only true cat matters though.I.e loss is (1, sample) -> Actually loss criterion avg across samples
optimizer.step()
optimizer.zero_grad() #stochastic
with torch.no_grad():
agg_loss = loss / self.batch_size
print(f"loss for epoch {i} : {agg_loss} \n")
# ndarray.sum(loss).asscalar()
if (minLoss is None or agg_loss < minLoss): #stochastic volative, so look per batch which is best
minLoss = agg_loss
self.save_params(net)
return net
def bestModel(self): #load best model for given architecture
model = Wavenet(self.res_channels, self.mu, self.skip_channels, 2, self.stack_size, self.layer_size)
model.load_state_dict(torch.load('models/best2/wavenet.pth'))
self.net = model
return model
def generate_slow(self, x, model, n, dilation_depth, n_repeat):
dilations = [2*i for i in range(dilation_depth)] * n_repeat
reference_window = sum(dilations)
x_generated = x.copy()
np.save("initial_wav_alt.npy",decode_mu_law(x_generated.copy()))
for i in tqdm(range(n)):
y = model(x_generated[-reference_window - 1:])
# y = model(x_generated[-reference_window -1: ]) ALR
y_next = np.squeeze(y.argmax(1).numpy())[-1] #n, c, samples
# print(y.argmax(1).numpy(), np.squeeze(y.argmax(1).numpy())[-1])
x_generated = np.append(x_generated, y_next)
#similar to LSTM logic but now instead of reference window of 1, you have reference window of dilations
#still add the prev output!
# print(x_generated.shape)
return x_generated
def generator(self, model, n, dilation_depth, n_repeat):
print('generating now...')
fr, data = load_music('data_parametric-2')
data_sample = data_generation_sample(data, fr, self.seq_length, self.mu, None)
generated_song = self.generate_slow(data_sample, model, n, self.layer_size, self.stack_size)
gen_wav = np.array(generated_song)
decoded_wave = decode_mu_law(gen_wav, 256)
np.save("wav_long_alt.npy",decoded_wave)
# write('test.wav', fr, decoded_wave)
# sd.play(decoded_wave, fr)
# Audio(gen_wav, rate=fr)
# LOGIC: Note that for the last output, elements beyond the reference window don't affect the subsequent output.
# as such we simply need to consider the reference window only. Reference window is the sum of the dilations for last point.
#i.e always consider left most point contributing towards it! Always + dilation for reference window.
#But what is x too small ie starts from 0? Not sure about - 1 also but still works This is especially because of the
# 1by1 conv. Therefore since weight predecided we only really need to know of the last node, if 2 by 1 would need to know of the other one also i.e one before.
# def generate_slower(self, x, models, dilation_depth, n_repeat, ctx, n):
# dilations = [2**i for i in range(dilation_depth)] * n_repeat
# res = list(x.asnumpy())
# for _ in trange(n):
# x = nd.array(res[-sum(dilations)-1:],ctx=ctx) i.e losing (k - 1)*dilation every time. So we sum all --> here only looking at 1 skip conn output
# y = models(x)
# res.append(y.argmax(1).asnumpy()[-1])
# return res
As for the utils file:
#borrowed from https://medium.com/@kion.kim/wavenet-a-network-good-to-know-7caaae735435
import numpy as np
from torch import nn
import torch.nn.functional as F
import torch
def encode_mu_law(x, mu=256):
mu = mu-1
fx = np.sign(x)*np.log(1+mu*np.abs(x))/np.log(1+mu) #1
print('TYPE', type(x[0]))
return np.floor((fx+1)/2*mu+0.5).astype(np.long) #2
def decode_mu_law(y, mu=256):
mu = mu-1
fx = (y-0.5)/mu*2-1 #reverse of #1
x = np.sign(fx)/mu*((1+mu)**np.abs(fx)-1)
return x
def one_hot_utils(x):
x = torch.tensor(np.array(x))
one_hot = F.one_hot(x, num_classes=256)
tf_shape = (1, -1, 256) #so rows actually are the points! I THINK! but then the way the conv channel works is weird... But image also shows like this
one_hot = torch.reshape(one_hot, tf_shape)
one_hot = torch.tensor(one_hot, dtype=torch.float32)
return one_hot
The dataloader file is as follows:
import os
from scipy.io import wavfile
import numpy as np
from utils import *
def load_music(music_name):
fs, data = wavfile.read(os.path.join('data', music_name + '.wav'))
# print(fs, data)
return fs, data
def load_music_test():
music_name = 'data_parametric-2'
fs, data = load_music(music_name)
print(fs, data, data.shape)
def data_generation(data, frame_rate, seq_size, mu, ctx):
max_val = max(abs(min(data)), max(data))
data = data / max_val
while True: #forever?
sequence_sample_start = np.random.randint(0, data.shape[0] - seq_size)
subsequence = data[sequence_sample_start: sequence_sample_start + seq_size]
condensed_subsequence = encode_mu_law(subsequence, mu)
yield condensed_subsequence #yield returns a generator object that is an iterable that can be iterated on (i.e with for loop) only once
# preserves memory since it doesn't store it in memory vs other iterables like arrays or lists (function continues where left off
# ) Returns one value at a time, as long as it knows the next I think it's ok.
def data_generation_sample(data, frame_rate, seq_size, mu, ctx):
#same logic as before but now we only return one, not a generator.
max_val = max(max(data), abs(min(data)))
data = data / max_val
start = np.random.randint(0, data.shape[0] - seq_size)
subset = data[start: start+ seq_size]
return encode_mu_law(subset, mu)
load_music_test()