Source code for libcity.data.dataset.trajectory_encoder.serm_encoder

import os
import pandas as pd
from libcity.data.dataset.trajectory_encoder.abstract_trajectory_encoder import AbstractTrajectoryEncoder
from libcity.utils import parse_time

parameter_list = ['dataset', 'min_session_len', 'min_sessions', 'traj_encoder', 'cut_method',
                  'window_size', 'history_type', 'min_checkins', 'max_session_len']
WORD_VEC_PATH = './raw_data/word_vec/glove.twitter.27B.50d.txt'


[docs]class SermEncoder(AbstractTrajectoryEncoder): def __init__(self, config): super().__init__(config) self.uid = 0 self.location2id = {} # 因为原始数据集中的部分 loc id 不会被使用到因此这里需要重新编码一下 self.loc_id = 0 self.tim_max = 47 # 时间编码方式得改变 self.word_vec = [] # words vector self.word_index = {} # word to word ID self.word_id = 0 self.text_vec = self.load_wordvec() self.history_type = self.config['history_type'] self.feature_dict = {'current_loc': 'int', 'current_tim': 'int', 'target': 'int', 'uid': 'int', 'text': 'no_tensor'} # if config['evaluate_method'] == 'sample': # self.feature_dict['neg_loc'] = 'int' # parameter_list.append('neg_samples') parameters_str = '' for key in parameter_list: if key in self.config: parameters_str += '_' + str(self.config[key]) self.cache_file_name = os.path.join( './libcity/cache/dataset_cache/', 'trajectory_{}.json'.format(parameters_str)) # load poi_profile self.poi_profile = None if self.config['dataset'] in ['foursquare_tky', 'foursquare_nyk', 'foursquare_serm']: self.poi_profile = pd.read_csv('./raw_data/{}/{}.geo'.format(self.config['dataset'], self.config['dataset']))
[docs] def encode(self, uid, trajectories, negative_sample=None): """standard encoder use the same method as DeepMove Recode poi id. Encode timestamp with its hour. Args: uid ([type]): same as AbstractTrajectoryEncoder trajectories ([type]): same as AbstractTrajectoryEncoder trajectory1 = [ (location ID, timestamp, timezone_offset_in_minutes), (location ID, timestamp, timezone_offset_in_minutes), ..... ] """ # 直接对 uid 进行重编码 uid = self.uid self.uid += 1 encoded_trajectories = [] for index, traj in enumerate(trajectories): current_loc = [] current_tim = [] current_word_vec = [] for point in traj: loc = point[4] now_time = parse_time(point[2]) if loc not in self.location2id: self.location2id[loc] = self.loc_id self.loc_id += 1 current_loc.append(self.location2id[loc]) # 采用工作日编码到0-23,休息日编码到24-47 time_code = self._time_encode(now_time) current_tim.append(time_code) # 处理语义信息 current_word_vec.append(self.get_text_from_point(point)) # 完成当前轨迹的编码,下面进行输入的形成 # 一条轨迹可以产生多条训练数据,根据第一个点预测第二个点,前两个点预测第三个点.... for i in range(len(current_loc) - 1): trace = [] target = current_loc[i+1] trace.append(current_loc[:i+1]) trace.append(current_tim[:i+1]) trace.append(target) trace.append(uid) trace.append(current_word_vec[:i+1]) # if negative_sample is not None: # neg_loc = [] # for neg in negative_sample[index]: # if neg not in self.location2id: # self.location2id[neg] = self.loc_id # self.loc_id += 1 # neg_loc.append(self.location2id[neg]) # trace.append(neg_loc) encoded_trajectories.append(trace) return encoded_trajectories
[docs] def gen_data_feature(self): loc_pad = self.loc_id tim_pad = self.tim_max + 1 self.pad_item = { 'current_loc': loc_pad, 'current_tim': tim_pad } self.data_feature = { 'loc_size': self.loc_id + 1, 'tim_size': self.tim_max + 2, 'uid_size': self.uid, 'loc_pad': loc_pad, 'tim_pad': tim_pad, 'text_size': len(self.word_index), 'word_vec': self.word_vec }
def _time_encode(self, time): if time.weekday() in [0, 1, 2, 3, 4]: return time.hour else: return time.hour + 24
[docs] def load_wordvec(self, vecpath=WORD_VEC_PATH): word_vec = {} if not os.path.exists(vecpath): raise FileNotFoundError('SERM need Glove word vectors. Please download serm_glove_word_vec.zip from' ' BaiduDisk or Google Drive, and unzip it to raw_data directory') with open(vecpath, 'r', encoding='utf-8') as f: for l in f: vec = [] attrs = l.replace('\n', '').split(' ') for i in range(1, len(attrs)): vec.append(float(attrs[i])) word_vec[attrs[0]] = vec return word_vec
[docs] def get_text_from_point(self, point): """ return word index """ if self.config['dataset'] in ['foursquare_tky', 'foursquare_nyk', 'foursquare_serm']: # 语义信息在 geo 表中 words = self.poi_profile.iloc[point[4]]['venue_category_name'].split(' ') word_index = [] for w in words: w = w.lower() if (w in self.text_vec) and (w not in self.word_index): self.word_index[w] = self.word_id self.word_id += 1 self.word_vec.append(self.text_vec[w]) if w in self.word_index: word_index.append(self.word_index[w]) return word_index else: return []