WeightedRandomSampler加权随机采样

torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)

iter(torch.multinomial(self.weights, self.num_samples,
self.replacement).tolist())

* weights为index权重，权重越大的取到的概率越高
* num_samples： 生成的采样长度
* replacement:是否为有放回取样
* multinomial： 伯努利随机数生成函数，也就是根据概率设定生成{0,1，…,n}

import torchvision from torchvision import transforms from torch.utils.data
import sampler from torch.utils.data import DataLoader from
torch.utils.data.sampler import * transform = transforms.Compose([
torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,),
(0.3081,)) ]) trainset = torchvision.datasets.MNIST( root='dataset/',
#如果为true，则从 Internet 下载数据集并将其放在根目录中。 如果已下载数据集，则不会再次下载。 transform=transform ) ##

label in trainset] sampler = WeightedRandomSampler(weights,num_samples=10,
sampler=sampler)

SubsetRandomSampler索引随机采样

torch.utils.data.SubsetRandomSampler(indices)

(self.indices[i] for i in torch.randperm(len(self.indices)))

* torch.randperm对数组随机排序
* indices为给定的下标数组

import torchvision from torchvision import transforms from torch.utils.data
import sampler from torch.utils.data import DataLoader transform =
transforms.Compose([ torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,)) ]) trainset =
torchvision.datasets.MNIST( root='dataset/', train=True, #如果为True，从 training.pt

int(len(trainset) * 0.8) index_list = list(range(len(trainset))) train_idx,
val_idx = index_list[:split_num], index_list[split_num:] train_sampler =
sampler.SubsetRandomSampler(train_idx) val_sampler =
batch_size=100, sampler=val_sampler)

GitHub

Gitee