from logging import getLogger
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from scipy.sparse.linalg import eigs
from libcity.model import loss
from libcity.model.abstract_traffic_state_model import AbstractTrafficStateModel
[docs]def scaled_laplacian(w):
    w = w.astype(float)
    n = np.shape(w)[0]
    d = []
    # simple graph, W_{i,i} = 0
    lap = -w
    # get degree matrix d and Laplacian matrix L
    for i in range(n):
        d.append(np.sum(w[i, :]))
        lap[i, i] = d[i]
    # symmetric normalized Laplacian L
    for i in range(n):
        for j in range(n):
            if (d[i] > 0) and (d[j] > 0):
                lap[i, j] = lap[i, j] / np.sqrt(d[i] * d[j])
    lambda_max = eigs(lap, k=1, which='LR')[0][0].real
    # lambda_max \approx 2.0
    # we can replace this sentence by setting lambda_max = 2
    return 2 * lap / lambda_max - np.identity(n) 
[docs]def cheb_poly(lap, ks):
    n = lap.shape[0]
    lap_list = [np.eye(n), lap[:]]
    for i in range(2, ks):
        lap_list.append(np.matmul(2 * lap, lap_list[-1]) - lap_list[-2])
    # lap_list: (Ks, n*n), Lk (n, Ks*n)
    return np.concatenate(lap_list, axis=-1) 
[docs]class Align(nn.Module):
    def __init__(self, c_in, c_out):
        super(Align, self).__init__()
        self.c_in = c_in
        self.c_out = c_out
        if c_in > c_out:
            self.conv1x1 = nn.Conv2d(c_in, c_out, 1, stride=1, padding=0)  # filter=(1,1)
[docs]    def forward(self, x):  # x: (batch_size, feature_dim(c_in), input_length, num_nodes)
        if self.c_in > self.c_out:
            return self.conv1x1(x)
        if self.c_in < self.c_out:
            return F.pad(x, [0, 0, 0, 0, 0, self.c_out - self.c_in, 0, 0])
        return x  # return: (batch_size, c_out, input_length-1+1, num_nodes-1+1)  
[docs]class ConvST(nn.Module):
    def __init__(self, supports, kt, ks, dim_in, dim_out, device):
        super(ConvST, self).__init__()
        self.supports = supports
        self.kt = kt
        self.ks = ks
        self.dim_in = dim_in
        self.dim_out = dim_out
        self.device = device
        self.align = Align(c_in=dim_in, c_out=dim_out)
        self.weights = nn.Parameter(torch.FloatTensor(
            2 * self.dim_out, self.ks * self.kt * self.dim_in).to(self.device))
        self.biases = nn.Parameter(torch.zeros(2 * self.dim_out).to(self.device))
        nn.init.xavier_uniform_(self.weights)
[docs]    def forward(self, x):
        """
        Args:
            x: torch.tensor, shape=[B, dim_in, T, num_nodes]
        Returns:
            torch.tensor: shape=[B, dim_out, T, num_nodes]
        """
        batch_size, len_time, num_nodes = x.shape[0], x.shape[2], x.shape[3]
        assert x.shape[1] == self.dim_in
        res_input = self.align(x)  # (B, dim_out, T, num_nodes)
        padding = torch.zeros(batch_size, self.dim_in, self.kt - 1, num_nodes).to(self.device)
        # extract spatial-temporal relationships at the same time
        x = torch.cat((x, padding), dim=2)
        # inputs.shape = [B, dim_in, len_time+kt-1, N]
        x = torch.stack([x[:, :, i:i + self.kt, :] for i in range(0, len_time)], dim=2)
        # inputs.shape = [B, dim_in, len_time, kt, N]
        x = torch.reshape(x, (-1, num_nodes, self.kt * self.dim_in))
        # inputs.shape = [B*len_time, N, kt*dim_in]
        conv_out = self.graph_conv(x, self.supports, self.kt * self.dim_in, 2 * self.dim_out)
        # conv_out: [B*len_time, N, 2*dim_out]
        conv_out = torch.reshape(conv_out, [-1, 2 * self.dim_out, len_time, num_nodes])
        # conv_out: [B, 2*dim_out, len_time, N]
        out = (conv_out[:, :self.dim_out, :, :] + res_input) * torch.sigmoid(conv_out[:, self.dim_out:, :, :])
        return out  # [B, dim_out, len_time, N] 
[docs]    def graph_conv(self, inputs, supports, dim_in, dim_out):
        """
        Args:
            inputs: a tensor of shape [batch, num_nodes, dim_in]
            supports: [num_nodes, num_nodes*ks], calculate the chebyshev polynomials in advance to save time
            dim_in:
            dim_out:
        Returns:
            torch.tensor: shape = [batch, num_nodes, dim_out]
        """
        num_nodes = inputs.shape[1]
        assert num_nodes == supports.shape[0]
        assert dim_in == inputs.shape[2]
        # [batch, num_nodes, dim_in] -> [batch, dim_in, num_nodes] -> [batch * dim_in, num_nodes]
        x_new = torch.reshape(inputs.permute(0, 2, 1), (-1, num_nodes))
        # [batch * dim_in, num_nodes] * [num_nodes, num_nodes*ks]
        #       -> [batch * dim_in, num_nodes*ks] -> [batch, dim_in, ks, num_nodes]
        x_new = torch.reshape(torch.matmul(x_new, supports), (-1, dim_in, self.ks, num_nodes))
        # [batch, dim_in, ks, num_nodes] -> [batch, num_nodes, dim_in, ks]
        x_new = x_new.permute(0, 3, 1, 2)
        # [batch, num_nodes, dim_in, ks] -> [batch*num_nodes, dim_in*ks]
        x_new = torch.reshape(x_new, (-1, self.ks * dim_in))
        outputs = F.linear(x_new, self.weights, self.biases)  # [batch*num_nodes, dim_out]
        outputs = torch.reshape(outputs, [-1, num_nodes, dim_out])  # [batch, num_nodes, dim_out]
        return outputs  
[docs]class AttentionT(nn.Module):
    def __init__(self, device, len_time, num_nodes, d_out, ext_dim):
        super(AttentionT, self).__init__()
        self.device = device
        self.len_time = len_time
        self.num_nodes = num_nodes
        self.d_out = d_out
        self.ext_dim = ext_dim
        self.weight1 = nn.Parameter(torch.FloatTensor(self.len_time, self.num_nodes * self.d_out, 1).to(self.device))
        self.weight2 = nn.Parameter(torch.FloatTensor(self.ext_dim, self.len_time).to(self.device))
        self.bias = nn.Parameter(torch.zeros(self.len_time).to(self.device))
        nn.init.xavier_uniform_(self.weight1)
        nn.init.xavier_uniform_(self.weight2)
[docs]    def forward(self, query, x):
        # query  # [B, ext_dim]
        # temporal attention: x.shape = [B, d_out, T, N]
        x_in = torch.reshape(x, (-1, self.num_nodes * self.d_out, self.len_time))
        # x_in.shape = [B, N*d_out, T]
        x = x_in.permute(2, 0, 1)
        # x.shape = [T, B, N*d_out]
        score = torch.reshape(torch.matmul(x, self.weight1), (-1, self.len_time)) + self.bias
        score = score + torch.matmul(query, self.weight2)
        score = torch.softmax(torch.tanh(score), dim=1)
        # score.shape = [B, T]
        x = torch.matmul(x_in, torch.unsqueeze(score, dim=-1))
        # x.shape = [B, N*d_out, 1]
        x = x.permute(0, 2, 1).reshape((-1, 1, self.num_nodes, self.d_out)).permute(0, 3, 1, 2)
        # x = torch.reshape(x, (-1, d_out, 1, N))
        # x.shape = [B, d_out, 1, N]
        return x  
[docs]class AttentionC(nn.Module):
    def __init__(self, device, num_nodes, d_out, ext_dim):
        super(AttentionC, self).__init__()
        self.device = device
        self.num_nodes = num_nodes
        self.d_out = d_out
        self.ext_dim = ext_dim
        self.weight1 = nn.Parameter(torch.FloatTensor(self.d_out, self.num_nodes, 1).to(self.device))
        self.weight2 = nn.Parameter(torch.FloatTensor(self.ext_dim, self.d_out).to(self.device))
        self.bias = nn.Parameter(torch.zeros(self.d_out).to(self.device))
        nn.init.xavier_uniform_(self.weight1)
        nn.init.xavier_uniform_(self.weight2)
[docs]    def forward(self, query, x):
        # query  # [B, ext_dim]
        # channel attention: x.shape = [B, d_out, 1, N]
        x_in = torch.reshape(x, (-1, self.num_nodes, self.d_out))
        # x_in.shape = [B, N, d_out]
        x = x_in.permute(2, 0, 1)
        # x.shape = [d_out, B, N]
        score = torch.reshape(torch.matmul(x, self.weight1), (-1, self.d_out)) + self.bias
        score = score + torch.matmul(query, self.weight2)
        score = torch.softmax(torch.tanh(score), dim=1)
        # score.shape = [B, d_out]
        x = torch.matmul(x_in, torch.unsqueeze(score, dim=-1)).permute(0, 2, 1)
        # x.shape = [B, 1, N] (1->dim)
        x = torch.unsqueeze(x, dim=2)  # [B, 1(dim), 1(T), N]
        return x  
[docs]class STG2Seq(AbstractTrafficStateModel):
    def __init__(self, config, data_feature):
        super().__init__(config, data_feature)
        self.adj_mx = self.data_feature.get('adj_mx')
        self.num_nodes = self.data_feature.get('num_nodes', 1)
        self.feature_dim = self.data_feature.get('feature_dim', 2)
        self.output_dim = self.data_feature.get('output_dim', 2)
        self.ext_dim = self.data_feature.get('ext_dim', 1)
        # self.len_row = self.data_feature.get('len_row', 32)
        # self.len_column = self.data_feature.get('len_column', 32)
        self._scaler = self.data_feature.get('scaler')
        self._logger = getLogger()
        self.input_window = config.get('input_window', 1)
        self.output_window = config.get('output_window', 1)
        self.window = config.get('window', 3)
        self.dim_out = config.get('dim_out', 32)
        self.ks = config.get('ks', 3)
        self.device = config.get('device', torch.device('cpu'))
        self.supports = torch.tensor(cheb_poly(scaled_laplacian(self.adj_mx), self.ks),
                                     dtype=torch.float32).to(self.device)
        self.long_term_layer = nn.Sequential(
            ConvST(self.supports, kt=3, ks=self.ks, dim_in=self.output_dim, dim_out=self.dim_out, device=self.device),
            nn.BatchNorm2d(self.dim_out),
            ConvST(self.supports, kt=3, ks=self.ks, dim_in=self.dim_out, dim_out=self.dim_out, device=self.device),
            nn.BatchNorm2d(self.dim_out),
            ConvST(self.supports, kt=3, ks=self.ks, dim_in=self.dim_out, dim_out=self.dim_out, device=self.device),
            nn.BatchNorm2d(self.dim_out),
            ConvST(self.supports, kt=3, ks=self.ks, dim_in=self.dim_out, dim_out=self.dim_out, device=self.device),
            nn.BatchNorm2d(self.dim_out),
            ConvST(self.supports, kt=3, ks=self.ks, dim_in=self.dim_out, dim_out=self.dim_out, device=self.device),
            nn.BatchNorm2d(self.dim_out),
            ConvST(self.supports, kt=2, ks=self.ks, dim_in=self.dim_out, dim_out=self.dim_out, device=self.device),
            nn.BatchNorm2d(self.dim_out),
        )
        self.short_term_gcn = nn.Sequential(
            ConvST(self.supports, kt=3, ks=self.ks, dim_in=self.output_dim, dim_out=self.dim_out, device=self.device),
            nn.BatchNorm2d(self.dim_out),
            ConvST(self.supports, kt=3, ks=self.ks, dim_in=self.dim_out, dim_out=self.dim_out, device=self.device),
            nn.BatchNorm2d(self.dim_out),
            ConvST(self.supports, kt=3, ks=self.ks, dim_in=self.dim_out, dim_out=self.dim_out, device=self.device),
            nn.BatchNorm2d(self.dim_out),
        )
        self.attention_t = AttentionT(self.device, self.input_window + self.window,
                                      self.num_nodes, self.dim_out, self.ext_dim)
        self.attention_c_1 = AttentionC(self.device, self.num_nodes, self.dim_out, self.ext_dim)
        self.attention_c_2 = AttentionC(self.device, self.num_nodes, self.dim_out, self.ext_dim)
[docs]    def forward(self, batch):
        inputs = batch['X'][:, :, :, :self.output_dim].contiguous()  # (B, input_window, N, output_dim)
        inputs = inputs.permute(0, 3, 1, 2)  # (B, output_dim, input_window, N)
        # input_ext = batch['X'][:, :, 0, self.output_dim:].contiguous()  # (B, input_window, ext_dim)
        batch_size, input_dim, len_time, num_nodes = inputs.shape
        assert num_nodes == self.num_nodes
        assert len_time == self.input_window
        assert input_dim == self.output_dim
        labels = batch['y'][:, :, :, :self.output_dim].contiguous()  # (B, output_window, N, output_dim)
        labels = labels.permute(0, 3, 1, 2)  # (B, output_dim, output_window, N)
        labels_ext = batch['y'][:, :, 0, self.output_dim:].contiguous()  # (B, output_window, ext_dim)
        long_output = self.long_term_layer(inputs)  # (B, dim_out, input_window, N)
        preds = []
        if self.training:
            label_padding = inputs[:, :, -self.window:, :]  # (B, feature_dim, window, N)
            padded_labels = torch.cat((label_padding, labels), dim=2)  # (B, feature_dim, window+output_window, N)
            padded_labels = torch.stack([padded_labels[:, :, i:i + self.window, :]
                                         for i in range(0, self.output_window)], dim=2)
            # (B, feature_dim, output_window, window, N)
            for i in range(0, self.output_window):
                s_inputs = padded_labels[:, :, i, :, :]  # (B, feature_dim, window, N)
                ext_input = labels_ext[:, i, :]  # (B, ext_dim)
                short_output = self.short_term_gcn(s_inputs)  # (B, dim_out, window, N)
                ls_inputs = torch.cat((short_output, long_output), dim=2)
                # (B, dim_out, input_window + window, N)
                ls_inputs = self.attention_t(ext_input, ls_inputs)
                if self.output_dim == 1:
                    pred = self.attention_c_1(ext_input, ls_inputs)
                elif self.output_dim == 2:
                    pred = torch.cat((self.attention_c_1(ext_input, ls_inputs),
                                      self.attention_c_2(ext_input, ls_inputs)), dim=1)
                else:
                    raise ValueError('Error Set output_dim!')
                # pred: (B, output_dim, 1, N)
                label_padding = torch.cat((label_padding[:, :, 1:, :], pred), dim=2)
                preds.append(pred)
        else:
            label_padding = inputs[:, :, -self.window:, :]  # (B, feature_dim, window, N)
            for i in range(0, self.output_window):
                s_inputs = label_padding
                ext_input = labels_ext[:, i, :]  # (B, ext_dim)
                short_output = self.short_term_gcn(s_inputs)  # (B, dim_out, window, N)
                ls_inputs = torch.cat((short_output, long_output), dim=2)
                # (B, dim_out, input_window + window, N)
                ls_inputs = self.attention_t(ext_input, ls_inputs)
                if self.output_dim == 1:
                    pred = self.attention_c_1(ext_input, ls_inputs)
                elif self.output_dim == 2:
                    pred = torch.cat((self.attention_c_1(ext_input, ls_inputs),
                                      self.attention_c_2(ext_input, ls_inputs)), dim=1)
                else:
                    raise ValueError('Error Set output_dim!')
                # pred: (B, output_dim, 1, N)
                label_padding = torch.cat((label_padding[:, :, 1:, :], pred), dim=2)
                preds.append(pred)
        return torch.cat(preds, dim=2).permute(0, 2, 3, 1) 
[docs]    def calculate_loss(self, batch):
        y_true = batch['y']
        y_predicted = self.predict(batch)
        y_true = self._scaler.inverse_transform(y_true[..., :self.output_dim])
        y_predicted = self._scaler.inverse_transform(y_predicted[..., :self.output_dim])
        return loss.masked_mse_torch(y_predicted, y_true) 
[docs]    def predict(self, batch):
        return self.forward(batch)