Report
In this report, we introduced a new encoder&decoder arch network to solve the sorting problem as the general purpose. The report can be seperated into several parts:
The Attention mechanism and the seq2seq model
The architecture&usage of the pointer network
The implementation of the code
The performance analysis of the result work.
Attention mechanism
If we talk about the attention mechanism, we must know what the RNN&seq2seq is.
In the NLP problem, we would meet the translation problem quite offten. We usually use the seq2seq model to solve it, the following architecture is a standard seq2seq model contains both the encoder and the decoder.
We have omitted the loss function design in the seq2seq, since it varies with different implementation. In the pointer network, we will represent a much more detailed information about the given network. In this implementation, we can see some shortcomings very easily. For example, once the sentence length expands, we can see that the yellow cell can not memory the enough information.
So we brought out a new conecept Attention, in this implementation, we can fully utilize the encoder’s output information, the following picture show the partial architecture of attention mechanism.
In the attention mechanism, we add a weight matrix A to represent the encoder’s ouput. The detail of the formula will be told in the implementation part.
But pay attention that the attention mechanism varies in different implementation and models, for more detailed work, your can refer to this survey https://lilianweng.github.io/lil-log/2018/06/24/attention-attention.html
What is Pointer Network?
The pointer network can be viewed as a seq2seq model equipped with unique attention mechanism, the picture below describe the architecture of a pointer network.
At the first glimpse, we knew this a seq2seq model, but we should pay attention the fact that the pointer with red mark has slight difference with the formal version of seq2seq,because the pointer means the output of the decoder at a certain timestamp represent a softmaxed vector, each element’s value means the probability that this element pointing to the index of the element. The formula below will give you a more detailed pic of how to get the softmax layer.
Here e_j and d_i means the encoders’ hidden states and decoders’ hidden states, the W1&W2 are the weight matrix remains to be trained, v_t is also a weight vector to transform the result to proper size,
then using softmax function can generate a probability value which means the probability that the current index of this timestamp is i. The implementation in pytorch will give you a more detailed insight.
The following 3 lines defined the W1, W2 and v_t in the formula above.
1 | self.W1 = nn.Linear(hidden_size, weight_size, bias=False) # blending encoder |
What is the sorting task?
Here we need to clarify what is the sorting task, namely what is the input and the ouput of our neural network. The sorting task here is slightly different from the normal sorting, we should output the expected index of each element at the order of raw array. For example:
1 | x = [0.54431329, 0.64373097, 0.9927333 , 0.70941862, 0.10016056] |
1 | y = [4,0,1,3,2] |
The y[i] means the expected index in the sorted array for value x[i]
How pointer network solve the sorting task?
We use lstm cell as the basic element of the encoder and decoder, and the following parts will tell you how to “feed” our input data, and construct our loss function just step by step.
Encoder Steps
At each time step of the encoder, we put the number directly(no embedding like the bag of word model!)
Then the lstm cell will generate a ouput vector at each timestamp, and pass the hidden layer to the next tiemstamp
1 | self.enc = nn.LSTM(input_size, hidden_size, batch_first=True) #Define |
Here the encoder_states
stack the state at each timestep.
Decoder Steps
For the decode steps, we can not use the nn.LSTM
directly, because we will hack each timestamp’s calculation, so we choose the nn.LSTMCell
, and implement the RNN mechanism by for loop manually.
Before we implement the algo, you should recap the formula we mention at What is the Pointer Network? , As for the implementation, we just need the corresponding logic at forward function.
1 | hidden, cell_state = self.dec(decoder_input, (hidden, cell_state)) # (batch_size, h), (batch_size, h) |
The code above is main logic of the given for loop, here we should pay attention the fact that decoder_input
, input layer of the current timestamp is the predicted index at the previous timestamp. (Impose argmax on probability maxtrix)
1 | out = F.log_softmax(out.transpose(0, 1).contiguous(), -1) # (bs, L) |
Network architecture overview
input_size: The input data size at each timestamp(1 in sorting task)
answer_seq_len: lens of the array
weight_size: the first dimension size of blending matrix
hidden_size: the hidden layer of the encoder and decoder’s lstm
Enc: encoder framework
dec: decoder network(step by step)
1 | class PointerNetwork(nn.Module): |
Loss Function Calculation
We use cross entropy loss to calculate the loss function for our neural network,
1 | out = F.log_softmax(out.transpose(0, 1).contiguous(), -1) # (bs, L) |
Training Framework
In the training part, we implement three different task based on three different training models(with the same structure)
And I refract the test code
1 | model = PointerNetwork(1, input_seq_len) |
This is the pipeling of training and testing, I use the DataLoader
provided by pytorch, which acts like a generator for a large dataset size, and we log the loss and accuracy at certain timestamps. The next module will show some result of our experiment under different circumstances.
1 | optimizer = optim.Adam(model.parameters()) |
Result and Conclusion
We implement the same network architecture to finish 3 tasks, we set the batch_size as 250, max_epoch as 1000, dataset_size as 2500.
We will analyse how the architecture of the network affect performance(training time& accuracy) of tasks, and also analyse how different sort tasks affect accuracy.
We analyze the task1 with different architecture.
Network Size(weight_size/hidden_size) | Time Used(to stable) | Accuracy |
---|---|---|
128/128 | 522s | 92.3% |
512/512 | 1033s | 90.1% |
10/10 | 213s | 87.4% |
So the 128/128 arch can be the best arch for the given arch.
We also analyse 3 tasks using the best architecture(128/128)
Task | Time Used(to stable) | Accuracy |
---|---|---|
1(0~100, 5) | 422s | 91.4% |
2(0~1, 5) | 300s | 67.3% |
3(0~100, 10) | 1400s | 87.2% |
Comparing 1 and 3, we knew that the as the sequence length enlarged, it became harder to train, and for the task2, it will be much more difficult than the previous two tasks, because the floating number can be considered with much larger range.