[Tex/LaTex] How to draw a diagram of Long Short-Term Memory

diagramsnodespositioningtikz-pgf

I'm new to TikZ and frankly have difficulty learning it.
I've been able to draw some simple shapes, but I want to draw the image below and
all my attempts so far have been really awful and discouraging. I have difficulty getting text into shapes, adjusting the relative positions of objects to each other and stuff seems to constantly move around when I try to add a new part of the picture. Could someone help me out with the image below please?
I would find the following really helpful (in ascending order):

  1. How to get started? Which functions of TikZ could be useful? what draw first?

  2. code skeleton

  3. full image in TikZ

enter image description here

Best Answer

Let's start:

First, define some useful styles to be used in different diagram elements: central cell (ct), other function elements (ft), filters (filter), ...

ct/.style={circle, draw, inner sep=5pt, ultra thick, minimum width=10mm},
ft/.style={circle, draw, minimum width=8mm, inner sep=1pt},
filter/.style={circle, draw, minimum width=7mm, inner sep=1pt, 
      path picture={%
           \draw[thick, rounded corners] 
              (path picture bounding box.center)--++(65:2mm)--++(0:1mm);
           \draw[thick, rounded corners] 
              (path picture bounding box.center)--++(245:2mm)--++(180:1mm);
      }},
...

Second, place different elements where you want. With positioning library you place nodes relative to other ones. At the same time, you can use label to label desired elements.

\node[ct, label={[mylabel]Cell}] (ct) {$c_t$};
\node[filter, right=of ct] (int1) {};
\node[prod, right=of int1] (x1) {$\times$}; 
...

Third, draw connection between nodes. In this case, a foreach loop save some typing.

\foreach \i/\j in {int2/x2, x2/ct, ct/int1, int1/x1,
            x1/ht, it/x2, ct/it, ct/ot, ot/x1, ft/x3}
    \draw[->] (\i)--(\j);

Some connections are drawn individually.

Fourth, a fit node is drawn to encompass the whole diagram.

\node[fit=(int2) (it) (ot) (ft), draw, inner sep=0pt] (fit) {};

Fifth, finish drawing external arrows.

\draw[<-] (fit.west|-int2) coordinate (aux)--++(180:7mm) node[left]{$x_t$};
\draw[<-] ([yshift=1mm]aux)--++(135:7mm);
\draw[<-] ([yshift=-1mm]aux)--++(-135:7mm);
...

That's all. The complete code looks like:

\documentclass[tikz,border=2mm]{standalone} 
\usetikzlibrary{positioning, fit, arrows.meta}

\begin{document}
\begin{tikzpicture}[
    prod/.style={circle, draw, inner sep=0pt},
    ct/.style={circle, draw, inner sep=5pt, ultra thick, minimum width=10mm},
    ft/.style={circle, draw, minimum width=8mm, inner sep=1pt},
    filter/.style={circle, draw, minimum width=7mm, inner sep=1pt, path picture={\draw[thick, rounded corners] (path picture bounding box.center)--++(65:2mm)--++(0:1mm);
    \draw[thick, rounded corners] (path picture bounding box.center)--++(245:2mm)--++(180:1mm);}},
    mylabel/.style={font=\scriptsize\sffamily},
    >=LaTeX
    ]

\node[ct, label={[mylabel]Cell}] (ct) {$c_t$};
\node[filter, right=of ct] (int1) {};
\node[prod, right=of int1] (x1) {$\times$}; 
\node[right=of x1] (ht) {$h_t$};
\node[prod, left=of ct] (x2) {$\times$}; 
\node[filter, left=of x2] (int2) {};
\node[prod, below=5mm of ct] (x3) {$\times$}; 
\node[ft, below=5mm of x3, label={[mylabel]right:Forget Gate}] (ft) {$f_t$};
\node[ft, above=of x2, label={[mylabel]left:Input Gate}] (it) {$i_t$};
\node[ft, above=of x1, label={[mylabel]left:Output Gate}] (ot) {$o_t$};

\foreach \i/\j in {int2/x2, x2/ct, ct/int1, int1/x1,
            x1/ht, it/x2, ct/it, ct/ot, ot/x1, ft/x3}
    \draw[->] (\i)--(\j);

\draw[->] (ct) to[bend right=45] (ft);

\draw[->] (ct) to[bend right=30] (x3);
\draw[->] (x3) to[bend right=30] (ct);

\node[fit=(int2) (it) (ot) (ft), draw, inner sep=0pt] (fit) {};

\draw[<-] (fit.west|-int2) coordinate (aux)--++(180:7mm) node[left]{$x_t$};
\draw[<-] ([yshift=1mm]aux)--++(135:7mm);
\draw[<-] ([yshift=-1mm]aux)--++(-135:7mm);

\draw[<-] (fit.north-|it) coordinate (aux)--++(90:7mm) node[above]{$x_t$};
\draw[<-] ([xshift=1mm]aux)--++(45:7mm);
\draw[<-] ([xshift=-1mm]aux)--++(135:7mm);

\draw[<-] (fit.north-|ot) coordinate (aux)--++(90:7mm) node[above]{$x_t$};
\draw[<-] ([xshift=1mm]aux)--++(45:7mm);
\draw[<-] ([xshift=-1mm]aux)--++(135:7mm);

\draw[<-] (fit.south-|ft) coordinate (aux)--++(-90:7mm) node[below]{$x_t$};
\draw[<-] ([xshift=1mm]aux)--++(-45:7mm);
\draw[<-] ([xshift=-1mm]aux)--++(-135:7mm);
\end{tikzpicture}
\end{document}

And the result:

enter image description here

Related Question