from torch_geometric.data import Data

import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import GCNConv

edge_index = torch.tensor([[0, 1, 1, 2, 1, 3],
                           [1, 0, 2, 1, 3, 1]], dtype=torch.long)

data1 = Data( x=torch.ones(4,1), edge_index=edge_index )
data2 = Data( x=torch.tensor( [[0],[1],[0],[0]], dtype=torch.float32 ), edge_index=edge_index )
#data3 = Data( x=torch.arange(4).unsqueeze(dim=1).float(), edge_index=edge_index )

conv1 = GCNConv(1,2)
conv1.weight.data[0][0] =  1.
conv1.weight.data[0][1] = -2.
print(conv1(data1.x, data1.edge_index))
print(conv1(data2.x, data2.edge_index))
"""
tensor([[ 0.8536, -1.7071],
        [ 1.3107, -2.6213],
        [ 0.8536, -1.7071],
        [ 0.8536, -1.7071]], grad_fn=<AddBackward0>)
tensor([[ 0.3536, -0.7071],
        [ 0.2500, -0.5000],
        [ 0.3536, -0.7071],
        [ 0.3536, -0.7071]], grad_fn=<AddBackward0>)
"""

conv2 = GCNConv(1,2,add_self_loops=False)
conv2.weight.data[0][0] =  1.
conv2.weight.data[0][1] = -2.
print(conv2(data1.x, data1.edge_index))
print(conv2(data2.x, data2.edge_index))
"""
tensor([[ 0.5774, -1.1547],
        [ 1.7321, -3.4641],
        [ 0.5774, -1.1547],
        [ 0.5774, -1.1547]], grad_fn=<AddBackward0>)
tensor([[ 0.5774, -1.1547],
        [ 0.0000,  0.0000],
        [ 0.5774, -1.1547],
        [ 0.5774, -1.1547]], grad_fn=<AddBackward0>)
"""
conv3 = GCNConv(1,2,normalize=False)
conv3.weight.data[0][0] =  1.
conv3.weight.data[0][1] = -2.
print(conv3(data1.x, data1.edge_index))
print(conv3(data2.x, data2.edge_index))
"""
tensor([[ 1., -2.],
        [ 3., -6.],
        [ 1., -2.],
        [ 1., -2.]], grad_fn=<AddBackward0>)
tensor([[ 1., -2.],
        [ 0.,  0.],
        [ 1., -2.],
        [ 1., -2.]], grad_fn=<AddBackward0>)
"""
