Source code for libcity.model.trajectory_loc_prediction.SERM

import torch
import torch.nn as nn
import numpy as np
from libcity.model.abstract_model import AbstractModel
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence


[docs]class EmbeddingMatrix(nn.Module): # text_embdeding def __init__(self, input_size, output_size, word_vec): super(EmbeddingMatrix, self).__init__() self.input_size = input_size self.output_size = output_size self.layer = nn.Linear(in_features=self.input_size, out_features=self.output_size, bias=False) self.init_weight(word_vec)
[docs] def init_weight(self, word_vec): # word_vec为text_embedding初始权重矩阵,从data_feature传入. # self.weight output_size*input_size namely length_of_wordvect_glove_pretrained(50) # *text_size(the size of dictionary) # 按照论文源代码 word_vec = text_size(the size of dictionary)*length_of_wordvect_glove_pretrained word_vec = torch.Tensor(word_vec).t() # 转置 self.layer.weight = nn.Parameter(word_vec)
[docs] def forward(self, x): # x:batch*seq*input_size # return torch.matmul(x, self.weights) #batch*seq*text_size * text_size*output_size = batch*seq*output_size return self.layer(x) # batch*seq*output_size
[docs]class SERM(AbstractModel): def __init__(self, config, data_feature): super(SERM, self).__init__(config, data_feature) # initialize parameters # print(config['dataset_class']) self.loc_size = data_feature['loc_size'] self.loc_emb_size = config['loc_emb_size'] self.tim_size = data_feature['tim_size'] self.tim_emb_size = config['tim_emb_size'] self.user_size = data_feature['uid_size'] self.user_emb_size = data_feature['loc_size'] # 根据论文 self.text_size = data_feature['text_size'] self.text_emb_size = len(data_feature['word_vec'][0]) # 这个受限于 word_vec 的长度 self.hidden_size = config['hidden_size'] self.word_one_hot_matrix = np.eye(self.text_size) self.device = config['device'] # Embedding layer self.emb_loc = nn.Embedding(num_embeddings=self.loc_size, embedding_dim=self.loc_emb_size, padding_idx=data_feature['loc_pad']) self.emb_tim = nn.Embedding(num_embeddings=self.tim_size, embedding_dim=self.tim_emb_size, padding_idx=data_feature['tim_pad']) self.emb_user = nn.Embedding(num_embeddings=self.user_size, embedding_dim=self.user_emb_size) self.emb_text = EmbeddingMatrix(self.text_size, self.text_emb_size, data_feature['word_vec']) # lstm layer self.lstm = nn.LSTM(input_size=self.loc_emb_size + self.tim_emb_size + self.text_emb_size, hidden_size=self.hidden_size) # self.lstm = nn.LSTM(input_size=self.loc_emb_size + self.tim_emb_size, hidden_size=self.hidden_size) # dense layer self.dense = nn.Linear(in_features=self.hidden_size, out_features=self.loc_size) # init weight self.apply(self._init_weight) def _init_weight(self, module): if isinstance(module, nn.Embedding): nn.init.xavier_normal_(module.weight) elif isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) elif isinstance(module, nn.LSTM): for name, param in module.named_parameters(): if 'weight_ih' in name: nn.init.xavier_uniform_(param.data) elif 'weight_hh' in name: nn.init.orthogonal_(param.data) elif 'bias' in name: nn.init.constant_(param.data, 0)
[docs] def forward(self, batch): loc = batch['current_loc'] tim = batch['current_tim'] user = batch['uid'] text = batch['text'] max_len = batch['current_loc'].shape[1] text_pad = np.zeros((self.text_size)) # text 现在是 word index 的形式,还需要进行 one_hot encoding one_hot_text = [] for word_index in text: one_hot_text_a_slice = [] for words in word_index: if len(words) == 0: one_hot_text_a_slice.append(np.zeros((self.text_size))) else: one_hot_text_a_slice.append(np.sum(self.word_one_hot_matrix[words], axis=0) / len(words)) # pad one_hot_text_a_slice += [text_pad] * (max_len - len(one_hot_text_a_slice)) one_hot_text.append(np.array(one_hot_text_a_slice)) # batch_size * seq_len * text_size one_hot_text = torch.FloatTensor(one_hot_text).to(self.device) loc_emb = self.emb_loc(loc) tim_emb = self.emb_tim(tim) user_emb = self.emb_user(user) text_emb = self.emb_text(one_hot_text) # change batch*seq*emb_size to seq*batch*emb_size x = torch.cat([loc_emb, tim_emb, text_emb], dim=2).permute(1, 0, 2) # attrs_latent = torch.cat([loc_emb, tim_emb], dim=2).permute(1, 0, 2) # print(attrs_latent.size()) # pack attrs_latent seq_len = batch.get_origin_len('current_loc') pack_x = pack_padded_sequence(x, lengths=seq_len, enforce_sorted=False) lstm_out, (h_n, c_n) = self.lstm(pack_x) # seq*batch*hidden_size # print(lstm_out.size()) # unpack lstm_out, out_len = pad_packed_sequence(lstm_out, batch_first=True) # user_emb is batch*loc_size, so we need get the final lstm_out for i in range(lstm_out.shape[0]): if i == 0: out = lstm_out[0][seq_len[i] - 1].reshape(1, -1) # .reshape(1,-1)表示:转化为1行 else: out = torch.cat((out, lstm_out[i][seq_len[i] - 1].reshape(1, -1)), 0) dense = self.dense(out) # batch * loc_size out_vec = torch.add(dense, user_emb) # batch * loc_size pred = nn.LogSoftmax(dim=1)(out_vec) # result # print(pred.size()) return pred # batch*loc_size
[docs] def predict(self, batch): return self.forward(batch)
[docs] def calculate_loss(self, batch): criterion = nn.NLLLoss() scores = self.forward(batch) # batch*loc_size return criterion(scores, batch['target'])