3x3 Median Filter [Edit]
A 3x3 median filter is applied to a Numpy array. Uses Numpy and Matplotlib for display. Creates a random, non-square matrix and filters it 20 times. The before and after matrices are plotted with Matplotlib.
License of this example: |
Public Domain |
Date: |
2010-11-29 |
PyCUDA version: |
0.94.2 |
1 #
2 # 3x3 Median Filter ported to PyCuda by Nick Hilton.
3 #
4
5 from matplotlib import pylab
6 import numpy
7
8 import pycuda.autoinit
9 from pycuda.compiler import SourceModule
10
11 kernel_cu = """
12
13 #define BLOCK_X 16
14 #define BLOCK_Y 16
15
16 #define s2(a,b) { float tmp = a; a = min(a,b); b = max(tmp,b); }
17 #define mn3(a,b,c) s2(a,b); s2(a,c);
18 #define mx3(a,b,c) s2(b,c); s2(a,c);
19
20 #define mnmx3(a,b,c) mx3(a,b,c); s2(a,b); // 3 exchanges
21 #define mnmx4(a,b,c,d) s2(a,b); s2(c,d); s2(a,c); s2(b,d); // 4 exchanges
22 #define mnmx5(a,b,c,d,e) s2(a,b); s2(c,d); mn3(a,c,e); mx3(b,d,e); // 6 exchanges
23 #define mnmx6(a,b,c,d,e,f) s2(a,d); s2(b,e); s2(c,f); mn3(a,b,c); mx3(d,e,f); // 7 exchanges
24
25 #define SMEM(x,y) smem[(x)+1][(y)+1]
26
27 #define IN(x,y) d_in[((y)-1) + ((x)-1) * NY]
28 #define OUT(x,y) d_out[((y)-1) + ((x)-1) * NY]
29
30 //////////////////////////////////////////////////////////////////////////////
31 __global__
32 void
33 medianFilter(
34 float * d_out,
35 const float * d_in,
36 const int NX, // Number of rows
37 const int NY) // Number of cols
38 {
39 const int tx = threadIdx.x;
40 const int ty = threadIdx.y;
41
42 // Guards, at the boundary?
43 bool is_x_top = (tx == 0);
44 bool is_x_bot = (tx == BLOCK_X-1);
45 bool is_y_top = (ty == 0);
46 bool is_y_bot = (ty == BLOCK_Y-1);
47
48 __shared__ float smem[BLOCK_X+2][BLOCK_Y+2];
49
50 // Clear out shared memory (zero padding)
51 if (is_x_top) SMEM(tx-1, ty ) = 0;
52 else if (is_x_bot) SMEM(tx+1, ty ) = 0;
53 if (is_y_top) { SMEM(tx , ty-1) = 0;
54 if (is_x_top) SMEM(tx-1, ty-1) = 0;
55 else if (is_x_bot) SMEM(tx+1, ty-1) = 0;
56 } else if (is_y_bot) { SMEM(tx , ty+1) = 0;
57 if (is_x_top) SMEM(tx-1, ty+1) = 0;
58 else if (is_x_bot) SMEM(tx+1, ty+1) = 0;
59 }
60
61 // x,y are 1 based indicies, the macros IN, OUT subtract 1
62 int x = blockIdx.x * blockDim.x + tx;
63 int y = blockIdx.y * blockDim.y + ty;
64
65 // Guards, at the boundary and still more image to process?
66 is_x_top &= (x > 0);
67 is_x_bot &= (x < NX);
68 is_y_top &= (y > 0);
69 is_y_bot &= (y < NY);
70
71 // Each thread reads the input matrix.
72
73 SMEM(tx , ty ) = IN(x , y ); // self
74 if (is_x_top) SMEM(tx-1, ty ) = IN(x-1, y );
75 else if (is_x_bot) SMEM(tx+1, ty ) = IN(x+1, y );
76 if (is_y_top) { SMEM(tx , ty-1) = IN(x , y-1);
77 if (is_x_top) SMEM(tx-1, ty-1) = IN(x-1, y-1);
78 else if (is_x_bot) SMEM(tx+1, ty-1) = IN(x+1, y-1);
79 } else if (is_y_bot) { SMEM(tx , ty+1) = IN(x , y+1);
80 if (is_x_top) SMEM(tx-1, ty+1) = IN(x-1, y+1);
81 else if (is_x_bot) SMEM(tx+1, ty+1) = IN(x+1, y+1);
82 }
83 __syncthreads();
84
85 // Pull top six values from shared memory
86
87 float v[6] =
88 {
89 SMEM(tx-1, ty-1), // NW (North West neighbor)
90 SMEM(tx , ty-1), // W
91 SMEM(tx+1, ty-1), // SW
92 SMEM(tx-1, ty ), // N
93 SMEM(tx , ty ), // self
94 SMEM(tx+1, ty ) // S
95 };
96
97 // With each pass, remove min and max values and add new value
98 mnmx6(v[0], v[1], v[2], v[3], v[4], v[5]);
99
100 // Replace Max with new value.
101
102 v[5] = SMEM(tx-1, ty+1); // NE
103
104 mnmx5(v[1], v[2], v[3], v[4], v[5]);
105
106 v[5] = SMEM(tx , ty+1); // E
107
108 mnmx4(v[2], v[3], v[4], v[5]);
109
110 v[5] = SMEM(tx+1, ty+1); // SE
111
112 mnmx3(v[3], v[4], v[5]);
113
114 // v[4] now contains the middle value.
115
116 // Guard against indicies out of range.
117 if(x >= 1 && x <= NX && y >= 1 && y <= NY)
118 {
119 OUT(x,y) = v[4];
120 }
121 }
122
123 """
124
125 SIZE_M = 16*2-1
126 SIZE_N = 16*2+1
127
128 gpu = SourceModule(kernel_cu)
129
130 medianFilter = gpu.get_function("medianFilter")
131
132 x = numpy.random.random((SIZE_M,SIZE_N)).astype(numpy.float32)
133
134 pylab.figure()
135 pylab.imshow(x, interpolation = "nearest", cmap = pylab.cm.gray_r)
136 pylab.title("before")
137 pylab.axis("tight")
138
139 y = numpy.zeros((SIZE_M,SIZE_N)).astype(numpy.float32)
140
141 grid_m = int(round(SIZE_M / 16.0 + 0.5))
142 grid_n = int(round(SIZE_N / 16.0 + 0.5))
143
144 print "grid = %dx%d" %(grid_m, grid_n)
145
146 medianFilter(
147 pycuda.driver.InOut(y),
148 pycuda.driver.In(x),
149 numpy.int32(SIZE_M),
150 numpy.int32(SIZE_N),
151 block=(16,16,1),
152 grid=(grid_m,grid_n))
153
154 for i in range(20):
155 x = numpy.array(y)
156
157 medianFilter(
158 pycuda.driver.Out(y),
159 pycuda.driver.In(x),
160 block=(16,16,1),
161 grid=(grid_m,grid_n))
162
163 pylab.figure()
164 pylab.imshow(y, interpolation = "nearest", cmap = pylab.cm.gray_r)
165 pylab.title("after")
166 pylab.axis("tight")
167
168 pylab.show()
169
