[Tex/LaTex] Drawing a convolution with Tikz

asymptotedrawtikz-pgf

I need to draw something like this, can anyone tell me how to do this with Tikz or recommend a program to draw diagrams like this one?

A 3x3 convolution with a filter kernel

Best Answer

enter image description here

This is a start with the Asymptote, using a module convolution.asy, which automatically calculates the output matrix

// convolutionTest.asy
// 
// run 
// asy convolutionTest.asy
// to get onvolutionTest.pdf
// 
settings.tex="pdflatex";

import convolution;

real pagew=12cm,pageh=1.618*pagew;
size(pagew,pageh);

import fontsize;defaultpen(fontsize(10pt));
texpreamble("\usepackage{lmodern}"
+"\usepackage{amsmath}"
+"\usepackage{amsfonts}"
+"\usepackage{amssymb}"
);

Matrix I={
  {0,1,1,1,0,0,0},
  {0,0,1,1,1,0,0},
  {0,0,0,1,1,1,0},
  {0,0,0,1,1,0,0},
  {0,0,1,1,0,0,0},
  {0,1,1,0,0,0,0},
  {1,1,0,0,0,0,0},
}; 

Matrix K={
  {1,0,1},
  {0,1,0},
  {1,0,1},
};

picture pic1,pic2;
Convolution conv1=Convolution(pic1,I,K,0,3,"I","K","I*K",cellSize=40bp);
add(pic1);
Convolution conv2=Convolution(pic2,I,K,3,2,"I","K","I*K",cellSize=40bp);
add(shift(0,-350bp)*pic2);

The code of convolution.asy module:

// convolution.asy

typedef real[][] Matrix;

Matrix mblock(Matrix A, int row, int col, int width, int height=width){
  int m=A.length;
  int n=A[0].length;
  assert(row>=0 && row<m);
  assert(col>=0 && col<n);
  assert(height>0 && row+height>=0 && row+height<=m);
  assert(width>0 && col+width>=0 && col+width<=n);
  Matrix Block;
  for(int i=row;i<row+height;++i){
    Block.push(A[i][col:col+width]);
  }
  return Block;
}

real dot(Matrix A, Matrix B){
  real s=0;
  assert(A.length==B.length && A[0].length==B[0].length );
  for(int i=0;i<A.length;++i){
    s+=dot(A[i],B[i]);
  }
  return s;
}

struct ConvolutionSkin{
  pen gridPen1, gridPen2, gridPen3;
  pen fillPen1, fillPen2, fillPen3;
  pen framePen1, framePen2, framePen3;
  pen xPen;
  pen linePen1, linePen2;

  void operator init(
    pen gridPen1=defaultpen, pen gridPen2=defaultpen, pen gridPen3=defaultpen,
    pen fillPen1, pen fillPen2, pen fillPen3,
    pen framePen1, pen framePen2, pen framePen3,
    pen xPen,
    pen linePen1, pen linePen2
  ){
    this.gridPen1=gridPen1;
    this.gridPen2=gridPen2;
    this.gridPen3=gridPen3;    
    this.fillPen1=fillPen1;
    this.fillPen2=fillPen2;
    this.fillPen3=fillPen3;    
    this.framePen1=framePen1;
    this.framePen2=framePen2;
    this.framePen3=framePen3;
    this.xPen=xPen;
    this.linePen1=linePen1;
    this.linePen2=linePen2;
  }
}

ConvolutionSkin defaultSkin
  =ConvolutionSkin(
    lightred, white, lightblue, 
    rgb(0.9,0.7,0.7), rgb(0.9,0.7,0.7), rgb(0.8,1.0,0.7),
    red+0.8bp, blue+0.8bp, deepgreen+0.8bp,
    rgb(0,0,0.9)+fontsize(6pt),
    lightblue+0.7bp+linetype(new real[]{2,2})+linecap(0), 
    deepgreen+0.7bp+linetype(new real[]{2,2})+linecap(0)
  );

struct Convolution{
  picture pic;
  ConvolutionSkin skin; 
  real cellSize;
  Matrix A;
  Matrix B;
  Matrix AxB;
  int el_i, el_j;

  string nameA, nameB, nameAxB; 

  pair posA, posB, posAxB;
  guide Abox, Bbox, AxBbox;

  void calcAxB(){
    int n=A.length;
    int m=B.length;

    for(int i=0;i<=n-m;++i){
      AxB[i]=new real[];
      for(int j=0;j<=n-m;++j){
        AxB[i].push(dot(mblock(A,i,j,m),B));
      }
    }
  }


  void drawMatrix(Matrix M, pair pos, pen p=defaultpen){
    int m=M.length;
    int n=M[0].length;
    for(int i=0;i<=m;++i){
      draw(pic, (pos.x,pos.y+i*cellSize)--(pos.x+n*cellSize,pos.y+i*cellSize), p);
    }
    for(int j=0;j<=n;++j){
      draw(pic, (pos.x+j*cellSize,pos.y)--(pos.x+j*cellSize,pos.y+m*cellSize), p);
    }
    for(int i=0;i<m;++i){
      for(int j=0;j<n;++j){
        label(pic, "$"+string(M[i][j])+"$",pos+cellSize*(j+1/2,m-i-1/2));
      }
    }
  }  

  void drawStarEq(){
    label(pic, "$*$",posB+cellSize*(-1/2,B.length/2));   
    label(pic, "$=$",posB+cellSize*(B[0].length+1/2,B.length/2));   
  }

  void drawLines(guide G, pen p){
    guide u,v;
    u=box(min(G),max(G));
    v=box(min(Bbox),max(Bbox));
    for(int i=0;i<4;++i){
      draw(pic,cut(point(u,i)--point(v,i),Bbox,0).before,p);
    }
  }

  void drawNames(){
    label(pic,"$"+nameA+"$",  posA+  cellSize/2*(A.length,-1));
    label(pic,"$"+nameB+"$",  posB+  cellSize/2*(B.length,-1));
    label(pic,"$"+nameAxB+"$",posAxB+cellSize/2*(AxB.length,-1));
  }

  void drawHelper(){
    pair pos=(min(Abox).x,max(Abox).y)+cellSize*(1,-1);
    for(int i=0;i<B.length;++i){
      for(int j=0;j<B[0].length;++j){
        label(pic,"$\scriptsize{\times "+string(B[i][j])+"}$"
         ,pos+cellSize*(j,-i),plain.NW,skin.xPen);
      }
    }
  }  

  void doit(){
    calcAxB();
    posA=(0,0);
    posB=cellSize*(A[0].length+1, (A.length-B.length)/2);
    posAxB=cellSize*(A[0].length+B[0].length+2, (A.length-AxB.length)/2);
    Abox=box(posA+cellSize*(el_j,A.length-el_i),posA+cellSize*(el_j+B[0].length,A.length-(el_i+B.length)));
    Bbox=box(posB,posB+cellSize*(B[0].length,B.length));
    AxBbox=box(posAxB+cellSize*(el_j,AxB.length-el_i)
      ,posAxB+cellSize*(el_j+1,AxB.length-(el_i+1)));
    fill(pic,Abox,skin.fillPen1);
    fill(pic,Bbox,skin.fillPen2);
    fill(pic,AxBbox,skin.fillPen3);
    drawMatrix(A,  posA,  skin.gridPen1);
    drawMatrix(B,  posB,  skin.gridPen2);   
    drawMatrix(AxB,posAxB,skin.gridPen3);
    draw(pic,Abox,skin.framePen1);
    draw(pic,Bbox,skin.framePen2);
    draw(pic,AxBbox,skin.framePen3);

    drawStarEq();
    drawLines(Abox,  skin.linePen1);
    drawLines(AxBbox,skin.linePen2);
    drawNames();
    drawHelper();
  }


  void operator init(
    picture pic=currentpicture,
    Matrix A, Matrix B,
    int el_i=0, int el_j=0,
    string nameA="A", 
    string nameB="B", 
    string nameAxB="A*B", 
    real cellSize,
    ConvolutionSkin skin=defaultSkin
  ){
    this.pic=pic;  
    this.A=A;  
    this.B=B; 
    this.el_i=el_i;
    this.el_j=el_j; 
    this.nameA=nameA;
    this.nameB=nameB;
    this.nameAxB=nameAxB;
    this.cellSize=cellSize;  
    this.skin=skin;  
    doit();
  } 
} 
Related Question