from logging import getLogger
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from libcity.model import loss
from libcity.model.abstract_traffic_state_model import AbstractTrafficStateModel
[docs]class SSelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == embed_dim
), "Embedding dim needs to be divisible by num_heads"
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(num_heads * self.head_dim, embed_dim)
[docs] def forward(self, values, keys, query):
batch_size, num_nodes, input_window, embed_dim = query.shape
values = values.reshape(batch_size, num_nodes, input_window, self.num_heads, self.head_dim)
keys = keys.reshape(batch_size, num_nodes, input_window, self.num_heads, self.head_dim)
query = query.reshape(batch_size, num_nodes, input_window, self.num_heads, self.head_dim)
values = self.values(values)
keys = self.keys(keys)
queries = self.queries(query)
energy = torch.einsum("bqthd,bkthd->bqkth", [queries, keys])
attention = torch.softmax(energy / (self.embed_dim ** (1 / 2)), dim=2)
out = torch.einsum("bqkth,bkthd->bqthd", [attention, values]).reshape(
batch_size, num_nodes, input_window, self.num_heads * self.head_dim
)
out = self.fc_out(out)
return out
[docs]class TSelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == embed_dim
), "Embedding dim needs to be divisible by num_heads"
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(num_heads * self.head_dim, embed_dim)
[docs] def forward(self, values, keys, query):
batch_size, num_nodes, input_window, embed_dim = query.shape
values = values.reshape(batch_size, num_nodes, input_window, self.num_heads, self.head_dim)
keys = keys.reshape(batch_size, num_nodes, input_window, self.num_heads, self.head_dim)
query = query.reshape(batch_size, num_nodes, input_window, self.num_heads, self.head_dim)
values = self.values(values)
keys = self.keys(keys)
queries = self.queries(query)
energy = torch.einsum("bnqhd,bnkhd->bnqkh", [queries, keys])
attention = torch.softmax(energy / (self.embed_dim ** (1 / 2)), dim=3)
out = torch.einsum("bnqkh,bnkhd->bnqhd", [attention, values]).reshape(
batch_size, num_nodes, input_window, self.num_heads * self.head_dim
)
out = self.fc_out(out)
return out
[docs]class GraphConvolution(nn.Module):
def __init__(self, in_features, out_features, bias=True, device=torch.device('cpu')):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features).to(device))
if bias:
self.bias = nn.Parameter(torch.FloatTensor(out_features).to(device))
else:
self.register_parameter('bias', None)
self.reset_parameters()
[docs] def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
[docs] def forward(self, x, adj_mx):
support = torch.einsum("bnd, dh->bnh", [x, self.weight])
output = torch.einsum("mn,bnh->bmh", [adj_mx, support])
if self.bias is not None:
return output + self.bias
else:
return output
def __repr__(self):
return self.__class__.__name__ + ' (' \
+ str(self.in_features) + ' -> ' \
+ str(self.out_features) + ')'
[docs]class GCN(nn.Module):
def __init__(self, nfeat, nhid, nclass, dropout_rate=0, device=torch.device('cpu')):
super().__init__()
self.gc1 = GraphConvolution(nfeat, nhid, device=device)
self.gc2 = GraphConvolution(nhid, nclass, device=device)
self.dropout_rate = dropout_rate
[docs] def forward(self, x, adj_mx):
x = F.relu(self.gc1(x, adj_mx))
x = F.dropout(x, self.dropout_rate, training=self.training)
x = self.gc2(x, adj_mx)
return F.log_softmax(x, dim=2)
[docs]class Encoder(nn.Module):
def __init__(self, adj_mx, embed_dim=64, num_layers=3, num_heads=2, TG_per_day=288,
forward_expansion=4, dropout_rate=0, device=torch.device('cpu')):
super().__init__()
self.layers = nn.ModuleList([
STTransformerBlock(
adj_mx, embed_dim=embed_dim, num_heads=num_heads, TG_per_day=TG_per_day,
forward_expansion=forward_expansion, dropout_rate=dropout_rate, device=device
)
for _ in range(num_layers)
])
self.dropout_layer = nn.Dropout(dropout_rate)
[docs] def forward(self, x):
out = self.dropout_layer(x)
for layer in self.layers:
out = layer(out, out, out)
return out
[docs]class STTN(AbstractTrafficStateModel):
def __init__(self, config, data_feature):
super().__init__(config, data_feature)
self._scaler = self.data_feature.get('scaler')
self.adj_mx = self.data_feature.get('adj_mx', 1)
# self.num_nodes = self.data_feature.get('num_nodes', 1)
self.feature_dim = self.data_feature.get('feature_dim', 1)
self.output_dim = self.data_feature.get('output_dim', 1)
# self.len_row = self.data_feature.get('len_row', 1)
# self.len_column = self.data_feature.get('len_column', 1)
self._logger = getLogger()
self.device = config.get('device', torch.device('cpu'))
self.embed_dim = config.get('embed_dim', 64)
self.num_layers = config.get('num_layers', 3)
self.num_heads = config.get('num_heads', 2)
self.TG_per_day = config.get('TG_in_one_day', 288) # number of time intevals per day
self.forward_expansion = config.get('forward_expansion', 4)
self.dropout_rate = config.get('dropout_rate', 0)
self.input_window = config.get('input_window', 1)
self.output_window = config.get('output_window', 1)
self.conv1 = nn.Conv2d(self.feature_dim, self.embed_dim, 1)
self.transformer = Transformer(
self.adj_mx, embed_dim=self.embed_dim, num_layers=self.num_layers, num_heads=self.num_heads,
TG_per_day=self.TG_per_day, forward_expansion=self.forward_expansion, dropout_rate=self.dropout_rate,
device=self.device,
)
self.conv2 = nn.Conv2d(self.input_window, self.output_window, 1)
self.conv3 = nn.Conv2d(self.embed_dim, self.output_dim, 1)
self.act_layer = nn.ReLU()
[docs] def forward(self, batch):
inputs = batch['X']
inputs = inputs.permute(0, 3, 2, 1)
input_transformer = self.conv1(inputs)
input_transformer = input_transformer.permute(0, 2, 3, 1)
output_transformer = self.transformer(input_transformer)
output_transformer = output_transformer.permute(0, 2, 1, 3)
out = self.act_layer(self.conv2(output_transformer))
out = out.permute(0, 3, 2, 1)
out = self.conv3(out)
out = out.permute(0, 3, 2, 1)
return out
[docs] def calculate_loss(self, batch):
y_true = batch['y']
y_predicted = self.predict(batch)
y_true = self._scaler.inverse_transform(y_true[..., :self.output_dim])
y_predicted = self._scaler.inverse_transform(y_predicted[..., :self.output_dim])
return loss.masked_mae_torch(y_predicted, y_true)
[docs] def predict(self, batch):
return self.forward(batch)