本帖最后由 zjh106 于 2023-11-03 22:27 编辑
一句话介绍功能等同于函数(x) => 1 / Math.sqrt(x) 。源码传送门: float Q_rsqrt( float number ) { long i; float x2, y; const float threehalfs = 1.5F;
x2 = number * 0.5F; y = number; i = * ( long * ) &y; // evil floating point bit level hacking i = 0x5f3759df - ( i >> 1 ); // what the fuck? y = * ( float * ) &i; y = y * ( threehalfs - ( x2 * y * y ) ); // 1st iteration // y = y * ( threehalfs - ( x2 * y * y ) ); // 2nd iteration, this can be removed
#ifndef Q3_VM #ifdef __linux__ assert( !isnan(y) ); // bk010122 - FPE? #endif #endif return y; }
前置知识牛顿迭代:本文用到的公式:y = 1 / sqrt(x) -> 1 / y^2 - x = 0 代入得结果。 IEEE-754浮点数表示:一句话总结:正数的符号位肯定为0,x = (1 + M / 2^23) * 2^(E - 127) 。 推导过程上述代码只有这几行 i = * ( long * ) &y; i = 0x5f3759df - ( i >> 1 ); y = * ( float * ) &i;
是不那么显然的。 考虑x ^ -0.5 的对数:-0.5 * (log2(1 + M / 2^23) + (E - 127)) ,log2(1 + M / 2^23) 必定属于[0, 1) ,所以可以近似为M / 2^23 。于是log2(x) ≈ (M + 2^23 * E - 2^23 * 127) / 2^23 ,其中M + 2^23 * E 就是将浮点数解释为int 后的值,记为b ,log2(x) ≈ b / 2^23 - 127 。设结果A = log2(x ^ -0.5) 的这一部分为a ,则a / 2^23 - 127 = -0.5 * (b / 2^23 - 127) ,解得a = 381 * 2^22 - (b >> 1) 。 至此我们已经能够理解整个算法。但上述代码的magic number并不是381 * 2^22 。别着急,我们马上讨论。 误差分析从直觉来看需要将y = x 稍微上移才能得到最优近似,所以我们设v = err(x) = log2(x + 1) - x, x ∈ [0, 1) 。根据参考链接2,计算v 常用的方式有: (max(err(v)) + min(err(v))) / 2 。err'(x) = 1 / (log(2) * (1 + x)) - 1 ,曲线先升后降,所以x = 1 / log(2) - 1 得最大值,结果为0.0430357 。- 取
err(x) 在[0,1] 上积分的平均。在wolframalpha中算积分Integrate[log2(1+x)-x,x] 得let g = x => (2*(x+1)*Math.log(x+1)-x*(x*Math.log(2)+2))/Math.log(4) ,代入得0.057304959 。
然而,这些都只是我个人的猜测。已知最终作者使用的magic number是0x5f3759df ,我们来算下最终作者选用的v 值。log2(1 + M / 2^23) 的最优近似为M / 2^23 + v ,于是log2(x) ≈ b / 2^23 - 127 - v => a / 2^23 - 127 - v = -0.5 * (b / 2^23 - 127 - v) ,得v = 0.0450465679 。 课后作业实现double 的版本。 #include <bits/stdc++.h> // Copyright 2023 hans7
double q_rsqrt(double v) { int64_t tmp = *reinterpret_cast<int64_t*>(&v); tmp = 6910773628200026112LL - (tmp >> 1); double res = *reinterpret_cast<double*>(&tmp); res = res * (1.5 - v * 0.5 * res * res); return res; }
int main(int, char**) { srand(time(nullptr)); int v = rand(); printf("%d %.10lf %.10lf\n", v, q_rsqrt(v), abs(q_rsqrt(v) - 1 / sqrt(v))); return 0; }
注:若转载请注明大神论坛来源(本贴地址)与作者信息。
|