import os.path as osp
import os
import torch
from torch_geometric.nn import MetaPath2Vec
import torch_geometric
from torch_geometric.data import Data

metapath = [                 # 적용할 metapath를 정의( '노드타입'-'에지타입'-'노드타입')
    ('actor', '-', 'movie'), # 여기서는 edge type을 무시했지만, 가운데 '-'로 다르게 설정할 수 있음
    ('movie', '-', 'direc'),
    ('direc', '-', 'movie'),
    ('movie', '-', 'actor'),
]
device = 'cpu'              # gpu로 windows 10에서 돌리면 에러 발생할 수 있음

rf = open('movie_data2.txt', 'r', encoding='utf-8')
Edge = {}
edge_index_dict = {}
num_nodes_dict = {}
y_dict = []
y_index_dict = []

A = set([])
D = set([])
M = set([])

for mp in metapath :
    Edge[mp] = [ [], [] ]

for line in rf :
    # 영화, 감독, 배우3명, 장르 label(0-7)
    line = list(map(int,line.strip().split()))
    movie = line[0]
    direc = line[1]
    actor = line[2:5]
    genre = line[5]
    
    if movie not in M : M.add(movie)
    if direc not in D : D.add(direc)
    Edge[metapath[1]][0].append( movie )
    Edge[metapath[1]][1].append( direc )
    Edge[metapath[2]][0].append( direc )
    Edge[metapath[2]][1].append( movie )
    for a in actor :
        if a not in A : A.add(a)
        Edge[metapath[0]][0].append( a )
        Edge[metapath[0]][1].append( movie )
        Edge[metapath[3]][0].append( movie )
        Edge[metapath[3]][1].append( a )
    y_dict.append( genre )
    y_index_dict.append( movie )
    
rf.close()

y_dict = { 'movie' : torch.tensor(y_dict) }
y_index_dict = { 'movie' : torch.tensor(y_index_dict) }

for path in Edge :
    edge_index_dict[path] = torch.tensor(Edge[path])
    
num_nodes_dict['actor'] = len(A)
num_nodes_dict['movie'] = len(M)
num_nodes_dict['direc'] = len(D)

data = Data()
data.edge_index_dict = edge_index_dict 
data.num_nodes_dict = num_nodes_dict
data.y_dict=y_dict
data.y_index_dict=y_index_dict

model = MetaPath2Vec(data.edge_index_dict, embedding_dim=32,
                     metapath=metapath, walk_length=30, context_size=5,
                     walks_per_node=5, num_negative_samples=3,
                     sparse=True).to(device)
loader = model.loader(batch_size=20, shuffle=True, num_workers=0)
optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=0.01)

def train(epoch, log_steps=100, eval_steps=2000):
    model.train()

    total_loss = 0
    for i, (pos_rw, neg_rw) in enumerate(loader):
        optimizer.zero_grad()
        loss = model.loss(pos_rw.to(device), neg_rw.to(device))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        if (i + 1) % log_steps == 0:
            print((f'Epoch: {epoch}, Step: {i + 1:05d}/{len(loader)}, '
                   f'Loss: {total_loss / log_steps:.4f}'))
            total_loss = 0

        if (i + 1) % eval_steps == 0:
            acc = test()
            print((f'Epoch: {epoch}, Step: {i + 1:05d}/{len(loader)}, '
                   f'Acc: {acc:.4f}'))


@torch.no_grad()
def test(train_ratio=0.9): # 학습에 사용할 데이터 비율
    model.eval()

    z = model('movie', batch=data.y_index_dict['movie'])   # 예측할 노드 설정
    y = data.y_dict['movie']

    perm = torch.randperm(z.size(0))
    train_perm = perm[:int(z.size(0) * train_ratio)]
    test_perm = perm[int(z.size(0) * train_ratio):]

    return model.test(z[train_perm], y[train_perm], z[test_perm], y[test_perm],
                      max_iter=150)


for epoch in range(1, 11):
    train(epoch)
    acc = test()
    print(f'Epoch: {epoch}, Accuracy: {acc:.4f}')