Welcome toVigges Developer Community-Open, Learning,Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
1.0k views
in Technique[技术] by (71.8m points)

pytorch - Is partial 3D convolution or transpose+2D convolution faster?

I have some data of shape B*C*T*H*W. I want to apply 2d convolutions on the H*W dimension.

There are two options (that I see):

  1. Apply partial 3D convolution with shape (1, 3, 3). 3D convolution accepts data with shape B*C*T*H*W which is exactly what I have. This is however a pseudo 3d conv that might be under-optimized (despite its heavy use in P3D networks).

  2. Transpose the data, apply 2D convolutions, and transpose the result back. This requires the overhead of data reshaping, but it makes use of the heavily optimized 2D convolutions.

data = raw_data.transpose(1,2).reshape(b*t, c, h, w).detach()
out = conv2d(data)
out = out.view(b, t, c, h, w).transpose(1, 2).contiguous()

Which one is faster?

(Note: I have a self-answer below. This aims to be a quick note for people who are googling, aka me 20 minutes ago)

question from:https://stackoverflow.com/questions/65930768/is-partial-3d-convolution-or-transpose2d-convolution-faster

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

Environment: PyTorch 1.7.1, CUDA 11.0, RTX 2080 TI.

TL;DR: Transpose + 2D conv is faster (in this environment, and for the tested data shapes).

Code (modified from here):

import torch
import torch.nn as nn
import time

b = 4
c = 64
t = 4
h = 256
w = 256
raw_data = torch.randn(b, c, t, h, w).cuda()

def time2D():
        conv2d = nn.Conv2d(c, c, kernel_size=3, padding=1).cuda()
        torch.cuda.synchronize()
        start = time.time()
        for _ in range(100):
            data = raw_data.transpose(1,2).reshape(b*t, c, h, w).detach()
            out = conv2d(data)
            out = out.view(b, t, c, h, w).transpose(1, 2).contiguous()
            out.mean().backward()
        torch.cuda.synchronize()
        end = time.time()
        print("  --- %s ---  " %(end - start))

def time3D():
        conv3d = nn.Conv3d(c, c, kernel_size=(1,3,3), padding=(0,1,1)).cuda()
        torch.cuda.synchronize()
        start = time.time()
        for _ in range(100):
            out = conv3d(raw_data.detach())
            out.mean().backward()
        torch.cuda.synchronize()
        end = time.time()
        print("  --- %s ---  " %(end - start))

print("Initializing cuda state")
time2D()

print("going to time2D")
time2D()
print("going to time3D")
time3D()

For shape = 4*64*4*256*256:

2D: 1.8675172328948975
3D: 4.384545087814331

For shape = 8*512*16*64*64:

2D: 37.95961904525757
3D: 49.730860471725464

For shape = 4*128*128*16*16:

2D: 0.6455907821655273
3D: 1.8380646705627441

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to Vigges Developer Community for programmer and developer-Open, Learning and Share
...