Solved – How to train LSTM layer of deep-network

classificationdeep learninglstmneural networks

I'm using a lstm and feed-forward network to classify text.

I convert the text into one-hot vectors and feed each into the lstm so I can summarise it as a single representation. Then I feed it to the other network.

But how do I train the lstm? I just want to sequence classify the text— should I feed it without training? I just want to represent the passage as a single item I can feed into the input layer of the classifier.

I would greatly appreciate any advice with this!

Update:

So I have an lstm and a classifier. I take all the outputs of the lstm and mean-pool them, then I feed that average into the classifier.

My issue is that I don't know how to train the lstm or the classifier. I know what the input should be for the lstm and what the output of the classifier should be for that input. Since they are two separate networks that are just being activated sequentially, I need to know and don't know what the ideal-output should be for the lstm, which would also be the input for the classifier. Is there a way to do this?

Best Answer

The best place to start with LSTMs is the blog post of A. Karpathy http://karpathy.github.io/2015/05/21/rnn-effectiveness/. If you are using Torch7 (which I would strongly suggest) the source code is available at github https://github.com/karpathy/char-rnn.

I would also try to alter your model a bit. I would use a many-to-one approach so that you input words through a lookup table and add a special word at the end of each sequence, so that only when you input the "end of the sequence" sign you will read the classification output and calculate the error based on your training criterion. This way you would train directly under a supervised context.

On the other hand, a simpler approach would be to use paragraph2vec (https://radimrehurek.com/gensim/models/doc2vec.html) to extract features for your input text and then run a classifier on top of your features. Paragraph vector feature extraction is very simple and in python it would be:

class LabeledLineSentence(object):
    def __init__(self, filename):
        self.filename = filename

    def __iter__(self):
        for uid, line in enumerate(open(self.filename)):
            yield LabeledSentence(words=line.split(), labels=['TXT_%s' % uid])

sentences = LabeledLineSentence('your_text.txt')

model = Doc2Vec(alpha=0.025, min_alpha=0.025, size=50, window=5, min_count=5, dm=1, workers=8, sample=1e-5)
model.build_vocab(sentences)

for epoch in range(epochs):
    try:
        model.train(sentences)
    except (KeyboardInterrupt, SystemExit):
        break