I’m working through the Pytorch-Geometric docs (here).
In the below code, we see
data being passed to the model without
train_mask. However, when passing the output and the label to the loss function,
train_mask is applied to both. Shouldn’t we also be applying the
data when inputting it into the model? As I see it, it shouldn’t be a problem. However, it looks like we are then wasting computation on outputs that are not used to train the model.
model.train() for epoch in range(200): optimizer.zero_grad() out = model(data) loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step()
I think the main reason that in the
Pytorch Geometric examples simply the output of all nodes are computed is a different one to the "no slicing of data issue" raised in the other answer. You need the hidden representation (derived by graph convolutions) of more nodes than the
train_mask contains. Hence, you cannot simply only give the features (respectively the data) for those nodes. But some optimisation is possible, which I will discuss at the end.
I’ll assume you’re setting is node classification (as in the example code and link in your question).
Let’s use a small toy example, which contains five nodes and the following edges:
A<->B B<->C C<->D D<->E
and let assume you use a 2-layer GNN with only the node
A as training. To calculate the GNN’s output of
A, you need the first hidden representation of
B, which uses the input features of
C. Hence, you need the 2-hop neighbourhood of
A to calculate its output.
If you have multiple training nodes (as you usually have) and you have a k-Layered GNN, it usually (and not always see diluted GNN as example) operates on the k-hop neighbourhood. Then, you can calculate the joined set of nodes by combining for each training node the k-hop neighbourhood. Since this is model dependent and requires some code, I’ll guess it was not included in an "introduction by example". Probably, you anyways will only see an effect on larger graphs and only negligible effects for graphs like Cora.
Answered By – Sparky05
Answer Checked By – Katrina (AngularFixing Volunteer)