张量的索引、分片、合并以及维度调整

2021/6/15 10:51:30

本文主要是介绍张量的索引、分片、合并以及维度调整,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

文章目录

  • 张量的符号索引
    • 一维张量索引
    • 二维张量索引
  • tensor.view()方法
  • 张量的分片函数
    • 分块:chunk函数
    • 拆分:split函数
  • 张量的合并操作
  • 拼接函数:cat
    • 堆叠函数:stack


张量的符号索引

张量也是有序序列,我们可以根据每个元素在系统内的顺序“编号”,来找出特定的元素,也就是索引。

一维张量索引

一维张量的索引过程和Python原生对象类型的索引一致,基本格式遵循`[start: end: step]

import torch
t1 = torch.arange(1, 11)
# 张量索引出来的结果还是零维张量, 而不是单独的数。要转化成单独的数,需要使用item()方法。
t1[0]  # tensor(1)
# 冒号分隔,表示对某个区域进行索引,也就是所谓的切片
t1[1: 8] # tensor([2, 3, 4, 5, 6, 7, 8])
# 第二个冒号,表示索引的间隔,在张量的索引中,step位必须大于0
t1[1: 8: 2] # tensor([2, 4, 6, 8])
# 冒号前后没有值,表示索引这个区域
t1[1: : 2]  # 从第二个元素开始索引,一直到结尾,并且每隔两个数取一个
t1[: 8: 2]  # 从第一个元素开始索引到第9个元素(不包含),并且每隔两个数取一个

二维张量索引

二维张量的索引逻辑和一维张量的索引逻辑基本相同,二维张量可以视为两个一维张量组合而成,而在实际的索引过程中,需要用逗号进行分隔,分别表示对哪个一维张量进行索引、以及具体的一维张量的索引。

import torch
t2 = torch.arange(1, 10).reshape(3, 3)
t2[0, 1]    # 表示索引第一行、第二个(第二列的)元素
t2[0, ::2]                # 表示索引第一行、每隔两个元素取一个
t2[::2, ::2]              # 表示每隔两行取一行、并且每一行中每隔两个元素取一个
t2[[0, 2], 1]              # 索引第一行、第三行、第二列的元素

tensor.view()方法

PyTorch中的.view()方法。该方法会返回一个类似视图的结果,该结果和原张量对象共享一块数据存储空间,并且通过.view()方法,还可以改变对象结构,生成一个不同结构,但共享一个存储空间的张量。当然,共享一个存储空间,也就代表二者是“浅拷贝”的关系,修改其中一个,另一个也会同步进行更改。

t = torch.arange(6).reshape(2, 3)
te = t.view(3, 2)              # 构建一个数据相同,但形状不同的“视图”
tr = t.view(1, 2, 3)           # 维度也可以修改

“视图”的作用就是节省空间,而值得注意的是,很多切分张量的方法中,返回结果都是“视图”,而不是新生成一个对象。

张量的分片函数

分块:chunk函数

chunk函数能够按照某维度,对张量进行均匀切分,并且返回结果是原张量的视图。

t2 = torch.arange(12).reshape(4, 3)
tc = torch.chunk(t2, 4, dim=0)           # 在第零个维度上(按行),进行四等分
# 当原张量不能均分时,chunk不会报错,但会返回其他均分的结果
torch.chunk(t2, 3, dim=0)            # 次一级均分结果

拆分:split函数

split既能进行均分,也能进行自定义切分。当然,需要注意的是,和chunk函数一样,split返回结果也是view。

t2 = torch.arange(12).reshape(4, 3)
torch.split(t2, 2, 0)  # 第二个参数只输入一个数值时表示均分,第三个参数表示切分的维度
torch.split(t2, [1, 3], 0)  # 第二个参数输入一个序列时,表示按照序列数值进行切分,也就是1/3分
# 注意,当第二个参数位输入一个序列时,序列的各数值的和必须等于对应维度下形状分量的取值。
torch.split(t2, [1, 1, 2], 0) 
ts = torch.split(t2, [1, 2], 1) 

张量的合并操作

张量的合并操作类似与列表的追加元素,可以拼接、也可以堆叠。

拼接函数:cat

PyTorch中,可以使用cat函数实现张量的拼接。
注意理解,拼接的本质是实现元素的堆积,也就是构成a、b两个二维张量的各一维张量的堆积,最终还是构成二维向量。

a = torch.zeros(2, 3)
b = torch.ones(2, 3)
torch.cat([a, b])                  # 按照行进行拼接,dim默认取值为0
torch.cat([a, b], 1)               # 按照列进行拼接

堆叠函数:stack

和拼接不同,堆叠不是将元素拆分重装,而是简单的将各参与堆叠的对象分装到一个更高维度的张量里。

a = torch.zeros(2, 3)
b = torch.ones(2, 3)
torch.stack([a, b])  # 堆叠之后,生成一个三维张量
torch.stack([a, b]).shape

注意对比二者区别,拼接之后维度不变,堆叠之后维度升高。拼接是把一个个元素单独提取出来之后再放到二维张量中,而堆叠则是直接将两个二维张量封装到一个三维张量中,因此,堆叠的要求更高,参与堆叠的张量必须形状完全相同。



这篇关于张量的索引、分片、合并以及维度调整的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!


扫一扫关注最新编程教程