Source code for libcity.model.traffic_speed_prediction.TGCLSTM

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from libcity.model import loss
import math
from libcity.model.abstract_traffic_state_model import AbstractTrafficStateModel


[docs]class FilterLinear(nn.Module): def __init__(self, device, input_dim, output_dim, in_features, out_features, filter_square_matrix, bias=True): """ filter_square_matrix : filter square matrix, whose each elements is 0 or 1. """ super(FilterLinear, self).__init__() self.device = device self.in_features = in_features self.out_features = out_features self.num_nodes = filter_square_matrix.shape[0] self.filter_square_matrix = Variable(filter_square_matrix.repeat(output_dim, input_dim).to(device), requires_grad=False) self.weight = Parameter(torch.Tensor(out_features, in_features).to(device)) # [out_features, in_features] if bias: self.bias = Parameter(torch.Tensor(out_features).to(device)) # [out_features] else: self.register_parameter('bias', None) self.reset_parameters()
[docs] def reset_parameters(self): stdv = 1. / math.sqrt(self.weight.size(1)) self.weight.data.uniform_(-stdv, stdv) if self.bias is not None: self.bias.data.uniform_(-stdv, stdv)
[docs] def forward(self, input): return F.linear(input, self.filter_square_matrix.mul(self.weight), self.bias)
def __repr__(self): return self.__class__.__name__ + '(' \ + 'in_features=' + str(self.in_features) \ + ', out_features=' + str(self.out_features) \ + ', bias=' + str(self.bias is not None) + ')'
[docs]class TGCLSTM(AbstractTrafficStateModel): def __init__(self, config, data_feature): super(TGCLSTM, self).__init__(config, data_feature) self.num_nodes = self.data_feature.get('num_nodes', 1) self.input_dim = self.data_feature.get('feature_dim', 1) self.in_features = self.input_dim * self.num_nodes self.output_dim = self.data_feature.get('output_dim', 1) self.out_features = self.output_dim * self.num_nodes self.K = config.get('K_hop_numbers', 3) self.back_length = config.get('back_length', 3) self.dataset_class = config.get('dataset_class', 'TrafficSpeedDataset') self.device = config.get('device', torch.device('cpu')) self._scaler = self.data_feature.get('scaler') self.A_list = [] # Adjacency Matrix List adj_mx = data_feature['adj_mx'] adj_mx[adj_mx > 1e-4] = 1 adj_mx[adj_mx <= 1e-4] = 0 adj = torch.FloatTensor(adj_mx).to(self.device) adj_temp = torch.eye(self.num_nodes, self.num_nodes, device=self.device) for i in range(self.K): adj_temp = torch.matmul(adj_temp, adj) if config.get('Clamp_A', True): # confine elements of A adj_temp = torch.clamp(adj_temp, max=1.) if self.dataset_class == "TGCLSTMDataset": self.A_list.append( torch.mul(adj_temp, torch.Tensor(data_feature['FFR'][self.back_length]) .to(self.device))) else: self.A_list.append(adj_temp) # a length adjustable Module List for hosting all graph convolutions self.gc_list = nn.ModuleList([FilterLinear(self.device, self.input_dim, self.output_dim, self.in_features, self.out_features, self.A_list[i], bias=False) for i in range(self.K)]) hidden_size = self.out_features input_size = self.out_features * self.K self.fl = nn.Linear(input_size + hidden_size, hidden_size) self.il = nn.Linear(input_size + hidden_size, hidden_size) self.ol = nn.Linear(input_size + hidden_size, hidden_size) self.Cl = nn.Linear(input_size + hidden_size, hidden_size) # initialize the neighbor weight for the cell state self.Neighbor_weight = Parameter(torch.FloatTensor(self.out_features, self.out_features).to(self.device)) stdv = 1. / math.sqrt(self.out_features) self.Neighbor_weight.data.uniform_(-stdv, stdv)
[docs] def step(self, step_input, hidden_state, cell_state): x = step_input # [batch_size, in_features] gc = self.gc_list[0](x) # [batch_size, out_features] for i in range(1, self.K): gc = torch.cat((gc, self.gc_list[i](x)), 1) # [batch_size, out_features * K] combined = torch.cat((gc, hidden_state), 1) # [batch_size, out_features * (K+1)] # fl: nn.linear(out_features * (K+1), out_features) f = torch.sigmoid(self.fl(combined)) i = torch.sigmoid(self.il(combined)) o = torch.sigmoid(self.ol(combined)) c_ = torch.tanh(self.Cl(combined)) nc = torch.matmul(cell_state, torch.mul( Variable(self.A_list[-1].repeat(self.output_dim, self.output_dim), requires_grad=False).to(self.device), self.Neighbor_weight)) cell_state = f * nc + i * c_ # [batch_size, out_features] hidden_state = o * torch.tanh(cell_state) # [batch_size, out_features] return hidden_state, cell_state, gc
[docs] def bi_torch(self, a): a[a < 0] = 0 a[a > 0] = 1 return a
[docs] def forward(self, batch): inputs = batch['X'] # [batch_size, input_window, num_nodes, input_dim] batch_size = inputs.size(0) time_step = inputs.size(1) hidden_state, cell_state = self.init_hidden(batch_size) # [batch_size, out_features] outputs = None for i in range(time_step): step_input = torch.squeeze(torch.transpose(inputs[:, i:i + 1, :, :], 2, 3)).reshape(batch_size, -1) hidden_state, cell_state, gc = self.step(step_input, hidden_state, cell_state) # gc: [batch_size, out_features * K] if outputs is None: outputs = hidden_state.unsqueeze(1) # [batch_size, 1, out_features] else: outputs = torch.cat((outputs, hidden_state.unsqueeze(1)), 1) # [batch_size, input_window, out_features] output = torch.transpose(torch.squeeze(outputs[:, -1, :]).reshape(batch_size, self.output_dim, self.num_nodes), 1, 2).unsqueeze(1) return output # [batch_size, 1, num_nodes, out_dim]
[docs] def calculate_loss(self, batch): y_true = batch['y'] y_predicted = self.predict(batch) y_true = self._scaler.inverse_transform(y_true[..., :self.output_dim]) y_predicted = self._scaler.inverse_transform(y_predicted[..., :self.output_dim]) return loss.masked_mse_torch(y_predicted, y_true)
[docs] def predict(self, batch): x = batch['X'] y = batch['y'] output_length = y.shape[1] y_preds = [] x_ = x.clone() for i in range(output_length): batch_tmp = {'X': x_} y_ = self.forward(batch_tmp) y_preds.append(y_.clone()) if y_.shape[3] < x_.shape[3]: y_ = torch.cat([y_, y[:, i:i + 1, :, self.output_dim:]], dim=3) x_ = torch.cat([x_[:, 1:, :, :], y_], dim=1) y_preds = torch.cat(y_preds, dim=1) # [batch_size, output_window, batch_size, output_dim] return y_preds
[docs] def init_hidden(self, batch_size): hidden_state = Variable(torch.zeros(batch_size, self.out_features).to(self.device)) cell_state = Variable(torch.zeros(batch_size, self.out_features).to(self.device)) return hidden_state, cell_state