# 01-ARGA-link-prediction.py
# ARGA(Adversarially Regularized Graph Autoencoder)의 기본적인 사용법

import random
import numpy as np
import torch

# Fix random seed
random.seed( 88731 )
np.random.seed( 27471 )
torch.manual_seed( 53811 )
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import GCNConv
from torch_geometric.nn import ARGA
from torch_geometric.utils import *
from torch_geometric.datasets import Planetoid

#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # GPU를 사용하고자 하는 경우...
device = 'cpu'

# Model
class Encoder(torch.nn.Module):
    def __init__(self, InputDim, HiddenDim, EmbeddingDim):
        super(Encoder, self).__init__()
        self.conv1 = GCNConv(InputDim , HiddenDim)
        self.conv2 = GCNConv(HiddenDim, EmbeddingDim)

    def forward(self, x, edge_index):
        x = F.relu( self.conv1(x, edge_index) )
        x = self.conv2(x, edge_index)
        return x

class Discriminator(torch.nn.Module):
    def __init__(self,EmbeddingDim,HiddenDim1,HiddenDim2):
        super(Discriminator, self).__init__()
        self.linear1 = Linear(EmbeddingDim, HiddenDim1)
        self.linear2 = Linear(HiddenDim1, HiddenDim2)
        self.linear3 = Linear(HiddenDim2, 1)

    def forward(self, x):
        x = F.relu( self.linear1(x) )
        x = F.relu( self.linear2(x) )
        x = self.linear3(x)
        x = x.squeeze(dim=1)
        return x

encoder = Encoder(1433,32,32)
discriminator = Discriminator(32,64,32)

model = ARGA(encoder,discriminator).to(device)

optimizerE = torch.optim.Adam( encoder.parameters(), lr=0.001)
optimizerD = torch.optim.Adam( discriminator.parameters(), lr=0.001)

# Load data
def load_data():
    dataset = Planetoid(root='.', name='Cora')
    data = dataset[0]
    data.edge_index,_ = remove_self_loops( data.edge_index )
    data.edge_index,_ = add_self_loops( data.edge_index )

    train_test_split_edges( data )

    data.train_neg_edge_index = negative_sampling( data.train_pos_edge_index, num_nodes=data.num_nodes, num_neg_samples=data.train_pos_edge_index.shape[1] )

    return data

data = load_data().to(device)

# Evaluation module
def evaluate():
    model.eval()
    Z = encoder( data.x, data.train_pos_edge_index )
    loss = model.recon_loss( Z, data.train_pos_edge_index )
    train_perf = model.test( Z, data.train_pos_edge_index, data.train_neg_edge_index )
    val_perf = model.test( Z, data.val_pos_edge_index, data.val_neg_edge_index )
    test_perf = model.test( Z, data.test_pos_edge_index, data.test_neg_edge_index )
    model.train()
    return (loss.item(), 'train', *train_perf, 'val', *val_perf, 'test', *test_perf)

# Training loop
model.train()
print('start training')
for i in range(200): 
    # ARGA의 사용방식은 논문/구현마다 조금씩 다르지만
    # 여기서는 가장 간단한 구조로 사용하도록 함
    if i%10==0: print(i,*evaluate())

    optimizerE.zero_grad() 
    optimizerD.zero_grad()

    Z = encoder( data.x, data.train_pos_edge_index )

    loss = model.discriminator_loss( Z )
    loss.backward()
    optimizerD.step()            ### Discriminator 학습

    loss = model.recon_loss( Z, data.train_pos_edge_index )
    loss += model.reg_loss( Z )  ### 이 라인을 주석처리하면 Discriminator 관련 loss를 반영하지 않음 -> 일반 GAE와 동일
    loss.backward()
    optimizerE.step()
print('FINAL',*evaluate())

# Visualize
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
Z = encoder( data.x, data.train_pos_edge_index ).detach().cpu().numpy()
Z_embedded = TSNE(n_components=2,learning_rate=500.0).fit_transform(Z)
c = data.y.cpu().numpy() #cluster id
plt.clf()
plt.scatter( Z_embedded[:,0], Z_embedded[:,1], c=c, cmap='Accent', s=1 )
plt.savefig('embedcluster.png', dpi=300)
