在 PyTorch 中展平

ID:20440 / 打印

在 pytorch 中展平

请我喝杯咖啡☕

*备忘录:

  • 我的帖子解释了 flatten() 和 ravel()。
  • 我的帖子解释了 unflatten()。

flatten() 可以通过从零个或多个元素的 0d 或多个 d 张量中选择维度来移除零个或多个维度,得到零个或多个元素的 1d 或多个 d 张量,如下所示:

*备忘录:

  • 初始化的第一个参数是 start_dim(optional-default:1-type:int)。
  • 初始化的第二个参数是 end_dim(可选-默认:-1-类型:int)。
  • 第一个参数是输入(必需类型:int、float、complex 或 bool 的张量)。
  • flatten() 可以将 0d 张量更改为 1d 张量。
  • flatten() 对于一维张量没有任何作用。
  • flatten() 和 flatten() 的区别是:
    • flatten() 的 start_dim 默认值为 1,而 flatten() 的 start_dim 默认值为 0。
    • 基本上,flatten() 用于定义模型,而 flatten() 不用于定义模型。
import torch from torch import nn  flatten = nn.Flatten() flatten # Flatten(start_dim=1, end_dim=-1)  flatten.start_dim # 1  flatten.end_dim # -1  my_tensor = torch.tensor(7)  flatten = nn.Flatten(start_dim=0, end_dim=0) flatten = nn.Flatten(start_dim=0, end_dim=-1) flatten = nn.Flatten(start_dim=-1, end_dim=0) flatten = nn.Flatten(start_dim=-1, end_dim=-1) flatten(input=my_tensor) # tensor([7])  my_tensor = torch.tensor([7, 1, -8, 3, -6, 0])  flatten = nn.Flatten(start_dim=0, end_dim=0) flatten = nn.Flatten(start_dim=0, end_dim=-1) flatten = nn.Flatten(start_dim=-1, end_dim=0) flatten = nn.Flatten(start_dim=-1, end_dim=-1) flatten(input=my_tensor) # tensor([7, 1, -8, 3, -6, 0])  my_tensor = torch.tensor([[7, 1, -8], [3, -6, 0]])  flatten = nn.Flatten(start_dim=0, end_dim=1) flatten = nn.Flatten(start_dim=0, end_dim=-1) flatten = nn.Flatten(start_dim=-2, end_dim=1) flatten = nn.Flatten(start_dim=-2, end_dim=-1) flatten(input=my_tensor) # tensor([7, 1, -8, 3, -6, 0])  flatten = nn.Flatten() flatten = nn.Flatten(start_dim=0, end_dim=0) flatten = nn.Flatten(start_dim=-1, end_dim=-1) flatten = nn.Flatten(start_dim=0, end_dim=-2) flatten = nn.Flatten(start_dim=1, end_dim=1) flatten = nn.Flatten(start_dim=1, end_dim=-1) flatten = nn.Flatten(start_dim=-1, end_dim=1) flatten = nn.Flatten(start_dim=-1, end_dim=-1) flatten = nn.Flatten(start_dim=-2, end_dim=0) flatten = nn.Flatten(start_dim=-2, end_dim=-2) flatten(input=my_tensor) # tensor([[7, 1, -8], [3, -6, 0]])  my_tensor = torch.tensor([[[7], [1], [-8]], [[3], [-6], [0]]])  flatten = nn.Flatten(start_dim=0, end_dim=2) flatten = nn.Flatten(start_dim=0, end_dim=-1) flatten = nn.Flatten(start_dim=-3, end_dim=2) flatten = nn.Flatten(start_dim=-3, end_dim=-1) flatten(input=my_tensor) # tensor([7, 1, -8, 3, -6, 0])  flatten = nn.Flatten(start_dim=0, end_dim=0) flatten = nn.Flatten(start_dim=0, end_dim=-3) flatten = nn.Flatten(start_dim=1, end_dim=1) flatten = nn.Flatten(start_dim=1, end_dim=-2) flatten = nn.Flatten(start_dim=2, end_dim=2) flatten = nn.Flatten(start_dim=2, end_dim=-1) flatten = nn.Flatten(start_dim=-1, end_dim=2) flatten = nn.Flatten(start_dim=-1, end_dim=-1) flatten = nn.Flatten(start_dim=-2, end_dim=1) flatten = nn.Flatten(start_dim=-2, end_dim=-2) flatten = nn.Flatten(start_dim=-3, end_dim=0) flatten = nn.Flatten(start_dim=-3, end_dim=-3) flatten(input=my_tensor) # tensor([[[7], [1], [-8]], [[3], [-6], [0]]])  flatten = nn.Flatten(start_dim=0, end_dim=1) flatten = nn.Flatten(start_dim=0, end_dim=-2) flatten = nn.Flatten(start_dim=-3, end_dim=1) flatten = nn.Flatten(start_dim=-3, end_dim=-2) flatten(input=my_tensor) # tensor([[7], [1], [-8], [3], [-6], [0]])  flatten = nn.Flatten() flatten = nn.Flatten(start_dim=1, end_dim=2) flatten = nn.Flatten(start_dim=1, end_dim=-1) flatten = nn.Flatten(start_dim=-2, end_dim=2) flatten = nn.Flatten(start_dim=-2, end_dim=-1) flatten(input=my_tensor) # tensor([[7, 1, -8], [3, -6, 0]])  my_tensor = torch.tensor([[[7.], [1.], [-8.]], [[3.], [-6.], [0.]]])  flatten = nn.Flatten() flatten(input=my_tensor) # tensor([[7., 1., -8.], [3., -6., 0.]])  my_tensor = torch.tensor([[[7.+0.j], [1.+0.j], [-8.+0.j]],                           [[3.+0.j], [-6.+0.j], [0.+0.j]]]) flatten = nn.Flatten() flatten(input=my_tensor) # tensor([[7.+0.j, 1.+0.j, -8.+0.j], #         [3.+0.j, -6.+0.j, 0.+0.j]])  my_tensor = torch.tensor([[[True], [False], [True]],                           [[False], [True], [False]]]) flatten = nn.Flatten() flatten(input=my_tensor) # tensor([[True, False, True], #         [False, True, False]]) 
上一篇: Windows 7 上如何快速安装最新版 PyTorch?
下一篇: Deepin 15.10 安装 OpenSSL 1.1.1d 后编译 Python 报错:如何解决 "libssl.so.1.1: version `OPENSSL_1_1_1' not found"?

作者:admin @ 24资源网   2025-01-14

本站所有软件、源码、文章均有网友提供,如有侵权联系308410122@qq.com

与本文相关文章

发表评论:

◎欢迎参与讨论,请在这里发表您的看法、交流您的观点。