import torch
from torch_geometric.nn import MessagePassing

class Net(MessagePassing):
    def __init__(self):
        super(Net,self).__init__(aggr='add')
    
    def forward(self, x, edge_index):
        x = torch.sum(x,dim=1).view(-1,1)
        print('forward',x)
        """
        forward tensor([[ 6.],
                [22.],
                [38.]])
        """
        return self.propagate(edge_index,x=x)

    def message(self, x_i, x_j):
        print('x_i',x_i)
        print('x_j',x_j)
        """
        x_i tensor([[22.],
                [ 6.],
                [38.],
                [22.]])
        x_j tensor([[ 6.],
                [22.],
                [22.],
                [38.]])
        """
        return x_j

    def update(self, inputs):
        print('inputs in update():',inputs)
        """
        inputs in update(): tensor([[22.],
                [44.],
                [22.]])
        """
        return inputs

model = Net()

x = torch.tensor( [[0,1,2,3],[4,5,6,7],[8,9,10,11]], dtype=torch.float32 )
edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)

print('output',model(x, edge_index))
"""
output tensor([[22.],
        [44.],
        [22.]])
"""
