import torch
import torch.nn as nn
from logging import getLogger
from libcity.model.abstract_traffic_state_model import AbstractTrafficStateModel
from libcity.model import loss
[docs]def split_cpt(value, cpt):
if not isinstance(value, torch.Tensor):
raise ValueError('Parameter Value should be a Tensor.')
scale = int(value.size()[1]) // sum(cpt)
split_list = []
for i in cpt:
if i > 0:
split_list.append(i * scale)
if len(split_list) <= 0:
raise ValueError('Get empty split_list.')
return torch.split(value, split_size_or_sections=split_list, dim=1)
[docs]class ConcatConv(nn.Module):
def __init__(self, in_channels1, in_channels2, out_channels, inter_channels, relu_conv=False, seq_len=None):
super().__init__()
self.in_channels1 = in_channels1
self.in_channels2 = in_channels2
self.in_channels = in_channels1 + in_channels2
self.out_channels = out_channels
self.inter_channels = inter_channels
self.relu_conv = relu_conv
self.seq_len = seq_len
if seq_len is not None:
self.in_channels //= seq_len
self.out_channels //= (seq_len)
self.model = nn.ModuleList()
for _ in range(self.seq_len):
self.model.append(self._layer())
else:
self.model = self._layer()
def _conv_layer(self, in_channels, out_channels):
return nn.Conv2d(in_channels, out_channels, 3, 1, 1)
def _layer(self):
if not self.relu_conv:
return self._conv_layer(self.in_channels, self.out_channels)
else:
conv1 = self._conv_layer(self.in_channels, self.inter_channels)
# relu = nn.ReLU(inplace=True)
selu = nn.SELU(inplace=True)
conv2 = self._conv_layer(self.inter_channels, self.out_channels)
# return nn.Sequential(conv1, relu, conv2)
return nn.Sequential(conv1, selu, conv2)
[docs] def forward(self, x, y):
if self.seq_len is not None:
x_splited_list = torch.split(x, self.in_channels1 // self.seq_len, dim=1)
y_splited_list = torch.split(y, self.in_channels2 // self.seq_len, dim=1)
outlist = []
for i in range(self.seq_len):
input = torch.cat([x_splited_list[i], y_splited_list[i]], dim=1)
outlist.append(self.model[i](input))
return torch.cat(outlist, dim=1)
else:
input = torch.cat([x, y], dim=1)
return self.model(input)
[docs]class ConvGate(nn.Module):
def __init__(self, in_channels, height, width, lstm_channels=16, peephole_conn=True):
super().__init__()
self.in_channels = in_channels
self.height = height
self.width = width
self.lstm_channels = lstm_channels
self.peephole_conn = peephole_conn
self.conv_x = self._conv_layer(in_channels)
self.conv_h = self._conv_layer(lstm_channels)
if peephole_conn:
self.w = nn.Parameter(torch.Tensor(lstm_channels, height, width))
self.b = nn.Parameter(torch.Tensor(lstm_channels, 1, 1))
nn.init.kaiming_normal_(self.w.data, a=0, mode='fan_in')
def _linear(self, x):
return x * self.w + self.b
def _conv_layer(self, in_channels):
return nn.Conv2d(in_channels, self.lstm_channels, 3, 1, 1, bias=True) # has bias
[docs] def forward(self, input, state):
hidden_state, cell_state = state
convx = self.conv_x(input)
convh = self.conv_h(hidden_state)
if self.peephole_conn:
conv = convx + convh + self._linear(cell_state)
return torch.sigmoid(conv)
else:
conv = convx + convh
return torch.tanh(conv)
[docs]class ConvLSTMCell(nn.Module):
def __init__(self, in_channels, height, width, lstm_channels=16):
super().__init__()
self.in_channels = in_channels
self.height = height
self.width = width
self.lstm_channels = lstm_channels
self.f_gate = ConvGate(in_channels, height, width, lstm_channels)
self.i_gate = ConvGate(in_channels, height, width, lstm_channels)
self.c_gate = ConvGate(in_channels, height, width, lstm_channels, False)
# The naming of 'c_gate' may be confusing, but it's just like the gate's operation.
self.o_gate = ConvGate(in_channels, height, width, lstm_channels)
[docs] def forward(self, input, state):
hidden_pre, cell_pre = state
f = self.f_gate(input, state)
i = self.i_gate(input, state)
cell_cur = f * cell_pre + i * self.c_gate(input, state)
o = self.o_gate(input, (hidden_pre, cell_cur))
hidden_cur = o * torch.tanh(cell_cur)
return hidden_cur, cell_cur
[docs]class ConvGRUCell(nn.Module):
def __init__(self, in_channels, height, width, lstm_channels=16):
super().__init__()
self.in_channels = in_channels
self.height = height
self.width = width
self.lstm_channels = lstm_channels
self.z_conv = self._conv_layer(in_channels + lstm_channels)
self.r_conv = self._conv_layer(in_channels + lstm_channels)
self.h_conv = self._conv_layer(in_channels + lstm_channels)
def _conv_layer(self, in_channels):
return nn.Conv2d(in_channels, self.lstm_channels, 3, 1, 1, bias=True) # has bias
[docs] def forward(self, input, state):
hidden_pre = state
mix_input = torch.cat([hidden_pre, input], dim=1)
z_t = torch.sigmoid(self.z_conv(mix_input))
r_t = torch.sigmoid(self.r_conv(mix_input))
mix_input = torch.cat([r_t * hidden_pre, input], dim=1)
h_t_hat = torch.tanh(self.h_conv(mix_input))
h_t = (1 - z_t) * hidden_pre + z_t * h_t_hat
return h_t
[docs]class ConvLSTM(nn.Module):
def __init__(self, in_channels, height, width, lstm_channels=16, all_hidden=False,
mode='merge', cpt=None, dropout_rate=0.5, last_conv=False,
conv_channels=None, gru=False):
super().__init__()
self.in_channels = in_channels
self.height = height
self.width = width
self.lstm_channels = lstm_channels
self.all_hidden = all_hidden
self.mode = mode
self.cpt = cpt
self.dropout_rate = dropout_rate
self.last_conv = last_conv
self.conv_channels = conv_channels
self.gru = gru
if gru:
self._lstm_cell = ConvGRUCell(in_channels, height, width, lstm_channels)
else:
self._lstm_cell = ConvLSTMCell(in_channels, height, width, lstm_channels)
if last_conv:
if self.conv_channels is None:
raise ValueError('Parameter Out Channel is needed to enable last_conv')
self._conv_layer = nn.Conv2d(lstm_channels, conv_channels, 3, 1, 1, bias=True)
if dropout_rate > 0:
self._dropout_layer = nn.Dropout2d(dropout_rate)
[docs] def lstm_layer(self, inputs):
n_in, c_in, h_in, w_in = inputs.size()
if self.gru:
state = torch.zeros(n_in, self.lstm_channels, h_in, w_in).cuda()
else:
state = (torch.zeros(n_in, self.lstm_channels, h_in, w_in).cuda(),
torch.zeros(n_in, self.lstm_channels, h_in, w_in).cuda())
seq = torch.split(inputs, self.in_channels, dim=1)
hiddent_list = []
for idx, input in enumerate(seq[::-1]): # using reverse order
state = self._lstm_cell(input, state)
if self.gru:
hidden = state
else:
hidden = state[0]
if self.last_conv:
if self.conv_channels is None:
raise ValueError('Parameter Out Channel is needed to enable last_conv')
hidden = self._conv_layer(hidden)
hiddent_list.append(hidden)
if not self.all_hidden:
return hiddent_list[-1]
else:
hiddent_list.reverse()
return torch.cat(hiddent_list, 1)
[docs] def forward(self, inputs):
if self.dropout_rate > 0:
inputs = self._dropout_layer(inputs)
if self.mode == 'merge':
output = self.lstm_layer(inputs)
return output
elif self.mode == 'cpt':
if self.cpt is None:
raise ValueError('Parameter \'cpt\' is required in mode \'cpt\' of ConvLSTM')
cpt_seq = split_cpt(inputs, self.cpt)
output_list = [
self.lstm_layer(input_) for input_ in cpt_seq
]
output = torch.cat(output_list, 1)
return output
else:
raise ('Invalid LSTM mode: ' + self.mode)
[docs]class ExtNN(nn.Module):
def __init__(self, in_features, out_height, out_width, out_channels,
inter_features=10, map=True, relu=True, mode='inter', dropout_rate=0):
super().__init__()
self.in_features = in_features
self.out_height = out_height
self.out_width = out_width
self.out_channels = out_channels
self.inter_features = inter_features
self.map = map
self.relu = relu
self.mode = mode
self.dropout_rate = dropout_rate
self.out_features = self.out_height * self.out_width * self.out_channels
self.model = self.external_block()
[docs] def external_block(self):
layers = []
layers.append(nn.Linear(self.in_features, self.inter_features))
# layers.append(nn.ReLU(inplace=True))
if self.dropout_rate > 0:
layers.append(nn.Dropout(self.dropout_rate))
layers.append(nn.SELU(inplace=True))
layers.append(nn.Linear(self.inter_features, self.out_features))
if self.relu:
# layers.append(nn.ReLU(inplace=True))
layers.append(nn.SELU(inplace=True))
return nn.Sequential(*layers)
[docs] def forward(self, x):
if self.mode == 'inter':
inputs = torch.split(x, 1, dim=1)
exts = []
for input in inputs:
input = input.squeeze(1)
out = self.model(input)
if self.map:
out = out.view(-1, self.out_channels, self.out_height, self.out_width)
exts.append(out)
return torch.cat(exts, 1)
else:
out = self.model(x)
if self.map:
out = out.view(-1, self.out_channels, self.out_height, self.out_width)
return out
[docs]class ResUnit(nn.Module):
def __init__(self, filters, bnmode=True):
super().__init__()
self.filters = filters
self.bnmode = bnmode
self.layer1 = self._bn_relu_conv()
self.layer2 = self._bn_relu_conv()
def _conv_layer(self):
return nn.Conv2d(self.filters, self.filters, 3, 1, 1)
def _bn_relu_conv(self):
layers = []
if self.bnmode:
layers.append(nn.BatchNorm2d(self.filters))
# layers.append(nn.ReLU(inplace=True))
layers.append(nn.SELU(inplace=True))
layers.append(self._conv_layer())
return nn.Sequential(*layers)
[docs] def forward(self, x):
residual = self.layer1(x)
residual = self.layer2(residual)
out = residual + x
return out
[docs]class ResNN(nn.Module):
def __init__(self, in_channels, out_channels, inter_channels, repetation=1,
bnmode=True, splitmode='split', cpt=None):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.inter_channels = inter_channels
self.repetation = repetation
self.bnmode = bnmode
self.inlist = []
self.resblocks = nn.ModuleList()
if splitmode == 'split':
seq_num = sum(cpt)
inscale = int(in_channels) // seq_num
outscale = int(out_channels) // seq_num
resblock = self.residual_block(inscale, outscale)
for i in range(seq_num):
self.inlist.append(inscale)
self.resblocks.append(resblock)
elif splitmode == 'split-chans':
seq_num = sum(cpt) * 2
inscale = int(in_channels) // seq_num
outscale = int(out_channels) // seq_num
resblock = self.residual_block(inscale, outscale)
for i in range(seq_num):
self.inlist.append(inscale)
self.resblocks.append(resblock)
elif splitmode == 'concat':
self.inlist.append(in_channels)
self.resblocks.append(self.residual_block(in_channels, out_channels))
elif splitmode == 'cpt':
seq_num = sum(cpt)
inscale = int(in_channels) // seq_num
outscale = int(out_channels) // seq_num
for i in cpt:
if i > 0:
self.inlist.append(i * inscale)
self.resblocks.append(self.residual_block(i * inscale, i * outscale))
elif splitmode == 'cpt-sameoutput':
seq_num = sum(cpt)
inscale = int(in_channels) // seq_num
for i in cpt:
if i > 0:
self.inlist.append(i * inscale)
self.resblocks.append(self.residual_block(i * inscale, 2))
else:
raise ValueError('Invalid ResNN split mode')
[docs] def residual_block(self, in_channels, out_channels):
layers = []
layers.append(nn.Conv2d(in_channels, self.inter_channels, 3, 1, 1))
for _ in range(self.repetation):
layers.append(ResUnit(self.inter_channels, self.bnmode))
# layers.append(nn.ReLU(inplace=True))
layers.append(nn.SELU(inplace=True))
layers.append(nn.Conv2d(self.inter_channels, out_channels, 3, 1, 1))
return nn.Sequential(*layers)
[docs] def forward(self, x):
inputs = torch.split(x, split_size_or_sections=self.inlist, dim=1)
if len(inputs) != len(self.resblocks):
raise ValueError('Input length and network in_channels are inconsistent')
outputs = []
for i in range(len(inputs)):
outputs.append(self.resblocks[i](inputs[i]))
return torch.cat(outputs, dim=1)
[docs]class ACFM(AbstractTrafficStateModel):
def __init__(self, config, data_feature):
super().__init__(config, data_feature)
self._scaler = self.data_feature.get('scaler')
self.adj_mx = self.data_feature.get('adj_mx')
self.num_nodes = self.data_feature.get('num_nodes', 1)
self.feature_dim = self.data_feature.get('feature_dim', 2)
self.ext_dim = self.data_feature.get('ext_dim', 0)
self.output_dim = self.data_feature.get('output_dim', 2)
self.len_row = self.data_feature.get('len_row', 32)
self.len_column = self.data_feature.get('len_column', 32)
self.len_closeness = self.data_feature.get('len_closeness', 4)
self.len_period = self.data_feature.get('len_period', 2)
self.len_trend = self.data_feature.get('len_trend', 0)
self._logger = getLogger()
self.len_seq = self.len_closeness + self.len_period + self.len_trend
self.cpt = [self.len_closeness, self.len_period, self.len_trend]
self.len_local = self.len_closeness
self.len_global = self.len_period + self.len_trend
self.res_repetation = config.get('res_repetation', 12)
self.res_nbfilter = config.get('res_nbfilter', 16)
self.res_bn = config.get('res_bn', True)
self.res_split_mode = config.get('res_split_mode', 'split') # 'split', 'split-chans', 'cpt', 'concat', 'none'
self.first_extnn_inter_channels = config.get('first_extnn_inter_channels', 40)
self.first_extnn_dropout = config.get('first_extnn_dropout', 0.5)
self.merge_mode = config.get('merge_mode', 'fuse') # 'LSTM', 'fuse'
self.lstm_channels = config.get('lstm_channels', 16)
self.lstm_dropout = config.get('lstm_dropout', 0)
self.device = config.get('device', torch.device('cpu'))
self.resnn = ResNN(
in_channels=self.feature_dim * self.len_seq,
out_channels=self.lstm_channels * self.len_seq,
inter_channels=self.res_nbfilter,
repetation=self.res_repetation,
bnmode=self.res_bn,
splitmode=self.res_split_mode,
cpt=self.cpt
)
self.conv_lstm = ConvLSTM(
in_channels=self.lstm_channels,
height=self.len_row,
width=self.len_column,
lstm_channels=self.lstm_channels,
all_hidden=True,
mode='cpt',
cpt=[self.len_local, self.len_global, 0],
dropout_rate=self.lstm_dropout,
last_conv=False
)
self.concat_conv_c = ConcatConv(
in_channels1=2 * self.len_local,
in_channels2=self.lstm_channels * self.len_local,
out_channels=self.lstm_channels * self.len_local,
inter_channels=self.lstm_channels,
relu_conv=True,
seq_len=self.len_local
)
if self.len_global > 0:
self.concat_conv_t = ConcatConv(
in_channels1=2 * self.len_global,
in_channels2=self.lstm_channels * self.len_global,
out_channels=self.lstm_channels * self.len_global,
inter_channels=self.lstm_channels,
relu_conv=True,
seq_len=self.len_global
)
self.conv_lstm_c = ConvLSTM(
in_channels=self.lstm_channels,
height=self.len_row,
width=self.len_column,
lstm_channels=self.lstm_channels,
all_hidden=False,
mode='merge',
dropout_rate=self.lstm_dropout,
last_conv=True,
conv_channels=2,
)
self.conv_lstm_t = ConvLSTM(
in_channels=self.lstm_channels,
height=self.len_row,
width=self.len_column,
lstm_channels=self.lstm_channels,
all_hidden=False,
mode='merge',
dropout_rate=self.lstm_dropout,
last_conv=True,
conv_channels=2,
)
if self.ext_dim > 0:
self.extnn = ExtNN(
in_features=self.ext_dim,
out_height=self.len_row,
out_width=self.len_column,
out_channels=self.lstm_channels,
inter_features=self.first_extnn_inter_channels,
mode='inter',
dropout_rate=self.first_extnn_dropout
)
self.time_aware_extnn = ExtNN(
in_features=self.ext_dim,
out_height=1,
out_width=1,
out_channels=1,
inter_features=32,
map=False,
relu=False,
mode='last'
)
[docs] def forward(self, batch):
x = batch['X'] # (batch_size, T_c+T_p+T_t, len_row, len_column, feature_dim)
x_ext = batch['X_ext'] # (batch_size, T_c+T_p+T_t, ext_dim)
y_ext = batch['y_ext'] # (batch_size, ext_dim)
batch_size, len_time, len_row, len_column, input_dim = x.shape
assert len_row == self.len_row
assert len_column == self.len_column
assert len_time == self.len_seq
assert input_dim == self.feature_dim
x = x.view(batch_size, len_time * input_dim, len_row, len_column).to(self.device)
features = self.resnn(x) # (batch_size, lstm_channels * len_seq, h, w)
if self.ext_dim > 0:
ext = self.extnn(x_ext) # (batch_size, lstm_channels * len_seq, h, w)
features = features + ext # (batch_size, lstm_channels * len_seq, h, w)
# calc attention using Conv-LSTM
# (batch_size, lstm_channels * len_seq, h, w)
hidden_list = self.conv_lstm(features)
# (batch_size, lstm_channels * len_local, h, w)
hidden_list_c = hidden_list[:, :self.lstm_channels * self.len_local]
features_c = features[:, :self.lstm_channels * self.len_local]
attention_c = self.concat_conv_c(features_c, hidden_list_c)
phase_c = features_c * (1 + attention_c)
pred_c = self.conv_lstm_c(phase_c) # (batch_size, 2, h, w)
# (batch_size, lstm_channels * (len_seq - len_local), h, w)
if self.len_global > 0:
hidden_list_t = hidden_list[:, self.lstm_channels * self.len_local:]
features_t = features[:, self.lstm_channels * self.len_local:]
attention_t = self.concat_conv_t(features_t, hidden_list_t)
phase_t = features_t * (1 + attention_t)
pred_t = self.conv_lstm_t(phase_t) # (batch_size, 2, h, w)
else:
pred_t = torch.zeros((batch_size, self.output_dim, self.len_row, self.len_column)).to(self.device)
if self.ext_dim > 0:
time_aware = self.time_aware_extnn(y_ext) # (batch_size, 1)
self.time_aware_c = torch.sigmoid(time_aware) # (batch_size, 1)
self.time_aware_t = torch.sigmoid(-1 * time_aware) # (batch_size, 1)
# time_aware_c + time_aware_t = 1 (sigmoid函数导致)
time_aware_c = self.time_aware_c.view(-1, 1, 1, 1) # (batch_size, 1, 1, 1)
time_aware_t = self.time_aware_t.view(-1, 1, 1, 1) # (batch_size, 1, 1, 1)
pred = time_aware_c * pred_c + time_aware_t * pred_t # (batch_size, 2, h, w)
else:
if self.len_global > 0:
pred = 0.5 * pred_c + 0.5 * pred_t
else:
pred = pred_c
h = torch.tanh(pred) # (64, 2, h, w)
h = h.view(batch_size, 1, len_row, len_column, self.output_dim)
return h
[docs] def calculate_loss(self, batch):
y_true = batch['y']
y_predicted = self.predict(batch)
y_true = self._scaler.inverse_transform(y_true[..., :self.output_dim])
y_predicted = self._scaler.inverse_transform(y_predicted[..., :self.output_dim])
return loss.masked_mse_torch(y_predicted, y_true)
[docs] def predict(self, batch):
return self.forward(batch)
"""
bast parameter:
TaxiBJ:
res_repetation = 12
res_nbfilter = 16
res_bn = True
res_split_mode = 'split' # 'split', 'split-chans', 'cpt', 'concat', 'none'
first_extnn_inter_channels = 40
first_extnn_dropout = 0.5
merge_mode = 'fuse' # 'LSTM', 'fuse'
lstm_channels = 16
lstm_dropout = 0
BikeNYC:
res_repetation = 2
res_nbfilter = 16
res_bn = True
res_split_mode = 'split' # 'split', 'split-chans', 'cpt', 'concat', 'none'
first_extnn_inter_channels = 30
first_extnn_dropout = 0.5
merge_mode = 'fuse' # 'LSTM', 'fuse'
lstm_channels = 16
lstm_dropout = 0.5
TaxiNYC:
res_repetation = 4
res_nbfilter = 32
res_bn = True
res_split_mode = 'split' # 'split', 'split-chans', 'cpt', 'concat', 'none'
first_extnn_inter_channels = 40
first_extnn_dropout = 0.5
merge_mode = 'fuse' # 'LSTM', 'fuse'
lstm_channels = 16
lstm_dropout = 0
"""