首页>>人工智能->图解WeNet求Attention损失(逐行分析)

图解WeNet求Attention损失(逐行分析)

时间:2023-11-29 本站 点击:1

先看今天要分析的代码:

def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:    """Compute loss between x and target.    The model outputs and data labels tensors are flatten to    (batch*seqlen, class) shape and a mask is applied to the    padding part which should not be calculated for loss.    Args:        x (torch.Tensor): prediction (batch, seqlen, class)        target (torch.Tensor):            target signal masked with self.padding_id (batch, seqlen)    Returns:        loss (torch.Tensor) : The KL loss, scalar float value    """    assert x.size(2) == self.size    batch_size = x.size(0)    x = x.view(-1, self.size)    target = target.view(-1)    # use zeros_like instead of torch.no_grad() for true_dist,    # since no_grad() can not be exported by JIT    true_dist = torch.zeros_like(x)    true_dist.fill_(self.smoothing / (self.size - 1))    ignore = target == self.padding_idx  # (B,)    total = len(target) - ignore.sum().item()    target = target.masked_fill(ignore, 0)  # avoid -1 index    true_dist.scatter_(1, target.unsqueeze(1), self.confidence)    kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)    denom = total if self.normalize_length else batch_size    return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom

这段代码乍一看很吓人,其实分析后你会觉得豁然开朗,其实WeNet的代码写的真的非常好,又干净又易懂(得仔细看才会体会到易懂,我初次毛毛略略的看确实感觉很难,坚持,干就完了!)

逐行分析

(以下所有变量维度都是假设,为了方便读者理解,实际运行中不一定是这个值,本文后面不再赘述)

首先看传入的参数,一个是x,一个是target,这个很好理解,就是要求编码器的输出和最终label的差异嘛!x是经过编码器的输出,维度假设是[16, 13, 4233],其中16是batch size,13是解码帧的序列长度,一段语音时间越长,这个值就会越大,因为每段语音长短不一,所以这个值不固定,13是在当前batch中最长的值。当前batch的长度假设是[11,13,12,9,5,2,3,1,4,11,12,11,12,7,6,7],对于不足13的部分进行padding补-1.下图是一段语音的示例图,当然x是16个这么大的二维向量,这里sequence length就是例子中的13,图中黄色二维向量的维度是[13, 4233]

target是标签向量,维度是[16, 13],其中保存了对应词典的索引。

下面开始看详细的代码:

assert x.size(2) == self.size

这是为了确保x第二个维度等于词表的个数,也就是4233(对于Aishell-1数据集来说)

batch_size = x.size(0)

取出batch size的长度,这里默认是16

x = x.view(-1, self.size)

这个是将x的维度从[16, 13, 4233]变成 [1613,4233], 目的当然是为了方便计算,可以理解为一次处理batch size sequence length这么多帧数据

target = target.view(-1)

与上同理

true_dist = torch.zeros_like(x)

这里是要一个和x同维度的label向量,别忘了参数中x的维度是[16, 13, 4233],而target的维度是[16, 13],所以后续的步骤肯定是将target变成独热编码为了x同纬度(这一步只是全部设置为0),如下图:

到这一步,true_dist全都是0,且和x同纬度

true_dist.fill_(self.smoothing / (self.size - 1))

独热编码是将label对应的位置设置为1,其他所有位置都设置为0,如上图target部分所示,而WeNet这里使用了label smoothing,将0替换为了$\frac{smoothing}{size-1}$,其中smoothing是0.1,size-1=4232.简单说一下什么是label smoothing,因为WeNet注解写的太好,我就直接copy了:

Label-smoothing loss.在标准交叉熵损失中, 标签的数据分布为:[0,1,2] ->[    [1.0, 0.0, 0.0],    [0.0, 1.0, 0.0],    [0.0, 0.0, 1.0],]而在smoothing版本的交叉熵损失中,一些概率取自真实标签probb(1.0),并在其他标签之间分配。e.g.smoothing=0.1[0,1,2] ->[    [0.9, 0.05, 0.05],    [0.05, 0.9, 0.05],    [0.05, 0.05, 0.9],]

这里填充的值是$\frac{0.1}{4232}=2.36e-5$这是一个非常小的值,所以下图用$\epsilon$表示$2.36e-5$

ignore = target == self.padding_idx  # (B,)

前面我们说了数据集中的句子长短不一,而我们取了当前batch中序列长度最长的那个值,其他短于这个值的帧用-1进行padding填充,所以这些-1的部分我们要忽略,不能加入求损失的过程。这里padding_idx是-1,判断target是否有值为-1的部分,如果有就忽略掉,将是否忽略保存到ignore变量中,ignore的维度这里是[16*13]

total = len(target) - ignore.sum().item()

这里是求刨开被占位的部分,真正需要计算损失的token总共有多少个,用target的总长度减去需要忽略的个数,这个例子就是146 = 16*13 - 62,也就是当前batch真正参与计算损失的字符也就146个,剩下62个都是被padding过的部分

target = target.masked_fill(ignore, 0)  # avoid -1 index

这行是将忽略掉的地方(-1)全部用0来替换,因为要开始计算损失了,-1会影响结果

true_dist.scatter_(1, target.unsqueeze(1), self.confidence)

这一步是将true_dist中target对应的索引用置信度0.9替换,其中1 = confidence + smoothing

kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)

这一步是计算kl损失,具体我写在了https://juejin.cn/post/7092778296043110413

denom = total if self.normalize_length else batch_size

这一步如果normalize_length为True则按照layer_norm的方式进行正则化,则denom为4233,如果normalize_length为False,则按照batch_norm的方式进行正则化,则denom为16.

return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom

208(16*13)个位置都计算了损失,将208个kl损失进行求和,然后除以denom,就得到了最终的损失


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:/AI/1112.html