多维数组转一维

其实本质上是权重累加,越靠外的维度,权重越大。二维数组转换时,感觉背个公式就得了。后面碰到三维,四维的时候,就有些焦头烂额了。

对于行优先(row-major)存储的二维数组 [rows][cols],坐标 (x, y) 对应的一维索引为:

1
index = x * cols + y

其中 x 为行号,y 为列号。这个公式的含义是:先跳过 x 个完整的行(每行 cols 个元素),再取该行中的第 y 个元素。

反向映射同样直接:

1
2
x = index / cols
y = index % cols

整数除法得到行号,取余得到列号。

三维数组的推广

对于三维数组 [dim_x][dim_y][dim_z],坐标 (x, y, z) 的一维索引为:

1
index = x * (dim_y * dim_z) + y * dim_z + z

拆解如下:

  • x * (dim_y * dim_z):跳过前面 x 个完整的二维平面,每个平面有 dim_y × dim_z 个元素
  • y * dim_z:在当前平面内,跳过 y 个完整的行,每行有 dim_z 个元素
  • + z:在当前行内,取第 z 个元素

反向映射通过连续整除和取余完成:

1
2
3
4
5
total_xy = dim_y * dim_z
x = index / total_xy
remainder = index % total_xy
y = remainder / dim_z
z = remainder % dim_z

通用规律

对于 N 维数组,各维长度分别为 d1, d2, ..., dn,坐标 (i1, i2, ..., in) 在行优先存储下的一维索引通式为:

1
2
3
4
5
index = i1 × (d2×d3×...×dn)
+ i2 × (d3×d4×...×dn)
+ ...
+ i_{n-1} × (dn)
+ in

每一维度的权重是该维度之后所有维度的长度乘积。最内层维度的权重为 1,最外层维度的权重最大。

反向映射是从最外层开始,逐层做整数除法和取余:

1
2
3
4
5
6
7
8
9
weight = d2 × d3 × ... × dn
i1 = index / weight
remainder = index % weight

weight = d3 × d4 × ... × dn
i2 = remainder / weight
remainder = remainder % weight

...

既然

既然有如此规律,不如写个类来encoder和decoder?

先了解一下std::mdspan。

std::mdspan

在C++23之前,操作多维数组是一件尴尬的事。

原生数组声明为int arr[3][4][5]后,一旦被传递,维度信息就会丢失。传递指针需要手动携带各维长度,索引计算要自己写x * 4 * 5 + y * 5 + z。使用vector<vector<vector<int>>>则导致内存不连续、访问开销高、缓存友好性差。

工业界存在大量多维数组处理需求:图像处理、科学计算、机器学习张量操作、物理模拟。这些场景需要一种类型:描述一块连续内存,将其解释为特定维度的多维数组,且不额外拷贝数据。这就是mdspan

mdspan多维跨度视图(Multidimensional Span)的缩写。它有三个核心特征:

  1. 非拥有:不分配内存,不管理生命周期,仅对现有连续内存提供视图
  2. 多维索引:支持[i][j][k](i,j,k)风格的多维下标访问
  3. 布局策略:可指定行优先、列优先、自定义映射规则

mdspan的设计目标是与span形成镜像:span是一维连续序列的非拥有视图,mdspan是多维连续序列的非拥有视图。

设计类

mdspan 启发,先来只做坐标与索引的转换,不涉足内存管理、访问器、布局策略的全特化。专注做好一件事。

设计一个类,就两个功能,给定N维坐标,返回一维索引。给定一维索引,返回N维坐标。

首先先搞清楚维度的长度和步长这两个概念:多维数组索引转换的核心在于区分“维度长度”和“步长”:维度长度表示该维度有多少个元素,是静态的容量描述;步长表示该维度坐标每增加1时,一维索引需要增加的数值,是动态的间距描述。行优先布局下,最内层维度的步长为1,向外每一层的步长等于内层所有维度的长度乘积;列优先则相反,最外层步长为1,向内累乘。编码(坐标转索引)是将各维坐标乘以对应步长后求和,解码(索引转坐标)则是从外层开始连续除以步长并取余。步长才是内存寻址的真正“进位制”,维度长度仅用于计算步长和边界检查,不能直接用于解码。

由于我变参模板类写得太少掌握的不好,让小d老师帮我理解代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
#include <vector>      // std::vector 动态数组
#include <array> // std::array 静态数组(编译期大小)
#include <concepts> // std::unsigned_integral 概念约束
#include <stdexcept> // std::out_of_range 异常
#include <cstddef> // std::size_t

/**
* @brief 行优先多维索引转换器(运行时维度版本)
*
* 将N维坐标与一维线性索引互相转换。
* 布局规则:最内层维度在内存中连续(C风格行优先)。
*
* 使用示例:
* row_major<> rm({3, 4, 5}); // 三维数组 [3][4][5]
* size_t idx = rm.encode(1, 2, 3); // 坐标→索引: 33
* auto coords = rm.decode(33); // 索引→坐标: {1,2,3}
*
* @tparam T 无符号整数类型,默认为 size_t
*/
template<std::unsigned_integral T = std::size_t>
class row_major {
private:
// -----------------------------------------------------------------
// 核心数据成员
// -----------------------------------------------------------------

std::vector<T> dims_; // 各维长度数组,dims_[0]是最外层
// 例: [3,4,5] 表示3层、4行、5列

std::vector<T> strides_; // 各维步长数组,strides_[0]是最外层
// 例: [20,5,1] 层步长20,行步长5,列步长1
// 含义:该维坐标+1,内存偏移增加这么多元素

std::size_t rank_; // 维度数,即 dims_.size()

T total_size_; // 总元素数,各维长度乘积

public:
// -----------------------------------------------------------------
// 构造函数1:从初始化列表构造
// 调用形式:row_major rm({3, 4, 5});
// -----------------------------------------------------------------

/**
* @brief 从初始化列表构造维度信息
* @param dims 各维长度,如 {3,4,5}
*/
explicit row_major(std::initializer_list<T> dims)
: dims_(dims), // 拷贝初始化列表到vector
rank_(dims.size()), // 维度数 = 列表长度
total_size_(1) // 初始化为1,后面累乘
{
// 预分配步长数组空间
strides_.resize(rank_);

// 行优先步长计算:从最内层(下标最大)向外层(下标0)计算
// 逆向遍历技巧:for (size_t i = rank_; i-- > 0; )
// 优点:无符号类型不会出现 i>=0 死循环,且 rank_=0 时安全
for (std::size_t i = rank_; i-- > 0; ) {
// 当前维度的步长 = 内层所有维度的总元素数
strides_[i] = total_size_;

// 总元素数累乘当前维度长度,为外层计算做准备
total_size_ *= dims_[i];

// 以 {3,4,5} 为例:
// 初始 total_size_ = 1
// i=2: strides_[2] = 1, total_size_ = 1*5 = 5
// i=1: strides_[1] = 5, total_size_ = 5*4 = 20
// i=0: strides_[0] = 20, total_size_ = 20*3 = 60
}
}

// -----------------------------------------------------------------
// 构造函数2:从 vector 构造
// 调用形式:vector<T> d{3,4,5}; row_major rm(d);
// -----------------------------------------------------------------

/**
* @brief 从 vector 构造维度信息
* @param dims 各维长度 vector
*/
explicit row_major(const std::vector<T>& dims)
: dims_(dims), // 拷贝整个vector
rank_(dims.size()),
total_size_(1)
{
strides_.resize(rank_);

// 与初始化列表版本完全相同的步长计算逻辑
for (std::size_t i = rank_; i-- > 0; ) {
strides_[i] = total_size_;
total_size_ *= dims_[i];
}
}

// -----------------------------------------------------------------
// 编码函数1:变参模板版本(最自然的调用形式)
// 调用形式:rm.encode(1, 2, 3)
// -----------------------------------------------------------------

/**
* @brief 将多维坐标转换为一维索引(变参版本)
* @tparam Indices 各维坐标的类型包
* @param indices 各维坐标值包(个数必须等于维度数)
* @return 一维线性索引
*
* 实现原理:offset = Σ(坐标[i] × 步长[i])
*/
template<typename... Indices>
requires (sizeof...(Indices) == rank_)
T encode(Indices... indices) const {
// 将参数包展开为初始化列表,构造 vector
// 注意:这里不能用 std::array,因为 rank_ 是运行时值
std::vector<T> idx = {static_cast<T>(indices)...};
// 包展开过程:
// indices = 1, 2, 3
// {static_cast<T>(indices)...}
// → {static_cast<T>(1), static_cast<T>(2), static_cast<T>(3)}
// → {1, 2, 3}

T offset = 0;
for (std::size_t i = 0; i < rank_; ++i) {
offset += idx[i] * strides_[i];
}
return offset;
}

// -----------------------------------------------------------------
// 编码函数2:vector 参数版本
// 调用形式:rm.encode({1, 2, 3}) 或 rm.encode(vec)
// -----------------------------------------------------------------

/**
* @brief 将多维坐标转换为一维索引(vector版本)
* @param indices 各维坐标 vector
* @return 一维线性索引
* @throw std::out_of_range 坐标数量与维度数不匹配
*/
T encode(const std::vector<T>& indices) const {
// 运行时检查维度匹配
if (indices.size() != rank_) {
throw std::out_of_range("row_major::encode: 维度不匹配");
}

T offset = 0;
for (std::size_t i = 0; i < rank_; ++i) {
offset += indices[i] * strides_[i];
}
return offset;
}

// -----------------------------------------------------------------
// 编码函数3:数组参数版本(仅当维度数编译期已知时可用)
// 调用形式:std::array<T,3> arr{1,2,3}; rm.encode(arr);
// 注意:此版本要求调用者确保 N == rank_,否则行为未定义
// -----------------------------------------------------------------

/**
* @brief 将多维坐标转换为一维索引(array版本)
* @tparam N 数组大小(编译期常量)
* @param indices 各维坐标 array
* @return 一维线性索引
*
* 注意:此函数不检查 N 与 rank_ 是否相等!
* 调用者必须保证 N == rank_,否则越界访问
*/
template<std::size_t N>
T encode(const std::array<T, N>& indices) const {
static_assert(N > 0, "维度必须大于0");

T offset = 0;
// 只取前 rank_ 个元素,假设调用者保证 N >= rank_
for (std::size_t i = 0; i < rank_; ++i) {
offset += indices[i] * strides_[i];
}
return offset;
}

// -----------------------------------------------------------------
// 解码函数1:返回 vector
// 调用形式:auto coords = rm.decode(33);
// -----------------------------------------------------------------

/**
* @brief 将一维索引转换为多维坐标
* @param index 一维线性索引
* @return 各维坐标 vector
* @throw std::out_of_range 索引超出总元素数
*
* 实现原理:连续除以步长并取余
* 以 [3][4][5] strides=[20,5,1] index=33 为例:
* i=0: coords[0]=33/20=1, remaining=33%20=13
* i=1: coords[1]=13/5=2, remaining=13%5=3
* i=2: coords[2]=3/1=3, remaining=3%1=0
* 结果: {1,2,3}
*/
std::vector<T> decode(T index) const {
if (index >= total_size_) {
throw std::out_of_range("row_major::decode: 索引越界");
}

std::vector<T> coords(rank_); // 分配并零初始化
T remaining = index;

for (std::size_t i = 0; i < rank_; ++i) {
coords[i] = remaining / strides_[i];
remaining %= strides_[i];
}

return coords; // 返回值优化(RVO),不拷贝
}

// -----------------------------------------------------------------
// 解码函数2:写入预先分配的缓冲区(C风格接口)
// 调用形式:T buf[3]; rm.decode(33, buf);
// -----------------------------------------------------------------

/**
* @brief 将一维索引转换为多维坐标,写入输出缓冲区
* @param index 一维线性索引
* @param out_coords 输出缓冲区指针,必须指向至少 rank_ 个元素的空间
* @throw std::out_of_range 索引超出总元素数
*
* 此接口用于性能敏感场景,避免 vector 分配开销
*/
void decode(T index, T* out_coords) const {
if (index >= total_size_) {
throw std::out_of_range("row_major::decode: 索引越界");
}

T remaining = index;
for (std::size_t i = 0; i < rank_; ++i) {
out_coords[i] = remaining / strides_[i];
remaining %= strides_[i];
}
}

// -----------------------------------------------------------------
// 查询接口(getters)
// -----------------------------------------------------------------

/// @brief 返回维度数
std::size_t rank() const noexcept {
return rank_;
}

/// @brief 返回总元素数
T size() const noexcept {
return total_size_;
}

/// @brief 返回第 i 维的长度
T dim(std::size_t i) const {
return dims_[i]; // 调用者保证 i < rank_
}

/// @brief 返回第 i 维的步长
T stride(std::size_t i) const {
return strides_[i]; // 调用者保证 i < rank_
}

/// @brief 返回所有维度的长度数组(只读)
const std::vector<T>& dims() const noexcept {
return dims_;
}

/// @brief 返回所有维度的步长数组(只读)
const std::vector<T>& strides() const noexcept {
return strides_;
}
};