libcity.model.traffic_speed_prediction.TGCN¶
-
class
libcity.model.traffic_speed_prediction.TGCN.
TGCN
(config, data_feature)[source]¶ Bases:
libcity.model.abstract_traffic_state_model.AbstractTrafficStateModel
-
calculate_loss
(batch)[source]¶ 输入一个batch的数据,返回训练过程的loss,也就是需要定义一个loss函数
- Parameters
batch (Batch) – a batch of input
- Returns
return training loss
- Return type
torch.tensor
-
forward
(batch)[source]¶ - Parameters
batch –
a batch of input, batch[‘X’]: shape (batch_size, input_window, num_nodes, input_dim)
batch[‘y’]: shape (batch_size, output_window, num_nodes, output_dim)
- Returns
(batch_size, self.output_window, self.num_nodes, self.output_dim)
- Return type
torch.tensor
-
predict
(batch)[source]¶ 输入一个batch的数据,返回对应的预测值,一般应该是**多步预测**的结果,一般会调用nn.Moudle的forward()方法
- Parameters
batch (Batch) – a batch of input
- Returns
predict result of this batch
- Return type
torch.tensor
-
training
: bool¶
-
-
class
libcity.model.traffic_speed_prediction.TGCN.
TGCNCell
(num_units, adj_mx, num_nodes, device, input_dim=1)[source]¶ Bases:
torch.nn.modules.module.Module
-
_gc
(inputs, state, output_size, bias_start=0.0)[source]¶ GCN
- Parameters
inputs – (batch, self.num_nodes * self.dim)
state – (batch, self.num_nodes * self.gru_units)
output_size –
bias_start –
- Returns
(B, num_nodes , output_size)
- Return type
torch.tensor
-
forward
(inputs, state)[source]¶ Gated recurrent unit (GRU) with Graph Convolution.
- Parameters
inputs – shape (batch, self.num_nodes * self.dim)
state – shape (batch, self.num_nodes * self.gru_units)
- Returns
shape (B, num_nodes * gru_units)
- Return type
torch.tensor
-
training
: bool¶
-