先看今天要分析的代码:
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Compute loss between x and target. The model outputs and data labels tensors are flatten to (batch*seqlen, class) shape and a mask is applied to the padding part which should not be calculated for loss. Args: x (torch.Tensor): prediction (batch, seqlen, class) target (torch.Tensor): target signal masked with self.padding_id (batch, seqlen) Returns: loss (torch.Tensor) : The KL loss, scalar float value """ assert x.size(2) == self.size batch_size = x.size(0) x = x.view(-1, self.size) target = target.view(-1) # use zeros_like instead of torch.no_grad() for true_dist, # since no_grad() can not be exported by JIT true_dist = torch.zeros_like(x) true_dist.fill_(self.smoothing / (self.size - 1)) ignore = target == self.padding_idx # (B,) total = len(target) - ignore.sum().item() target = target.masked_fill(ignore, 0) # avoid -1 index true_dist.scatter_(1, target.unsqueeze(1), self.confidence) kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) denom = total if self.normalize_length else batch_size return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
这段代码乍一看很吓人,其实分析后你会觉得豁然开朗,其实WeNet的代码写的真的非常好,又干净又易懂(得仔细看才会体会到易懂,我初次毛毛略略的看确实感觉很难,坚持,干就完了!)
逐行分析
(以下所有变量维度都是假设,为了方便读者理解,实际运行中不一定是这个值,本文后面不再赘述)
首先看传入的参数,一个是x,一个是target,这个很好理解,就是要求编码器的输出和最终label的差异嘛!x是经过编码器的输出,维度假设是[16, 13, 4233],其中16是batch size,13是解码帧的序列长度,一段语音时间越长,这个值就会越大,因为每段语音长短不一,所以这个值不固定,13是在当前batch中最长的值。当前batch的长度假设是[11,13,12,9,5,2,3,1,4,11,12,11,12,7,6,7],对于不足13的部分进行padding补-1.下图是一段语音的示例图,当然x是16个这么大的二维向量,这里sequence length就是例子中的13,图中黄色二维向量的维度是[13, 4233]
target是标签向量,维度是[16, 13],其中保存了对应词典的索引。
下面开始看详细的代码:
assert x.size(2) == self.size
这是为了确保x第二个维度等于词表的个数,也就是4233(对于Aishell-1数据集来说)
batch_size = x.size(0)
取出batch size的长度,这里默认是16
x = x.view(-1, self.size)
这个是将x的维度从[16, 13, 4233]变成 [1613,4233], 目的当然是为了方便计算,可以理解为一次处理batch size sequence length这么多帧数据
target = target.view(-1)
与上同理
true_dist = torch.zeros_like(x)
这里是要一个和x同维度的label向量,别忘了参数中x的维度是[16, 13, 4233],而target的维度是[16, 13],所以后续的步骤肯定是将target变成独热编码为了x同纬度(这一步只是全部设置为0),如下图:
到这一步,true_dist全都是0,且和x同纬度
true_dist.fill_(self.smoothing / (self.size - 1))
独热编码是将label对应的位置设置为1,其他所有位置都设置为0,如上图target部分所示,而WeNet这里使用了label smoothing,将0替换为了$\frac{smoothing}{size-1}$,其中smoothing是0.1,size-1=4232.简单说一下什么是label smoothing,因为WeNet注解写的太好,我就直接copy了:
Label-smoothing loss.在标准交叉熵损失中, 标签的数据分布为:[0,1,2] ->[ [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0],]而在smoothing版本的交叉熵损失中,一些概率取自真实标签probb(1.0),并在其他标签之间分配。e.g.smoothing=0.1[0,1,2] ->[ [0.9, 0.05, 0.05], [0.05, 0.9, 0.05], [0.05, 0.05, 0.9],]
这里填充的值是$\frac{0.1}{4232}=2.36e-5$这是一个非常小的值,所以下图用$\epsilon$表示$2.36e-5$
ignore = target == self.padding_idx # (B,)
前面我们说了数据集中的句子长短不一,而我们取了当前batch中序列长度最长的那个值,其他短于这个值的帧用-1进行padding填充,所以这些-1的部分我们要忽略,不能加入求损失的过程。这里padding_idx是-1,判断target是否有值为-1的部分,如果有就忽略掉,将是否忽略保存到ignore变量中,ignore的维度这里是[16*13]
total = len(target) - ignore.sum().item()
这里是求刨开被占位的部分,真正需要计算损失的token总共有多少个,用target的总长度减去需要忽略的个数,这个例子就是146 = 16*13 - 62,也就是当前batch真正参与计算损失的字符也就146个,剩下62个都是被padding过的部分
target = target.masked_fill(ignore, 0) # avoid -1 index
这行是将忽略掉的地方(-1)全部用0来替换,因为要开始计算损失了,-1会影响结果
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
这一步是将true_dist中target对应的索引用置信度0.9替换,其中1 = confidence + smoothing
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
这一步是计算kl损失,具体我写在了https://juejin.cn/post/7092778296043110413
denom = total if self.normalize_length else batch_size
这一步如果normalize_length为True则按照layer_norm的方式进行正则化,则denom为4233,如果normalize_length为False,则按照batch_norm的方式进行正则化,则denom为16.
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
208(16*13)个位置都计算了损失,将208个kl损失进行求和,然后除以denom,就得到了最终的损失