import warnings
warnings.filterwarnings("ignore", category=UserWarning) 

from torch_geometric.datasets import Planetoid
from torch_geometric.data import DataLoader, Data
from torch_geometric.utils import accuracy, train_test_split_edges, negative_sampling

import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import GCNConv

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 64)
        self.conv2 = GCNConv(64, 64)
        self.linear1 = Linear(128, 64)
        self.linear2 = Linear(64, 1)

    def forward(self, x, edge_index, node_pair ):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)

        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)

        x = torch.cat( [x[node_pair[0]],x[node_pair[1]]], dim=1 )

        x = self.linear1(x)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)

        x = self.linear2(x)
        x = x.squeeze(1)

        return x

dataset = Planetoid(root='.', name='Cora')
data = dataset[0]

train_test_split_edges( data )

n_pos = data.train_pos_edge_index.shape[1]
data.train_neg_edge_index = negative_sampling( data.train_pos_edge_index, num_neg_samples = n_pos )
n_neg = data.train_neg_edge_index.shape[1]

x = data.x
edge_index = data.train_pos_edge_index
node_pair = torch.cat( [ data.train_pos_edge_index, data.train_neg_edge_index ], dim=1 )
y = torch.cat( [ torch.ones(n_pos), torch.zeros(n_neg) ] )

val_node_pair = torch.cat( [ data.val_pos_edge_index, data.val_neg_edge_index ], dim=1 )
val_y = torch.cat( [ torch.ones(data.val_pos_edge_index.shape[1]), torch.zeros(data.val_neg_edge_index.shape[1]) ] )

test_node_pair = torch.cat( [ data.test_pos_edge_index, data.test_neg_edge_index ], dim=1 )
test_y = torch.cat( [ torch.ones(data.test_pos_edge_index.shape[1]), torch.zeros(data.test_neg_edge_index.shape[1]) ] )



model = Net()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

def evaluate( pred, y ):
    model.eval()
    pred = (pred>0).long()
    acc = accuracy(pred, y)
    model.train()
    return acc

model.train()
for epoch in range(500):
    if epoch % 100 == 0:
        print('EPOCH', epoch)
        pred = model(x,edge_index,node_pair)
        print('train acc:', evaluate(pred,y))
        pred = model(x,edge_index,val_node_pair)
        print('val   acc:', evaluate(pred,val_y))
        pred = model(x,edge_index,test_node_pair)
        print('test  acc:', evaluate(pred,test_y))

    optimizer.zero_grad()
    out = model(x, edge_index, node_pair)
    loss = F.binary_cross_entropy_with_logits(out, y)
    loss.backward()
    optimizer.step()

print('FINAL')
pred = model(x,edge_index,node_pair)
print('train acc:', evaluate(pred,y))
pred = model(x,edge_index,val_node_pair)
print('val   acc:', evaluate(pred,val_y))
pred = model(x,edge_index,test_node_pair)
print('test  acc:', evaluate(pred,test_y))


"""
EPOCH 0
train acc: 0.49994429590017825
val   acc: 0.5
test  acc: 0.5
EPOCH 100
train acc: 0.8262032085561497
val   acc: 0.6806083650190115
test  acc: 0.6736242884250474
EPOCH 200
train acc: 0.8821858288770054
val   acc: 0.7034220532319392
test  acc: 0.7229601518026565
EPOCH 300
train acc: 0.9007352941176471
val   acc: 0.7319391634980988
test  acc: 0.7371916508538899
EPOCH 400
train acc: 0.9188948306595366
val   acc: 0.7756653992395437
test  acc: 0.7590132827324478
FINAL
train acc: 0.9255793226381461
val   acc: 0.7509505703422054
test  acc: 0.7476280834914611
"""
