• 首页 首页 icon
  • 工具库 工具库 icon
    • IP查询 IP查询 icon
  • 内容库 内容库 icon
    • 快讯库 快讯库 icon
    • 精品库 精品库 icon
    • 问答库 问答库 icon
  • 更多 更多 icon
    • 服务条款 服务条款 icon

PyTorch的TensorDataset

武飞扬头像
Iareges
帮助1

一、简介

顾名思义,torch.utils.data 中的 TensorDataset 基于一系列张量构建数据集。这些张量的形状可以不尽相同,但第一个维度必须具有相同大小,这是为了保证在使用 DataLoader 时可以正常地返回一个批量的数据。

二、源码解读

以下是 TensorDataset 的源码:

class TensorDataset(Dataset[Tuple[Tensor, ...]]):
    r"""Dataset wrapping tensors.

    Each sample will be retrieved by indexing tensors along the first dimension.

    Args:
        *tensors (Tensor): tensors that have the same size of the first dimension.
    """
    tensors: Tuple[Tensor, ...]

    def __init__(self, *tensors: Tensor) -> None:
        assert all(tensors[0].size(0) == tensor.size(0)
                   for tensor in tensors), "Size mismatch between tensors"
        self.tensors = tensors

    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in self.tensors)

    def __len__(self):
        return self.tensors[0].size(0)

*tensors 告诉我们实例化 TensorDataset 时传入的是一系列张量,即:

dataset = TensorDataset(tensor_1, tensor_2, ..., tensor_n)

随后的 assert 是用来确保传入的这些张量中,每个张量在第一个维度的大小都等于第一个张量在第一个维度的大小,即要求所有张量在第一个维度的大小都相同。

__getitem__ 方法返回的结果等价于

return tensor_1[index], tensor_2[index], ..., tensor_n[index]

从这行代码可以看出,如果 n n n 个张量在第一个维度的大小不完全相同,则必然会有一个张量出现 IndexError。确保第一个维度大小相同也是为了之后传入 DataLoader 中能够正常地以一个批量的形式加载。

__len__ 就不用多说了,因为所有张量的第一个维度大小都相同,所以直接返回传入的第一个张量在第一个维度的大小即可。

📌 TensorDataset 将张量的第一个维度视为数据集大小的维度,数据集在传入 DataLoader 后,该维度也是 batch_size 所在的维度

三、通过例子进一步理解

假设当前目录下存放一个 data.csv 文件,其中的每一行的后六个数字代表样本对应的特征向量,第一个数字代表该样本对应的标签。

1.0000, 0.9449, -0.8295, -0.7112, -0.7005, -0.2167, -0.7059
1.0000, -2.1290, 0.3062, -0.2188, -1.3525, 1.6726, -0.8547
-1.0000, -1.5803, 0.6320, -1.9216, -0.0722, 1.4919, -0.3219
1.0000, -0.2993, 0.3256, 0.3015, 0.4959, -0.1034, -1.0536
-1.0000, -0.0025, 0.8698, 0.9149, 1.4535, 1.1784, 0.1983
-1.0000, -0.5881, -0.5728, 2.5740, 0.9449, 1.9096, 0.3761
1.0000, -0.9585, -1.3368, -1.1004, 0.6487, 1.7098, 1.5862
-1.0000, 1.4861, 1.3814, 0.7968, 0.5741, 1.0919, -0.1592

接下来我们分别用普通方法和 TensorDataset 方法来构建数据集。

普通方法:

import torch
from torch.utils.data import Dataset
import pandas as pd


class MyDataset(Dataset):

    def __init__(self):
        self.data = pd.read_csv('data.csv', header=None).values

    def __getitem__(self, idx):
        feature = torch.from_numpy(self.data[idx, 1:])
        label = torch.tensor(self.data[idx, 0])
        return feature, label

    def __len__(self):
        return len(self.data)


mydataset = MyDataset()

TensorDataset 方法:

import torch
from torch.utils.data import TensorDataset
import pandas as pd

data = pd.read_csv('data.csv', header=None).values
features = torch.from_numpy(data[:, 1:])
labels = torch.from_numpy(data[:, 0])

mydataset = TensorDataset(features, labels)

这篇好文章是转载于:学新通技术网

  • 版权申明: 本站部分内容来自互联网,仅供学习及演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,请提供相关证据及您的身份证明,我们将在收到邮件后48小时内删除。
  • 本站站名: 学新通技术网
  • 本文地址: /boutique/detail/tanhgcfifk
系列文章
更多 icon
同类精品
更多 icon
继续加载