# -*- coding: utf-8 -*-
"""Transfomer_old.ipynb

Automatically generated by Colab.

Original file is located at
    <https://colab.research.google.com/drive/1_JdluzBkJszICa3BricJP-WhiVib6VtN>
"""

import numpy as np
import matplotlib.pyplot as plt
import re

import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, lr_scheduler
import logging

torch.manual_seed(1)

# import urllib.request
# import pandas as pd

# def get_text_data():
#   urllib.request.urlretrieve("<https://raw.githubusercontent.com/songys/Chatbot_data/master/ChatbotData.csv>", filename="ChatBotData.csv")
#   train_data = pd.read_csv('ChatBotData.csv')
#   # NULL값과 같은 불필요 값이 있는지 확인한다
#   print(train_data.isnull().sum())
#   return train_data

# !pip install soynlp

import os
import torch
from torch.utils.data import DataLoader
from torchtext.vocab import build_vocab_from_iterator
# import torchtext.transforms as T

# import os
# import torch
# from torch.utils.data import DataLoader
# from torchtext.vocab import build_vocab_from_iterator
# import torchtext.transforms as T

# import pickle
# import torch
# from torchtext.data.metrics import bleu_score

# def save_pkl(data, fname):
#     with open(fname, "wb") as f:
#         pickle.dump(data, f)

# def load_pkl(fname):
#     with open(fname, "rb") as f:
#         data = pickle.load(f)
#     return data

# def get_bleu_score(output, gt, vocab, specials, max_n=4):

#     def itos(x):
#         x = list(x.cpu().numpy())
#         tokens = vocab.lookup_tokens(x)
#         tokens = list(filter(lambda x: x not in {"", " ", "."} and x not in list(specials.keys()), tokens))
#         return tokens

#     pred = [out.max(dim=1)[1] for out in output]
#     pred_str = list(map(itos, pred))
#     gt_str = list(map(lambda x: [itos(x)], gt))

#     score = bleu_score(pred_str, gt_str, max_n=max_n) * 100
#     return  score

# def greedy_decode(model, src, max_len, start_symbol, end_symbol):
#     src = src.to(model.device)
#     src_mask = model.make_src_mask(src).to(model.device)
#     memory = model.encode(src, src_mask)

#     ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(model.device)
#     for i in range(max_len-1):
#         memory = memory.to(model.device)
#         tgt_mask = model.make_tgt_mask(ys).to(model.device)
#         src_tgt_mask = model.make_src_tgt_mask(src, ys).to(model.device)
#         out = model.decode(ys, memory, tgt_mask, src_tgt_mask)
#         prob = model.generator(out[:, -1])
#         _, next_word = torch.max(prob, dim=1)
#         next_word = next_word.item()

#         ys = torch.cat([ys,
#                         torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
#         if next_word == end_symbol:
#             break
#     return

# class Multi30k():

#     def __init__(self,
#                  lang=("en", "de"),
#                  max_seq_len=256,
#                  unk_idx=0,
#                  pad_idx=1,
#                  sos_idx=2,
#                  eos_idx=3,
#                  vocab_min_freq=2):

#         self.dataset_name = "multi30k"
#         self.lang_src, self.lang_tgt = lang
#         self.max_seq_len = max_seq_len
#         self.unk_idx = unk_idx
#         self.pad_idx = pad_idx
#         self.sos_idx = sos_idx
#         self.eos_idx = eos_idx
#         self.unk = "<unk>"
#         self.pad = "<pad>"
#         self.sos = "<sos>"
#         self.eos = "<eos>"
#         self.specials={
#                 self.unk: self.unk_idx,
#                 self.pad: self.pad_idx,
#                 self.sos: self.sos_idx,
#                 self.eos: self.eos_idx
#                 }
#         self.vocab_min_freq = vocab_min_freq

#         self.tokenizer_src = self.build_tokenizer(self.lang_src)
#         self.tokenizer_tgt = self.build_tokenizer(self.lang_tgt)

#         self.train = None
#         self.valid = None
#         self.test = None
#         self.build_dataset()

#         self.vocab_src = None
#         self.vocab_tgt = None
#         self.build_vocab()

#         self.transform_src = None
#         self.transform_tgt = None
#         self.build_transform()

#     def build_dataset(self, raw_dir="raw", cache_dir=".data"):
#         cache_dir = os.path.join(cache_dir, self.dataset_name)
#         raw_dir = os.path.join(cache_dir, raw_dir)
#         os.makedirs(raw_dir, exist_ok=True)

#         train_file = os.path.join(cache_dir, "train.pkl")
#         valid_file = os.path.join(cache_dir, "valid.pkl")
#         test_file = os.path.join(cache_dir, "test.pkl")

#         if os.path.exists(train_file):
#             self.train = load_pkl(train_file)
#         else:
#             with open(os.path.join(raw_dir, "train.en"), "r") as f:
#                 train_en = [text.rstrip() for text in f]
#             with open(os.path.join(raw_dir, "train.de"), "r") as f:
#                 train_de = [text.rstrip() for text in f]
#             self.train = [(en, de) for en, de in zip(train_en, train_de)]
#             save_pkl(self.train , train_file)

#         if os.path.exists(valid_file):
#             self.valid = load_pkl(valid_file)
#         else:
#             with open(os.path.join(raw_dir, "val.en"), "r") as f:
#                 valid_en = [text.rstrip() for text in f]
#             with open(os.path.join(raw_dir, "val.de"), "r") as f:
#                 valid_de = [text.rstrip() for text in f]
#             self.valid = [(en, de) for en, de in zip(valid_en, valid_de)]
#             save_pkl(self.valid, valid_file)

#         if os.path.exists(test_file):
#             self.test = load_pkl(test_file)
#         else:
#             with open(os.path.join(raw_dir, "test_2016_flickr.en"), "r") as f:
#                 test_en = [text.rstrip() for text in f]
#             with open(os.path.join(raw_dir, "test_2016_flickr.de"), "r") as f:
#                 test_de = [text.rstrip() for text in f]
#             self.test = [(en, de) for en, de in zip(test_en, test_de)]
#             save_pkl(self.test, test_file)

#     def build_vocab(self, cache_dir=".data"):
#         assert self.train is not None
#         def yield_tokens(is_src=True):
#             for text_pair in self.train:
#                 if is_src:
#                     yield [str(token) for token in self.tokenizer_src(text_pair[0])]
#                 else:
#                     yield [str(token) for token in self.tokenizer_tgt(text_pair[1])]

#         cache_dir = os.path.join(cache_dir, self.dataset_name)
#         os.makedirs(cache_dir, exist_ok=True)

#         vocab_src_file = os.path.join(cache_dir, f"vocab_{self.lang_src}.pkl")
#         if os.path.exists(vocab_src_file):
#             vocab_src = load_pkl(vocab_src_file)
#         else:
#             vocab_src = build_vocab_from_iterator(yield_tokens(is_src=True), min_freq=self.vocab_min_freq, specials=self.specials.keys())
#             vocab_src.set_default_index(self.unk_idx)
#             save_pkl(vocab_src, vocab_src_file)

#         vocab_tgt_file = os.path.join(cache_dir, f"vocab_{self.lang_tgt}.pkl")
#         if os.path.exists(vocab_tgt_file):
#             vocab_tgt = load_pkl(vocab_tgt_file)
#         else:
#             vocab_tgt = build_vocab_from_iterator(yield_tokens(is_src=False), min_freq=self.vocab_min_freq, specials=self.specials.keys())
#             vocab_tgt.set_default_index(self.unk_idx)
#             save_pkl(vocab_tgt, vocab_tgt_file)

#         self.vocab_src = vocab_src
#         self.vocab_tgt = vocab_tgt

#     def build_tokenizer(self, lang):
#         from torchtext.data.utils import get_tokenizer
#         spacy_lang_dict = {
#                 'en': "en_core_web_sm",
#                 'de': "de_core_news_sm"
#                 }
#         assert lang in spacy_lang_dict.keys()
#         return get_tokenizer("spacy", spacy_lang_dict[lang])

#     def build_transform(self):
#         def get_transform(self, vocab):
#             return T.Sequential(
#                     T.VocabTransform(vocab),
#                     T.Truncate(self.max_seq_len-2),
#                     T.AddToken(token=self.sos_idx, begin=True),
#                     T.AddToken(token=self.eos_idx, begin=False),
#                     T.ToTensor(padding_value=self.pad_idx))

#         self.transform_src = get_transform(self, self.vocab_src)
#         self.transform_tgt = get_transform(self, self.vocab_tgt)

#     def collate_fn(self, pairs):
#         src = [self.tokenizer_src(pair[0]) for pair in pairs]
#         tgt = [self.tokenizer_tgt(pair[1]) for pair in pairs]
#         batch_src = self.transform_src(src)
#         batch_tgt = self.transform_tgt(tgt)
#         return (batch_src, batch_tgt)

#     def get_iter(self, **kwargs):
#         if self.transform_src is None:
#             self.build_transform()
#         train_iter = DataLoader(self.train, collate_fn=self.collate_fn, **kwargs)
#         valid_iter = DataLoader(self.valid, collate_fn=self.collate_fn, **kwargs)
#         test_iter = DataLoader(self.test, collate_fn=self.collate_fn, **kwargs)
#         return train_iter, valid_iter, test_iter

#     def translate(self, model, src_sentence: str, decode_func):
#         model.eval()
#         src = self.transform_src([self.tokenizer_src(src_sentence)]).view(1, -1)
#         num_tokens = src.shape[1]
#         tgt_tokens = decode_func(model,
#                                  src,
#                                  max_len=num_tokens+5,
#                                  start_symbol=self.sos_idx,
#                                  end_symbol=self.eos_idx).flatten().cpu().numpy()
#         tgt_sentence = " ".join(self.vocab_tgt.lookup_tokens(tgt_tokens))
#         return tgt_sentence

from soynlp.utils import DoublespaceLineCorpus
from soynlp.noun import LRNounExtractor_v2
from soynlp.tokenizer import LTokenizer

def preprocess_sentence_func(sentence):
  # 단어와 구두점 사이에 공백 추가하여 단어 임베딩을 생성
  sentence = re.sub(r"([?.!,@,#,$,*])", r" \\1 ", sentence)
  sentence = sentence.strip()
  return sentence

# csv 문장데이터 형태를 합쳐서
def integete_data_sets(train_data, preprocess_function):
  sentences = []
  a_sentence = []
  for sentence in train_data['A']:
    a_sentence.append(preprocess_function(sentence))
    sentences.append(preprocess_function(sentence))

  q_sentence = []
  for sentence in train_data['Q']:
    q_sentence.append(preprocess_function(sentence))
    sentences.append(preprocess_function(sentence))
  return {"A":a_sentence, "Q":q_sentence, "Tot": sentences}

# 데이터 문장들을 배열의 형태로 받아 vocab을 만들고, soynlp의 스코어 형태로 제작
def soylnp_tokenizer(sentences):
  noun_extractor = LRNounExtractor_v2(verbose=True)
  nouns = noun_extractor.train_extract(sentences)

  cohesion_score = {word:score.score  for word, score in nouns.items()}
  tokenizer = LTokenizer(scores=cohesion_score)

  tot_sent = ""
  for sentence in sentences:
    words = tokenizer.tokenize(sentence)
    for word in words:
      tot_sent += word + " "
  # ['데이터마이닝', '을', '공부', '중이다']

  # 중복을 제거한 단어들의 집합인 단어 집합 생성.
  word_set = set(tot_sent.split())

  # 단어 집합의 각 단어에 고유한 정수 맵핑.
  vocab = {word: i+2 for i, word in enumerate(word_set)}
  vocab['<unk>'] = 0
  vocab['<pad>'] = 1

  return (vocab, cohesion_score)

# 토크나이즈 완료된 문장 '배열'을 받아 임배팅의 형태로 변환
def tokenize_sentence_to_embedding(word_arr, vocab, embedding_layer, for_decode=False):

  idxs = []
  for word in word_arr:
    try:
        idxs.append(vocab[word])
      # 단어 집합에 없는 단어일 경우 <unk>로 대체된다.
    except KeyError:
        idxs.append(vocab['<unk>'])
  if for_decode:
    idxs.insert(0, vocab['<pad>'])
  idxs = torch.LongTensor(idxs)
  lookup_result = embedding_layer.weight[idxs, :]

  return lookup_result, idxs

# 임베딩 벡터의 길이를 모델에 적용가능한 길이로 규격화
def standardization_embedding_vector(embedding_vector, pad_tensor, dim=128):

  result = torch.FloatTensor([])
  # output = torch.stack(out_list, 0)
  for i in range(dim):
    if len(embedding_vector) > i :
      result = torch.cat([result, embedding_vector[i]], dim=0)
    else:
      result = torch.cat([result, pad_tensor], dim=0)

  result = torch.reshape(result, (dim, len(pad_tensor)))
  return result

# 정수 인코딩이 완료된 배열을 받아 임베딩 테이블을 제작
def create_embedding_table_pytorch(vocab, dim=512):
  embedding_layer = nn.Embedding(num_embeddings=len(vocab),
                               embedding_dim=dim,
                               padding_idx=1)
  return embedding_layer

class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_len, device):
        """
        sin, cos encoding 구현

        parameter
        - d_model : model의 차원
        - max_len : 최대 seaquence 길이
        - device : cuda or cpu
        """

        super(PositionalEncoding, self).__init__() # nn.Module 초기화

        # input matrix(자연어 처리에선 임베딩 벡터)와 같은 size의 tensor 생성
        # 즉, (max_len, d_model) size
        self.encoding = torch.zeros(max_len, d_model, device=device)
        self.encoding.requires_grad = False # 인코딩의 그래디언트는 필요 없다.

        # 위치 indexing용 벡터
        # pos는 max_len의 index를 의미한다.
        pos = torch.arange(0, max_len, device =device)
        # 1D : (max_len, ) size -> 2D : (max_len, 1) size -> word의 위치를 반영하기 위해

        pos = pos.float().unsqueeze(dim=1) # int64 -> float32 (없어도 되긴 함)

        # i는 d_model의 index를 의미한다. _2i : (d_model, ) size
        # 즉, embedding size가 512일 때, i = [0,512]
        _2i = torch.arange(0, d_model, step=2, device=device).float()

        # (max_len, 1) / (d_model/2 ) -> (max_len, d_model/2)
        self.encoding[:, ::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
        self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))

    def forward(self, x):
        # self.encoding
        # [max_len = 512, d_model = 512]

        # batch_size = 128, seq_len = 30
        batch_size, seq_len = x.size()

        # [seq_len = 30, d_model = 512]
        # [128, 30, 512]의 size를 가지는 token embedding에 더해질 것이다.
        #
        return self.encoding[:seq_len, :]

def tensor_plot_graph(some_tensor):
  plt.pcolormesh(some_tensor.numpy(), cmap='RdBu')
  plt.xlabel('Depth')
  plt.ylabel('Position')
  plt.colorbar()
  plt.show()

def graphical_works(data):

  for datom in data:
    if type(datom) is list:
      tensor_plot_graph(datom.encoding)
    elif type(datom) is type(torch.tensor([])):
      tensor_plot_graph(datom.detach())
    else:
      print("typeError:", type(datom), datom)

import numpy as np
import math
def lookup_head_mask(arr, pad):
  size = len(arr)
  mask_arr = torch.zeros(size, size)
  for i in range(size):
    for j in range(size):
      if i < j:
        mask_arr[i][j] = 1

  for i in range(size):
    if arr[i] == pad:
      for j in range(size):
        mask_arr[j][i] = 1

  return mask_arr

def scaled_dot_product_attention(Q, K, V, d_k, mask=None):
  # 이미 전치행렬 처리가 완료됨
  # print("QK",Q.size(), K.size())
  mul_step = torch.matmul(Q, K.t())
  scale_step =  torch.div(mul_step, math.sqrt(d_k))
  if mask is not None:
    pre_mask = mask + ((-1e9)* mask)
    semi_result_1 = torch.matmul(pre_mask, scale_step )
  else:
    result = scale_step
    semi_result_1 = F.softmax( result,  dim=1)

  #  print("semi V", semi_result.size(), V.size())
  result =  torch.matmul( semi_result_1 , V )

  return result

def multi_head_attention_concat(heads):
  idx = 0
  output = torch.tensor(heads[0])
  for head in heads:
    if(idx != 0):
      output = torch.cat((output, head), dim=1)
    else:
      idx += 1

  return output

def multi_head_attention_i(embedding_layer, W_V, K, W_K, Q, W_Q,):

  Q, K, V =  nn.Linear(embedding_layer), nn.Linear(embedding_layer), nn.Linear(embedding_layer)

def copy_tensors_return_as_pylist(embedding_layer, h_dim):
  copy_and_paste_tensor_as_h_dim_nums = []
  for i in range(h_dim):
    copy_and_paste_tensor_as_h_dim_nums.append(embedding_layer.clone().detach())
  return copy_and_paste_tensor_as_h_dim_nums

def multi_head_attention(embedding_layer, d_model, W_O, scaled_dot_product_attention, h_dim=8):
  Q_s = copy_tensors_return_as_pylist(embedding_layer,h_dim)
  K_s = copy_tensors_return_as_pylist(embedding_layer,h_dim)
  V_s = copy_tensors_return_as_pylist(embedding_layer,h_dim)

  W_init = torch.randn(h_dim, len(embedding_layer), len(embedding_layer[0]))
  #W_O =  torch.tensor.randn(h_dim*len( len(embedding_layer), embedding_layer[0]))

  W_Q_s = copy_tensors_return_as_pylist(W_init, h_dim)
  W_K_s = copy_tensors_return_as_pylist(W_init, h_dim)
  W_V_s = copy_tensors_return_as_pylist(W_init, h_dim)

  heads = list()

  for i in range(h_dim):
    heads.append(scaled_dot_product_attention(Q_s[i],K_s[i], V_s[i], d_k=(d_model/h_dim)))
  result = torch.matmul(multi_head_attention_concat(heads), W_O.t())
  return result

def res_net_and_layer_normalization(layer, sub_layer, d_model=512):
  la = sub_layer + layer
  lay = nn.LayerNorm(d_model,len(layer))

  return  lay(la)

def standardization_embedding_idxs(embedding_idxs, pad_tensor, dim=128):

  result = torch.FloatTensor(dim)
  for i in range(dim):
    if i < len(embedding_idxs):
      if embedding_idxs[i] != pad_tensor:
        result[i] = embedding_idxs[i]
    else:
        result[i] = pad_tensor

  return result

class MultiHeadAttention(nn.Module):
  def __init__(self, scaled_dot_product_attention, d_model=512, embedding_langth=128, d_k = 64):
    super(MultiHeadAttention, self).__init__()
    self.d_k = int(d_k)
    self.sdpa = scaled_dot_product_attention

    self.L_Q = nn.Linear(d_model, self.d_k, bias=False )
    self.L_K = nn.Linear(d_model, self.d_k, bias=False )
    self.L_V = nn.Linear(d_model, self.d_k, bias=False )

  def forward(self, Q, K, V, mask = None):

     out = self.sdpa(self.L_Q(Q), self.L_V(V), self.L_K(K), self.d_k, mask)
     return out

class MultiHeadsAttention(nn.Module):

    def __init__(self, scaled_dot_product_attention, d_model=512, embedding_length=128, h_dim=8):
      super(MultiHeadsAttention, self).__init__()

      self.d_model=d_model
      self.embedding_length=embedding_length
      self.h_dim=h_dim

      self.heads_model = list()
      for head in range(h_dim):
        self.heads_model.append(MultiHeadAttention(scaled_dot_product_attention, d_model, embedding_length, (d_model/h_dim)))

    def multi_head_attention_concat(self):
      idx = 0
      output = torch.tensor(self.heads[0])
      for head in self.heads:
        if(idx != 0):
          output = torch.cat((output, head), dim=1)
        else:
          idx += 1

      return output

    def forward(self, Q, K ,V, mask = None):
      self.heads= list()
      self.mask = mask
      # 에러 나는게 가능한 부분 weight의 정확한 처리 (시스템 구동내 기억와 같은 개념이 불분명한 상태)
      for head in range(self.h_dim):
        self.heads.append(self.heads_model[head](Q, K ,V, self.mask))
      out = self.multi_head_attention_concat()
      return out

# class MaskedMultiHeadAttention(nn.Module):
#   def __init__(self, scaled_dot_product_attention, d_model=512, embedding_langth=128, d_k = 64):
#     super(MultiHeadAttention, self).__init__()
#     self.d_k = int(d_k)
#     self.sdpa = scaled_dot_product_attention
#     self.L_Q = nn.Linear(d_model, self.d_k, bias=False )
#     self.L_K = nn.Linear(d_model, self.d_k, bias=False )
#     self.L_V = nn.Linear(d_model, self.d_k, bias=False )

#   def forward(self, Q, K, V):
#      out = self.sdpa(self.L_Q(Q), self.L_V(V), self.L_K(K), self.d_k)
#      return out

# class MaskedMultiHeadsAttention(nn.Module):

#     def __init__(self, scaled_dot_product_attention, d_model=512, embedding_length=128, h_dim=8, mask=None):
#       super(MultiHeadsAttention, self).__init__()

#       self.d_model=d_model
#       self.embedding_length=embedding_length
#       self.h_dim=h_dim
#       self.heads_model = list()
#       for head in range(h_dim):
#         self.heads_model.append(MultiHeadAttention(scaled_dot_product_attention, d_model, embedding_length, (d_model/h_dim)))

#     def multi_head_attention_concat(self):
#       idx = 0
#       output = torch.tensor(self.heads[0])
#       for head in self.heads:
#         if(idx != 0):
#           output = torch.cat((output, head), dim=1)
#         else:
#           idx += 1

#       return output

#     def forward(self, Q, K ,V):
#       self.heads= list()
#       # 에러 나는게 가능한 부분 weight의 정확한 처리 (시스템 구동내 기억와 같은 개념이 불분명한 상태)
#       for head in range(self.h_dim):
#         self.heads.append(self.heads_model[head](Q, K ,V))
#       out = self.multi_head_attention_concat()
#       return out

class ResNetLayerNormalization(nn.Module):
    def __init__(self, d_model=512, embedding_langth=128):
      super(ResNetLayerNormalization, self).__init__()
      self.LN = nn.LayerNorm(d_model)

    def forward(self, sub_model, x):
      out = x + sub_model
      out = self.LN (out)
      return out

class FeedForwardWiseLayer(nn.Module):
    def __init__(self, d_embed=512, d_ff= 2048):
        super(FeedForwardWiseLayer, self).__init__()
        self.fc1 = nn.Linear(d_embed, d_ff)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(d_ff, d_embed)

    def forward(self, x):
        out = x
        out = self.fc1(out)
        out = self.relu(out)
        out = self.fc2(out)
        return out

class EncoderStack(nn.Module):
  def __init__(self, d_model, embedding_vector_dim, h_dim, d_ff, n_layer = 6):
    super(EncoderStack, self).__init__()
    self.d_model = d_model
    self.embedding_vector_dim = embedding_vector_dim
    self.h_dim = h_dim
    self.d_ff = d_ff
    self.n_layer = n_layer
    self.layers = list()
    self.recoder = list()
    self.m = nn.Dropout(p=0.2)

    for i in range(n_layer):
      self.layers.append(self.encoder_layer())

  def encoder_layer(self):
    self.MA = MultiHeadsAttention(scaled_dot_product_attention, d_model=self.d_model, embedding_length=self.embedding_vector_dim, h_dim=self.h_dim)
    self.RN1 = ResNetLayerNormalization(self.d_model, self.d_model)
    self.FFWL = FeedForwardWiseLayer(self.d_model, self.d_ff)
    self.RN2 = ResNetLayerNormalization(self.d_model, self.d_ff)

    return {"MA": self.MA, "RN1":self.RN1, "FFWL": self.FFWL, "RN2": self.RN2}

  def get_record(self):
    return self.recoder

  def forward(self, x):
    for i in range(self.n_layer):
      out = self.layers[i]["MA"](x, x, x)
      self.m(out)
      self.recoder.append(out)
      out2 = self.layers[i]["RN1"](out, x)
      self.m(out)
      self.recoder.append(out2)
      out3 = self.layers[i]["FFWL"](out2)
      self.m(out2)
      out = self.layers[i]["RN2"](out3, out2)
      self.m(out3)
      x = out
      self.recoder.append(x)
    return out

class DecoderStack(nn.Module):
  def __init__(self, d_model, embedding_vector_dim, h_dim, d_ff, n_layer = 6):
    super(DecoderStack, self).__init__()
    self.d_model = d_model
    self.embedding_vector_dim = embedding_vector_dim
    self.h_dim = h_dim
    self.d_ff = d_ff
    self.n_layer = n_layer
    self.layers = list()
    self.recoder = list()
    self.m = nn.Dropout(p=0.2)

    for i in range(n_layer):
      self.layers.append(self.decoder_layer())

  def decoder_layer(self):
    self.MMA = MultiHeadsAttention(scaled_dot_product_attention, d_model=self.d_model, embedding_length=self.embedding_vector_dim, h_dim=self.h_dim)
    self.RN1 = ResNetLayerNormalization(self.d_model, self.d_model)
    self.MA = MultiHeadsAttention(scaled_dot_product_attention, d_model=self.d_model, embedding_length=self.embedding_vector_dim, h_dim=self.h_dim)
    self.RN2 = ResNetLayerNormalization(self.d_model, self.d_model)
    self.FFWL = FeedForwardWiseLayer(self.d_model, self.d_ff)
    self.RN3 = ResNetLayerNormalization(self.d_model, self.d_ff)

    return { "MMA":self.MMA, "MA": self.MA, "RN1":self.RN1,  "RN3":self.RN3, "FFWL": self.FFWL, "RN2": self.RN2}

  def get_record(self):
    return self.recoder

  def forward(self, x, encoder_embedding, mask=None):
    self.mask = mask
    for i in range(self.n_layer):
      out = self.m(self.layers[i]["MMA"](x, x, x, mask=self.mask))
      out1 = self.layers[i]["RN1"](out, x)
      out2 = self.m(self.layers[i]["MA"](encoder_embedding, encoder_embedding, out1, mask=mask))
      out3 = self.layers[i]["RN2"](out2, out1)
      out4 = self.m(self.layers[i]["FFWL"](out3))
      out5 = self.layers[i]["RN3"](out4, out3)

      x = out5

      self.recoder.append(x)
    out5 = F.softmax( out5,  dim=1)

    return out5

class Transfomer(nn.Module):
  def __init__(self, d_model, h_dim, d_ff, pad_idx, embedding_vector_dim, num_tokens):
    super(Transfomer, self).__init__()

    # PositionalEncoding 작업 진행
    self.sample_pos_encoding = PositionalEncoding(d_model, embedding_vector_dim, device='cpu')
    self.encoder = EncoderStack(d_model, embedding_vector_dim, h_dim, d_ff)
    self.decoder = DecoderStack(d_model, embedding_vector_dim, h_dim, d_ff)
    # 추후에 encode data를 output data 로 변경할 것
    self.fc = nn.Linear(d_model, num_tokens)

  def get_attention_matrix(self):
    return torch.matmul(self.encode_data, self.encode_data.t())

  def get_recoder(self):
    return self.encoder.get_record() +  self.decoder.get_record()

  def forward(self, input_embedding, output_embedding, embedding_idxs):
    self.input_embedding = input_embedding + self.sample_pos_encoding.encoding
    self.output_embedding = output_embedding + self.sample_pos_encoding.encoding
    self.encode_data = self.encoder(self.input_embedding)

    self.embedding_idxs = standardization_embedding_idxs(embedding_idxs, pad_idx, dim=embedding_vector_dim)
    self.mask_arr = lookup_head_mask(self.embedding_idxs, pad_idx)
    self.decode_data = self.decoder(self.output_embedding, self.encode_data, self.mask_arr )
    result =  self.fc(self.decode_data)
    return result
    # self.encoder_recordes = self.encoder.get_recorde()
    # self.decoder_recordes= decoder.get_record()

class TokenEmbedding(nn.Module):

    def __init__(self, d_embed, vocab_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_embed)
        self.d_embed = d_embed

    def forward(self, x):
        out = self.embedding(x) * math.sqrt(self.d_embed)
        return out

def train(model, data_loader, optimizer, criterion, epoch, d_embed, checkpoint_dir):
    model.train()
    epoch_loss = 0
    torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
    for idx, (src, tgt) in enumerate(data_loader):
        # src = src.to(model.device)
        # tgt = tgt.to(model.device)
        tgt_x = tgt[:, :-1]
        tgt_y = tgt[:, 1:]
        token_layer = TokenEmbedding(len(tgt_x), d_embed)
        embt =  token_layer(tgt_x)
        print(tgt_x[0:5], tgt_x.shape)
        print(embt[0:5], embt.shape)
        optimizer.zero_grad()

        output, _ = model(src, tgt_x, )

        y_hat = output.contiguous().view(-1, output.shape[-1])
        y_gt = tgt_y.contiguous().view(-1)
        loss = criterion(y_hat, y_gt)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        epoch_loss += loss.item()
    num_samples = idx + 1

    if checkpoint_dir:
        os.makedirs(checkpoint_dir, exist_ok=True)
        checkpoint_file = os.path.join(checkpoint_dir, f"{epoch:04d}.pt")
        torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss
                   }, checkpoint_file)

    return epoch_loss / num_samples

def evaluate(model, data_loader, criterion):
    model.eval()
    epoch_loss = 0

    total_bleu = []
    with torch.no_grad():
        for idx, (src, tgt) in enumerate(data_loader):
            src = src.to(model.device)
            tgt = tgt.to(model.device)
            tgt_x = tgt[:, :-1]
            tgt_y = tgt[:, 1:]

            output, _ = model(src, tgt_x)

            y_hat = output.contiguous().view(-1, output.shape[-1])
            y_gt = tgt_y.contiguous().view(-1)
            loss = criterion(y_hat, y_gt)

            epoch_loss += loss.item()
            score = get_bleu_score(output, tgt_y, DATASET.vocab_tgt, DATASET.specials)
            total_bleu.append(score)
        num_samples = idx + 1

    loss_avr = epoch_loss / num_samples
    bleu_score = sum(total_bleu) / len(total_bleu)
    return loss_avr, bleu_score

# config
CHECKPOINT_DIR = "./checkpoint"
N_EPOCH = 1000

BATCH_SIZE = 2048
NUM_WORKERS = 8

LEARNING_RATE = 1e-5
WEIGHT_DECAY = 5e-4
ADAM_EPS = 5e-9
SCHEDULER_FACTOR = 0.9
SCHEDULER_PATIENCE = 10

WARM_UP_STEP = 100

DROPOUT_RATE = 0.1

import os, sys
from google.colab import drive
drive.mount('/content/drive')

my_path = '/content/package'
save_path = '/content/drive/MyDrive/Colab Notebooks/package' ## 패키지가 저장될 경로
try:
    os.symlink(save_path, my_path)
except:
    pass

sys.path.insert(0, my_path)

# 데이터를
class EncodeDecode:

  def __init__(self, src_tokenizer, tgt_tokenizer, src_ttoi, src_itot, tgt_ttoi, tgt_itot):
    self.src_tokenizer = src_tokenizer
    self.tgt_tokenizer = tgt_tokenizer
    self.src_ttoi = src_ttoi
    self.src_itot = src_itot
    self.tgt_ttoi = tgt_ttoi
    self.tgt_itot = tgt_itot

    self.unk_idx = 0
    self.pad_idx = 3

  def src_encode(self, src_text):
    return list(map(lambda x: self.src_ttoi.get(x, self.unk_idx), self.src_tokenizer(src_text)))

  def tgt_encode(self, tgt_text):
    tokens = ['<sos>'] + self.tgt_tokenizer(tgt_text) + ['<eos>']
    return list(map(lambda x: self.tgt_ttoi.get(x, self.unk_idx), tokens))

  def src_decode(self, ids):
    sentence = list(map(lambda x: self.src_id2token[x], ids))
    return " ".join(sentence)

  def tgt_decode(self, ids):
    sentence = list(map(lambda x: self.tgt_id2token[x], ids))[1:-1]
    return " ".join(sentence)

  def batchfiy( self, token_arr, batch_size=128, pad_idx=3):
    batch = [ pad_idx for i in range(batch_size) ]
    for i in range(len(token_arr)):
      if i > batch_size: break
      batch[i] = token_arr[i]

    return batch

!pip install torchdata

!python -m spacy download en_core_web_sm
!python -m spacy download de_core_news_sm

from torchtext.datasets import Multi30k
import torchtext
from torchtext.data import get_tokenizer

multi_train, multi_valid, multi_test = Multi30k(language_pair=('en','de'))
for i, (eng, de) in enumerate(multi_train):
    if i == 5: break
    print(f"index:{i}, English: {eng}, das Deutsche: {de}")

from torchtext.vocab import build_vocab_from_iterator
# tokenizer를 호출한다.
en_tokenizer = get_tokenizer(tokenizer='spacy', language='en_core_web_sm')
de_tokenizer = get_tokenizer(tokenizer='spacy', language='de_core_news_sm')
# transfomer는 여기서 BPE로 진행함 BERT는 WORD PIECE로 진행함
en_vocab = build_vocab_from_iterator(map(en_tokenizer, [english for english, _ in multi_train]), min_freq=2, specials=["<unk>", "<sos>", "<eos>", "<pad>"])
de_vocab = build_vocab_from_iterator(map(de_tokenizer, [de for _ , de in multi_train]), min_freq=2, specials=["<unk>", "<sos>", "<eos>", "<pad>"])

en_token2id = en_vocab.get_stoi()
de_token2id = de_vocab.get_stoi()

en_id2token = en_vocab.get_itos()
de_id2token = de_vocab.get_itos()

pre_process = EncodeDecode(en_tokenizer, de_tokenizer, en_token2id, en_id2token,  de_token2id, de_id2token)

# 텍스트 데이터 다운, 및 텍스트 천처리
train_data, valid_data, test_data = multi_train, multi_valid, multi_test

# 모델 차원 설정
d_model, embedding_vector_dim, h_dim, d_ff, pad_idx = 512, 128, 8, 2048, 1
num_tokens = len(en_token2id)
# 데이터 받아서 정수 인코딩 및 단어 스코어링 진행(토큰화 때문에 진행)
# 토크나이저 스코어 셋팅

# d_model = 512 -> 임베딩테이블 생성 512의 길이

model = Transfomer(d_model, h_dim, d_ff, pad_idx, embedding_vector_dim, num_tokens)
criterion = nn.CrossEntropyLoss()
optimizer = Adam(params=model.parameters(), lr=0.01)

# import numpy as np
# import matplotlib.pyplot as plt

# grpah_data = model.get_attention_matrix()
# fig, ax = plt.subplots()

# min_val, max_val = 0, 1000
# intersection_matrix = grpah_data.detach().numpy()

# ax.matshow(intersection_matrix, cmap=plt.cm.Blues)

# for i in xrange(len(grpah_data)):
#     for j in xrange(len(grpah_data)):
#         c = intersection_matrix[j,i]
#         ax.text(i, j, str(c), va='center', ha='center')

# 모델 학습(train) 함수
def train(model, iterator, optimizer, criterion, clip):
    model.train() # 학습 모드
    epoch_loss = 0

    # 전체 학습 데이터를 확인하며
    for i, batch in enumerate(iterator):
        src = batch.src
        trg = batch.trg

        optimizer.zero_grad()

        # 출력 단어의 마지막 인덱스(<eos>)는 제외
        # 입력을 할 때는 <sos>부터 시작하도록 처리
        output, _ = model(src, trg[:,:-1])

        # output: [배치 크기, trg_len - 1, output_dim]
        # trg: [배치 크기, trg_len]

        output_dim = output.shape[-1]

        output = output.contiguous().view(-1, output_dim)
        # 출력 단어의 인덱스 0(<sos>)은 제외
        trg = trg[:,1:].contiguous().view(-1)

        # output: [배치 크기 * trg_len - 1, output_dim]
        # trg: [배치 크기 * trg len - 1]

        # 모델의 출력 결과와 타겟 문장을 비교하여 손실 계산
        loss = criterion(output, trg)
        loss.backward() # 기울기(gradient) 계산

        # 기울기(gradient) clipping 진행
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        # 파라미터 업데이트
        optimizer.step()

        # 전체 손실 값 계산
        epoch_loss += loss.item()

    return epoch_loss / len(iterator)

import math
import time

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

import torch.optim as optim

# Adam optimizer로 학습 최적화
LEARNING_RATE = 0.0005
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# 뒷 부분의 패딩(padding)에 대해서는 값 무시
TRG_PAD_IDX = 3
criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)

import time
import math
import random

N_EPOCHS = 10
CLIP = 1
best_valid_loss = float('inf')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 일반적인 데이터 로더(data loader)의 iterator와 유사하게 사용 가능
train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data),
    batch_size=BATCH_SIZE,
    device=device)

for epoch in range(N_EPOCHS):
    start_time = time.time() # 시작 시간 기록

    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iterator, criterion)

    end_time = time.time() # 종료 시간 기록
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'transformer_german_to_english.pt')

    print(f'Epoch: {epoch + 1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):.3f}')
    print(f'\\tValidation Loss: {valid_loss:.3f} | Validation PPL: {math.exp(valid_loss):.3f}')

!python -m spacy download de_core_news_sm

!pip install datasets

!git clone --recursive <https://github.com/multi30k/dataset.git> multi30k-dataset

from data import Multi30k