Source code for libcity.model.traffic_speed_prediction.ATDM

import torch
import torch.nn as nn
import torch.nn.functional as F
from logging import getLogger
from libcity.model import loss
from libcity.model.abstract_traffic_state_model import AbstractTrafficStateModel


[docs]class SANN(nn.Module): def __init__(self, n_inp, n_out, t_inp, t_out, n_points, past_t, hidden_dim, dropout): super(SANN, self).__init__() # Variables self.n_inp = n_inp self.n_out = n_out self.t_inp = t_inp self.t_out = t_out self.n_points = n_points self.past_t = past_t self.hidden_dim = hidden_dim # Convolutional layer self.conv_block = AgnosticConvBlock(n_inp, n_points, past_t, hidden_dim, num_conv=1) self.convT = nn.ConvTranspose2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=(1, n_points)) # Regressor layer self.regressor = ConvRegBlock(t_inp, t_out, n_points, hidden_dim) # Dropout self.drop = nn.Dropout2d(p=dropout)
[docs] def forward(self, x): N, C, T, S = x.size() # Padding xp = F.pad(x, pad=(0, 0, self.past_t - 1, 0)) # NxCxTxS ---> NxHxTx1 out = self.conv_block(xp) out = out.view(N, self.hidden_dim, T, 1) # NxHxTx1 ---> NxHxTxS out = self.convT(out) # 2D dropout out = self.drop(out) # NxHxTxS ---> NxC'xT'xS out = self.regressor(out.reshape(N, -1, S)) return out.view(N, self.n_out, self.t_out, self.n_points)
[docs]class AgnosticConvBlock(nn.Module): def __init__(self, n_inp, n_points, past_t, hidden_dim, num_conv): super(AgnosticConvBlock, self).__init__() layers = [nn.Conv2d(in_channels=n_inp, out_channels=hidden_dim, kernel_size=(past_t, n_points), bias=True), nn.BatchNorm2d(num_features=hidden_dim, affine=True, track_running_stats=True), nn.ReLU()] self.op = nn.Sequential(*layers)
[docs] def forward(self, x): return self.op(x)
[docs]class ConvRegBlock(nn.Module): def __init__(self, t_inp, t_out, n_points, hidden_dim): super(ConvRegBlock, self).__init__() layers = [nn.Conv1d(in_channels=hidden_dim * t_inp, out_channels=t_out, kernel_size=1, bias=True), nn.BatchNorm1d(num_features=t_out, affine=True, track_running_stats=True)] self.op = nn.Sequential(*layers)
[docs] def forward(self, x): return self.op(x)
[docs]class ATDM(AbstractTrafficStateModel): def __init__(self, config, data_feature): super().__init__(config, data_feature) # get data feature self.adj_mx = self.data_feature.get('adj_mx') self.num_nodes = self.data_feature.get('num_nodes', 1) self.feature_dim = self.data_feature.get('feature_dim', 1) self.output_dim = self.data_feature.get('output_dim', 1) self._scaler = self.data_feature.get('scaler') # get model config self.hidden_size = config.get('hidden_size', 64) self.device = config.get('device', torch.device('cpu')) self.input_window = config.get('input_window', 12) self.output_window = config.get('output_window', 12) self.past_t = config.get('past_t', 3) self.dropout = config.get('dropout', 0.2) # init logger self._logger = getLogger() # define the model structure self.sann = SANN(self.feature_dim, self.output_dim, self.input_window, self.output_window, self.num_nodes, self.past_t, self.hidden_size, self.dropout)
[docs] def forward(self, batch): input_x = batch['X'].permute(0, 3, 1, 2) # B x 1 x input_window x num_nodes output_y = self.sann(input_x) # bz x 1 x T' x S return output_y.permute(0, 2, 3, 1) # bz x T' x S x 1
[docs] def calculate_loss(self, batch): y_true = batch['y'] y_predicted = self.predict(batch) y_true = self._scaler.inverse_transform(y_true[..., :self.output_dim]) y_predicted = self._scaler.inverse_transform(y_predicted[..., :self.output_dim]) return loss.masked_mse_torch(y_predicted, y_true)
[docs] def predict(self, batch): return self.forward(batch)