Source code for libcity.model.traffic_flow_prediction.DSAN

import numpy as np
import torch
import torch.nn as nn
from libcity.model.abstract_traffic_state_model import AbstractTrafficStateModel
from libcity.model.loss import masked_rmse_torch


[docs]def get_angles(pos, l, d): """ equ (5) Args: pos: row(r) / column(c) in equ (5) l: the l-th dimension, with shape (1, d) d: d dimension in total Returns: angles with shape (1, d) """ angle_rates = 1 / np.power(10000, (2 * (l // 2)) / np.float32(d)) return torch.tensor(pos * angle_rates)
[docs]def spatial_posenc(r, c, d, device): """ get SPE Args: r: row of the spatial position c: column of the spatial position d: d dimension in total Returns: """ angle_rads_r = get_angles(pos=r, l=np.arange(d)[np.newaxis, :], d=d) # l and ret with shape (1, d) angle_rads_c = get_angles(pos=c, l=np.arange(d)[np.newaxis, :], d=d) # l and ret with shape (1, d) pos_encoding = torch.zeros(size=angle_rads_r.shape, device=device) # shape (1, d) pos_encoding[:, 0::2] = torch.sin(angle_rads_r[:, 0::2]) # from 0 to d step 2 pos_encoding[:, 1::2] = torch.cos(angle_rads_c[:, 1::2]) # from 1 to d step 2 return pos_encoding[np.newaxis, ...] # (1, 1, d)
[docs]def cal_attention(Q, K, V, M, n_h): """ equ (3), calculate the attention mechanism performed by the i-th attention head Args: Q: query, shape (N, h, L_q, d) K: key, shape (N, h, L_k, d) V: value, shape (N, h, L_k, d) M: mask, shape (N, h, L_q, L_k) n_h: number of attention head Returns: Att: shape # (N, h, L_q, d) """ QK = torch.matmul(input=Q, other=K.transpose(-1, -2)) # (h, L_q, L_k) d = K.shape[-1] d_h = d / n_h # the split dimensionality in n_h attention heads QK_d_h = QK / np.sqrt(d_h) # (h, L_q, L_k) if M is not None: M = M.unsqueeze(2) M = M.repeat(QK_d_h.shape[0] // M.shape[0], 1, 1, 1, 1) QK_d_h += (M * -1e9) attention_weights = torch.softmax(input=QK_d_h, dim=-1) # (h, L_q, L_k) softmax along key axis output = torch.matmul(attention_weights, V) # (h, L_q, d) return output
[docs]def two_layer_ffn(d, num_hid, input_dim): """ implementation of two-layer feed-forward network Args: d: d-dimension representations num_hid: hidden layer size input_dim: input feature dimension Returns: """ return nn.Sequential( nn.Linear(input_dim, num_hid), nn.ReLU(), nn.Linear(num_hid, d) )
[docs]def ex_encoding(d, num_hid, input_dim): """ implementation of TPE Args: d: d-dimension representations num_hid: hidden layer size input_dim: input feature dimension Returns: """ return nn.Sequential( nn.Linear(in_features=input_dim, out_features=num_hid), nn.ReLU(), nn.Linear(in_features=num_hid, out_features=d), nn.Sigmoid() )
[docs]def create_look_ahead_mask(size): mask = 1 - torch.tril(torch.ones((size, size)), diagonal=-1) return mask.cuda()
[docs]def create_threshold_mask(inp): """ Args: inp: [batch_size, input_window, column, row, input_dim] Returns: """ oup = torch.sum(inp, dim=-1) shape = oup.shape oup = torch.reshape(oup, [shape[0], shape[1], -1]) mask = (oup == 0).float() return mask
[docs]def create_threshold_mask_tar(inp): oup = torch.sum(inp, dim=-1) mask = (oup == 0).float() return mask
[docs]def create_masks(inp_g, inp_l, tar): """ Args: inp_g: shape == [batch_size, input_window, column, row, input_dim] inp_l: shape == [batch_size, input_window, column, row, l_d, l_d, input_dim] torch.Size([64, 12, 192, 49, 2]) tar: shape == [batch_size, input_window, N, ext_dim] torch.Size([64, 12, 192, 8]) Returns: """ threshold_mask_g = create_threshold_mask(inp_g).unsqueeze(2) inp_l = inp_l.permute([0, 2, 3, 1, 4, 5, 6]) inp_l = torch.reshape(inp_l, [inp_l.shape[0] * inp_l.shape[1] * inp_l.shape[2], inp_l.shape[3], inp_l.shape[4], inp_l.shape[5], inp_l.shape[6]]) threshold_mask = create_threshold_mask(inp_l).unsqueeze(2) look_ahead_mask = create_look_ahead_mask(tar.shape[1]) tar = tar.permute(0, 2, 1, 3) tar = torch.reshape(tar, [tar.shape[0] * tar.shape[1], tar.shape[2], tar.shape[3]]) dec_target_threshold_mask = create_threshold_mask_tar( tar).unsqueeze(1).unsqueeze(2) combined_mask = torch.max(dec_target_threshold_mask, look_ahead_mask) return threshold_mask_g, threshold_mask, combined_mask
[docs]class Convs(nn.Module): """ Conv layers for input, to form a d-dimension representation """ def __init__(self, n_layer, n_filter, input_window, input_dim, r_d=0.1): """ Args: n_layer: num of conv layers n_filter: num of filters input_window: input window size input_dim: input dimension size r_d: dropout rate """ super(Convs, self).__init__() self.n_layer = n_layer self.input_window = input_window self.convs = nn.ModuleList( [nn.ModuleList([nn.Conv2d(in_channels=input_dim, out_channels=n_filter, kernel_size=(3, 3), padding=(1, 1)) for _ in range(input_window)])]) self.convs += nn.ModuleList( [nn.ModuleList([nn.Conv2d(in_channels=n_filter, out_channels=n_filter, kernel_size=(3, 3), padding=(1, 1)) for _ in range(input_window)]) for _ in range(n_layer - 1)]) self.dropouts = nn.ModuleList([nn.ModuleList( [nn.Dropout(r_d) for _ in range(input_window)]) for _ in range(n_layer)])
[docs] def forward(self, inps): """ Args: inps: with shape [batch_size, input_window, row, column, N_d, input_dim] or [batch_size, input_window, row, column, input_dim] Returns: """ outputs = list(torch.split(inps, 1, dim=1)) if len(inps.shape) == 6: for i in range(self.input_window): outputs[i] = outputs[i].permute([0, 1, 4, 5, 2, 3]) outputs[i] = torch.reshape(input=outputs[i], shape=[-1, outputs[i].shape[3], outputs[i].shape[4], outputs[i].shape[5]]) else: for i in range(self.input_window): outputs[i] = outputs[i].permute([0, 1, 4, 2, 3]) outputs[i] = torch.reshape(input=outputs[i], shape=[-1, outputs[i].shape[2], outputs[i].shape[3], outputs[i].shape[4]]) for i in range(self.n_layer): for j in range(self.input_window): outputs[j] = self.convs[i][j](outputs[j]) outputs[j] = torch.relu(outputs[j]) outputs[j] = self.dropouts[i][j](outputs[j]) output = torch.stack(outputs, dim=1) if len(inps.shape) == 6: output = torch.reshape(input=output, shape=[inps.shape[0], -1, output.shape[1], output.shape[2], output.shape[3], output.shape[4]]).permute([0, 2, 4, 5, 1, 3]) else: output = output.permute([0, 1, 3, 4, 2]) return output
[docs]class MSA(nn.Module): """ Multi-space attention """ def __init__(self, d, n_h, self_att=True): """ Args: d: d-dimension representations after B-layer CNN/FCN n_h: num of head self_att: whether use self attention """ super(MSA, self).__init__() self.d = d self.n_h = n_h self.d_h = d / n_h self.self_att = self_att assert d % n_h == 0 self.d_h = d // n_h if self_att: self.wx = nn.Linear(in_features=d, out_features=d * 3) else: self.wq = nn.Linear(in_features=d, out_features=d) self.wkv = nn.Linear(in_features=d, out_features=d * 2) self.wo = nn.Linear(in_features=d, out_features=d)
[docs] def split_heads(self, x): """ Args: x: shape == [batch_size, input_window, N, d] Returns: shape == [batch_size, input_window, n_h, N, d_h] """ shape = x.shape x = torch.reshape(x, [shape[0], shape[1], shape[2], self.n_h, self.d_h]) return x.permute([0, 1, 3, 2, 4])
[docs] def forward(self, V, K, Q, M): # linear if self.self_att: wx_o = self.wx(Q) Q, K, V = torch.split(tensor=wx_o, split_size_or_sections=wx_o.shape[-1] // 3, dim=-1) else: Q = self.wq(Q) wkv_o = self.wkv(K) K, V = torch.split(tensor=wkv_o, split_size_or_sections=wkv_o.shape[-1] // 2, dim=-1) # split head Q = self.split_heads(Q) K = self.split_heads(K) V = self.split_heads(V) scaled_attention = cal_attention(Q=Q, K=K, V=V, M=M, n_h=self.n_h) scaled_attention = scaled_attention.permute([0, 1, 3, 2, 4]) d_shape = scaled_attention.shape concat_attention = torch.reshape( scaled_attention, (d_shape[0], d_shape[1], d_shape[2], self.d)) output = self.wo(concat_attention) return output
[docs]class EncoderLayer(nn.Module): """ Enc-G implementation """ def __init__(self, d, n_h, num_hid, r_d=0.1): """ Args: d: d-dimension representations n_h: number of heads in Multi-space attention num_hid: hidden layer size r_d: drop out rate """ super(EncoderLayer, self).__init__() # msa self.msa = MSA(d=d, n_h=n_h) # ffn self.ffn = two_layer_ffn(d=d, num_hid=num_hid, input_dim=64) # normalization self.layernorm1 = nn.LayerNorm(normalized_shape=d, eps=1e-6) self.layernorm2 = nn.LayerNorm(normalized_shape=d, eps=1e-6) # dropout self.dropout1 = nn.Dropout(r_d) self.dropout2 = nn.Dropout(r_d)
[docs] def forward(self, x, mask): """ Args: x: shape == [batch_size, input_window, N, d] mask: Returns: """ # msa attn_output = self.msa(x, x, x, mask) attn_output = self.dropout1(attn_output) out1 = self.layernorm1(x + attn_output) # Residual # ffn ffn_output = self.ffn(out1) ffn_output = self.dropout2(ffn_output) out2 = self.layernorm2(out1 + ffn_output) # Residual return out2
[docs]class DecoderLayer(nn.Module): """ Enc-D / Dec-S / Dec-T implementation """ def __init__(self, d, n_h, num_hid, r_d=0.1, revert_q=False): """ Args: d: d-dimension representations n_h: number of heads in Multi-space attention num_hid: hidden layer size r_d: drop out rate revert_q: """ super(DecoderLayer, self).__init__() self.revert_q = revert_q self.msa1 = MSA(d=d, n_h=n_h) self.msa2 = MSA(d=d, n_h=n_h, self_att=False) self.ffn = two_layer_ffn(d=d, num_hid=num_hid, input_dim=d) self.layernorm1 = nn.LayerNorm(normalized_shape=[d], eps=1e-6) self.layernorm2 = nn.LayerNorm(normalized_shape=[d], eps=1e-6) self.layernorm3 = nn.LayerNorm(normalized_shape=[d], eps=1e-6) self.dropout1 = nn.Dropout(r_d) self.dropout2 = nn.Dropout(r_d) self.dropout3 = nn.Dropout(r_d)
[docs] def forward(self, x, kv, look_ahead_mask, threshold_mask): # first msa attn1 = self.msa1(x, x, x, look_ahead_mask) attn1 = self.dropout1(attn1) out1 = self.layernorm1(attn1 + x) if self.revert_q: out1_r = out1.permute([0, 2, 1, 3]) attn2 = self.msa2(kv, kv, out1_r, threshold_mask) attn2 = attn2.permute([0, 2, 1, 3]) else: kv = kv.repeat(out1.shape[0] // kv.shape[0], 1, 1, 1) attn2 = self.msa2(kv, kv, out1, threshold_mask) attn2 = self.dropout2(attn2) out2 = self.layernorm2(attn2 + out1) ffn_output = self.ffn(out2) ffn_output = self.dropout3(ffn_output) out3 = self.layernorm3(ffn_output + out2) return out3
[docs]class DAE(nn.Module): """ DAE Dynamic Attention Encoder """ def __init__(self, L, d, n_h, num_hid, conv_layer, input_window, input_dim, ext_dim, r_d=0.1): """ Dynamic Attention Encoder Args: L: num of Enc-G/Enc-D layers d: d-dimension representations n_h: number of heads in Multi-space attention num_hid: hidden layer size conv_layer: num of conv layers input_window: input window size input_dim: input dimension r_d: drop out rate """ super(DAE, self).__init__() self.d = d self.L = L # conv layers to get d-dimension representations self.convs_d = Convs(n_layer=conv_layer, n_filter=d, input_window=input_window, input_dim=input_dim, r_d=r_d) self.convs_g = Convs(n_layer=conv_layer, n_filter=d, input_window=input_window, input_dim=input_dim, r_d=r_d) # get TPE self.ex_encoder_d = ex_encoding(d=d, num_hid=num_hid, input_dim=ext_dim) self.ex_encoder_g = ex_encoding(d=d, num_hid=num_hid, input_dim=ext_dim) # dropout layer self.dropout_d = nn.Dropout(p=r_d) self.dropout_g = nn.Dropout(p=r_d) # self.Enc_G = nn.ModuleList( [EncoderLayer(d=d, n_h=n_h, num_hid=num_hid, r_d=r_d) for _ in range(L)]) self.Enc_D = nn.ModuleList( [DecoderLayer(d=d, n_h=n_h, num_hid=num_hid, r_d=r_d) for _ in range(L)])
[docs] def forward(self, x_d, x_g, ex, cors_d, cors_g, threshold_mask_d, threshold_mask_g): """ Args: x_d: a subset of 𝑿 that contains the closest neighbors that share strong correlations with v_i within a local block.(X_d in figure 4) x_g: all the training data (X in figure 4) ex: time-related features for Temporal Positional Encoding cors_d: Spatial Positional Encoding of x_d cors_g: Spatial Positional Encoding of x_g threshold_mask_d: threshold_mask_g: Returns: """ shape = x_d.shape TPE_d = self.ex_encoder_d(ex) TPE_g = self.ex_encoder_g(ex) SPE_d = cors_d SPE_g = cors_g x_d = self.convs_d(inps=x_d) x_g = self.convs_g(inps=x_g) x_d *= np.sqrt(self.d) x_g *= np.sqrt(self.d) x_d = x_d.reshape([shape[0], shape[1], -1, shape[4], self.d]) x_g = x_g.reshape([shape[0], shape[1], -1, self.d]) TPE_d = torch.reshape(input=TPE_d, shape=[TPE_d.shape[0], TPE_d.shape[1], TPE_d.shape[2] * TPE_d.shape[3], -1, TPE_d.shape[-1]]) TPE_g = torch.reshape(input=TPE_g, shape=[TPE_g.shape[0], TPE_g.shape[1], TPE_g.shape[2] * TPE_g.shape[3], TPE_g.shape[4]]) x_d = x_d + TPE_d + SPE_d x_g = x_g + TPE_g + SPE_g x_d = self.dropout_d(x_d) x_g = self.dropout_g(x_g) for i in range(self.L): x_g = self.Enc_G[i](x_g, threshold_mask_g) x_d_ = x_d.permute([0, 2, 1, 3, 4]) x_d_ = torch.reshape(x_d_, [x_d_.shape[0] * x_d_.shape[1], x_d_.shape[2], x_d_.shape[3], x_d_.shape[4]]) for i in range(self.L): x_d = self.Enc_D[i]( x_d_, x_g, threshold_mask_d, threshold_mask_g) return x_d
[docs]class SAD(nn.Module): def __init__(self, L, d, n_h, num_hid, conv_layer, ext_dim, input_window, output_window, device, r_d=0.1): """ Args: L: num of Enc-G/Enc-D layers d: d-dimension representations n_h: number of heads in Multi-space attention num_hid: hidden layer size conv_layer: num of conv layers ext_dim: external data dimension input_window: output_window: r_d: drop out rate """ super(SAD, self).__init__() self.d = d self.L = L self.pos_enc = spatial_posenc(0, 0, self.d, device) self.output_window = output_window self.ex_encoder = ex_encoding(d=d, num_hid=num_hid, input_dim=ext_dim) self.dropout = nn.Dropout(r_d) self.li_conv = nn.Sequential() self.li_conv.add_module("linear", nn.Linear(2, d)) self.li_conv.add_module("activation_relu", nn.ReLU()) for i in range(conv_layer - 1): self.li_conv.add_module( "linear{}".format(i), nn.Linear( d, d)) self.li_conv.add_module("activation_relu{}".format(i), nn.ReLU()) self.dec_s = nn.ModuleList( [DecoderLayer(d=d, n_h=n_h, num_hid=num_hid, r_d=r_d) for _ in range(L)]) self.linear = nn.Linear(in_features=input_window, out_features=output_window) self.dec_t = nn.ModuleList( [DecoderLayer(d=d, n_h=n_h, num_hid=num_hid, r_d=r_d, revert_q=True) for _ in range(L)])
[docs] def forward(self, x, ex, dae_output, look_ahead_mask): ex_enc = self.ex_encoder(ex) x = self.li_conv(x) x *= np.sqrt(self.d) ex_enc = torch.reshape(input=ex_enc, shape=[ex_enc.shape[0], ex_enc.shape[1], ex_enc.shape[2] * ex_enc.shape[3], ex_enc.shape[4]]) x = x + ex_enc + self.pos_enc x = self.dropout(x) x_s = x x_t = x x_s = x_s.unsqueeze(3).expand(-1, -1, -1, self.output_window, -1) x_s_ = x_s.permute([0, 2, 1, 3, 4]) x_s_ = torch.reshape(x_s_, [x_s_.shape[0] * x_s_.shape[1], x_s_.shape[2], x_s_.shape[3], x_s_.shape[4]]) # linear dae_output = dae_output.permute(0, 2, 3, 1) dae_output = self.linear(dae_output) dae_output = dae_output.permute(0, 3, 1, 2) for i in range(self.L): x_s = self.dec_s[i](x_s_, dae_output, look_ahead_mask, None) x_s_ = x_s.permute([0, 2, 1, 3]) x_t_ = x_t.permute([0, 2, 1, 3]) x_t_ = torch.reshape(x_t_, [x_t_.shape[0] * x_t_.shape[1], 1, x_t_.shape[2], x_t_.shape[3]]) for i in range(self.L): x_t = self.dec_t[i]( x_t_, x_s_, look_ahead_mask, None) output = x_t.squeeze(1) return output
[docs]class DsanUse(nn.Module): """ DSAN use """ def __init__(self, L, d, n_h, row, column, num_hid, conv_layer, input_window, output_window, input_dim, ext_dim, device, r_d=0.1): """ Args: L: num of layers in Enc-G/D / Dec-S/T d: d-dimension representations n_h: number of heads in Multi-space attention num_hid: conv_layer: num of conv layers input_window: input window size input_dim: input dimension r_d: dropout rate """ super(DsanUse, self).__init__() self.row = row self.column = column # DAE Dynamic Attention Encoder self.dae = DAE(L=L, d=d, n_h=n_h, num_hid=num_hid, conv_layer=conv_layer, input_window=input_window, input_dim=input_dim, ext_dim=ext_dim, r_d=r_d) # SAD Switch-Attention Decoder self.sad = SAD(L=L, d=d, n_h=n_h, num_hid=num_hid, conv_layer=conv_layer, ext_dim=ext_dim, input_window=input_window, output_window=output_window, device=device, r_d=r_d) # final layer self.final_layer = nn.Linear(d, input_dim)
[docs] def forward(self, dae_inp_g, dae_inp, dae_inp_ex, sad_inp, sad_inp_ex, cors, cors_g, threshold_mask, threshold_mask_g, look_ahead_mask): # DAE dae_output = self.dae( x_d=dae_inp, x_g=dae_inp_g, ex=dae_inp_ex, cors_d=cors, cors_g=cors_g, threshold_mask_d=threshold_mask, threshold_mask_g=threshold_mask_g ) # SAD sad_output = self.sad( x=sad_inp, ex=sad_inp_ex, dae_output=dae_output, look_ahead_mask=look_ahead_mask ) # final layer final_output = self.final_layer(sad_output) final_output = torch.tanh(final_output) final_output = torch.reshape(final_output, [-1, self.column, self.row, final_output.shape[-2], final_output.shape[-1]]) final_output = final_output.permute([0, 3, 2, 1, 4]) return final_output
[docs]class DSAN(AbstractTrafficStateModel): def __init__(self, config, data_feature): super().__init__(config, data_feature) # device self.device = config.get('device', torch.device('cpu')) # data_feature self._scaler = self.data_feature.get('scaler') # 用于数据归一化 # self.adj_mx = torch.tensor(self.data_feature.get('adj_mx'), device=self.device) self.len_row = self.data_feature.get('len_row', 16) # row self.len_column = self.data_feature.get('len_column', 12) # column self.num_nodes = self.data_feature.get('num_nodes', 1) # len_row * len_column self.feature_dim = self.data_feature.get('feature_dim', 1) # 输入维度 self.ext_dim = self.data_feature.get('ext_dim', 1) # 额外数据的维度 self.output_dim = self.data_feature.get('output_dim', 1) # b in paper # config self.input_window = config.get('input_window', 12) # l in paper self.output_window = config.get('output_window', 12) # F in paper self.L = config.get('L', 3) # num of layers in Enc-G/D / Dec-S/T self.d = config.get('d', 64) # d-dimension representations self.n_h = config.get('n_h', 8) # num of head in Multi-space Attention self.num_hid = 4 * self.d # hidden layer size self.B = config.get('B', 3) # num of layers in conv self.l_d = config.get('l_d', 3) self.r_d = config.get('r_d', 0.1) # dropout rate self.dsan = DsanUse(L=self.L, d=self.d, n_h=self.n_h, row=self.len_row, column=self.len_column, num_hid=self.num_hid, conv_layer=self.B, input_window=self.input_window, output_window=self.output_window, input_dim=self.output_dim, ext_dim=self.ext_dim, device=self.device, r_d=self.r_d)
[docs] def generate_x(self, batch): """ from batch['X'] to Args: batch: batch['X'].shape == [batch_size, input_window, row, column, feature_dim] batch['y'].shape == [batch_size, output_window, row, column, output_dim] Returns: dae_inp_g: X in figure(2) shape == [batch_size, input_window, row, column, output_dim] dae_inp: X_d in figure(2) shape == [batch_size, input_window, row, column, L_D, L_D output_dim] N_D = L_d * L_d ,L_d = 2 * l_d + 1 dae_inp_ex: external data for TPE shape == [batch_size, input_window, N, external_dim] sad_inp: x in figure(2) shape == [batch_size, output_window, N, output_dim] sad_inp_ex: external data for TPE shape == [batch_size, input_window, N, external_dim] cors: for SPE,shape == [1, 1, N_d, d] cors_g: for SPE, shape == [1, N, d] y: """ X = batch['X'][:, :, :, :, :self.output_dim] X_ext = batch['X'][:, :, :, :, self.output_dim:] X_shape = X.shape # [batch_size, input_window, row, column, feature_dim] l_d = self.l_d # dae_inp_g dae_inp_g = torch.reshape(input=X, shape=[X_shape[0], X_shape[1], X_shape[2], X_shape[3], X_shape[4]]) # dae_inp L_d = 2 * l_d + 1 # l_d: half length of the block (L_d = 2 * l_d + 1) dae_inp = torch.zeros(size=[X_shape[0], X_shape[1], X_shape[2], X_shape[3], L_d, L_d, X_shape[4]], device=self.device) for i in range(X_shape[2]): for j in range(X_shape[3]): dae_inp[:, :, i, j, max(0, l_d - i):min(L_d, X_shape[2] - i + l_d), max(0, l_d - j):min(L_d, X_shape[3] - j + l_d), :] = \ X[:, :, max(0, i - l_d):min(X_shape[2], i + l_d + 1), max(0, j - l_d):min(X_shape[3], j + l_d + 1), :] # dae_inp_ex dae_inp_ex = X_ext # sad_inp sad_inp = torch.reshape(input=X[:, -self.output_window:, :, :, :self.output_dim], shape=[X_shape[0], -1, X_shape[2] * X_shape[3], X_shape[4]]) # sad_inp_ex sad_inp_ex = X_ext[:, -self.output_window:, :, :, :] # cors cors = torch.zeros(size=[L_d, L_d, self.d], device=self.device) for i in range(L_d): for j in range(L_d): cors[i, j, :] = spatial_posenc(i - L_d // 2, j - L_d // 2, self.d, device=self.device) cors = torch.reshape(input=cors, shape=[1, 1, cors.shape[0] * cors.shape[1], cors.shape[2]]) # cors_g cors_g = torch.zeros(size=[self.len_row, self.len_column, self.d], device=self.device) for i in range(self.len_row): for j in range(self.len_column): cors_g[i, j, :] = spatial_posenc(i - self.len_row // 2, j - self.len_column // 2, self.d, device=self.device) cors_g = torch.reshape(input=cors_g, shape=[1, cors_g.shape[0] * cors_g.shape[1], cors_g.shape[2]]) # y y = batch['y'] return dae_inp_g, dae_inp, dae_inp_ex, sad_inp, sad_inp_ex, cors, cors_g, y
[docs] def predict(self, batch): # generate x dae_inp_g, dae_inp, dae_inp_ex, sad_inp, sad_inp_ex, cors, cors_g, y = \ self.generate_x(batch=batch) # generate mask threshold_mask_g, threshold_mask, combined_mask = create_masks( dae_inp_g[..., :self.output_dim], dae_inp[..., :self.output_dim], sad_inp) # reshape dae_inp = torch.reshape(input=dae_inp, shape=[dae_inp.shape[0], dae_inp.shape[1], dae_inp.shape[2], dae_inp.shape[3], dae_inp.shape[4] * dae_inp.shape[5], dae_inp.shape[6]]) res = self.dsan( dae_inp_g=dae_inp_g, dae_inp=dae_inp, dae_inp_ex=dae_inp_ex, sad_inp=sad_inp, sad_inp_ex=sad_inp_ex, cors=cors, cors_g=cors_g, threshold_mask=threshold_mask, threshold_mask_g=threshold_mask_g, look_ahead_mask=combined_mask ) return res
[docs] def calculate_loss(self, batch): y_true = batch['y'] y_pred = self.predict(batch) y_true = self._scaler.inverse_transform(y_true[..., :self.output_dim]) y_predicted = self._scaler.inverse_transform(y_pred[..., :self.output_dim]) return masked_rmse_torch(y_predicted, y_true, 0)