libcity.model.trajectory_loc_prediction.HSTLSTM¶
-
class
libcity.model.trajectory_loc_prediction.HSTLSTM.
HSTLSTM
(config, data_feature)[source]¶ Bases:
libcity.model.abstract_model.AbstractModel
RNN classifier using ST-LSTM as its core.
-
cal_inter
(ld, hd, l, h, embed)[source]¶ Calculate a linear interpolation. :param ld: Distances to lower bound, shape (batch_size, step) :param hd: Distances to higher bound, shape (batch_size, step) :param l: Lower bound indexes, shape (batch_size, step) :param h: Higher bound indexes, shape (batch_size, step)
-
calculate_loss
(batch)[source]¶ Train model using one batch of data and return loss value. :param model: One instance of STLSTMClassifier. :param batch_l: batch of input location sequences,
size (batch_size, time_step, input_size)
- Parameters
batch_t – batch of temporal interval value, size (batch_size, step)
batch_d – batch of spatial distance value, size (batch_size, step)
batch_label – batch of label, size (batch_size)
- Returns
loss value.
-
forward
(batch_l, batch_t, batch_d, origin_len)[source]¶ Process forward propagation of ST-LSTM classifier. :param batch_l: batch of input location sequences,
size (batch_size, time_step, input_size)
- Parameters
batch_t – batch of temporal interval value, size (batch_size, step)
batch_d – batch of spatial distance value, size (batch_size, step)
- Returns
prediction result of this batch, size (batch_size, output_size, step).
-
predict
(batch)[source]¶ Predict a batch of data. :param batch_l: batch of input location sequences,
size (batch_size, time_step, input_size)
- Parameters
batch_t – batch of temporal interval value, size (batch_size, step)
batch_d – batch of spatial distance value, size (batch_size, step)
- Returns
batch of predicted class indices, size (batch_size).
-
training
: bool¶
-
-
class
libcity.model.trajectory_loc_prediction.HSTLSTM.
STLSTM
(input_size, hidden_size, bias=True)[source]¶ Bases:
torch.nn.modules.module.Module
One layer, batch-first Spatial-Temporal LSTM network. Kong D, Wu F. HST-LSTM: A Hierarchical Spatial-Temporal Long-Short Term Memory Network for Location Prediction[C]//IJCAI. 2018: 2341-2347. .. rubric:: Examples
>>> st_lstm = STLSTM(10, 20) >>> input_l = torch.randn(6, 3, 10) >>> input_s = torch.randn(6, 3, 10) >>> input_q = torch.randn(6, 3, 10) >>> hidden_out, cell_out = st_lstm(input_l, input_s, input_q)
-
forward
(input_l, input_s, input_q, hc=None)[source]¶ Proceed forward propagation of ST-LSTM network. :param input_l: input of location embedding vector, shape (batch_size, step, input_size) :param input_s: input of spatial embedding vector, shape (batch_size, step, input_size) :param input_q: input of temporal embedding vector, shape (batch_size, step, input_size) :param hc: tuple containing initial hidden state and cell state, optional. :return: hidden states and cell states produced by iterate through the steps.
-
training
: bool¶
-
-
class
libcity.model.trajectory_loc_prediction.HSTLSTM.
STLSTMCell
(input_size, hidden_size, bias=True)[source]¶ Bases:
torch.nn.modules.module.Module
A Spatial-Temporal Long Short Term Memory (ST-LSTM) cell. Kong D, Wu F. HST-LSTM: A Hierarchical Spatial-Temporal Long-Short Term Memory Network for Location Prediction[C]//IJCAI. 2018: 2341-2347. .. rubric:: Examples
>>> st_lstm = STLSTMCell(10, 20) >>> input_l = torch.randn(6, 3, 10) >>> input_s = torch.randn(6, 3, 10) >>> input_q = torch.randn(6, 3, 10) >>> hc = (torch.randn(3, 20), torch.randn(3, 20)) >>> output = [] >>> for i in range(6): >>> hc = st_lstm(input_l[i], input_s[i], input_q[i], hc) >>> output.append(hc[0])
-
forward
(input_l, input_s, input_q, hc=None)[source]¶ Proceed one step forward propagation of ST-LSTM. :param input_l: input of location embedding vector, shape (batch_size, input_size) :param input_s: input of spatial embedding vector, shape (batch_size, input_size) :param input_q: input of temporal embedding vector, shape (batch_size, input_size) :param hc: tuple containing hidden state and cell state of previous step. :return: hidden state and cell state of this step.
-
st_lstm_cell_cal
(input_l, input_s, input_q, hidden, cell, w_ih, w_hh, w_s, w_q, b_ih, b_hh)[source]¶ Proceed calculation of one step of STLSTM. :param input_l: input of location embedding, shape (batch_size, input_size) :param input_s: input of spatial embedding, shape (batch_size, input_size) :param input_q: input of temporal embedding, shape (batch_size, input_size) :param hidden: hidden state from previous step, shape (batch_size, hidden_size) :param cell: cell state from previous step, shape (batch_size, hidden_size) :param w_ih: chunk of weights for process input tensor, shape (4 * hidden_size, input_size) :param w_hh: chunk of weights for process hidden state tensor, shape (4 * hidden_size, hidden_size) :param w_s: chunk of weights for process input of spatial embedding, shape (3 * hidden_size, input_size) :param w_q: chunk of weights for process input of temporal embedding, shape (3 * hidden_size, input_size) :param b_ih: chunk of biases for process input tensor, shape (4 * hidden_size) :param b_hh: chunk of biases for process hidden state tensor, shape (4 * hidden_size) :return: hidden state and cell state of this step.
-
training
: bool¶
-
-
libcity.model.trajectory_loc_prediction.HSTLSTM.
cal_slot_distance
(value, slots)[source]¶ Calculate a value’s distance with nearest lower bound and higher bound in slots. :param value: The value to be calculated. :param slots: values of slots, needed to be sorted. :return: normalized distance with lower bound and higher bound,
and index of lower bound and higher bound.
-
libcity.model.trajectory_loc_prediction.HSTLSTM.
cal_slot_distance_batch
(batch_value, slots)[source]¶ Proceed cal_slot_distance on a batch of data. :param batch_value: a batch of value, size (batch_size, step) :param slots: values of slots, needed to be sorted. :return: batch of distances and indexes. All with shape (batch_size, step).
-
libcity.model.trajectory_loc_prediction.HSTLSTM.
construct_slots
(min_value, max_value, num_slots, type)[source]¶ Construct values of slots given min value and max value. :param min_value: minimum value. :param max_value: maximum value. :param num_slots: number of slots to construct. :param type: type of slots to construct, ‘linear’ or ‘exp’. :return: values of slots.