import torch
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GraphUNet
import torch.nn.functional as F
from torch_geometric.utils import dropout_adj

from dgl import DGLGraph
from dgl.data import citation_graph as citegrh
import networkx as nx
import matplotlib.pyplot as plt

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        pool_ratios = [2000 / data.num_nodes, 0.5]
        self.unet = GraphUNet(dataset.num_features, 32, dataset.num_classes,
                              depth=3, pool_ratios=pool_ratios)

    def forward(self):
        edge_index, _ = dropout_adj(data.edge_index, p=0.2,
                                    force_undirected=True,
                                    num_nodes=data.num_nodes,
                                    training=self.training)
        x = F.dropout(data.x, p=0.92, training=self.training)

        x = self.unet(x, edge_index)
        return F.log_softmax(x, dim=1)


dataset = Planetoid(root='tmp/Cora', name='Cora')
data = dataset[0]
device = 'cpu'
model, data = Net().to(device), data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.001)
best_val_acc = test_acc = 0
epoch_arr = []
loss_arr = []
acc_arr = []

for epoch in range(1, 201):
    model.train()
    optimizer.zero_grad()
    loss = F.nll_loss(model()[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    
    model.eval()
    logits, accs = model(), []
    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        pred = logits[mask].max(1)[1]
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
    acc_arr.append(acc)    

    train_acc, val_acc, tmp_test_acc = accs
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        test_acc = tmp_test_acc
        
        
    epoch_arr.append(epoch)
    loss_arr.append(loss.item())
    log = 'Epoch: {:03d}, Loss :{:.4f}, Accuracy :{:.4f}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'
    print(log.format(epoch, train_acc, best_val_acc, test_acc,loss.item(),acc))
   
   
   

plt.savefig("loss_viz.png", dpi=300)
plt.clf()
# Loss-Accuracy 시각화
plt.plot(epoch_arr, loss_arr)
# Accuracy-Accuracy 시각화
plt.savefig("acc_viz.png", dpi=300)
plt.clf()
plt.plot(epoch_arr, acc_arr, 'r')

