多维数组转一维
其实本质上是权重累加,越靠外的维度,权重越大。二维数组转换时,感觉背个公式就得了。后面碰到三维,四维的时候,就有些焦头烂额了。
对于行优先(row-major)存储的二维数组 [rows][cols],坐标 (x, y) 对应的一维索引为:
其中 x 为行号,y 为列号。这个公式的含义是:先跳过 x 个完整的行(每行 cols 个元素),再取该行中的第 y 个元素。
反向映射同样直接:
1 2
| x = index / cols y = index % cols
|
整数除法得到行号,取余得到列号。
三维数组的推广
对于三维数组 [dim_x][dim_y][dim_z],坐标 (x, y, z) 的一维索引为:
1
| index = x * (dim_y * dim_z) + y * dim_z + z
|
拆解如下:
x * (dim_y * dim_z):跳过前面 x 个完整的二维平面,每个平面有 dim_y × dim_z 个元素
y * dim_z:在当前平面内,跳过 y 个完整的行,每行有 dim_z 个元素
+ z:在当前行内,取第 z 个元素
反向映射通过连续整除和取余完成:
1 2 3 4 5
| total_xy = dim_y * dim_z x = index / total_xy remainder = index % total_xy y = remainder / dim_z z = remainder % dim_z
|
通用规律
对于 N 维数组,各维长度分别为 d1, d2, ..., dn,坐标 (i1, i2, ..., in) 在行优先存储下的一维索引通式为:
1 2 3 4 5
| index = i1 × (d2×d3×...×dn) + i2 × (d3×d4×...×dn) + ... + i_{n-1} × (dn) + in
|
每一维度的权重是该维度之后所有维度的长度乘积。最内层维度的权重为 1,最外层维度的权重最大。
反向映射是从最外层开始,逐层做整数除法和取余:
1 2 3 4 5 6 7 8 9
| weight = d2 × d3 × ... × dn i1 = index / weight remainder = index % weight
weight = d3 × d4 × ... × dn i2 = remainder / weight remainder = remainder % weight
...
|
既然
既然有如此规律,不如写个类来encoder和decoder?
先了解一下std::mdspan。
std::mdspan
在C++23之前,操作多维数组是一件尴尬的事。
原生数组声明为int arr[3][4][5]后,一旦被传递,维度信息就会丢失。传递指针需要手动携带各维长度,索引计算要自己写x * 4 * 5 + y * 5 + z。使用vector<vector<vector<int>>>则导致内存不连续、访问开销高、缓存友好性差。
工业界存在大量多维数组处理需求:图像处理、科学计算、机器学习张量操作、物理模拟。这些场景需要一种类型:描述一块连续内存,将其解释为特定维度的多维数组,且不额外拷贝数据。这就是mdspan。
mdspan是多维跨度视图(Multidimensional Span)的缩写。它有三个核心特征:
- 非拥有:不分配内存,不管理生命周期,仅对现有连续内存提供视图
- 多维索引:支持
[i][j][k]或(i,j,k)风格的多维下标访问
- 布局策略:可指定行优先、列优先、自定义映射规则
mdspan的设计目标是与span形成镜像:span是一维连续序列的非拥有视图,mdspan是多维连续序列的非拥有视图。
设计类
受 mdspan 启发,先来只做坐标与索引的转换,不涉足内存管理、访问器、布局策略的全特化。专注做好一件事。
设计一个类,就两个功能,给定N维坐标,返回一维索引。给定一维索引,返回N维坐标。
首先先搞清楚维度的长度和步长这两个概念:多维数组索引转换的核心在于区分“维度长度”和“步长”:维度长度表示该维度有多少个元素,是静态的容量描述;步长表示该维度坐标每增加1时,一维索引需要增加的数值,是动态的间距描述。行优先布局下,最内层维度的步长为1,向外每一层的步长等于内层所有维度的长度乘积;列优先则相反,最外层步长为1,向内累乘。编码(坐标转索引)是将各维坐标乘以对应步长后求和,解码(索引转坐标)则是从外层开始连续除以步长并取余。步长才是内存寻址的真正“进位制”,维度长度仅用于计算步长和边界检查,不能直接用于解码。
由于我变参模板类写得太少掌握的不好,让小d老师帮我理解代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272
| #include <vector> #include <array> #include <concepts> #include <stdexcept> #include <cstddef>
template<std::unsigned_integral T = std::size_t> class row_major { private: std::vector<T> dims_; std::vector<T> strides_; std::size_t rank_; T total_size_;
public:
explicit row_major(std::initializer_list<T> dims) : dims_(dims), // 拷贝初始化列表到vector rank_(dims.size()), // 维度数 = 列表长度 total_size_(1) { strides_.resize(rank_); for (std::size_t i = rank_; i-- > 0; ) { strides_[i] = total_size_; total_size_ *= dims_[i]; } }
explicit row_major(const std::vector<T>& dims) : dims_(dims), // 拷贝整个vector rank_(dims.size()), total_size_(1) { strides_.resize(rank_); for (std::size_t i = rank_; i-- > 0; ) { strides_[i] = total_size_; total_size_ *= dims_[i]; } }
template<typename... Indices> requires (sizeof...(Indices) == rank_) T encode(Indices... indices) const { std::vector<T> idx = {static_cast<T>(indices)...}; T offset = 0; for (std::size_t i = 0; i < rank_; ++i) { offset += idx[i] * strides_[i]; } return offset; }
T encode(const std::vector<T>& indices) const { if (indices.size() != rank_) { throw std::out_of_range("row_major::encode: 维度不匹配"); } T offset = 0; for (std::size_t i = 0; i < rank_; ++i) { offset += indices[i] * strides_[i]; } return offset; }
template<std::size_t N> T encode(const std::array<T, N>& indices) const { static_assert(N > 0, "维度必须大于0"); T offset = 0; for (std::size_t i = 0; i < rank_; ++i) { offset += indices[i] * strides_[i]; } return offset; }
std::vector<T> decode(T index) const { if (index >= total_size_) { throw std::out_of_range("row_major::decode: 索引越界"); } std::vector<T> coords(rank_); T remaining = index; for (std::size_t i = 0; i < rank_; ++i) { coords[i] = remaining / strides_[i]; remaining %= strides_[i]; } return coords; }
void decode(T index, T* out_coords) const { if (index >= total_size_) { throw std::out_of_range("row_major::decode: 索引越界"); } T remaining = index; for (std::size_t i = 0; i < rank_; ++i) { out_coords[i] = remaining / strides_[i]; remaining %= strides_[i]; } } std::size_t rank() const noexcept { return rank_; } T size() const noexcept { return total_size_; } T dim(std::size_t i) const { return dims_[i]; } T stride(std::size_t i) const { return strides_[i]; } const std::vector<T>& dims() const noexcept { return dims_; } const std::vector<T>& strides() const noexcept { return strides_; } };
|