需求:使用pandas读取ratings.csv文件,把读取的内容按时间戳排序后写入文件中
csv文件:ratings.csv
userId,movieId,rating,timestamp 1,296,5.0,1147880044 1,306,3.5,1147868817 1,307
,5.0,1147868828 1,665,5.0,1147878820 1,899,3.5,1147868510 1,1088,4.0,1147868495
1,1175,3.5,1147868826 1,1217,3.5,1147878326 1,1237,5.0,1147868839 1,1250,4.0,
1147868414
test.py
import pandas as pd # usecols=[0,1,2,3] 使用前4列,全读取的话usecols可省 # nrows
读取csv文件前10000行数据 data = pd.read_csv('../data/rating/ratings.csv',nrows = 10000,
encoding='utf-8',usecols=[0,1,2,3]) # data里面按时间戳排序 data = data.sort_values(by=[
"timestamp"],ascending=[False]) print(data) # 删除时间戳这一列 axis=1表示对列操作 data = data.
drop('timestamp',axis=1) # 将数据从DataFrame转成list列表 data = data.values.tolist()
print(data) # 写入到文件中 with open("../data/rating/demo_data01",'w',encoding='utf-8'
) as f: for list in data: # 转成列表时 userId 和 movieId 是float类型,我需要转成int类型 list[0] =
int(list[0]) list[1] = int(list[1]) s=str(list[0])+'\t'+str(list[1])+'\t'+str(
list[2])+'\n' f.write(s)
然而,我在构建完demo_data01后使用这个输入集到模型网络中有问题
报错:
InvalidArgumentError (see above for traceback): indices[5519] = 4979 is not in
[0, 3287) [[Node: embedding_lookup_1 = GatherV2[Taxis=DT_INT32, Tindices=
DT_INT32, Tparams=DT_FLOAT, _device=
"/job:localhost/replica:0/task:0/device:CPU:0"](Variable_1/read,
_arg_Placeholder_1_0_1, embedding_lookup_1/axis)]]
我透,[0,3287)是我的模型从1w条数据中抽取了75个用户,包含3287条电影ID及3287条电影评分。
这个4979是从哪冒出来的???

原因:在模型中大量用到了range()函数,1w条数据随机抽取了几千条,这意味输入集的 userId 和 movieId
需要在一个连续的区间里面,上面这个报错是 movieId = 4979 不在连续区间 [0,3287)中。所以需要对 userId 和 movieId
做数据预处理让 movieId 处于[0,3287)这个区间中( userId 对我来说没必要处理)。
import pandas as pd #添加数据 #usecols=[0,1,2]使用前三列,不要时间戳了 data = pd.read_csv(
'../data/rating/ratings.csv',nrows = 10000,encoding='utf-8',usecols=[0,1,2,3])
# 根据 timestamp 降序排序 data = data.sort_values(by=["timestamp"],ascending=[False])
# print(data) #删除 timestamp 列 data = data.drop('timestamp',axis=1) # 对 movieID
升序排序 data = data.sort_values(by=["movieId"],ascending=[True]) # 取出 movieId 列 x =
data.drop(['userId','rating'],axis=1) #print(data) #print(x.drop_duplicates())
# movieId列去掉重复值,保证做字典时value值是唯一的 y = x.drop_duplicates().values.tolist() # 创建字典
d={} # 循环字典映射 index : movieId # index:movieId {0: [1], 1: [2], 2: [3], 3: [5],
4: [6],`...} for index in range(len(y)): d[index] = y[index] #print(d) #data =
data.values.tolist() # 写入demo_data01 with open("../data/rating/demo_data01",'w',
encoding='utf-8') as f: for i in range(len(data)): for k in d: # 根据 movieId 这个
value 值找 key 值,用 key 值取代 movieId if data[i][1] in d[k]: s = str(int(data[i][0]))
+'\t'+str(k)+'\t'+str(data[i][2])+'\n' f.write(s)
经过上述操作后,数据集已经能成功在模型中运行了!

好吧,老师让我把userId也映射成连续的区间
test.py
import pandas as pd import utils #添加数据,取前10w行 统计得到用户ID数量:757,电影ID数量:9786 data =
pd.read_csv('../data/rating/ratings.csv',nrows = 100000,encoding='utf-8',
usecols=[0,1,2,3]) #usecols=[0,1,2]使用前三列,不要时间戳了 # 根据时间戳删除timestamp列降序排序 data =
data.sort_values(by=["timestamp"],ascending=[False]) # print(data) #删除timestamp列
data= data.drop('timestamp',axis=1) # 对 movieID 升序排序 data = data.sort_values(by
=["movieId"],ascending=[True]) # 取出movieId列 x = data.drop(['userId','rating'],
axis=1) # 取出userId列 x1 = data.drop(['movieId','rating'],axis=1) #print(data)
#print(x.drop_duplicates()) # movieId列去重,保证做字典时value值是唯一的 y = x.drop_duplicates(
).values.tolist() # userId列去重,保证做字典时value值是唯一的 y1 = x1.drop_duplicates().values.
tolist() # 电影字典 d={} # 用户字典 d1={} # 循环字典映射 index : movieId # index:movieId {0:
[1], 1: [2], 2: [3], 3: [5], 4: [6],`...},后期需要把这个字典返回出去以便找出真正的电影ID for index in
range(len(y)): d[index] = y[index] for index in range(len(y1)): d1[index] = y1[
index] # 保存文件以便以后找到原始ID,你也可以用其他方法保存 # 获取时 u_dict =
utils.pickle_load(“../data/rating/user_dict”) utils.pickle_save(d,
"../data/rating/movie_dict") utils.pickle_save(d1,"../data/rating/user_dict")
print(d) print(d1) print('用户ID数量:%d,电影ID数量:%d'%(len(d1),len(d))) data = data.
values.tolist() # 写入demo_data01,pycharm 新建一个file(text) with open(
"../data/rating/demo_data01",'w',encoding='utf-8') as f: for i in range(len(data
)): s='' for u in d1: if data[i][0] in d1[u]: s = s+str(u)+'\t' for k in d: if
data[i][1] in d[k]: s = s + str(k)+'\t'+str(data[i][2])+'\n' f.write(s)
utils.py
#coding: utf-8 import pickle import time class Log(): def log(self, text,
log_time=False): print('log: %s' % text) if log_time: print('time: %s' % time.
asctime(time.localtime(time.time()))) def pickle_save(object, file_path): f =
open(file_path, 'wb') pickle.dump(object, f) def pickle_load(file_path): f =
open(file_path, 'rb') return pickle.load(f)

技术
下载桌面版
GitHub
百度网盘(提取码:draw)
Gitee
云服务器优惠
阿里云优惠券
腾讯云优惠券
华为云优惠券
站点信息
问题反馈
邮箱:ixiaoyang8@qq.com
QQ群:766591547
关注微信