• 首页 首页 icon
  • 工具库 工具库 icon
    • IP查询 IP查询 icon
  • 内容库 内容库 icon
    • 快讯库 快讯库 icon
    • 精品库 精品库 icon
    • 问答库 问答库 icon
  • 更多 更多 icon
    • 服务条款 服务条款 icon

GAT, Self Attention, Cross Attention对比以和在自动驾驶轨迹预测任务的pytorch应用

武飞扬头像
Cameron Chen
帮助1

Do not blindly trust anything I say, try to make your own judgement.


目录

1.GAT, Self Attention, Cross Attention对比

1.1 Self Attention

1.2 Graph Attention Network(GAT)

1.3 Cross Attention

2. 自动驾驶轨迹预测任务定义

3. 在轨迹预测场景中的pytorch应用

3.1 GAT

3.2 Self Attention

3.3 Cross Attention

4. Reference


1.GAT, Self Attention, Cross Attention对比

1.1 Self Attention

Self Attention在2017年Google机器翻译团队发表的《Attention is All You Need》中被提出,它完全抛弃了RNN和CNN等网络结构,而仅采用新提出的Self Attention机制来处理机器翻译任务,并且取得了很好的效果。

在Encoder-Decoder框架下,广义的attention机制中的输入Source和输出Target内容是不一样的,以英-中机器翻译为例,Source是英文句子,Target是对应的翻译出的中文句子,Attention机制发生在Target的元素和Source中的所有元素之间。此时Query来自Target, Key和Value来自Source

Self Attention顾名思义,指不是Target和Source之间做Attend,而是Source内部元素之间或者Target内部元素之间发生的Attention机制,也可以理解为Target=Source这种特殊情况下的注意力计算机制。此时Query、Key和Value都来自Target或Source

Attention和Self Attention的本质上都是为序列中每个元素都分配一个权重系数,这也可以理解为软寻址。如果序列中每一个元素都以(K,V)形式存储,那么attention则通过计算Q和K的相似度来完成寻址。Q和K计算出来的相似度反映了取出来的V值的重要程度,即权重,然后加权求和就得到了attention值(实际上就是一种learnable weighted sum)。

那么self-attention是如何做到weighted sum的呢,我们可以对计算公式进行分析:

学新通

首先,向量点积可以表征向量之间的相关程度,点积越大越相关。学新通就是自身与自身求了相关性,得到一个所有数值为0-1的mask矩阵,即weights;学新通是query和key的输入的第一维大小,也就是multi-head的每个head上的feature dims,对学新通做scaling可以避免其值特别大的时输入到softmax之后会得到极小的gradient的情况;同时学到合适的V以保留输入的特征,用weights对V做加权求和得到attention values。

Self Attention这样做的优点主要有以下几点:

  • 与RNN相比,对长期依赖关系有着更强的捕捉能力;
  • 与CNN相比,矩阵计算得到了文本序列中任意两个元素的相似度,观察的范围为全局,而CNN只将一个固定的窗口作为感受野;
  • 可以处理非固定长度的输入,只需要做layer normalization来消除长度的影响;
  • multi-head机制可以实现高并行计算。

1.2 Graph Attention Network(GAT)

GAT在2018年论文“Graph Attention Networks”中提出, GAT和Self Attention本质都在做同样的事情,Self Attention利用 attention 机制将输入中的每个单词用其上下文所有单词的加权来表示,而 GAT 是利用 attention 机制对每个节点只用其邻接节点的加权来表示。

学新通

学新通

由公式可以看到,学新通是所有节点共享的weighted matrix,学新通是第i个节点的feature,学新通是attention网络层,求出的是节点j对于节点i的重要程度。这里学新通和Self Attention中的学新通是在做同一件事。

GAT和Self Attention的区别在于,GAT只对邻接节点做attention,Self Attention对全局信息做attention;此外,GAT的图中各节点都有邻接信息来表示位置关系,而Transformer 将文本隐式的建图过程中丢失了单词之间的位置关系,因此为了补偿这种建图损失的位置关系,Transformer也额外用了位置编码来表征位置信息。但Transformer被验证能达到比GAT更好的效果,因此如果追求精度的话比较建议用Transformer来建模交互信息。

1.3 Cross Attention

Cross Attention只对Self Attention的输入做了改动, 它最初使用于Transformer的decoder。Transformer的decoder如下图右侧模块所示,它有三个输入分别在图中标为input1~3。decoder先递归地输入input1:过去时刻decoder的output(第一次的输入为<bos>,表示句子的开始),与表示位置信息的input2:position encoding相加,经过masked multi-head attention后与input3:encoder的output做cross attention。因此cross attention通常作为decoder模块,与Self Attention作为encoder共同使用。

学新通

Cross Attention输入的Query来自encoder(Self Attention)的输出,而Key和Value则来自初始的input,即encoder的输入。意思是将encoder输出作为预测句子学新通的一种表示,然后其来查询与原始输入句子中每个单词的相似性。直觉上来说,Cross Attention做的事情是用key/value的信息表示query的信息,或者说将query condition在key/value条件上,也可以说将key/value信息引入到query信息中(因为有residual层会与原query信息相加),求得的是query对于key的相关性(query attending to key, e.g., vehicle attending to lanes, vice versa)。

Cross Attention在decoder中应用非常多,一般在encoder使用了Self Attention后,在decoder中先用Cross Attention网络层获取attention value,然后接一个MLP层或LSTM层预测目标如句子或车辆轨迹,比直接用MLP或LSTM作为decoder的效果会好很多。Cross Attention还可以将先前任意一层网络层的信息再次引入,类似于residual的功能,但更加灵活。

除此之外,Cross Attention的Query和Key也可以来自两个不同模态的输入,例如一个是图像,一个是对应的文本,用来求两者的相关性,即图像-文本任务,这也是该模块设计出来的初衷之一。


2. 自动驾驶轨迹预测任务定义

在轨迹预测场景中,考虑一个目标车辆ego_car和M个(如M=8)邻近车辆neighbor_car,这M 1辆车就在整个场景的局部范围内形成了一个图,每辆邻近车辆都与目标车辆存在双向的弧,目标车辆与自身也有一个自环弧。每辆车对应的结点包含该车辆的状态信息如历史轨迹、历史速度、目标类型(Vehicle,bicycle, pedestrian)、与目标车辆的方位关系等等。

这里借用论文“Graph and Recurrent Neural Network-based Vehicle Trajectory Prediction For Highway Driving ”中的图表示该场景,如下图所示。 图中位于中心的0号车是目标车辆,其他1~8号车是8个不同方向上的邻近车辆。图中包含的弧为学新通,每条弧都表示一个相关性权重,做完attention之后只需要取出目标车辆的那一份attention value,以表示目标车辆与其他车辆交互的信息。

学新通

在该场景中,模型以目标车辆的历史轨迹信息邻近车辆的历史轨迹信息目标车辆附近的车道信息这三者作为输入,然后选用合适的网络分别提取目标车自身信息车与车交互信息车与路交互信息,输出未来轨迹,最后根据任务指标设置对应的loss function(通常为多种loss相加)来进行梯度下降。Self Attention或GAT通常是为了计算目标车辆与邻近车辆或与车道信息,亦或是两者都考虑在内的交互信息,输入的数据是目标车辆历史轨迹的信息、邻近车辆历史轨迹以及车道信息;Cross Attention通常为了计算Encoder(如Self Attention)的输出与三个输入之间的相关性(更为常用),也可以是将车辆信息与车道信息之间计算相关性(可作为Encoder中的一部分,也很常用)称为side information。具体设计根据所需功能而定。


3. 在轨迹预测场景中的pytorch应用

3.1 GAT

GAT本身就是用来处理图中多个节点之间的交互关系的模块,每条弧就对应的表征相关程度的权值。与其他attention机制相同,它包含两个步骤:计算注意力系数(attention coefficients),加权求和(aggregation)。GAT对应的pytorch模块是torch_geometric.nn库中的GATconv,不过torch_geometric库的配置有点复杂,需要对CUDA,pytorch的版本要求。

1)GATconv使用方法

GATConv(in_size, out_size, heads_num, concat_flag, dropout=0.0)

该算子只在输入数据的最后一维进行计算。若concat_flag=True,最终输出的维度为out_size*heads_num;若concat_flag=False,则采用求average来处理每个head的输出结果,输出维度为out_size。dropout是对求得的结果做dropout操作。一般In_size对应input_size,out_size对应encoder_size。

GATConv需要输入edge_index用来表示有向弧的头和尾,即节点邻接关系,在该场景下是一个2*18的矩阵。而GATConv每次只能处理一个图(即batch中的一条数据),即输入只有两个维度,如果需要输入batchsize个图,就要先将这batchsize个图的数据合并,即将三维减少成二维的tensor,此时edge_index如果仍然是取值0~8就取的不是对应图内的0~8了,而只对应了batch中第一个图。所以GAT的使用需要加个循环,把batch中每条数据取出来逐个输入,再将最后每个输出结果拼接。

edge_index的具体取值如下,从左往右前八列对应 学新通 ,后八列对应学新通。这里假设batch中每条数据的edge_index矩阵都相同。

       edge_index= [[0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # in_edges_idx

                            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8]]  # out_edges_idx

target_idx 通常取0,表示输入的node_matrix中,索引0对应的数据为target车的历史轨迹,其他索引对应的是邻近车辆的轨迹。

node_matrix是一个[batch, cars_num, history_length, state_length]的tensor。state_length可以包括每时刻xy坐标、velocity、heading等。

以下代码只是展示使用所需的步骤,不代表一个完整的类。

  1.  
    from torch_geometric.nn import GATConv
  2.  
     
  3.  
    class GAT_Encoder(nn.Module):
  4.  
    def __init__(self, args):
  5.  
    self.gat_conv1 = GATConv(self.args.num_gat_heads * self.args.encoder_size, self.args.encoder_size, heads=self.args.num_gat_heads, concat=True, dropout=0.0)
  6.  
    self.gat_conv2 = GATConv(self.args.num_gat_heads * self.args.encoder_size, self.args.encoder_size, heads=self.args.num_gat_heads, concat=True, dropout=0.0)
  7.  
    self.fc = torch.nn.Linear(self.args.num_gat_heads * self.args.encoder_size, self.args.encoder_size)
  8.  
     
  9.  
    def GAT_Interation(self, node_matrix, edge_idx, target_idx):
  10.  
    # 以下写法为了表示输入数据经过两层gat网络处理。
  11.  
    # target_idx通常取0,表示输入的node_matrix中,索引0对应的数据为target车的历史
  12.  
    # 轨迹,其他索引对应的是邻近车辆的轨迹。
  13.  
     
  14.  
    """以下依次遍历batchsize个子图的写法并不推荐,训练速度会非常慢"""
  15.  
    # batch_size = node_matrix.shape[0]
  16.  
    # x = []
  17.  
    # for i in range(batch_size):
  18.  
    # x_tmp = self.gat_conv1(node_matrix[i], edge_idx)
  19.  
    # x.append(x_tmp.unsqueeze(0))
  20.  
    # gat_feature = torch.cat(x, dim=0)
  21.  
     
  22.  
     
  23.  
    """更常用的做法是将一个batch的图合并为一个大图做训练,可以大大加快训练速度"""
  24.  
    def _flat_edges(edges, target_idx, batch_size):
  25.  
    edge_num = edges.shape[-1]
  26.  
    offset = target_idx.reshape(-1, 1)
  27.  
    edges = edges.reshape(batch_size, 2 * edge_num) offset.repeat(2 * edge_num, axis=-1)
  28.  
    flattened_edges = edges.reshape(batch_size, 2, edge_num)
  29.  
    flattened_edges = np.hstack([b_edges for b_edges in flattened_edges])
  30.  
     
  31.  
    flattened_target_idx = np.hstack([b_target for b_target in target_idx])
  32.  
    return flattened_edges, flattened_target_idx
  33.  
     
  34.  
    #这里默认一个batchsize中每个sample的n_nodes数相同
  35.  
    batchsize, n_nodes = node_matrix.shape[0], node_matrix.shape[1]
  36.  
     
  37.  
    # 将batchsize个子图合并成一个图
  38.  
    node_matrix = node_matrix.reshape(batchsize * n_nodes, -1)
  39.  
    # 将edge_idx和target_idx展平
  40.  
    flattened_edge_idx, flattened_target_idx = _flat_edges(edge_idx, target_idx, batch_size)
  41.  
     
  42.  
    node_matrix = torch.tensor(node_matrix)
  43.  
    flattened_edge_idx = torch.tensor(flattened_edge_idx)
  44.  
    flattened_target_idx = torch.tensor(flattened_target_idx)
  45.  
     
  46.  
    x = self.gat_conv1(node_matrix, flattened_edge_idx)
  47.  
    x = self.gat_conv2(x, flattened_edge_idx)
  48.  
     
  49.  
    target_gat_feature = gat_feature[flattened_target_idx]
  50.  
    GAT_Enc = self.leaky_relu(self.fc(target_gat_feature))
  51.  
    return GAT_Enc
学新通

2)参考Transformer对GAT做改进:在论文“Is Graph Structure Necessary for Multi-hop Question Answering?" 中作者提出,GAT的模型结构可以完全用Transformers来替代,那我们也可以尝试引入residual 和 layer norm等操作将GAT改进成Transformer的Encoder Module,并且可以实现多个Module叠加。实际测试确实会有效果提升,residual可以加强对target自身的信息,layer norm可以消除输入长度的影响,对所有attention value做归一化,能明显加快收敛速度,但对于精度上一般只会有较小提升。

学新通

 Tranformer的Encoder Module有两种模式如上图所示,这里采用POST_LN 模式,并用GAT layer替换attention layer。

  1.  
    class GAT_Layer(nn.Module):
  2.  
    def __init__(self, args):
  3.  
    super(TransGAT_Layer, self).__init__()
  4.  
    self.args = args
  5.  
    self.gat_conv = GATv2Conv(self.args.num_gat_heads * self.args.encoder_size, self.args.encoder_size, heads=self.args.num_gat_heads, concat=True, dropout=0.0) # (96,3*32,heads=3)
  6.  
    self.dropout1 = torch.nn.Dropout(0.1)
  7.  
    self.dropout2 = torch.nn.Dropout(0.1)
  8.  
    self.dropout3 = torch.nn.Dropout(0.1)
  9.  
    layer_norm_eps = 1e-5
  10.  
    self.norm1 = torch.nn.LayerNorm(self.args.num_gat_heads * self.args.encoder_size, eps=layer_norm_eps)
  11.  
    self.norm2 = torch.nn.LayerNorm(self.args.num_gat_heads * self.args.encoder_size, eps=layer_norm_eps)
  12.  
    self.leaky_relu = torch.nn.LeakyReLU(0.3)
  13.  
    self.linear1 = torch.nn.Linear(self.args.num_gat_heads * self.args.encoder_size, self.args.num_gat_heads * self.args.encoder_size)
  14.  
    self.linear2 = torch.nn.Linear(self.args.num_gat_heads * self.args.encoder_size, self.args.num_gat_heads * self.args.encoder_size)
  15.  
     
  16.  
    def sa_block(self, gat_feature, edge_idx):
  17.  
    batch_size = gat_feature.shape[0]
  18.  
    x = []
  19.  
    for i in range(batch_size):
  20.  
    x_tmp = self.gat_conv(gat_feature[i], edge_idx)
  21.  
    x.append(x_tmp.unsqueeze(0))
  22.  
    x = torch.cat(x, dim=0)
  23.  
    return self.dropout1(x)
  24.  
     
  25.  
    def ff_block(self, x):
  26.  
    x = self.linear2(self.dropout2(self.leaky_relu(self.linear1(x))))
  27.  
    return self.dropout3(x)
  28.  
     
  29.  
    def forward(self, x, edge_idx):
  30.  
    # POST_LN Transformer
  31.  
    x = self.norm1(x self.sa_block(x, edge_idx))
  32.  
    x = self.norm2(x self.ff_block(x))
  33.  
    return x
  34.  
     
  35.  
    class GAT_Encoder(nn.Module):
  36.  
    def __init__(self, args):
  37.  
    # initialize layer
  38.  
    self.MultiGAT = nn.ModuleList([GAT_Layer(args) for _ in range(num_layers)])
  39.  
     
  40.  
    def GAT_Interaction(self, node_matrix, edge_idx, target_idx):
  41.  
    # target_idx 通常取0,表示输入的node_matrix中,索引0对应的数据为target车的历史
  42.  
    # 轨迹,其他索引对应的是邻近车辆的轨迹。
  43.  
    x = node_matrix
  44.  
    # repeat x for 3 times for doing residual. 3 equals to the number of heads.
  45.  
    x = x.repeat(1, 3)
  46.  
    for layer in self.MultiGAT:
  47.  
    x = layer(x, edge_idx)
  48.  
    target_gat_feature = x[target_idx]
  49.  
    GAT_Enc = self.leaky_relu(self.fc(target_gat_feature))
  50.  
    return GAT_Enc
学新通

3)GATconv也可以直接替换为GATv2Conv,其他参数和原来一样,在论文  “How Attentive are Graph Attention Networks?” 中被证明可以提高性能。

The GATv2 operator from the “How Attentive are Graph Attention Networks?” paper, which fixes the static attention problem of the standard GATConv layer: since the linear layers in the standard GAT are applied right after each other, the ranking of attended nodes is unconditioned on the query node.

3.2 Self Attention

在轨迹预测任务中,我们不需要像Transformer那样依照seq-to-seq的思路recursive地输出单词,因此我们只需要用到Transformer的encoder,也就是self-attention同时提取信息又输出轨迹。这就和BERT的思想比较类似,因此可以把轨迹预测看成一个MLM(Masked-Language Modeling) 任务,Scene Transformer这篇论文就是从这个角度建模的。

如果用Self Attention对多车交互进行建模,那么每个车辆结点相当于一个单词,没有位置信息,对每个节点都用其他所有节点算出一个attention值,最后再取出属于目标车辆的那一份attention值作为多车交互中提取出的特征信息,因此Self Attention又可以称为Global Graph Attention;Self Attention也可以将每个车的轨迹当作一个句子,做时序上的信息提取(encoder)和预测(decoder)。

Self Attention可以处理不同长度的输入,一般将每个batch内的数据都padding到该batch中数据的最大长度,然后输入到attention网络后传入layer norm层做处理,对attention值做标准化,消除长度的影响,并且通过attention mask告诉网络每条数据的实际长度,这样只需要每个batch内部的维度大小保持一致就能输入不同长度的数据了。该处理由merge_tensors函数和attention_mask的定义实现。

使用方法如下, 输出维度:[batch_size, max_poly_num, hidden_size]。

  1.  
    class GlobalGraph_Encoder(nn.Module):
  2.  
    def __init__(self, args, device):
  3.  
    self.global_graph = GlobalGraph(args.hidden_size)
  4.  
     
  5.  
    def forward(self, x):
  6.  
    # 将每条数据按照当前batch中的最大长度做padding,并用inputs_length记录每条数据实际长度
  7.  
    inputs, inputs_lengths = merge_tensors(x, device, args.hidden_size)
  8.  
     
  9.  
    # 用attentin_mask来让网络知道每条数据实际长度
  10.  
    max_poly_num = max(inputs_lengths)
  11.  
    attention_mask = torch.zeros([batch_size, max_poly_num, max_poly_num], device=device)
  12.  
    for i, length in enumerate(inputs_lengths):
  13.  
    attention_mask[i][:length][:length].fill_(1)
  14.  
     
  15.  
    hidden_states = self.global_graph(inputs, attention_mask)
  16.  
     
  17.  
    return hidden_states
  18.  
     
  19.  
    def merge_tensors(tensors: List[torch.Tensor], device, hidden_size):
  20.  
    lengths = []
  21.  
    for tensor in tensors:
  22.  
    lengths.append(tensor.shape[0] if tensor is not None else 0)
  23.  
    max_length = max(lengths)
  24.  
    res = torch.zeros([len(tensors), max_length, hidden_size], device=device)
  25.  
    for i, tensor in enumerate(tensors):
  26.  
    if tensor is not None:
  27.  
    res[i][:tensor.shape[0]] = tensor
  28.  
    return res, lengths
学新通

Global Graph(self attention):

  1.  
    class GlobalGraph(nn.Module):
  2.  
    def __init__(self, hidden_size, attention_head_size=None, num_attention_heads=1):
  3.  
    super(GlobalGraph, self).__init__()
  4.  
    self.num_attention_heads = num_attention_heads
  5.  
    self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size
  6.  
    self.all_head_size = self.num_attention_heads * self.attention_head_size
  7.  
     
  8.  
    self.num_qkv = 1
  9.  
     
  10.  
    self.query = nn.Linear(hidden_size, self.all_head_size * self.num_qkv)
  11.  
    self.key = nn.Linear(hidden_size, self.all_head_size * self.num_qkv)
  12.  
    self.value = nn.Linear(hidden_size, self.all_head_size * self.num_qkv)
  13.  
     
  14.  
    def get_extended_attention_mask(self, attention_mask):
  15.  
    """
  16.  
    1 in attention_mask stands for doing attention, 0 for not doing attention.
  17.  
    After this function, 1 turns to 0, 0 turns to -10000.0
  18.  
    Because the -10000.0 will be fed into softmax and -10000.0 can be thought as 0 in softmax.
  19.  
    """
  20.  
    extended_attention_mask = attention_mask.unsqueeze(1)
  21.  
    extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
  22.  
    return extended_attention_mask
  23.  
     
  24.  
    def transpose_for_scores(self, x):
  25.  
    sz = x.size()[:-1] (self.num_attention_heads,
  26.  
    self.attention_head_size)
  27.  
    # (batch, max_vector_num, head, head_size)
  28.  
    x = x.view(*sz)
  29.  
    # (batch, head, max_vector_num, head_size)
  30.  
    return x.permute(0, 2, 1, 3)
  31.  
     
  32.  
    def forward(self, hidden_states, attention_mask=None, return_scores=False):
  33.  
    mixed_query_layer = self.query(hidden_states)
  34.  
    mixed_key_layer = nn.functional.linear(hidden_states, self.key.weight)
  35.  
    mixed_value_layer = self.value(hidden_states)
  36.  
     
  37.  
    query_layer = self.transpose_for_scores(mixed_query_layer)
  38.  
    key_layer = self.transpose_for_scores(mixed_key_layer)
  39.  
    value_layer = self.transpose_for_scores(mixed_value_layer)
  40.  
     
  41.  
    attention_scores = torch.matmul(
  42.  
    query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2))
  43.  
    if attention_mask is not None:
  44.  
    attention_scores = attention_scores self.get_extended_attention_mask(attention_mask)
  45.  
    attention_probs = nn.Softmax(dim=-1)(attention_scores)
  46.  
    context_layer = torch.matmul(attention_probs, value_layer)
  47.  
    context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  48.  
    new_context_layer_shape = context_layer.size()[
  49.  
    :-2] (self.all_head_size,)
  50.  
    context_layer = context_layer.view(*new_context_layer_shape)
  51.  
    if return_scores:
  52.  
    attention_probs = torch.squeeze(attention_probs, dim=1)
  53.  
    return context_layer, attention_probs
  54.  
    return context_layer
学新通

3.3 Cross Attention

与Self Attention相比只换了函数的输入,forward输入参数中区分了query和key的对象,而在Self Attention中这两者是同一个对象。Cross Attention的query来自Encoder的输出,Key和Value的对象来自Encoder的输入,尤其两者维度可能不同,所以当维度不同时需要额外指定query_hidden_sze和key_hidden_size。

使用方法:

  1.  
    self.cross_attention = CrossAttention(args.hidden_size)
  2.  
     
  3.  
    hidden_attention = self.cross_attention(hidden_state, inputs, attention_mask)

 Cross Attention网络:

  1.  
    class CrossAttention(GlobalGraph):
  2.  
    def __init__(self, hidden_size, attention_head_size=None, num_attention_heads=1, key_hidden_size=None, query_hidden_size=None):
  3.  
    super(CrossAttention, self).__init__(hidden_size, attention_head_size, num_attention_heads)
  4.  
    if query_hidden_size is not None:
  5.  
    self.query = nn.Linear(query_hidden_size, self.all_head_size * self.num_qkv)
  6.  
    if key_hidden_size is not None:
  7.  
    self.key = nn.Linear(key_hidden_size, self.all_head_size * self.num_qkv)
  8.  
    self.value = nn.Linear(key_hidden_size, self.all_head_size * self.num_qkv)
  9.  
     
  10.  
    def forward(self, hidden_states_query, hidden_states_key=None, attention_mask=None, return_scores=False):
  11.  
    mixed_query_layer = self.query(hidden_states_query)
  12.  
    mixed_key_layer = self.key(hidden_states_key)
  13.  
    mixed_value_layer = self.value(hidden_states_key)
  14.  
     
  15.  
    query_layer = self.transpose_for_scores(mixed_query_layer)
  16.  
    key_layer = self.transpose_for_scores(mixed_key_layer)
  17.  
    value_layer = self.transpose_for_scores(mixed_value_layer)
  18.  
     
  19.  
    attention_scores = torch.matmul(
  20.  
    query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2))
  21.  
    if attention_mask is not None:
  22.  
    assert hidden_states_query.shape[1] == attention_mask.shape[1] \
  23.  
    and hidden_states_key.shape[1] == attention_mask.shape[2]
  24.  
    attention_scores = attention_scores self.get_extended_attention_mask(attention_mask)
  25.  
    attention_probs = nn.Softmax(dim=-1)(attention_scores)
  26.  
    context_layer = torch.matmul(attention_probs, value_layer)
  27.  
    context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  28.  
    new_context_layer_shape = context_layer.size()[
  29.  
    :-2] (self.all_head_size,)
  30.  
    context_layer = context_layer.view(*new_context_layer_shape)
  31.  
    if return_scores:
  32.  
    return context_layer, torch.squeeze(attention_probs, dim=1)
  33.  
    return context_layer
学新通

4. Reference

"Attention Is All You Need"

"Graph Attention Networks"

Self Attention 自注意力机制 - 云 社区 - 腾讯云

简析Transformer和GAT在自注意力运用上的相似性 - 知乎

深入理解图注意力机制(Graph Attention Network) – 闪念基因 – 个人技术分享

详解Self-Attention和Multi-Head Attention - 张浩在路上

如何理解attention中的Q,K,V? - 知乎

Attention and the Transformer · Deep Learning

2020 On Layer Normalization in the Transformer Architecture

2020 ECCV "Learning Lane Graph Representations for Motion Forecasting"

这篇好文章是转载于:学新通技术网

  • 版权申明: 本站部分内容来自互联网,仅供学习及演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,请提供相关证据及您的身份证明,我们将在收到邮件后48小时内删除。
  • 本站站名: 学新通技术网
  • 本文地址: /boutique/detail/tanhgabiic
系列文章
更多 icon
同类精品
更多 icon
继续加载