145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234 | class flexGCN(nn.Module):
"""
A Graph Neural Network (GNN) model using configurable convolution and activation layers.
This class defines a GNN that can utilize various graph convolution types and activation functions.
It supports a configurable number of convolutional layers with batch normalization and dropout
for regularization. The model aggregates node features into a single vector per graph using
a fully connected layer.
Attributes:
act (torch.nn.Module): Activation function applied after each convolution.
convs (nn.ModuleList): List of convolutional layers.
bns (nn.ModuleList): List of batch normalization layers applied after each convolution.
dropout (nn.Dropout): Dropout layer applied after activation to prevent overfitting.
fc (torch.nn.Linear): Fully connected layer that aggregates node features into a single vector.
Args:
node_count (int): The number of nodes in each graph.
node_feature_count (int): The number of features each node initially has.
node_embedding_dim (int): The size of the node embeddings (output dimension of the convolutions).
output_dim (int): The size of the output vector, which is the final feature vector for the whole graph.
num_convs (int, optional): Number of convolutional layers in the network. Defaults to 2.
dropout_rate (float, optional): The dropout probability used for regularization. Defaults to 0.2.
conv (str, optional): Type of convolution layer to use. Supported types include 'GCN' for Graph Convolution Network,
'GAT' for Graph Attention Network, 'SAGE' for GraphSAGE, and 'GC' for generic Graph Convolution.
Defaults to 'GC'.
act (str, optional): Type of activation function to use. Supported types include 'relu', 'sigmoid',
'leakyrelu', 'tanh', and 'gelu'. Defaults to 'relu'.
Raises:
ValueError: If an unsupported activation function or convolution type is specified.
Example:
>>> model = flexGCN(node_count=100, node_feature_count=5, node_embedding_dim=64, output_dim=10,
num_convs=3, dropout_rate=0.3, conv='GAT', act='relu')
>>> output = model(input_features, edge_index)
# Where `input_features` is a tensor of shape (batch_size, num_nodes, node_feature_count)
# and `edge_index` is a list of edges in the COO format (2, num_edges).
"""
def __init__(self, node_count, node_feature_count, node_embedding_dim, output_dim,
num_convs = 2, dropout_rate = 0.2, conv='GC', act='relu'):
super().__init__()
act_options = {
'relu': nn.ReLU(),
'sigmoid': nn.Sigmoid(),
'leakyrelu': nn.LeakyReLU(),
'tanh': nn.Tanh(),
'gelu': nn.GELU()
}
if act not in act_options:
raise ValueError("Invalid activation function string. Choose from ", list(act_options.keys()))
conv_options = {
'GCN': GCNConv,
'GAT': GATConv,
'SAGE': SAGEConv,
'GC': GraphConv
}
if conv not in conv_options:
raise ValueError('Unknown convolution type. Choose one of: ', list(conv_options.keys()))
self.act = act_options[act]
self.convs = nn.ModuleList()
self.bns = nn.ModuleList()
self.dropout = nn.Dropout(dropout_rate)
# Initialize the first convolution layer separately if different input size
self.convs.append(conv_options[conv](node_feature_count, node_embedding_dim))
self.bns.append(nn.BatchNorm1d(node_embedding_dim))
# Loop to create the remaining convolution and BN layers
for _ in range(1, num_convs):
self.convs.append(conv_options[conv](node_embedding_dim, node_embedding_dim))
self.bns.append(nn.BatchNorm1d(node_embedding_dim))
# Final fully connected layer
self.fc = nn.Linear(node_embedding_dim * node_count, output_dim)
def forward(self, x, edge_index):
for conv, bn in zip(self.convs, self.bns):
x = conv(x, edge_index)
x = bn(x.view(-1, x.size(2))).view_as(x)
x = self.act(x)
x = self.dropout(x)
# Flatten the output of all nodes into a single vector per graph/sample
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
|