Convolution

Convolution/Cross-correration function

TODO add cudnn wrapped functions

struct Convolution (
T
size_t imDims
bool isConv = false
bool isNchw = true
) {
int[imDims] stride;
int[imDims] pad;
int[imDims] dilation;
enum int nbDims;
enum int ngroup;
Variable!(T, nbDims, HostStorage) hx;
Variable!(T, nbDims, HostStorage) hw;
cudnnConvolutionFwdAlgo_t forwardAlgo;
cudnnConvolutionBwdDataAlgo_t backwardAlgo;
Variable!(T, nbDims, DeviceStorage) dx;
Variable!(T, nbDims, DeviceStorage) dw;
}

Members

Functions

outShape
auto outShape(uint[nbDims] inShape, uint[nbDims] weightShape)

https://pytorch.org/docs/master/nn.html#convolution-layers

Examples

Conv1d pytorch equality test

1 >>> iota = lambda s: torch.arange(torch.prod(torch.tensor(s))).view(s)
2 >>> torch.nn.functional.conv1d(iota([2, 3, 4]), iota([5, 3, 3]))
3 tensor([[[  258.,   294.],
4          [  663.,   780.],
5          [ 1068.,  1266.],
6          [ 1473.,  1752.],
7          [ 1878.,  2238.]],
8 
9         [[  690.,   726.],
10          [ 2067.,  2184.],
11          [ 3444.,  3642.],
12          [ 4821.,  5100.],
13          [ 6198.,  6558.]]])
14 >>> y.shape
15 [2, 5, 2]
16 
17 >>> x = iota([2, 3, 4])
18 >>> x.requires_grad = True
19 >>> w = iota([5, 3, 3])
20 >>> w.requires_grad = True
21 >>> y = torch.nn.functional.conv1d(x, w)
22 >>> y.backward(torch.ones_like(y))
23 >>> x.grad
24 tensor(
25     [[[  90.,  185.,  195.,  100.],
26       [ 105.,  215.,  225.,  115.],
27       [ 120.,  245.,  255.,  130.]],
28 
29      [[  90.,  185.,  195.,  100.],
30       [ 105.,  215.,  225.,  115.],
31       [ 120.,  245.,  255.,  130.]]])
32 >>> w.grad
33 tensor([[[ 26.,  30.,  34.],
34          [ 42.,  46.,  50.],
35          [ 58.,  62.,  66.]],
36 
37         [[ 26.,  30.,  34.],
38          [ 42.,  46.,  50.],
39          [ 58.,  62.,  66.]],
40 
41         [[ 26.,  30.,  34.],
42          [ 42.,  46.,  50.],
43          [ 58.,  62.,  66.]],
44 
45         [[ 26.,  30.,  34.],
46          [ 42.,  46.,  50.],
47          [ 58.,  62.,  66.]],
48 
49         [[ 26.,  30.,  34.],
50          [ 42.,  46.,  50.],
51          [ 58.,  62.,  66.]]])
1 import std.stdio;
2 import mir.ndslice;
3 import numir;
4 auto x = iota(2, 3, 4).as!float.slice.variable;
5 auto w = iota(5, 3, 3).as!float.slice.variable;
6 Convolution!(float, 1) conv;
7 auto y = conv.forward(x, w);
8 auto yx = [[[  258.,   294.],
9             [  663.,   780.],
10             [ 1068.,  1266.],
11             [ 1473.,  1752.],
12             [ 1878.,  2238.]],
13 
14            [[  690.,   726.],
15             [ 2067.,  2184.],
16             [ 3444.,  3642.],
17             [ 4821.,  5100.],
18             [ 6198.,  6558.]]];
19 assert(y.sliced == yx);
20 
21 // test backward
22 auto gy = y.uninit;
23 gy.data[] = 1;
24 auto gs = conv.backward(gy);
25 auto gx = gs[0];
26 auto gw = gs[1];
27 
28 auto gxx = [[[  90.,  185.,  195.,  100.],
29              [ 105.,  215.,  225.,  115.],
30              [ 120.,  245.,  255.,  130.]],
31 
32             [[  90.,  185.,  195.,  100.],
33              [ 105.,  215.,  225.,  115.],
34              [ 120.,  245.,  255.,  130.]]];
35 assert(gx.sliced == gxx);
36 
37 auto gwx = [[[ 26.,  30.,  34.],
38              [ 42.,  46.,  50.],
39              [ 58.,  62.,  66.]],
40 
41             [[ 26.,  30.,  34.],
42              [ 42.,  46.,  50.],
43              [ 58.,  62.,  66.]],
44 
45             [[ 26.,  30.,  34.],
46              [ 42.,  46.,  50.],
47              [ 58.,  62.,  66.]],
48 
49             [[ 26.,  30.,  34.],
50              [ 42.,  46.,  50.],
51              [ 58.,  62.,  66.]],
52 
53             [[ 26.,  30.,  34.],
54              [ 42.,  46.,  50.],
55              [ 58.,  62.,  66.]]];
56 assert(gw.sliced == gwx);
57 
58 import grain.testing : gradCheck;
59 auto hx = uniform!float(x.shape.castArray!size_t).slice.variable;
60 auto hw = uniform!float(w.shape.castArray!size_t).slice.variable;
61 auto hgy = uniform!float(y.shape.castArray!size_t).slice.variable;
62 auto hy = conv.forward(hx, hw);
63 auto hgx = conv.backward(hgy);
64 gradCheck(conv, tuple(hx, hw), hgy, 1e-3, 1e-3, 1e-2);
65 
66 version (grain_cuda) {
67     auto dy = conv.forward(hx.to!DeviceStorage, hw.to!DeviceStorage);
68     auto dgx = conv.backward(hgy.to!DeviceStorage);
69     assert(approxEqual(dy.to!HostStorage.sliced, hy.sliced));
70     assert(approxEqual(dgx[0].to!HostStorage.sliced, hgx[0].sliced));
71     assert(approxEqual(dgx[1].to!HostStorage.sliced, hgx[1].sliced));
72 }

Conv2d pytorch equality test

1 >>> import torch
2 >>> iota = lambda s: torch.arange(torch.prod(torch.tensor(s))).view(s)
3 >>> x = iota([2, 3, 4, 4])
4 >>> px.requires_grad = True
5 >>> w = iota([2, 3, 3, 3])
6 >>> w.requires_grad = True
7 >>> y = torch.nn.functional.conv2d(x, w)
8 >>> y
9 tensor([[[[ 10197.,  10548.],
10           [ 11601.,  11952.]],
11 
12          [[ 25506.,  26586.],
13           [ 29826.,  30906.]]],
14 
15 
16          [[[ 27045.,  27396.],
17            [ 28449.,  28800.]],
18 
19           [[ 77346.,  78426.],
20            [ 81666.,  82746.]]]])
21 
22 >>> y = torch.nn.functional.conv1d(iota([2, 3, 4]), w)
23 >>> y.backward(torch.ones_like(y))
24 >>> x.grad
25 tensor(
26     [[[[  27.,   56.,   60.,   31.],
27        [  60.,  124.,  132.,   68.],
28        [  72.,  148.,  156.,   80.],
29        [  39.,   80.,   84.,   43.]],
30 
31       [[  45.,   92.,   96.,   49.],
32        [  96.,  196.,  204.,  104.],
33        [ 108.,  220.,  228.,  116.],
34        [  57.,  116.,  120.,   61.]],
35 
36       [[  63.,  128.,  132.,   67.],
37        [ 132.,  268.,  276.,  140.],
38        [ 144.,  292.,  300.,  152.],
39        [  75.,  152.,  156.,   79.]]],
40 
41 
42      [[[  27.,   56.,   60.,   31.],
43        [  60.,  124.,  132.,   68.],
44        [  72.,  148.,  156.,   80.],
45        [  39.,   80.,   84.,   43.]],
46 
47       [[  45.,   92.,   96.,   49.],
48        [  96.,  196.,  204.,  104.],
49        [ 108.,  220.,  228.,  116.],
50        [  57.,  116.,  120.,   61.]],
51 
52       [[  63.,  128.,  132.,   67.],
53        [ 132.,  268.,  276.,  140.],
54        [ 144.,  292.,  300.,  152.],
55        [  75.,  152.,  156.,   79.]]]])
56 >>> w.grad
57 tensor(
58     [[[[ 212.,  220.,  228.],
59        [ 244.,  252.,  260.],
60        [ 276.,  284.,  292.]],
61 
62       [[ 340.,  348.,  356.],
63        [ 372.,  380.,  388.],
64        [ 404.,  412.,  420.]],
65 
66       [[ 468.,  476.,  484.],
67        [ 500.,  508.,  516.],
68        [ 532.,  540.,  548.]]],
69 
70 
71      [[[ 212.,  220.,  228.],
72        [ 244.,  252.,  260.],
73        [ 276.,  284.,  292.]],
74 
75       [[ 340.,  348.,  356.],
76        [ 372.,  380.,  388.],
77        [ 404.,  412.,  420.]],
78 
79       [[ 468.,  476.,  484.],
80        [ 500.,  508.,  516.],
81        [ 532.,  540.,  548.]]]])
1 import std.stdio;
2 import mir.ndslice;
3 import numir;
4 auto x = iota(2, 3, 4, 4).as!float.slice.variable;
5 auto w = iota(2, 3, 3, 3).as!float.slice.variable;
6 Convolution!(float, 2) conv;
7 auto y = conv.forward(x, w);
8 auto yx = [[[[ 10197.,  10548.],
9              [ 11601.,  11952.]],
10             [[ 25506.,  26586.],
11              [ 29826.,  30906.]]],
12            [[[ 27045.,  27396.],
13              [ 28449.,  28800.]],
14             [[ 77346.,  78426.],
15              [ 81666.,  82746.]]]];
16 assert(y.sliced == yx);
17 
18 // test backward
19 auto gy = y.uninit;
20 gy.data[] = 1;
21 auto gs = conv.backward(gy);
22 auto gx = gs[0];
23 auto gw = gs[1];
24 
25 auto gxx = [[[[  27.,   56.,   60.,   31.],
26               [  60.,  124.,  132.,   68.],
27               [  72.,  148.,  156.,   80.],
28               [  39.,   80.,   84.,   43.]],
29 
30              [[  45.,   92.,   96.,   49.],
31               [  96.,  196.,  204.,  104.],
32               [ 108.,  220.,  228.,  116.],
33               [  57.,  116.,  120.,   61.]],
34 
35              [[  63.,  128.,  132.,   67.],
36               [ 132.,  268.,  276.,  140.],
37               [ 144.,  292.,  300.,  152.],
38               [  75.,  152.,  156.,   79.]]],
39 
40 
41             [[[  27.,   56.,   60.,   31.],
42               [  60.,  124.,  132.,   68.],
43               [  72.,  148.,  156.,   80.],
44               [  39.,   80.,   84.,   43.]],
45 
46              [[  45.,   92.,   96.,   49.],
47               [  96.,  196.,  204.,  104.],
48               [ 108.,  220.,  228.,  116.],
49               [  57.,  116.,  120.,   61.]],
50 
51              [[  63.,  128.,  132.,   67.],
52               [ 132.,  268.,  276.,  140.],
53               [ 144.,  292.,  300.,  152.],
54               [  75.,  152.,  156.,   79.]]]];
55 assert(gx.sliced == gxx);
56 
57 auto gwx = [[[[ 212.,  220.,  228.],
58               [ 244.,  252.,  260.],
59               [ 276.,  284.,  292.]],
60              [[ 340.,  348.,  356.],
61               [ 372.,  380.,  388.],
62               [ 404.,  412.,  420.]],
63              [[ 468.,  476.,  484.],
64               [ 500.,  508.,  516.],
65               [ 532.,  540.,  548.]]],
66             [[[ 212.,  220.,  228.],
67               [ 244.,  252.,  260.],
68               [ 276.,  284.,  292.]],
69              [[ 340.,  348.,  356.],
70               [ 372.,  380.,  388.],
71               [ 404.,  412.,  420.]],
72              [[ 468.,  476.,  484.],
73               [ 500.,  508.,  516.],
74               [ 532.,  540.,  548.]]]];
75 assert(gw.sliced == gwx);
76 
77 import grain.testing : gradCheck;
78 auto hx = uniform!float(x.shape.castArray!size_t).slice.variable;
79 auto hw = uniform!float(w.shape.castArray!size_t).slice.variable;
80 auto hgy = uniform!float(y.shape.castArray!size_t).slice.variable;
81 auto hy = conv.forward(hx, hw);
82 auto hgx = conv.backward(hgy);
83 gradCheck(conv, tuple(hx, hw), hgy, 1e-3, 1e-3, 1e-2);
84 
85 version (grain_cuda) {
86     auto dy = conv.forward(hx.to!DeviceStorage, hw.to!DeviceStorage);
87     auto dgx = conv.backward(hgy.to!DeviceStorage);
88     assert(approxEqual(dy.to!HostStorage.sliced, hy.sliced));
89     assert(approxEqual(dgx[0].to!HostStorage.sliced, hgx[0].sliced));
90     assert(approxEqual(dgx[1].to!HostStorage.sliced, hgx[1].sliced));
91 }

Meta