import torch
import torch.nn as nn
import torch.nn.functional as F
from logging import getLogger
from libcity.model.abstract_traffic_state_model import AbstractTrafficStateModel
from libcity.model import loss
import numpy as np
import scipy.sparse as sp
from scipy.sparse import linalg
[docs]def sym_adj(adj):
"""Symmetrically normalize adjacency matrix."""
adj = sp.coo_matrix(adj)
rowsum = np.array(adj.sum(1))
d_inv_sqrt = np.power(rowsum, -0.5).flatten()
d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).astype(np.float32).todense()
[docs]def asym_adj(adj):
adj = sp.coo_matrix(adj)
rowsum = np.array(adj.sum(1)).flatten()
d_inv = np.power(rowsum, -1).flatten()
d_inv[np.isinf(d_inv)] = 0.
d_mat = sp.diags(d_inv)
return d_mat.dot(adj).astype(np.float32).todense()
[docs]def calculate_normalized_laplacian(adj):
"""
# L = D^-1/2 (D-A) D^-1/2 = I - D^-1/2 A D^-1/2
# D = diag(A 1)
:param adj:
:return:
"""
adj = sp.coo_matrix(adj)
d = np.array(adj.sum(1))
d_inv_sqrt = np.power(d, -0.5).flatten()
d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
normalized_laplacian = sp.eye(adj.shape[0]) - adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()
return normalized_laplacian
[docs]def calculate_scaled_laplacian(adj_mx, lambda_max=2, undirected=True):
if undirected:
adj_mx = np.maximum.reduce([adj_mx, adj_mx.T])
lap = calculate_normalized_laplacian(adj_mx)
if lambda_max is None:
lambda_max, _ = linalg.eigsh(lap, 1, which='LM')
lambda_max = lambda_max[0]
lap = sp.csr_matrix(lap)
m, _ = lap.shape
identity = sp.identity(m, format='csr', dtype=lap.dtype)
lap = (2 / lambda_max * lap) - identity
return lap.astype(np.float32).todense()
[docs]class NConv(nn.Module):
def __init__(self):
super(NConv, self).__init__()
[docs] def forward(self, x, adj):
x = torch.einsum('ncvl,vw->ncwl', (x, adj))
return x.contiguous()
[docs]class Linear(nn.Module):
def __init__(self, c_in, c_out):
super(Linear, self).__init__()
self.mlp = torch.nn.Conv2d(c_in, c_out, kernel_size=(1, 1), padding=(0, 0), stride=(1, 1), bias=True)
[docs] def forward(self, x):
return self.mlp(x)
[docs]class GCN(nn.Module):
def __init__(self, c_in, c_out, dropout, support_len=3, order=2):
super(GCN, self).__init__()
self.nconv = NConv()
c_in = (order*support_len+1)*c_in
self.mlp = Linear(c_in, c_out)
self.dropout = dropout
self.order = order
[docs] def forward(self, x, support):
out = [x]
for a in support:
x1 = self.nconv(x, a)
out.append(x1)
for k in range(2, self.order + 1):
x2 = self.nconv(x1, a)
out.append(x2)
x1 = x2
h = torch.cat(out, dim=1)
h = self.mlp(h)
h = F.dropout(h, self.dropout, training=self.training)
return h
[docs]class GWNET(AbstractTrafficStateModel):
def __init__(self, config, data_feature):
self.adj_mx = data_feature.get('adj_mx')
self.num_nodes = data_feature.get('num_nodes', 1)
self.feature_dim = data_feature.get('feature_dim', 2)
super().__init__(config, data_feature)
self.dropout = config.get('dropout', 0.3)
self.blocks = config.get('blocks', 4)
self.layers = config.get('layers', 2)
self.gcn_bool = config.get('gcn_bool', True)
self.addaptadj = config.get('addaptadj', True)
self.adjtype = config.get('adjtype', 'doubletransition')
self.randomadj = config.get('randomadj', True)
self.aptonly = config.get('aptonly', True)
self.kernel_size = config.get('kernel_size', 2)
self.nhid = config.get('nhid', 32)
self.residual_channels = config.get('residual_channels', self.nhid)
self.dilation_channels = config.get('dilation_channels', self.nhid)
self.skip_channels = config.get('skip_channels', self.nhid * 8)
self.end_channels = config.get('end_channels', self.nhid * 16)
self.input_window = config.get('input_window', 1)
self.output_window = config.get('output_window', 1)
self.output_dim = self.data_feature.get('output_dim', 1)
self.device = config.get('device', torch.device('cpu'))
self._logger = getLogger()
self._scaler = self.data_feature.get('scaler')
self.filter_convs = nn.ModuleList()
self.gate_convs = nn.ModuleList()
self.residual_convs = nn.ModuleList()
self.skip_convs = nn.ModuleList()
self.bn = nn.ModuleList()
self.gconv = nn.ModuleList()
self.start_conv = nn.Conv2d(in_channels=self.feature_dim,
out_channels=self.residual_channels,
kernel_size=(1, 1))
self.cal_adj(self.adjtype)
self.supports = [torch.tensor(i).to(self.device) for i in self.adj_mx]
if self.randomadj:
self.aptinit = None
else:
self.aptinit = self.supports[0]
if self.aptonly:
self.supports = None
receptive_field = self.output_dim
self.supports_len = 0
if self.supports is not None:
self.supports_len += len(self.supports)
if self.gcn_bool and self.addaptadj:
if self.aptinit is None:
if self.supports is None:
self.supports = []
self.nodevec1 = nn.Parameter(torch.randn(self.num_nodes, 10).to(self.device),
requires_grad=True).to(self.device)
self.nodevec2 = nn.Parameter(torch.randn(10, self.num_nodes).to(self.device),
requires_grad=True).to(self.device)
self.supports_len += 1
else:
if self.supports is None:
self.supports = []
m, p, n = torch.svd(self.aptinit)
initemb1 = torch.mm(m[:, :10], torch.diag(p[:10] ** 0.5))
initemb2 = torch.mm(torch.diag(p[:10] ** 0.5), n[:, :10].t())
self.nodevec1 = nn.Parameter(initemb1, requires_grad=True).to(self.device)
self.nodevec2 = nn.Parameter(initemb2, requires_grad=True).to(self.device)
self.supports_len += 1
for b in range(self.blocks):
additional_scope = self.kernel_size - 1
new_dilation = 1
for i in range(self.layers):
# dilated convolutions
self.filter_convs.append(nn.Conv2d(in_channels=self.residual_channels,
out_channels=self.dilation_channels,
kernel_size=(1, self.kernel_size), dilation=new_dilation))
# print(self.filter_convs[-1])
self.gate_convs.append(nn.Conv1d(in_channels=self.residual_channels,
out_channels=self.dilation_channels,
kernel_size=(1, self.kernel_size), dilation=new_dilation))
# print(self.gate_convs[-1])
# 1x1 convolution for residual connection
self.residual_convs.append(nn.Conv1d(in_channels=self.dilation_channels,
out_channels=self.residual_channels,
kernel_size=(1, 1)))
# 1x1 convolution for skip connection
self.skip_convs.append(nn.Conv1d(in_channels=self.dilation_channels,
out_channels=self.skip_channels,
kernel_size=(1, 1)))
self.bn.append(nn.BatchNorm2d(self.residual_channels))
new_dilation *= 2
receptive_field += additional_scope
additional_scope *= 2
if self.gcn_bool:
self.gconv.append(GCN(self.dilation_channels, self.residual_channels,
self.dropout, support_len=self.supports_len))
self.end_conv_1 = nn.Conv2d(in_channels=self.skip_channels,
out_channels=self.end_channels,
kernel_size=(1, 1),
bias=True)
self.end_conv_2 = nn.Conv2d(in_channels=self.end_channels,
out_channels=self.output_window,
kernel_size=(1, 1),
bias=True)
self.receptive_field = receptive_field
self._logger.info('receptive_field: '+str(self.receptive_field))
[docs] def forward(self, batch):
inputs = batch['X'] # (batch_size, input_window, num_nodes, feature_dim)
inputs = inputs.transpose(1, 3) # (batch_size, feature_dim, num_nodes, input_window)
inputs = nn.functional.pad(inputs, (1, 0, 0, 0)) # (batch_size, feature_dim, num_nodes, input_window+1)
in_len = inputs.size(3)
if in_len < self.receptive_field:
x = nn.functional.pad(inputs, (self.receptive_field-in_len, 0, 0, 0))
else:
x = inputs
x = self.start_conv(x) # (batch_size, residual_channels, num_nodes, self.receptive_field)
skip = 0
# calculate the current adaptive adj matrix once per iteration
new_supports = None
if self.gcn_bool and self.addaptadj and self.supports is not None:
adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1)
new_supports = self.supports + [adp]
# WaveNet layers
for i in range(self.blocks * self.layers):
# |----------------------------------------| *residual*
# | |
# | |-- conv -- tanh --| |
# -> dilate -|----| * ----|-- 1x1 -- + --> *input*
# |-- conv -- sigm --| |
# 1x1
# |
# ---------------------------------------> + -------------> *skip*
# (dilation, init_dilation) = self.dilations[i]
# residual = dilation_func(x, dilation, init_dilation, i)
residual = x
# (batch_size, residual_channels, num_nodes, self.receptive_field)
# dilated convolution
filter = self.filter_convs[i](residual)
# (batch_size, dilation_channels, num_nodes, receptive_field-kernel_size+1)
filter = torch.tanh(filter)
gate = self.gate_convs[i](residual)
# (batch_size, dilation_channels, num_nodes, receptive_field-kernel_size+1)
gate = torch.sigmoid(gate)
x = filter * gate
# (batch_size, dilation_channels, num_nodes, receptive_field-kernel_size+1)
# parametrized skip connection
s = x
# (batch_size, dilation_channels, num_nodes, receptive_field-kernel_size+1)
s = self.skip_convs[i](s)
# (batch_size, skip_channels, num_nodes, receptive_field-kernel_size+1)
try:
skip = skip[:, :, :, -s.size(3):]
except(Exception):
skip = 0
skip = s + skip
# (batch_size, skip_channels, num_nodes, receptive_field-kernel_size+1)
if self.gcn_bool and self.supports is not None:
# (batch_size, dilation_channels, num_nodes, receptive_field-kernel_size+1)
if self.addaptadj:
x = self.gconv[i](x, new_supports)
else:
x = self.gconv[i](x, self.supports)
# (batch_size, residual_channels, num_nodes, receptive_field-kernel_size+1)
else:
# (batch_size, dilation_channels, num_nodes, receptive_field-kernel_size+1)
x = self.residual_convs[i](x)
# (batch_size, residual_channels, num_nodes, receptive_field-kernel_size+1)
# residual: (batch_size, residual_channels, num_nodes, self.receptive_field)
x = x + residual[:, :, :, -x.size(3):]
# (batch_size, residual_channels, num_nodes, receptive_field-kernel_size+1)
x = self.bn[i](x)
x = F.relu(skip)
# (batch_size, skip_channels, num_nodes, self.output_dim)
x = F.relu(self.end_conv_1(x))
# (batch_size, end_channels, num_nodes, self.output_dim)
x = self.end_conv_2(x)
# (batch_size, output_window, num_nodes, self.output_dim)
return x
[docs] def cal_adj(self, adjtype):
if adjtype == "scalap":
self.adj_mx = [calculate_scaled_laplacian(self.adj_mx)]
elif adjtype == "normlap":
self.adj_mx = [calculate_normalized_laplacian(self.adj_mx).astype(np.float32).todense()]
elif adjtype == "symnadj":
self.adj_mx = [sym_adj(self.adj_mx)]
elif adjtype == "transition":
self.adj_mx = [asym_adj(self.adj_mx)]
elif adjtype == "doubletransition":
self.adj_mx = [asym_adj(self.adj_mx), asym_adj(np.transpose(self.adj_mx))]
elif adjtype == "identity":
self.adj_mx = [np.diag(np.ones(self.adj_mx.shape[0])).astype(np.float32)]
else:
assert 0, "adj type not defined"
[docs] def calculate_loss(self, batch):
y_true = batch['y']
y_predicted = self.predict(batch)
# print('y_true', y_true.shape)
# print('y_predicted', y_predicted.shape)
y_true = self._scaler.inverse_transform(y_true[..., :self.output_dim])
y_predicted = self._scaler.inverse_transform(y_predicted[..., :self.output_dim])
return loss.masked_mae_torch(y_predicted, y_true, 0)
[docs] def predict(self, batch):
return self.forward(batch)