# 02-ARGA-tree.py
# ARGA를 이용하여 2차원으로 임베딩한 후 시각화.
# Encoder의 Loss 계산부분(111번 라인)을 주석처리함으로써 Adversarial regularization을 하지 않은 경우와 비교실험을 수행할 수 있음

import random
import numpy as np
import torch
import torch.nn.functional as F

# Fix random seed
random.seed( 88731 )
np.random.seed( 27471 )
torch.manual_seed( 53811 )
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

from torch.nn import Linear
from torch_geometric.nn import GCNConv
from torch_geometric.nn import ARGA
from torch_geometric.utils import *
from torch_geometric.data import Data, DataLoader, Batch
from torch_geometric.datasets import Planetoid

#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # GPU를 사용하고자 하는 경우...
device = 'cpu'

# Model
class Encoder(torch.nn.Module):
    def __init__(self, InputDim, HiddenDim, EmbeddingDim):
        super(Encoder, self).__init__()
        self.conv1 = GCNConv(InputDim , HiddenDim)
        self.conv2 = GCNConv(HiddenDim, EmbeddingDim)

    def forward(self, x, edge_index):
        x = F.relu( self.conv1(x, edge_index) )
        x = self.conv2(x, edge_index)
        return x

class Discriminator(torch.nn.Module):
    def __init__(self,EmbeddingDim,HiddenDim1,HiddenDim2):
        super(Discriminator, self).__init__()
        self.linear1 = Linear(EmbeddingDim, HiddenDim1)
        self.linear2 = Linear(HiddenDim1, HiddenDim2)
        self.linear3 = Linear(HiddenDim2, 1)

    def forward(self, x):
        x = F.relu( self.linear1(x) )
        x = F.relu( self.linear2(x) )
        x = self.linear3(x)
        x = x.squeeze(dim=1)
        return x

encoder = Encoder(1,8,2)             # Embedding Dimension = 2
discriminator = Discriminator(2,8,4)

model = ARGA(encoder,discriminator).to(device)

# Generate Data
def generate_tree(n):
    edges = []
    for v in range(1,n) :
        u = np.random.randint( v )
        edges.append( (u,v) )
    perm = np.random.permutation(n)
    edges = [ ( perm[u], perm[v] ) for u,v in edges ]
    edges = [ (u,v) if u<v else (v,u) for u,v in edges ]
    return edges

def to_torch_edge_index(edges):
    edges = [ (u,v) for u,v in edges ] + [ (v,u) for u,v in edges ]
    return torch.tensor(edges, dtype=torch.long).t().contiguous()

def load_data():
    n = 10000
    x = torch.ones((n,1))
    edge_index=to_torch_edge_index( generate_tree(n) )
    y = (degree(edge_index[0],n)==1).long() # Leaf/Nonleaf (학습에는 사용되지 않음. 시각화 용도)
    data = Data( x=x, edge_index=edge_index, y=y )
    data.edge_index,_ = remove_self_loops( data.edge_index )
    data.edge_index,_ = add_self_loops( data.edge_index )

    train_test_split_edges( data )

    data.train_neg_edge_index = negative_sampling( data.train_pos_edge_index, num_nodes=data.num_nodes, num_neg_samples=data.train_pos_edge_index.shape[1] )

    return data

data = load_data().to(device)

optimizerE = torch.optim.Adam( encoder.parameters(), lr=0.001)
optimizerD = torch.optim.Adam( discriminator.parameters(), lr=0.001)

# Training loop
model.train()
print('start training')
for i in range(200): 
    # ARGA의 사용방식은 논문/구현마다 조금씩 다르지만
    # 여기서는 가장 간단한 구조로 사용하도록 함
    if i%10==0: print(i)

    optimizerE.zero_grad() 
    optimizerD.zero_grad()

    Z = encoder( data.x, data.train_pos_edge_index )

    loss = model.discriminator_loss( Z )
    loss.backward()
    optimizerD.step()            ### Discriminator 학습

    loss = model.recon_loss( Z, data.train_pos_edge_index )
    loss += model.reg_loss( Z )  ### 이 라인을 주석처리하면 Discriminator 관련 loss를 반영하지 않음 -> 일반 GAE와 동일
    loss.backward()
    optimizerE.step()

# Visualize Embedded Vectors
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
Z = encoder( data.x, data.train_pos_edge_index ).detach().cpu().numpy()
c = data.y.cpu().numpy() #cluster id
plt.clf()
plt.scatter( Z[c==0,0], Z[c==0,1], c='blue', s=1 )
plt.scatter( Z[c==1,0], Z[c==1,1], c='red', s=1 )
plt.savefig('embedvec.png', dpi=300)
