Source code for libcity.model.road_representation.ChebConv

import torch
import torch.nn as nn
import numpy as np
from logging import getLogger
from libcity.model.abstract_traffic_state_model import AbstractTrafficStateModel
from libcity.model import loss
from libcity.model import utils


[docs]class ChebConvModule(nn.Module): """ 路网表征模型的基类并不统一 图卷积,将N*C的输入矩阵映射成N*F的输出矩阵,其中邻接矩阵形状N*N。 """ def __init__(self, num_nodes, max_diffusion_step, adj_mx, device, input_dim, output_dim, filter_type): """ K阶切比雪夫估计 Args: num_nodes: 节点个数n max_diffusion_step: K阶 adj_mx: list of 拉普拉斯矩阵 device: 设备 input_dim: 输入维度 output_dim: 输出维度 """ super().__init__() self._num_nodes = num_nodes self._max_diffusion_step = max_diffusion_step # 计算拉普拉斯 supports = utils.get_supports_matrix(adj_mx=adj_mx, filter_type=filter_type) results = [] for support in supports: results.append(utils.build_sparse_matrix(device, support)) self._supports = results self._device = device self._ks = len(self._supports) * self._max_diffusion_step + 1 # Ks self._input_dim = input_dim self._output_dim = output_dim shape = (self._input_dim * self._ks, self._output_dim) self.weight = torch.nn.Parameter(torch.empty(*shape, device=self._device)) self.biases = torch.nn.Parameter(torch.empty(self._output_dim, device=self._device)) torch.nn.init.xavier_normal_(self.weight) torch.nn.init.constant_(self.biases, 0) @staticmethod def _concat(x, x_): x_ = x_.unsqueeze(0) return torch.cat([x, x_], dim=0)
[docs] def forward(self, x): """ GONV :param x: (N, input_dim) :return: (N, output_dim) """ num_nodes, input_dim = x.shape # T(0)=I x(0)=T(0)*x=x x0 = x # (N, C) x = torch.unsqueeze(x0, 0) # (1, N, C) # 3阶[T0,T1,T2] Chebyshev多项式近似g(theta) for support in self._supports: x1 = torch.sparse.mm(support, x0) # supports: N*N; x0: (N, C) --> (N, C) x = self._concat(x, x1) # (2, N, C) for k in range(2, self._max_diffusion_step + 1): x2 = 2 * torch.sparse.mm(support, x1) - x0 x = self._concat(x, x2) # (3, N, C) x1, x0 = x2, x1 # 循环 # x.shape (Ks, N, C) # Ks = len(supports) * self._max_diffusion_step + 1 x = torch.reshape(x, shape=[self._ks, self._num_nodes, input_dim]) # (Ks, N, C) x = x.permute(1, 2, 0) # (N, C, Ks) x = torch.reshape(x, shape=[self._num_nodes, input_dim * self._ks]) # (N, Ks * C) # (N, Ks * C) * (Ks * C, F) --> (N, F) x = torch.matmul(x, self.weight) # (N, F) x += self.biases # (N, F) return x
[docs]class ChebConv(AbstractTrafficStateModel): def __init__(self, config, data_feature): super().__init__(config, data_feature) self.adj_mx = data_feature.get('adj_mx') self.num_nodes = data_feature.get('num_nodes', 1) self.feature_dim = data_feature.get('feature_dim', 1) config['num_nodes'] = self.num_nodes config['feature_dim'] = self.feature_dim self.device = config.get('device', torch.device('cpu')) self._logger = getLogger() self._scaler = self.data_feature.get('scaler') self.max_diffusion_step = config.get('max_diffusion_step', 2) self.output_dim = config.get('output_dim', 32) self.filter_type = config.get('filter_type', 'dual_random_walk') self.model = config.get('model', '') self.dataset = config.get('dataset', '') self.exp_id = config.get('exp_id', None) self.encoder = ChebConvModule(num_nodes=self.num_nodes, max_diffusion_step=self.max_diffusion_step, adj_mx=self.adj_mx, device=self.device, input_dim=self.feature_dim, output_dim=self.output_dim, filter_type=self.filter_type) self.decoder = ChebConvModule(num_nodes=self.num_nodes, max_diffusion_step=self.max_diffusion_step, adj_mx=self.adj_mx, device=self.device, input_dim=self.output_dim, output_dim=self.feature_dim, filter_type=self.filter_type)
[docs] def forward(self, batch): """ 自回归任务 Args: batch: dict, need key 'node_features' contains tensor shape=(N, feature_dim) Returns: torch.tensor: N, feature_dim """ inputs = batch['node_features'] encoder_state = self.encoder(inputs) # N, output_dim np.save('./libcity/cache/{}/evaluate_cache/embedding_{}_{}_{}.npy' .format(self.exp_id, self.model, self.dataset, self.output_dim), encoder_state.detach().cpu().numpy()) output = self.decoder(encoder_state) # N, feature_dim return output
[docs] def calculate_loss(self, batch): """ Args: batch: dict, need key 'node_features', 'node_labels', 'mask' Returns: """ y_true = batch['node_labels'] # N, feature_dim y_predicted = self.predict(batch) # N, feature_dim y_true = self._scaler.inverse_transform(y_true) y_predicted = self._scaler.inverse_transform(y_predicted) mask = batch['mask'] return loss.masked_mse_torch(y_predicted[mask], y_true[mask])
[docs] def predict(self, batch): """ Args: batch: dict, need key 'node_features' Returns: torch.tensor """ return self.forward(batch)