3x3 Median Filter

A 3x3 median filter is applied to a Numpy array. Uses Numpy and Matplotlib for display. Creates a random, non-square matrix and filters it 20 times. The before and after matrices are plotted with Matplotlib.

License of this example:

Public Domain

Date:

2010-11-29

PyCUDA version:

0.94.2

   1 #
   2 # 3x3 Median Filter ported to PyCuda by Nick Hilton.
   3 #
   4 
   5 from matplotlib import pylab
   6 import numpy
   7 
   8 import pycuda.autoinit
   9 from pycuda.compiler import SourceModule
  10 
  11 kernel_cu = """
  12 
  13 #define BLOCK_X  16
  14 #define BLOCK_Y  16
  15 
  16 #define s2(a,b)            { float tmp = a; a = min(a,b); b = max(tmp,b); }
  17 #define mn3(a,b,c)         s2(a,b); s2(a,c);
  18 #define mx3(a,b,c)         s2(b,c); s2(a,c);
  19 
  20 #define mnmx3(a,b,c)       mx3(a,b,c); s2(a,b);                               // 3 exchanges
  21 #define mnmx4(a,b,c,d)     s2(a,b); s2(c,d); s2(a,c); s2(b,d);                // 4 exchanges
  22 #define mnmx5(a,b,c,d,e)   s2(a,b); s2(c,d); mn3(a,c,e); mx3(b,d,e);          // 6 exchanges
  23 #define mnmx6(a,b,c,d,e,f) s2(a,d); s2(b,e); s2(c,f); mn3(a,b,c); mx3(d,e,f); // 7 exchanges
  24 
  25 #define SMEM(x,y)  smem[(x)+1][(y)+1]
  26 
  27 #define  IN(x,y)    d_in[((y)-1) + ((x)-1) * NY]
  28 #define OUT(x,y)   d_out[((y)-1) + ((x)-1) * NY]
  29 
  30 //////////////////////////////////////////////////////////////////////////////
  31 __global__
  32 void
  33 medianFilter(
  34         float *       d_out,
  35         const float * d_in,
  36         const int     NX,       // Number of rows
  37         const int     NY)       // Number of cols
  38 {
  39     const int tx = threadIdx.x;
  40     const int ty = threadIdx.y;
  41 
  42     // Guards, at the boundary?
  43     bool is_x_top = (tx == 0);
  44     bool is_x_bot = (tx == BLOCK_X-1);
  45     bool is_y_top = (ty == 0);
  46     bool is_y_bot = (ty == BLOCK_Y-1);
  47 
  48     __shared__ float smem[BLOCK_X+2][BLOCK_Y+2];
  49 
  50     // Clear out shared memory (zero padding)
  51     if (is_x_top)           SMEM(tx-1, ty  ) = 0;
  52     else if (is_x_bot)      SMEM(tx+1, ty  ) = 0;
  53     if (is_y_top) {         SMEM(tx  , ty-1) = 0;
  54         if (is_x_top)       SMEM(tx-1, ty-1) = 0;
  55         else if (is_x_bot)  SMEM(tx+1, ty-1) = 0;
  56     } else if (is_y_bot) {  SMEM(tx  , ty+1) = 0;
  57         if (is_x_top)       SMEM(tx-1, ty+1) = 0;
  58         else if (is_x_bot)  SMEM(tx+1, ty+1) = 0;
  59     }
  60 
  61     // x,y are 1 based indicies, the macros IN, OUT subtract 1
  62     int x = blockIdx.x * blockDim.x + tx;
  63     int y = blockIdx.y * blockDim.y + ty;
  64 
  65     // Guards, at the boundary and still more image to process?
  66     is_x_top &= (x > 0);
  67     is_x_bot &= (x < NX);
  68     is_y_top &= (y > 0);
  69     is_y_bot &= (y < NY);
  70 
  71     // Each thread reads the input matrix.
  72 
  73                             SMEM(tx  , ty  ) = IN(x  , y  ); // self
  74     if (is_x_top)           SMEM(tx-1, ty  ) = IN(x-1, y  );
  75     else if (is_x_bot)      SMEM(tx+1, ty  ) = IN(x+1, y  );
  76     if (is_y_top) {         SMEM(tx  , ty-1) = IN(x  , y-1);
  77         if (is_x_top)       SMEM(tx-1, ty-1) = IN(x-1, y-1);
  78         else if (is_x_bot)  SMEM(tx+1, ty-1) = IN(x+1, y-1);
  79     } else if (is_y_bot) {  SMEM(tx  , ty+1) = IN(x  , y+1);
  80         if (is_x_top)       SMEM(tx-1, ty+1) = IN(x-1, y+1);
  81         else if (is_x_bot)  SMEM(tx+1, ty+1) = IN(x+1, y+1);
  82     }
  83     __syncthreads();
  84 
  85     // Pull top six values from shared memory
  86 
  87     float v[6] =
  88     {
  89         SMEM(tx-1, ty-1),    //  NW     (North West neighbor)
  90         SMEM(tx  , ty-1),    //   W
  91         SMEM(tx+1, ty-1),    //  SW
  92         SMEM(tx-1, ty  ),    //  N
  93         SMEM(tx  , ty  ),    //     self
  94         SMEM(tx+1, ty  )     //  S
  95     };
  96 
  97     // With each pass, remove min and max values and add new value
  98     mnmx6(v[0], v[1], v[2], v[3], v[4], v[5]);
  99 
 100     // Replace Max with new value.
 101 
 102     v[5] = SMEM(tx-1, ty+1);    // NE
 103 
 104     mnmx5(v[1], v[2], v[3], v[4], v[5]);
 105 
 106     v[5] = SMEM(tx  , ty+1);    //  E
 107 
 108     mnmx4(v[2], v[3], v[4], v[5]);
 109 
 110     v[5] = SMEM(tx+1, ty+1);    // SE
 111 
 112     mnmx3(v[3], v[4], v[5]);
 113 
 114     // v[4] now contains the middle value.
 115 
 116     // Guard against indicies out of range.
 117     if(x >= 1 && x <= NX && y >= 1 && y <= NY)
 118     {
 119         OUT(x,y) = v[4];
 120     }
 121 }
 122 
 123 """
 124 
 125 SIZE_M = 16*2-1
 126 SIZE_N = 16*2+1
 127 
 128 gpu = SourceModule(kernel_cu)
 129 
 130 medianFilter = gpu.get_function("medianFilter")
 131 
 132 x = numpy.random.random((SIZE_M,SIZE_N)).astype(numpy.float32)
 133 
 134 pylab.figure()
 135 pylab.imshow(x, interpolation = "nearest", cmap = pylab.cm.gray_r)
 136 pylab.title("before")
 137 pylab.axis("tight")
 138 
 139 y = numpy.zeros((SIZE_M,SIZE_N)).astype(numpy.float32)
 140 
 141 grid_m = int(round(SIZE_M / 16.0 + 0.5))
 142 grid_n = int(round(SIZE_N / 16.0 + 0.5))
 143 
 144 print "grid = %dx%d" %(grid_m, grid_n)
 145 
 146 medianFilter(
 147         pycuda.driver.InOut(y),
 148         pycuda.driver.In(x),
 149         numpy.int32(SIZE_M),
 150         numpy.int32(SIZE_N),
 151         block=(16,16,1),
 152         grid=(grid_m,grid_n))
 153 
 154 for i in range(20):
 155         x = numpy.array(y)
 156 
 157         medianFilter(
 158                 pycuda.driver.Out(y),
 159                 pycuda.driver.In(x),
 160                 block=(16,16,1),
 161                 grid=(grid_m,grid_n))
 162 
 163 pylab.figure()
 164 pylab.imshow(y, interpolation = "nearest", cmap = pylab.cm.gray_r)
 165 pylab.title("after")
 166 pylab.axis("tight")
 167 
 168 pylab.show()

MedianFilter (last edited 2010-11-29 17:54:41 by 66-146-167-66)