'''

    Graph Attention Network Tutorial Code
    author : Sang Woon Park (ALDE, PNU)
    date : 2021. 08. 09.
    This code referenced following sites' tutorial:
    https://docs.dgl.ai/tutorials/models/1_gnn/9_gat.html
    
'''


# 모듈 import
import os
import time
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from dgl import DGLGraph
from dgl.data import citation_graph as citegrh
import networkx as nx
import matplotlib.pyplot as plt

# OpenMP Error 해결 섹션
os.environ['KMP_DUPLICATE_LIB_OK']='True'


# GAT Layer 구현
class GATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim):
        super(GATLayer, self).__init__()
        self.g = g
        self.fc = nn.Linear(in_dim, out_dim, bias=False)
        # coefficient를 구하는 부분
        # equation (1)
        self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)
        self.reset_parameters()

    def reset_parameters(self):
        """Reinitialize learnable parameters."""
        gain = nn.init.calculate_gain('relu')
        nn.init.xavier_normal_(self.fc.weight, gain=gain)
        nn.init.xavier_normal_(self.attn_fc.weight, gain=gain)

    def edge_attention(self, edges):
        # edge UDF for equation (1)
        z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
        a = self.attn_fc(z2)
        return {'e': F.leaky_relu(a)}

    def message_func(self, edges):
        # message UDF for equation (2) & (3)
        return {'z': edges.src['z'], 'e': edges.data['e']}

    def reduce_func(self, nodes):
        # reduce UDF for equation (2) & (3)
        # mailbox에 저장된 coefficient를 softmax로 aggregate
        # equation (2)
        alpha = F.softmax(nodes.mailbox['e'], dim=1)
        # equation (3)
        h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
        return {'h': h}

    def forward(self, h):
        # 각 노드 사이의 텐서 업데이트
        z = self.fc(h)
        self.g.ndata['z'] = z
        # equation (1)
        # coefficient 갱신
        self.g.apply_edges(self.edge_attention)
        # equation (2) & (3)
        # attention propagation
        self.g.update_all(self.message_func, self.reduce_func)
        return self.g.ndata.pop('h')


# GATLayer 클래스를 바탕으로 Multi-Head Attention을 구현
class MultiHeadGATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):
        super(MultiHeadGATLayer, self).__init__()
        self.heads = nn.ModuleList()
        for i in range(num_heads):
            self.heads.append(GATLayer(g, in_dim, out_dim))
        self.merge = merge

    def forward(self, h):
        head_outs = [attn_head(h) for attn_head in self.heads]
        if self.merge == 'cat':
            # concat on the output feature dimension (dim=1)
            return torch.cat(head_outs, dim=1)
        else:
            # merge using average
            return torch.mean(torch.stack(head_outs))
            
class GAT(nn.Module):
    def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):
        super(GAT, self).__init__()
        self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)
        # Be aware that the input dimension is hidden_dim*num_heads since
        # multiple head outputs are concatenated together. Also, only
        # one attention head in the output layer.
        self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)

    def forward(self, h):
        h = self.layer1(h)
        h = F.elu(h)
        h = self.layer2(h)
        return h

# citeseer data를 불러오는 함수
def load_cora_data():
    data = citegrh.load_cora()
    features = torch.FloatTensor(data.features)
    labels = torch.LongTensor(data.labels)
    mask = torch.BoolTensor(data.train_mask)
    g = DGLGraph(data.graph)

    # draw citeseer graph with networkx
    '''
    nx_G = g.to_networkx().to_undirected()
    pos = nx.kamada_kawai_layout(nx_G)
    nx.draw(nx_G, pos, with_labels=False, node_size = 0.01, node_color='#00b4d9', width=0.3)
    plt.savefig("data_viz.png",dpi=300)
    '''
    return g, features, labels, mask

g, features, labels, mask = load_cora_data()

# create the model, 2 heads, each head has hidden size 8
# num_heads는 multi-head의 개수를 조절, bottleneck으로써 기능
net = GAT(g,
          in_dim=features.size()[1],
          hidden_dim=8,
          out_dim=7,
          num_heads=2)

# create optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

# main loop
dur = []
epoch_arr = []
loss_arr = []
acc_arr = []
for epoch in range(30):
    if epoch >= 3:
        t0 = time.time()

    logits = net(features)
    logp = F.log_softmax(logits, 1)
    loss = F.nll_loss(logp[mask], labels[mask])

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch >= 3:
        dur.append(time.time() - t0)
    
    # Prediction의 Accuracy를 구하기 위해서 answer와 비교
    pred = np.argmax(logp[mask].detach().numpy(), axis = 1)
    answ = labels[mask].numpy()
    acc = np.sum([1 if pred[i] == answ[i] else 0 for i in range(len(pred))]) / len(pred) * 100
    
    # 시각화를 위해 array에 추가
    epoch_arr.append(epoch)
    acc_arr.append(acc)
    loss_arr.append(loss.item())
    
    print("Epoch {:05d} | Loss {:.4f} | Accuracy {:.2f}% | Time(s) {:.4f}".format(epoch, loss.item(), acc, np.mean(dur)))
    
# Epoch-Loss 시각화
plt.plot(epoch_arr, loss_arr)
plt.savefig("loss_viz.png", dpi=300)
plt.clf()
# Epoch-Accuracy 시각화
plt.plot(epoch_arr, acc_arr, 'r')
plt.savefig("acc_viz.png", dpi=300)
