-
Notifications
You must be signed in to change notification settings - Fork 1
/
graph_recurrent_conv.py
executable file
·131 lines (97 loc) · 4.34 KB
/
graph_recurrent_conv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from typing import Union, Tuple, Callable
from torch_geometric.typing import OptTensor, OptPairTensor, Adj, Size
import torch
from torch import Tensor
from torch.nn import Parameter, Tanh, Linear, RNNCell
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import reset, uniform, zeros
class DoubleRNNConv(MessagePassing):
"""
Message passing network for recurrent graph convolutions.
Parameters:
- channels: size of the hidden state embeddings, should be equal to the hidden states defined for the RNN Cell.
- rnn: predefined RNN Cell to operate (could be other: LSTM, GRU etc.).
- aggr: way to combine messages from neighbors to the corresponding single node.
- root_weight: boolean to learn weights for the node features. (theta)
- root_bias: boolean to learn biases for the node features. (theta)
- **kwargs: check MessagePassing class.
"""
def __init__(self, channels: int, rnn: Callable, aggr: str = 'mean',
root_weight: bool = True, bias: bool = True, **kwargs):
super(DoubleRNNConv, self).__init__(aggr=aggr, **kwargs)
self.channels = channels
self.rnn = rnn
self.aggr = aggr
if root_weight:
self.root = Parameter(torch.Tensor(channels, channels))
else:
self.register_parameter('root', None)
if bias:
self.bias = Parameter(torch.Tensor(channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
reset(self.rnn)
if self.root is not None:
uniform(self.root.size(0), self.root)
zeros(self.bias)
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
edge_attr: OptTensor = None, size: Size = None) -> Tensor:
if isinstance(x, Tensor):
x: OptPairTensor = (x, x)
out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)
x_r = x[1]
if x_r is not None and self.root is not None:
out += torch.matmul(x_r, self.root)
if self.bias is not None:
out += self.bias
return out
def message(self, x_i: Tensor, x_j: Tensor, edge_attr: Tensor) -> Tensor:
# Creating pair messages from 2 RNN cells.
h_i = self.rnn(edge_attr,x_i)
h_j = self.rnn(edge_attr,x_j)
next_msg = h_i * h_j
return next_msg
def __repr__(self):
return '{}(In-Out: {})'.format(self.__class__.__name__, self.channels)
class EdgeRNNConv(MessagePassing):
def __init__(self, channels: int, rnn: Callable, aggr: str = 'mean',
root_weight: bool = True, bias: bool = True, **kwargs):
super(EdgeRNNConv, self).__init__(aggr=aggr, **kwargs)
self.channels = channels
self.rnn = rnn
self.nn = Linear(2*channels,channels)
self.aggr = aggr
if root_weight:
self.root = Parameter(torch.Tensor(channels, channels))
else:
self.register_parameter('root', None)
if bias:
self.bias = Parameter(torch.Tensor(channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
reset(self.rnn)
reset(self.nn)
if self.root is not None:
uniform(self.root.size(0), self.root)
zeros(self.bias)
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
edge_attr: OptTensor = None, size: Size = None) -> Tensor:
if isinstance(x, Tensor):
x: OptPairTensor = (x, x)
out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)
x_r = x[1]
if x_r is not None and self.root is not None:
out += torch.matmul(x_r, self.root)
if self.bias is not None:
out += self.bias
return out
def message(self, x_i: Tensor, x_j: Tensor, edge_attr: Tensor) -> Tensor:
# Edge Convolutional RNN Cell
h_i = self.rnn(edge_attr,torch.cat([x_i, x_j - x_i], dim=-1))
return self.nn(h_i)
def __repr__(self):
return '{}(In-Out: {})'.format(self.__class__.__name__, self.channels)