WeightedRandomSampler加权随机采样

平衡不平衡数据的抽取
torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)
其中__iter__为:
iter(torch.multinomial(self.weights, self.num_samples,
self.replacement).tolist())
其中

* weights为index权重,权重越大的取到的概率越高
* num_samples: 生成的采样长度
* replacement:是否为有放回取样
* multinomial: 伯努利随机数生成函数,也就是根据概率设定生成{0,1,…,n}

如果label为1,那么对应的该类别被取出来的概率是另外一个类别的2倍
import torchvision from torchvision import transforms from torch.utils.data
import sampler from torch.utils.data import DataLoader from
torch.utils.data.sampler import * transform = transforms.Compose([
torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,),
(0.3081,)) ]) trainset = torchvision.datasets.MNIST( root='dataset/',
train=True, #如果为True,从 training.pt 创建数据,否则从 test.pt 创建数据。 download=True,
#如果为true,则从 Internet 下载数据集并将其放在根目录中。 如果已下载数据集,则不会再次下载。 transform=transform ) ##
如果label为1,那么对应的该类别被取出来的概率是另外一个类别的2倍 weights = [2 if label == 1 else 1 for data,
label in trainset] sampler = WeightedRandomSampler(weights,num_samples=10,
replacement=True) dataloader = DataLoader(trainset, batch_size=16,
sampler=sampler)

SubsetRandomSampler索引随机采样

根据index从数据集中抽取这些index对应的图片,然后随机排序
torch.utils.data.SubsetRandomSampler(indices)
其中__iter__为:
(self.indices[i] for i in torch.randperm(len(self.indices)))
其中

* torch.randperm对数组随机排序
* indices为给定的下标数组
所以SubsetRandomSampler的功能是在给定一个数据集下标后,对该下标数组随机排序,然后不放回取样
 

如果我要划分train_set和test_set, 那么读进整个数据集来再split比较慢

不如我直接生成train_set的index和test_set的index这样就可以很快了,所以就出现了SubsetRandomSampler
import torchvision from torchvision import transforms from torch.utils.data
import sampler from torch.utils.data import DataLoader transform =
transforms.Compose([ torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,)) ]) trainset =
torchvision.datasets.MNIST( root='dataset/', train=True, #如果为True,从 training.pt
创建数据,否则从 test.pt 创建数据。 download=True, #如果为true,则从 Internet 下载数据集并将其放在根目录中。
如果已下载数据集,则不会再次下载。 transform=transform ) testset = torchvision.datasets.MNIST(
root='dataset/', train=False, download=True, transform=transform ) split_num =
int(len(trainset) * 0.8) index_list = list(range(len(trainset))) train_idx,
val_idx = index_list[:split_num], index_list[split_num:] train_sampler =
sampler.SubsetRandomSampler(train_idx) val_sampler =
sampler.SubsetRandomSampler(val_idx) loader_train = DataLoader(trainset,
batch_size=100, sampler=train_sampler) loader_val = DataLoader(trainset,
batch_size=100, sampler=val_sampler)

技术
今日推荐
PPT
阅读数 119
下载桌面版
GitHub
百度网盘(提取码:draw)
Gitee
云服务器优惠
阿里云优惠券
腾讯云优惠券
华为云优惠券
站点信息
问题反馈
邮箱:ixiaoyang8@qq.com
QQ群:766591547
关注微信