详细讲一下PYG 里面的torch_geometric.nn.conv.transformer_conv函数

1.首先先讲一下代码

这是官方给的代码:torch_geometric.nn.conv.transformer_conv — pytorch_geometric documentation

import math
import typing
from typing import Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import Tensor

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import (
    Adj,
    NoneType,
    OptTensor,
    PairTensor,
    SparseTensor,
)
from torch_geometric.utils import softmax

if typing.TYPE_CHECKING:
    from typing import overload
else:
    from torch.jit import _overload_method as overload


[docs]class TransformerConv(MessagePassing):
    r"""The graph transformer operator from the `"Masked Label Prediction:
    Unified Message Passing Model for Semi-Supervised Classification"
    <https://arxiv.org/abs/2009.03509>`_ paper.

    .. math::
        \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i +
        \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \mathbf{x}_{j},

    where the attention coefficients :math:`\alpha_{i,j}` are computed via
    multi-head dot product attention:

    .. math::
        \alpha_{i,j} = \textrm{softmax} \left(
        \frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j)}
        {\sqrt{d}} \right)

    Args:
        in_channels (int or tuple): Size of each input sample, or :obj:`-1` to
            derive the size from the first input(s) to the forward method.
            A tuple corresponds to the sizes of source and target
            dimensionalities.
        out_channels (int): Size of each output sample.
        heads (int, optional): Number of multi-head-attentions.
            (default: :obj:`1`)
        concat (bool, optional): If set to :obj:`False`, the multi-head
            attentions are averaged instead of concatenated.
            (default: :obj:`True`)
        beta (bool, optional): If set, will combine aggregation and
            skip information via

            .. math::
                \mathbf{x}^{\prime}_i = \beta_i \mathbf{W}_1 \mathbf{x}_i +
                (1 - \beta_i) \underbrace{\left(\sum_{j \in \mathcal{N}(i)}
                \alpha_{i,j} \mathbf{W}_2 \vec{x}_j \right)}_{=\mathbf{m}_i}

            with :math:`\beta_i = \textrm{sigmoid}(\mathbf{w}_5^{\top}
            [ \mathbf{W}_1 \mathbf{x}_i, \mathbf{m}_i, \mathbf{W}_1
            \mathbf{x}_i - \mathbf{m}_i ])` (default: :obj:`False`)
        dropout (float, optional): Dropout probability of the normalized
            attention coefficients which exposes each node to a stochastically
            sampled neighborhood during training. (default: :obj:`0`)
        edge_dim (int, optional): Edge feature dimensionality (in case
            there are any). Edge features are added to the keys after
            linear transformation, that is, prior to computing the
            attention dot product. They are also added to final values
            after the same linear transformation. The model is:

            .. math::
                \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i +
                \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \left(
                \mathbf{W}_2 \mathbf{x}_{j} + \mathbf{W}_6 \mathbf{e}_{ij}
                \right),

            where the attention coefficients :math:`\alpha_{i,j}` are now
            computed via:

            .. math::
                \alpha_{i,j} = \textrm{softmax} \left(
                \frac{(\mathbf{W}_3\mathbf{x}_i)^{\top}
                (\mathbf{W}_4\mathbf{x}_j + \mathbf{W}_6 \mathbf{e}_{ij})}
                {\sqrt{d}} \right)

            (default :obj:`None`)
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
        root_weight (bool, optional): If set to :obj:`False`, the layer will
            not add the transformed root node features to the output and the
            option  :attr:`beta` is set to :obj:`False`. (default: :obj:`True`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.MessagePassing`.
    """
    _alpha: OptTensor

    def __init__(
        self,
        in_channels: Union[int, Tuple[int, int]],
        out_channels: int,
        heads: int = 1,
        concat: bool = True,
        beta: bool = False,
        dropout: float = 0.,
        edge_dim: Optional[int] = None,
        bias: bool = True,
        root_weight: bool = True,
        **kwargs,
    ):
        kwargs.setdefault('aggr', 'add')
        super().__init__(node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.beta = beta and root_weight
        self.root_weight = root_weight
        self.concat = concat
        self.dropout = dropout
        self.edge_dim = edge_dim
        self._alpha = None

        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)

        self.lin_key = Linear(in_channels[0], heads * out_channels)
        self.lin_query = Linear(in_channels[1], heads * out_channels)
        self.lin_value = Linear(in_channels[0], heads * out_channels)
        if edge_dim is not None:
            self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False)
        else:
            self.lin_edge = self.register_parameter('lin_edge', None)

        if concat:
            self.lin_skip = Linear(in_channels[1], heads * out_channels,
                                   bias=bias)
            if self.beta:
                self.lin_beta = Linear(3 * heads * out_channels, 1, bias=False)
            else:
                self.lin_beta = self.register_parameter('lin_beta', None)
        else:
            self.lin_skip = Linear(in_channels[1], out_channels, bias=bias)
            if self.beta:
                self.lin_beta = Linear(3 * out_channels, 1, bias=False)
            else:
                self.lin_beta = self.register_parameter('lin_beta', None)

        self.reset_parameters()

[docs]    def reset_parameters(self):
        super().reset_parameters()
        self.lin_key.reset_parameters()
        self.lin_query.reset_parameters()
        self.lin_value.reset_parameters()
        if self.edge_dim:
            self.lin_edge.reset_parameters()
        self.lin_skip.reset_parameters()
        if self.beta:
            self.lin_beta.reset_parameters()

    @overload
    def forward(
        self,
        x: Union[Tensor, PairTensor],
        edge_index: Adj,
        edge_attr: OptTensor = None,
        return_attention_weights: NoneType = None,
    ) -> Tensor:
        pass

    @overload
    def forward(  # noqa: F811
        self,
        x: Union[Tensor, PairTensor],
        edge_index: Tensor,
        edge_attr: OptTensor = None,
        return_attention_weights: bool = None,
    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
        pass

    @overload
    def forward(  # noqa: F811
        self,
        x: Union[Tensor, PairTensor],
        edge_index: SparseTensor,
        edge_attr: OptTensor = None,
        return_attention_weights: bool = None,
    ) -> Tuple[Tensor, SparseTensor]:
        pass

[docs]    def forward(  # noqa: F811
        self,
        x: Union[Tensor, PairTensor],
        edge_index: Adj,
        edge_attr: OptTensor = None,
        return_attention_weights: Optional[bool] = None,
    ) -> Union[
            Tensor,
            Tuple[Tensor, Tuple[Tensor, Tensor]],
            Tuple[Tensor, SparseTensor],
    ]:
        r"""Runs the forward pass of the module.

        Args:
            x (torch.Tensor or (torch.Tensor, torch.Tensor)): The input node
                features.
            edge_index (torch.Tensor or SparseTensor): The edge indices.
            edge_attr (torch.Tensor, optional): The edge features.
                (default: :obj:`None`)
            return_attention_weights (bool, optional): If set to :obj:`True`,
                will additionally return the tuple
                :obj:`(edge_index, attention_weights)`, holding the computed
                attention weights for each edge. (default: :obj:`None`)
        """
        H, C = self.heads, self.out_channels

        if isinstance(x, Tensor):
            x = (x, x)

        query = self.lin_query(x[1]).view(-1, H, C)
        key = self.lin_key(x[0]).view(-1, H, C)
        value = self.lin_value(x[0]).view(-1, H, C)

        # propagate_type: (query: Tensor, key:Tensor, value: Tensor,
        #                  edge_attr: OptTensor)
        out = self.propagate(edge_index, query=query, key=key, value=value,
                             edge_attr=edge_attr)

        alpha = self._alpha
        self._alpha = None

        if self.concat:
            out = out.view(-1, self.heads * self.out_channels)
        else:
            out = out.mean(dim=1)

        if self.root_weight:
            x_r = self.lin_skip(x[1])
            if self.lin_beta is not None:
                beta = self.lin_beta(torch.cat([out, x_r, out - x_r], dim=-1))
                beta = beta.sigmoid()
                out = beta * x_r + (1 - beta) * out
            else:
                out = out + x_r

        if isinstance(return_attention_weights, bool):
            assert alpha is not None
            if isinstance(edge_index, Tensor):
                return out, (edge_index, alpha)
            elif isinstance(edge_index, SparseTensor):
                return out, edge_index.set_value(alpha, layout='coo')
        else:
            return out

    def message(self, query_i: Tensor, key_j: Tensor, value_j: Tensor,
                edge_attr: OptTensor, index: Tensor, ptr: OptTensor,
                size_i: Optional[int]) -> Tensor:

        if self.lin_edge is not None:
            assert edge_attr is not None
            edge_attr = self.lin_edge(edge_attr).view(-1, self.heads,
                                                      self.out_channels)
            key_j = key_j + edge_attr

        alpha = (query_i * key_j).sum(dim=-1) / math.sqrt(self.out_channels)
        alpha = softmax(alpha, index, ptr, size_i)
        self._alpha = alpha
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)

        out = value_j
        if edge_attr is not None:
            out = out + edge_attr

        out = out * alpha.view(-1, self.heads, 1)
        return out

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, heads={self.heads})')

2.详细解释一下

几个重要的参数

in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.

out_channels (int): Size of each output sample.

heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`)

怎么理解这几个参数?

 

  • in_channels 表示每个输入样本的大小。如果设置为整数,则表示所有输入样本的大小相同;如果设置为 -1,则表示输入样本的大小将从 forward 方法的第一个输入中推导出来;如果设置为元组,则表示输入样本的大小对应于源维度和目标维度的大小。

  • out_channels 表示每个输出样本的大小,即经过卷积操作后产生的特征向量的维度大小。

 

当使用 tg.nn.TransformerConv 时,可以通过以下方式理解 in_channelsout_channels

假设我们有一个图数据集,每个节点都有一个 10 维的特征向量表示。那么在这种情况下:

  • 如果我们想将每个节点的特征向量作为输入,然后使用 tg.nn.TransformerConv 进行卷积操作,那么 in_channels 应该设置为 10,表示每个输入样本的大小为 10。

  • 假设我们想将节点的特征向量转换为一个 16 维的特征向量,那么 out_channels 应该设置为 16,表示每个输出样本的大小为 16,即经过卷积操作后每个节点的特征向量将变为 16 维。

  • tg.nn.TransformerConv 中,heads 参数表示多头注意力的数量。举个例子,如果 heads 参数设置为 4,那么模型将学习 4 组注意力权重,每组权重都用于计算输入的不同子空间的注意力,然后将这些头的输出进行合并以产生最终的输出。

 举个整体的例子

我们有一个输入张量 x,它的形状是 (batch_size, seq_length, input_dim),其中:

  • batch_size 表示批量大小;
  • seq_length 表示序列长度;
  • input_dim 表示输入特征的维度。

现在假设我们使用了 tg.nn.TransformerConv,并设置 heads=2,那么模型将学习两组注意力权重,每组用于计算不同的注意力。输出张量的形状将取决于 out_channels 参数,我们假设 out_channels=64

import torch
import torch_geometric.nn as tg

# 假设输入张量的形状是 (batch_size, seq_length, input_dim)
x = torch.randn(32, 10, 128)  # 32 个样本,每个样本有 10 个时间步,每个时间步有 128 个特征

# 创建 TransformerConv 模型,设置 heads=2,out_channels=64
conv_layer = tg.nn.TransformerConv(in_channels=128, out_channels=64, heads=2)

# 使用模型进行前向传播
output = conv_layer(x)

print("输出张量的形状:", output.shape)

 2.1将特征映射到键值对中

在这里,通过线性变换层 Linear,输入特征被转换成了键(key)、查询(query)和数值(value)的表示形式,以便用于多头自注意力机制。

具体来说:

  • self.lin_key 用于将输入特征(in_channels[0])映射到键的表示形式。
  • self.lin_query 用于将输入特征(in_channels[1])映射到查询的表示形式。
  • self.lin_value 用于将输入特征(in_channels[0])映射到数值的表示形式。

 具体地,假设输入特征的维度是 (batch_size, num_nodes, in_channels),其中 batch_size 是批量大小,num_nodes 是节点数,in_channels 是输入特征的通道数。在映射到键的过程中,线性变换层的权重矩阵将是一个维度为 (in_channels, heads * out_channels) 的矩阵,其中 heads 是注意力头的数量,out_channels 是输出特征的通道数。因此,通过矩阵乘法运算,输入特征将被映射到一个新的特征空间,其维度为 (batch_size, num_nodes, heads, out_channels)。在这个新的特征空间中,每个节点的每个头都有一个键表示。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/606590.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

nestjs 全栈进阶--Module和Provider的循环依赖

视频教程 21_nest中的循环依赖_哔哩哔哩_bilibili 1. 循环依赖 当两个类相互依赖时&#xff0c;就会发生循环依赖。比如 A 类需要 B 类&#xff0c;B 类也需要 A 类。Nest中 模块之间和 提供器之间也可能会出现循环依赖。 nest new dependency -p pnpm nest g res aaa --n…

【Java EE】网络原理——UDP

目录 1.应用层 2.传输层 2.1端口号 2.1.1端口号的范围划分 2.1.2一个端口号可以被多个进程绑定吗&#xff1f; 2.1.3一个进程可以绑定多个端口号吗&#xff1f; 3.UDP协议 3.1UDP的格式 3.1.1 UDP的源端口号 3.1.2 UDP的目的端口号 3.1.3 UDP长度 3.1.4UDP校验和 3…

springboot项目中前端页面无法加载怎么办

在springboot前后端分离的项目中&#xff0c;经常会出现前端页面无法加载的情况&#xff08;比如&#xff1a;前端页面为空白页&#xff0c;或者出现404&#xff09;&#xff0c;该怎么办&#xff1f;&#xff1f;&#xff1f; 一个简单有效的方法&#xff1a;&#xff1a; 第…

24 | MySQL是怎么保证主备一致的?

MySQL 主备的基本原理 内部流程 备库 B 跟主库 A 之间维持了一个长连接。主库 A 内部有一个线程,专门用于服务备库 B 的这个长连接。一个事务日志同步的完整过程是这样的: 在备库 B 上通过 change master 命令,设置主库 A 的 IP、端口、用户名、密码,以及要从哪个位置开始…

钉钉群定时发送消息1.0软件【附源码】

内容目录 一、详细介绍二、效果展示1.部分代码2.效果图展示 三、学习资料下载 一、详细介绍 有时候需要在钉钉群里提醒一些消息。要通知的群成员又不方便用定时钉的功能&#xff0c;所以写了这么一个每日定时推送群消息的工具。 易语言程序&#xff0c;附上源码与模块&#x…

【记录42】centos 7.6安装nginx教程详细教程

环境&#xff1a;腾讯云centos7.6 需求&#xff1a;安装nginx-1.24.0 1. 切入home文件 cd home 2. 创建nginx文件 mkdir nginx 3. 切入nginx文件 cd nginx 4. 下载nginx安装包 wget https://nginx.org/download/nginx-1.24.0.tar.gz 5. 解压安装包 tar -zxvf nginx-1.24.0.…

ESD静电问题 | 选型TVS单向还是双向?

【转自微信公众号&#xff1a;Amazing晶炎科技】

Mysql进阶-索引篇

Mysql进阶 存储引擎前言特点对比 索引介绍常见的索引结构索引分类索引语法sql分析索引使用原则索引失效的几种情况sql提示覆盖索引前缀索引索引设计原则 存储引擎 前言 Mysql的体系结构&#xff1a; 连接层 最上层是一些客户端和链接服务&#xff0c;主要完成一些类似于连接…

C语言例题38、有n个人围成一圈,顺序排号。从第一个人开始报数(从1到3报数),凡报到3的人退出圈子,最后留下来的是原来第几号人员?

#include <stdio.h> #define MAX_CALLER 3void main() {int j 0;int p_total;//人数int p_caller 0;//每3人循环计数&#xff1a;1,2,3int p_exit 0; //退出游戏的人数int people[255] {0};//参与游戏人员名单printf("请输入参与游戏人数&#xff1a;");s…

CCF-Csp算法能力认证,202206-1归一化处理(C++)含解析

前言 推荐书目&#xff0c;在这里推荐那一本《算法笔记》&#xff08;胡明&#xff09;&#xff0c;需要PDF的话&#xff0c;链接如下 「链接&#xff1a;https://pan.xunlei.com/s/VNvz4BUFYqnx8kJ4BI4v1ywPA1?pwd6vdq# 提取码&#xff1a;6vdq”复制这段内容后打开手机迅雷…

Macbook pnpm 安装 node-sass 报错(node-gyp)

换了 Macbook M3 Pro 后安装项目依赖时报错&#xff0c;提示 node-sass 安装出错。 &#xff08;此外&#xff0c;ValueError: invalid mode: rU while trying to load binding.gyp 也是类似原因。只需要确保 node-gyp 运行条件就可以&#xff09; 原因是 node-gyp 运行环境缺…

手写SpringBoot核心功能流程

本文通过手写模拟实现一个简易版的Spring Boot 程序&#xff0c;让大家能以非常简单的方式知道Spring Boot大概的工作流程。 工程依赖 创建maven工程&#xff0c;并创建两个module springboot模块&#xff1a;手写模拟springboot框架的源码实现 test模块&#xff1a;业务系统…

提升工作效率,用ONLYOFFICE打造高效团队协作环境

作为一名深耕技术领域已有六七年的开发者&#xff0c;同时又是断断续续进行技术创作将近六年的一个小小作者&#xff0c;我在工作和日常生活中&#xff0c;使用过各色各样的软件。 而在最近几年&#xff0c;一款名为ONLYOFFICE的开源办公套件逐渐走进并融入我的工作与生活&…

使用Vue连接Mqtt实现主题的订阅及消息发布

效果如下&#xff1a; 直接贴代码&#xff0c;本地创建一个html文件将以下内容贴入即可 <!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, …

为什么职场关系越来越冷漠?

不知道从什么时候开始&#xff0c;我们的职场关系变得越来越冷漠了。 早上上班打卡的时候&#xff0c;一个个都低着头&#xff0c;眼神紧紧盯着手机&#xff0c;生怕错过什么重要的信息&#xff1b; 下班后大家一哄而散&#xff0c;各自抱着手机“享受”生活&#xff0c;谁也…

如何添加、编辑、调整WordPress菜单

我们最近在使用WordPress建站建设公司网站。我们是使用的hostease的主机产品建设的WordPress网站。在建设网站使用遇到了一些WordPress菜单使用方面的问题。好在hostease提供了不少帮助。 下面把WordPress菜单使用心得分享一下。 本文将详细介绍WordPress菜单的各种功能&#x…

Total Store Orderand(TSO) the x86 MemoryModel

一种广泛实现的内存一致性模型是总store顺序 (total store order, TSO)。 TSO 最早由 SPARC 引入&#xff0c;更重要的是&#xff0c;它似乎与广泛使用的 x86 架构的内存一致性模型相匹配。RISC-V 还支持 TSO 扩展 RVTSO&#xff0c;部分是为了帮助移植最初为 x86 或 SPARC 架…

1-3ARM_GD32点亮LED灯

简介&#xff1a; 最多可支持 112 个通用 I/O 引脚(GPIO)&#xff0c;分别为 PA0 ~ PA15&#xff0c;PB0 ~ PB15&#xff0c;PC0 ~ PC15&#xff0c;PD0 ~ PD15&#xff0c;PE0 ~ PE15&#xff0c;PF0 ~ PF15 和 PG0 ~ PG15&#xff0c;各片上设备用其来实现逻辑输入/输出功能。…

使用DBeaver连接postgreSql提示缺少驱动

重新安装电脑之后用dbeaver链接数据库的时候&#xff0c;链接PG库一直提示缺少驱动&#xff0c;当选择下载驱动的时候又非常非常慢经常失败&#xff0c;尝试了一下更改源然后下载库驱动就非常快了&#xff0c;当然也包括dbeaver的自动更新。 方法&#xff1a;点击菜单栏【窗口…

霸榜!近期不容错过的3个AI开源项目,来了

在人工智能领域的迅速发展下&#xff0c;各种AI开源项目如雨后春笋般涌现&#xff0c;今天就来为大家介绍近期三个热门的AI开源项目&#xff0c;它们不仅技术前沿&#xff0c;而且非常实用&#xff0c;对于技术爱好者和业界专家来说&#xff0c;绝对不容错过。 一键创作漫画和视…
最新文章