Source code for libcity.model.trajectory_loc_prediction.ATSTLSTM

from libcity.model.abstract_model import AbstractModel

import torch.nn as nn
import torch
from math import sqrt
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence
from torch.nn.functional import normalize


[docs]class Attn(nn.Module): """ Attention 注意力机制模块, 对 LSTM 中间层输出做加权平均. """ def __init__(self, hidden_size): """ 初始化. Args: hidden_size (int): 中间层输出向量的大小 """ super(Attn, self).__init__() self.sqrt_rec_size = 1. / sqrt(hidden_size) # context vector self.zu = nn.Linear(in_features=hidden_size, out_features=1, bias=False) self.softmax = nn.Softmax(dim=2)
[docs] def forward(self, x): """ 前向传播. Args: x (torch.tensor.Tensor): shape (batch, seq_len, hidden_size) 中间层输出序列 Returns: (torch.tensor.Tensor): shape (batch, size) """ w = self.zu(x) * self.sqrt_rec_size w = w.permute(0, 2, 1) w = self.softmax(w) # batch_size * 1 *seq_len c = torch.bmm(w, x) return c.squeeze(1)
[docs]class ATSTLSTM(AbstractModel): """ ATST_LSTM 轨迹下一跳预测模型. """ def __init__(self, config, data_feature): """ 模型初始化. Args: config: useless data_feature: useless """ super(ATSTLSTM, self).__init__(config, data_feature) self.hidden_size = config['hidden_size'] self.loc_size = data_feature['loc_size'] self.uid_size = data_feature['uid_size'] self.device = config['device'] # 构建网络 self.loc_embedding = nn.Embedding(num_embeddings=self.loc_size, embedding_dim=self.hidden_size, padding_idx=data_feature['loc_pad']) self.user_embedding = nn.Embedding(num_embeddings=self.uid_size, embedding_dim=self.hidden_size) # Wv self.wv = nn.Linear(in_features=self.hidden_size, out_features=self.hidden_size, bias=False) # Wl self.wl = nn.Linear(in_features=1, out_features=self.hidden_size, bias=False) # Wt self.wt = nn.Linear(in_features=1, out_features=self.hidden_size, bias=False) # Wn self.wn = nn.Linear(in_features=self.hidden_size, out_features=self.hidden_size, bias=False) # Wp self.wp = nn.Linear(in_features=self.hidden_size, out_features=self.hidden_size, bias=False) self.lstm = nn.LSTM(input_size=self.hidden_size, hidden_size=self.hidden_size) self.attn = Attn(hidden_size=self.hidden_size)
[docs] def forward(self, batch): # batch_size * padded_seq_len loc = batch['current_loc'] dis = batch['current_dis'] tim = batch['current_tim'] # batch_size * neg_samples loc_neg = batch['loc_neg'] dis_neg = batch['dis_neg'] tim_neg = batch['tim_neg'] # batch_size uid = batch['uid'] # batch_size * 1 target_loc = batch['target_loc'].unsqueeze(1) target_dis = batch['target_dis'].unsqueeze(1) target_tim = batch['target_tim'].unsqueeze(1) origin_len = batch.get_origin_len('current_loc') padded_seq_len = loc.shape[1] # concat all input to do embedding total_loc = torch.cat([loc, target_loc, loc_neg], dim=1) total_dis = torch.cat([dis, target_dis, dis_neg], dim=1).unsqueeze(2) total_tim = torch.cat([tim, target_tim, tim_neg], dim=1).unsqueeze(2) # embedding total_loc_emb = self.loc_embedding(total_loc) # batch_size * total_len * hidden_size total_emb = self.wv(total_loc_emb) + self.wl(total_dis) + self.wt(total_tim) # split emb current_emb, rest_emb = torch.split(total_emb, [padded_seq_len, total_emb.shape[1] - padded_seq_len], dim=1) # lstm pack_current_emb = pack_padded_sequence(current_emb, lengths=origin_len, enforce_sorted=False, batch_first=True) lstm_out, (h_n, c_n) = self.lstm(pack_current_emb) lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True) # attn rn = self.attn(lstm_out) # batch_size * hidden_size # get user laten vec pu = self.user_embedding(uid) # batch_size * hidden_size # first output (wn*rn + wp * pu) first_part = self.wn(rn) + self.wp(pu) # batch_size * hidden_size first_part = first_part.unsqueeze(2) # batch_size * hidden_size * 1 output = torch.bmm(rest_emb, first_part).squeeze(2) # batch_size * (neg_samples+1) return output
[docs] def predict(self, batch): return self.forward(batch)
[docs] def calculate_loss(self, batch): """ 计算模型损失(不包含正则项) Args: batch (libcity.data.batch): 输入 Returns: (torch.tensor.Tensor): shape () 损失 """ score = self.predict(batch) # 这里需要对 score 进行一个归一化,不然 loss 会变成 inf score = normalize(score, dim=1) score_pos, score_neg = torch.split(score, [1, score.shape[1] - 1], dim=1) # score_pos is batch_size * 1 # score_neg is batch_size * neg_samples loss = -(score_pos - score_neg).sigmoid().log().sum() return loss