Source code for libcity.model.trajectory_loc_prediction.LSTPM

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence
from libcity.model.abstract_model import AbstractModel


[docs]class LSTPM(AbstractModel): """ reference: https://github.com/NLPWM-WHU/LSTPM """ def __init__(self, config, data_feature): super(LSTPM, self).__init__(config, data_feature) self.loc_size = data_feature['loc_size'] self.tim_size = data_feature['tim_size'] self.hidden_size = config['hidden_size'] self.emb_size = config['emb_size'] self.device = config['device'] # todo why embeding? self.loc_emb = nn.Embedding(self.loc_size, self.emb_size, padding_idx=data_feature['loc_pad']) # self.emb_tim = nn.Embedding(self.tim_size, 10) 根本就没有用时间???这都能 soda ?? self.lstmcell = nn.LSTM(input_size=self.emb_size, hidden_size=self.hidden_size) self.lstmcell_history = nn.LSTM( input_size=self.emb_size, hidden_size=self.hidden_size) self.linear = nn.Linear(self.hidden_size * 2, self.loc_size, bias=False) self.dropout = nn.Dropout(config['dropout']) # self.user_dropout = nn.Dropout(user_dropout) 不是这模型对吗?? self.tim_sim_matrix = data_feature['tim_sim_matrix'] # could be the same as self.lstmcell self.dilated_rnn = nn.LSTMCell( input_size=self.emb_size, hidden_size=self.hidden_size) self.linear1 = nn.Linear(self.hidden_size, self.hidden_size, bias=False) self.init_weights()
[docs] def init_weights(self): ih = (param.data for name, param in self.named_parameters() if 'weight_ih' in name) hh = (param.data for name, param in self.named_parameters() if 'weight_hh' in name) b = (param.data for name, param in self.named_parameters() if 'bias' in name) for t in ih: nn.init.xavier_uniform_(t) for t in hh: nn.init.orthogonal_(t) for t in b: nn.init.constant_(t, 0)
def _pad_batch_of_lists_masks(self, current_session, origin_len): # 因为没有 pad 所以 max_len == len(l) padde_mask_non_local = [1.0] * (origin_len) + [0.0] * (len(current_session) - origin_len) padde_mask_non_local = torch.FloatTensor(padde_mask_non_local).to(self.device) return padde_mask_non_local
[docs] def forward(self, batch): batch_size = batch['current_loc'].shape[0] origin_len = batch.get_origin_len('current_loc') current_loc = batch['current_loc'] current_tim = batch['current_tim'] items = self.loc_emb(current_loc).permute(1, 0, 2) # sequence * batch_size * embedding current_loc = current_loc.tolist() # pack x and history_x pack_items = pack_padded_sequence(items, lengths=origin_len, enforce_sorted=False) h1 = torch.zeros(1, batch_size, self.hidden_size).to(self.device) c1 = torch.zeros(1, batch_size, self.hidden_size).to(self.device) out, (h1, c1) = self.lstmcell(pack_items, (h1, c1)) # batch_size * sequence_length * hidden_size out, _ = pad_packed_sequence(out, batch_first=True) items = items.permute(1, 0, 2) # sequence * batch_size * embeeding y_list = [] out_hie = [] # batch_size * hidden_size dilated_rnn_input_index = batch['dilated_rnn_input_index'] for ii in range(batch_size): current_session_input_dilated_rnn_index = dilated_rnn_input_index[ii].tolist() # origin_cur_len hiddens_current = items[ii] dilated_lstm_outs_h = [] dilated_lstm_outs_c = [] for index_dilated in range(len(current_session_input_dilated_rnn_index)): index_dilated_explicit = current_session_input_dilated_rnn_index[index_dilated] hidden_current = hiddens_current[index_dilated].unsqueeze(0) if index_dilated == 0: h = torch.zeros(1, self.hidden_size).to(self.device) c = torch.zeros(1, self.hidden_size).to(self.device) (h, c) = self.dilated_rnn(hidden_current, (h, c)) dilated_lstm_outs_h.append(h) dilated_lstm_outs_c.append(c) else: (h, c) = self.dilated_rnn(hidden_current, ( dilated_lstm_outs_h[index_dilated_explicit], dilated_lstm_outs_c[index_dilated_explicit])) dilated_lstm_outs_h.append(h) dilated_lstm_outs_c.append(c) out_hie.append(dilated_lstm_outs_h[-1]) current_session_timid = current_tim[ii].tolist()[origin_len[ii] - 1] # 不包含我 pad 的那个点 current_session_embed = out[ii] # sequence_len * hidden_size # FloatTensor sequence_len * 1 current_session_mask = self._pad_batch_of_lists_masks(current_loc[ii], origin_len[ii]).unsqueeze(1) # mask_batch_ix_non_local[ii].unsqueeze(1) sequence_length = origin_len[ii] # do average pooling for current_session # 1 * hidden_size current_session_represent = torch.sum(current_session_embed * current_session_mask, dim=0).unsqueeze(0)/sum(current_session_mask) list_for_sessions = [] # his_cnt * hidden_size h2 = torch.zeros(1, 1, self.hidden_size).to(self.device) c2 = torch.zeros(1, 1, self.hidden_size).to(self.device) # 处理历史轨迹 for jj in range(len(batch['history_loc'][ii])): sequence = batch['history_loc'][ii][jj] sequence_emb = self.loc_emb(sequence).unsqueeze(1) # his_seq_len * 1 * embedding_size sequence_emb, (h2, c2) = self.lstmcell_history(sequence_emb, (h2, c2)) sequence_tim_id = batch['history_tim'][ii][jj].tolist() # 根据 time slot 相似度修正历史轨迹表征 # tim_size jaccard_sim_row = torch.FloatTensor(self.tim_sim_matrix[current_session_timid]).to(self.device) jaccard_sim_expicit = jaccard_sim_row[sequence_tim_id] # his_seq_len jaccard_sim_expicit_last = F.softmax(jaccard_sim_expicit, dim=0).unsqueeze(0) # 1 * his_seq_len # 1 * hidden_size hidden_sequence_for_current = torch.mm(jaccard_sim_expicit_last, sequence_emb.squeeze(1)) list_for_sessions.append(hidden_sequence_for_current) # 1 * his_cnt avg_distance = batch['history_avg_distance'][ii].unsqueeze(0) # 1 * his_cnt * hidden_size sessions_represent = torch.cat(list_for_sessions, dim=0).unsqueeze(0) # 1 * hidden_size * 1 current_session_represent = current_session_represent.unsqueeze(2) # 1 * 1 * his_cnt sim_between_cur_his = F.softmax(sessions_represent.bmm(current_session_represent).squeeze(2), dim=1).unsqueeze(1) # TODO: why do linear1 and selu? # 1 * hidden_size out_y_current = torch.selu(self.linear1(sim_between_cur_his.bmm(sessions_represent).squeeze(1))) # 1 * hidden_size * 1 layer_2_current = (0.5 * out_y_current + 0.5 * current_session_embed[sequence_length - 1]).unsqueeze(2) # 1 * 1 * his_cnt layer_2_sims = F.softmax(sessions_represent.bmm(layer_2_current).squeeze(2) * 1.0 / avg_distance, dim=1).unsqueeze(1) # 1 * hidden_size out_layer_2 = layer_2_sims.bmm(sessions_represent).squeeze(1) y_list.append(out_layer_2) # batch_size * hidden_size y = torch.selu(torch.cat(y_list, dim=0)) # 得到 shor-term 的输出 # batch_size * hidden_size out_hie = F.selu(torch.cat(out_hie, dim=0)) # get the final lstm out final_out_index = torch.tensor(origin_len) - 1 final_out_index = final_out_index.reshape(final_out_index.shape[0], 1, -1) final_out_index = final_out_index.repeat(1, 1, self.hidden_size).to(self.device) final_out = torch.gather(out, 1, final_out_index).squeeze(1) final_out = F.selu(final_out) final_out = (final_out + out_hie) * 0.5 out_put_emb_v1 = torch.cat([y, final_out], dim=1) output_ln = self.linear(out_put_emb_v1) # batch_size * loc_size output = F.log_softmax(output_ln, dim=1) return output
[docs] def calculate_loss(self, batch): criterion = nn.NLLLoss(reduction='sum').to(self.device) scores = self.forward(batch) return criterion(scores, batch['target'])
[docs] def predict(self, batch): return self.forward(batch)