Solved – CNN + LSTM in tensorflow

deep learninglstmtensorflow

There are quite a few examples on how to use LSTMs alone in TF, but I couldn't find any good examples on how to train CNN + LSTM jointly. From what I see, it is not quite straightforward how to do such training, and I can think of just one option.

I believe the simplest solution (or the most primitive one) would be to train CNN independently to learn features and then to train LSTM on CNN features without updating the CNN part, since one would probably have to extract and save these features in numpy and then feed them to LSTM in TF. But in that scenario, one would probably have to use a differently labeled dataset for pretraining of CNN, which eliminates the advantage of end to end training, i.e. learning of features for final objective targeted by LSTM (besides the fact that one has to have these additional labels in the first place).

Is there any other way to do it? I haven't been able to find any example in tensorflow.

Best Answer

CNN + RNN possible. To understand let me try to post commented code. CNN running of chars of sentences and output of CNN merged with word embedding is feed to LSTM

N - number of batches
M - number of examples
L - number of sentence length
W - max length of characters in any word
coz - cnn char output size

Consider x = [N, M, L] - Word level
Consider cnnx = [N, M, L, W] - character level

Aim is to use character and word embedding in LSTM

CNNx = tf.nn.embedding_lookup(emb_mat, cnnx) [N, M, L, W, dc]

filter_sizes = [100]
heights = [5]
outs = []
for filter_size, height in zip(filter_sizes, heights):
    num_channels = 3
    filter_ = [1, height, num_channels, filter_size] 
    strides = [1, 1, 1, 1]

    xxc = tf.nn.conv2d(Acx, filter_, strides, "VALID")  # [N*M, L, W/stride, d]
    out = tf.reduce_max(tf.nn.relu(xxc), 2)  # [-1, L, d]
    outs.append(xxc)

concat_out = tf.concat(2, outs)
xx = tf.reshape(concat_out, [-1, M, L, coz])

Ax = tf.nn.embedding_lookup(emb_mat, x)
xx = tf.concat(3, [xx, Ax])  # [N, M, L, di]

This xx we can feed to RNN with LSTM