Pytorch 搭建 SearchTransfer
SearchTransfer源自论文Learning Texture Transformer Network for Image Super-Resolution的代码
本文记录了复现transformer module中遇到的一些用法
关键函数
torch.nn.functional.unfold
torch.nn.functional.fold
torch.expand
torch.gather
unfold展开方便做blocks间的attention,然后利用得到的相似图计算索引来提取ref_unfold中的信息,最后用fold还原
1. unfold
unfold
用 与nn.Conv2d
相同的滑动窗口 将输入划分为一个个blocks
importtorchimporttorch.nn.functionalasFx=torch.rand((1,3,5,5))x_unfold=F.unfold(x,kernel_size=3,padding=1,stride=1)print(x.shape)#torch.Size([1,3,5,5])print(x_unfold.shape)#torch.Size([1,27,25])
x的形状为(batch,channel,H,W),可以看到x_unfold的shape为(batch,k x k x channel, number_blocks)
k是kernel_size,k x k x channel表示一个blocks中的像素个数
number_blocks是在给定kernel_size, padding,stride的情况下,可以滑出几个block
2. fold
fold的用法与unfold相反,是将一个个blocks还原回(batch,channel,H,W)的样子
k=6s=2p=(k-s)//2H,W=100,100x=torch.rand((1,3,H,W))x_unfold=F.unfold(x,kernel_size=k,stride=s,padding=p)x_fold=F.fold(x_unfold,output_size=(H,W),kernel_size=k,stride=s,padding=p)print(x_unfold.shape)#torch.Size([1,108,2500])print(x_fold.shape)#torch.Size([1,3,10,10])print(x.mean())#tensor(0.5012)print(x_fold.mean())#tensor(4.3924)
可以看到虽然形状是还原了,但x和x_fold的值域发生了变化,这是因为unfold的时候一个位置(1x1xchannel)可以出现在多个blocks中,因此fold的时候会求和这些重叠的位置,导致了数据不一致。因此得出x_fold后还需要除以重叠数才能得出原始数据范围。k=6,s=2时,一个位置会出现在3*3=9个blocks中(窗口上下左右滑动)。
x=torch.rand((1,3,H,W))x_unfold=F.unfold(x,kernel_size=k,stride=s,padding=p)x_fold=F.fold(x_unfold,output_size=(H,W),kernel_size=k,stride=s,padding=p)/(3.*3.)print(x_unfold.shape)print(x_fold.shape)print(x.mean())#tensor(0.4998)print(x_fold.mean())#tensor(0.4866)print((x[:,:,30:40,30:40]==x_fold[:,:,30:40,30:40]).sum())#tensor(189)
由sum()可以看出只有部分数据被还原了。还有一种准确计算divisor(如3. x 3.)的方法是用torch.ones作输入。
k=5s=3p=(k-s)//2H,W=100,100x=torch.rand((1,3,H,W))x_unfold=F.unfold(x,kernel_size=k,stride=s,padding=p)x_fold=F.fold(x_unfold,output_size=(H,W),kernel_size=k,stride=s,padding=p)ones=torch.ones((1,3,H,W))ones_unfold=F.unfold(ones,kernel_size=k,stride=s,padding=p)ones_fold=F.fold(ones_unfold,output_size=(H,W),kernel_size=k,stride=s,padding=p)x_fold=x_fold/ones_foldprint(x.mean())#tensor(0.5001)print(x_fold.mean())#tensor(0.5001)print((x==x_fold).sum())#tensor(30000)每个点都被还原了
3. expand
用法Tensor.expand(*size)
,在size中可以用-1代表保持不变的维度
x=torch.rand((1,4))#x=torch.rand(4)也可以得到同样的结果x_expand1=x.expand((3,4))x_expand2=x.expand((3,-1))print(x)#tensor([[0.1745,0.2331,0.5449,0.1914]])print(x_expand1)#tensor([[0.1745,0.2331,0.5449,0.1914],#[0.1745,0.2331,0.5449,0.1914],#[0.1745,0.2331,0.5449,0.1914]])print(x_expand2)#tensor([[0.1745,0.2331,0.5449,0.1914],#[0.1745,0.2331,0.5449,0.1914],#[0.1745,0.2331,0.5449,0.1914]])
4. gather
用法torch.gather(input, dim, index, *, sparse_grad=False, out=None)
,效果如下
foriinrange(dim0):forjinrange(dim1):forkinrange(dim2):out[i,j,k]=input[index[i][j][k],j,k]#ifdim==0out[i,j,k]=input[i,index[i][j][k],k]#ifdim==1out[i,j,k]=input[i,j,index[i][j][k]]#ifdim==2
使用gather时首先用expand使index的size与input相等。
如index.shape == [B, blocks]
,用expand将index.shape变为[B,c x c x k,blocks],这样index[i, :, k]
是一个1D tensor,且每个元素值都等于expand之前的index[i, j]
如此,当 j 变化时index[i][j][k]
就不会变,故循环中的out[i, j, k] = input[i, j, index[i][j][k]]
就将 out中的第k个block 和 input中的第index[i][j][k]
个block 的每个点一一对应(遍历j)起来。
5. 搭建 Features Transfer
importtorchimporttorch.nnasnnimporttorch.nn.functionalasFclassTransfer(nn.Module):def__init__(self):super(Transfer,self).__init__()defbis(self,unfold,dim,index):"""blockindexselectargs:unfold:[B,k*k*C,Hr*Wr]dim:哪个维度是blocksindex:[B,H*W],valuerangeis[0,Hr*Wr-1]return:[B,k*k*C,H*W]"""views=[unfold.size(0)]+[-1ifi==dimelse1foriinrange(1,len(unfold.size()))]#[B,1,-1(H*W)]expanse=list(unfold.size())expanse[0]=-1expanse[dim]=-1#[-1,k*k*C,-1]index=index.view(views).expand(expanse)#[B,H*W]->[B,1,H*W]->[B,k*k*C,H*W]returntorch.gather(unfold,dim,index)#return[i][j][k]=unfold[i][j][index[i][j][k]]defforward(self,lrsr_lv3,refsr_lv3,ref_lv1,ref_lv2,ref_lv3):"""args:lrsr_lv3:[B,C,H,W]refsr_lv3:[B,C,Hr,Wr]ref_lv1:[B,C,Hr*4,Wr*4]ref_lv2:[B,C,Hr*2,Wr*2]ref_lv3:[B,C,Hr,Wr]"""H,W=lrsr_lv3.size()[-2:]lrsr_lv3_unfold=F.unfold(lrsr_lv3,kernel_size=3,padding=1,stride=1)#[B,k*k*C,H*W]refsr_lv3_unfold=F.unfold(refsr_lv3,kernel_size=3,padding=1,stride=1).transpose(1,2)#[B,Hr*Wr,k*k*C]lrsr_lv3_unfold=F.normalize(lrsr_lv3_unfold,dim=1)refsr_lv3_unfold=F.normalize(refsr_lv3_unfold,dim=2)R=torch.bmm(refsr_lv3_unfold,lrsr_lv3_unfold)#[B,Hr*Wr,H*W]score,index=torch.max(R,dim=1)#[B,H*W]ref_lv3_unfold=F.unfold(ref_lv3,kernel_size=3,padding=1,stride=1)#vgg19ref_lv2_unfold=F.unfold(ref_lv2,kernel_size=6,padding=2,stride=2)#lv1->lv2,lv2->lv3有一次maxpoolingref_lv1_unfold=F.unfold(ref_lv1,kernel_size=12,padding=4,stride=4)#kernel_size没有按照真实的感受野计算#被除数,记录fold(unfold)时的overlapdivisor_lv3=F.unfold(torch.ones_like(ref_lv3),kernel_size=3,padding=1,stride=1)divisor_lv2=F.unfold(torch.ones_like(ref_lv2),kernel_size=6,padding=2,stride=2)divisor_lv1=F.unfold(torch.ones_like(ref_lv1),kernel_size=12,padding=4,stride=4)T_lv3_unfold=self.bis(ref_lv3_unfold,2,index)#[B,k*k*C,H*W]T_lv2_unfold=self.bis(ref_lv2_unfold,2,index)T_lv1_unfold=self.bis(ref_lv1_unfold,2,index)divisor_lv3=self.bis(divisor_lv3,2,index)#[B,k*k*C,H*W]divisor_lv2=self.bis(divisor_lv2,2,index)divisor_lv1=self.bis(divisor_lv1,2,index)divisor_lv3=F.fold(divisor_lv3,(H,W),kernel_size=3,padding=1,stride=1)divisor_lv2=F.fold(divisor_lv2,(2*H,2*W),kernel_size=6,padding=2,stride=2)divisor_lv1=F.fold(divisor_lv1,(4*H,4*W),kernel_size=12,padding=4,stride=4)T_lv3=F.fold(T_lv3_unfold,(H,W),kernel_size=3,padding=1,stride=1)/divisor_lv3T_lv2=F.fold(T_lv2_unfold,(2*H,2*W),kernel_size=6,padding=2,stride=2)/divisor_lv2T_lv1=F.fold(T_lv1_unfold,(4*H,4*W),kernel_size=12,padding=4,stride=4)/divisor_lv1score=score.view(lrsr_lv3.size(0),1,H,W)#[B,1,H,W]returnscore,T_lv1,T_lv2,T_lv3
bis中gather的解释:使用gather时首先用expand使index的size与input相等。
如index.shape == [B, blocks]
,用expand将index.shape
变为[B,c x c x k,blocks],这样index[i, :, k]
是一个1D tensor,且每个元素值都等于expand之前的index[i, j]
如此,当 j 变化时index[i][j][k]
就不会变,故循环中的out[i, j, k] = input[i, j, index[i][j][k]]
就将 out中的第k个block 和 input中的第index[i][j][k]
个block 的每个点一一对应(遍历j)起来。
参考
https://pytorch.org/docs/stable/index.html
https://github.com/researchmm/TTSR