arithmetic.gno
10.12 Kb · 450 lines
1// arithmetic provides arithmetic operations for Uint objects.
2// This includes basic binary operations such as addition, subtraction, multiplication, division, and modulo operations
3// as well as overflow checks, and negation. These functions are essential for numeric
4// calculations using 256-bit unsigned integers.
5package uint256
6
7import "math/bits"
8
9// Add sets z to the sum x+y and returns z.
10func (z *Uint) Add(x, y *Uint) *Uint {
11 var carry uint64
12 z[0], carry = bits.Add64(x[0], y[0], 0)
13 z[1], carry = bits.Add64(x[1], y[1], carry)
14 z[2], carry = bits.Add64(x[2], y[2], carry)
15 z[3], _ = bits.Add64(x[3], y[3], carry)
16 return z
17}
18
19// AddOverflow sets z to the sum x+y and returns z and true if overflow occurred.
20func (z *Uint) AddOverflow(x, y *Uint) (*Uint, bool) {
21 var carry uint64
22 z[0], carry = bits.Add64(x[0], y[0], 0)
23 z[1], carry = bits.Add64(x[1], y[1], carry)
24 z[2], carry = bits.Add64(x[2], y[2], carry)
25 z[3], carry = bits.Add64(x[3], y[3], carry)
26 return z, carry != 0
27}
28
29// Sub sets z to the difference x-y and returns z.
30func (z *Uint) Sub(x, y *Uint) *Uint {
31 var carry uint64
32 z[0], carry = bits.Sub64(x[0], y[0], 0)
33 z[1], carry = bits.Sub64(x[1], y[1], carry)
34 z[2], carry = bits.Sub64(x[2], y[2], carry)
35 z[3], _ = bits.Sub64(x[3], y[3], carry)
36 return z
37}
38
39// SubOverflow sets z to the difference x-y and returns z and true if underflow occurred.
40func (z *Uint) SubOverflow(x, y *Uint) (*Uint, bool) {
41 var carry uint64
42 z[0], carry = bits.Sub64(x[0], y[0], 0)
43 z[1], carry = bits.Sub64(x[1], y[1], carry)
44 z[2], carry = bits.Sub64(x[2], y[2], carry)
45 z[3], carry = bits.Sub64(x[3], y[3], carry)
46 return z, carry != 0
47}
48
49// Neg returns -x mod 2^256.
50func (z *Uint) Neg(x *Uint) *Uint {
51 return z.Sub(Zero(), x)
52}
53
54// Mul sets z to the product x*y and returns z.
55func (z *Uint) Mul(x, y *Uint) *Uint {
56 var (
57 res Uint
58 carry uint64
59 res1, res2, res3 uint64
60 )
61
62 carry, res[0] = bits.Mul64(x[0], y[0])
63 carry, res1 = umulHop(carry, x[1], y[0])
64 carry, res2 = umulHop(carry, x[2], y[0])
65 res3 = x[3]*y[0] + carry
66
67 carry, res[1] = umulHop(res1, x[0], y[1])
68 carry, res2 = umulStep(res2, x[1], y[1], carry)
69 res3 = res3 + x[2]*y[1] + carry
70
71 carry, res[2] = umulHop(res2, x[0], y[2])
72 res3 = res3 + x[1]*y[2] + carry
73
74 res[3] = res3 + x[0]*y[3]
75
76 return z.Set(&res)
77}
78
79// MulOverflow sets z to the product x*y and returns z and true if overflow occurred.
80func (z *Uint) MulOverflow(x, y *Uint) (*Uint, bool) {
81 p := umul(x, y)
82 copy(z[:], p[:4])
83 return z, (p[4] | p[5] | p[6] | p[7]) != 0
84}
85
86// Div sets z to the quotient x/y and returns z.
87// It panics if y == 0.
88func (z *Uint) Div(x, y *Uint) *Uint {
89 if y.IsZero() {
90 panic("division by zero")
91 }
92 if y.Gt(x) {
93 return z.Clear()
94 }
95 if x.Eq(y) {
96 return z.SetOne()
97 }
98 // Shortcut some cases
99 if x.IsUint64() {
100 return z.SetUint64(x.Uint64() / y.Uint64())
101 }
102
103 // At this point, we know
104 // x/y ; x > y > 0
105
106 var quot Uint
107 udivrem(quot[:], x[:], y)
108 return z.Set(")
109}
110
111// Mod sets z to the modulus x%y and returns z.
112// It panics if y == 0.
113func (z *Uint) Mod(x, y *Uint) *Uint {
114 if y.IsZero() {
115 panic("modulo by zero")
116 }
117 if x.IsZero() {
118 return z.Clear()
119 }
120 switch x.Cmp(y) {
121 case -1:
122 // x < y
123 copy(z[:], x[:])
124 return z
125 case 0:
126 // x == y
127 return z.Clear() // They are equal
128 }
129
130 // At this point:
131 // x != 0
132 // y != 0
133 // x > y
134
135 // Shortcut trivial case
136 if x.IsUint64() {
137 return z.SetUint64(x.Uint64() % y.Uint64())
138 }
139
140 var quot Uint
141 *z = udivrem(quot[:], x[:], y)
142 return z
143}
144
145// MulMod sets z to (x * y) mod m and returns z.
146// It panics if m == 0.
147func (z *Uint) MulMod(x, y, m *Uint) *Uint {
148 if m.IsZero() {
149 panic("modulo by zero")
150 }
151 if x.IsZero() || y.IsZero() {
152 return z.Clear()
153 }
154 p := umul(x, y)
155
156 if m[3] != 0 {
157 mu := Reciprocal(m)
158 r := reduce4(p, m, mu)
159 return z.Set(&r)
160 }
161
162 var (
163 pl Uint
164 ph Uint
165 )
166
167 pl[0], pl[1], pl[2], pl[3] = p[0], p[1], p[2], p[3]
168 ph[0], ph[1], ph[2], ph[3] = p[4], p[5], p[6], p[7]
169
170 // If the multiplication is within 256 bits use Mod().
171 if ph.IsZero() {
172 return z.Mod(&pl, m)
173 }
174
175 var quot [8]uint64
176 rem := udivrem(quot[:], p[:], m)
177 return z.Set(&rem)
178}
179
180// DivMod sets z to the quotient x/y and m to the modulus x%y, returning the pair (z, m).
181// It panics if y == 0.
182func (z *Uint) DivMod(x, y, m *Uint) (*Uint, *Uint) {
183 if y.IsZero() {
184 panic("division by zero")
185 }
186
187 switch x.Cmp(y) {
188 case -1:
189 // x < y
190 return z.Clear(), m.Set(x)
191 case 0:
192 // x == y
193 return z.SetOne(), m.Clear()
194 }
195
196 // At this point:
197 // x != 0
198 // y != 0
199 // x > y
200
201 // Shortcut trivial case
202 if x.IsUint64() {
203 x0, y0 := x.Uint64(), y.Uint64()
204 return z.SetUint64(x0 / y0), m.SetUint64(x0 % y0)
205 }
206
207 var quot Uint
208 *m = udivrem(quot[:], x[:], y)
209 *z = quot
210 return z, m
211}
212
213// udivrem divides u by d and produces both quotient and remainder.
214// The quotient is stored in provided quot - len(u)-len(d)+1 words.
215// It loosely follows the Knuth's division algorithm (sometimes referenced as "schoolbook" division) using 64-bit words.
216// See Knuth, Volume 2, section 4.3.1, Algorithm D.
217func udivrem(quot, u []uint64, d *Uint) (rem Uint) {
218 var dLen int
219 for i := len(d) - 1; i >= 0; i-- {
220 if d[i] != 0 {
221 dLen = i + 1
222 break
223 }
224 }
225
226 shift := uint(bits.LeadingZeros64(d[dLen-1]))
227
228 var dnStorage Uint
229 dn := dnStorage[:dLen]
230 for i := dLen - 1; i > 0; i-- {
231 dn[i] = (d[i] << shift) | (d[i-1] >> (64 - shift))
232 }
233 dn[0] = d[0] << shift
234
235 var uLen int
236 for i := len(u) - 1; i >= 0; i-- {
237 if u[i] != 0 {
238 uLen = i + 1
239 break
240 }
241 }
242
243 if uLen < dLen {
244 copy(rem[:], u)
245 return rem
246 }
247
248 var unStorage [9]uint64
249 un := unStorage[:uLen+1]
250 un[uLen] = u[uLen-1] >> (64 - shift)
251 for i := uLen - 1; i > 0; i-- {
252 un[i] = (u[i] << shift) | (u[i-1] >> (64 - shift))
253 }
254 un[0] = u[0] << shift
255
256 if dLen == 1 {
257 r := udivremBy1(quot, un, dn[0])
258 rem.SetUint64(r >> shift)
259 return rem
260 }
261
262 udivremKnuth(quot, un, dn)
263
264 for i := 0; i < dLen-1; i++ {
265 rem[i] = (un[i] >> shift) | (un[i+1] << (64 - shift))
266 }
267 rem[dLen-1] = un[dLen-1] >> shift
268
269 return rem
270}
271
272// umul computes full 256 x 256 -> 512 multiplication.
273func umul(x, y *Uint) [8]uint64 {
274 var res [8]uint64
275
276 topX := highestNonZeroWord(x)
277 topY := highestNonZeroWord(y)
278
279 if topX < 0 || topY < 0 {
280 return res
281 }
282
283 lenX := topX + 1
284 lenY := topY + 1
285
286 for i := 0; i < lenX; i++ {
287 xi := x[i]
288 if xi == 0 {
289 continue
290 }
291 var carry uint64
292 k := i
293 for j := 0; j < lenY; j++ {
294 hi, lo := bits.Mul64(xi, y[j])
295 lo, c := bits.Add64(lo, res[k], 0)
296 hi += c
297 lo, c = bits.Add64(lo, carry, 0)
298 hi += c
299 res[k] = lo
300 carry = hi
301 k++
302 }
303 res[i+lenY] = carry
304 }
305
306 return res
307}
308
309// highestNonZeroWord returns the highest index with non-zero value or -1 if the Uint is zero.
310func highestNonZeroWord(u *Uint) int {
311 for i := 3; i >= 0; i-- {
312 if u[i] != 0 {
313 return i
314 }
315 }
316 return -1
317}
318
319// umulStep computes (hi * 2^64 + lo) = z + (x * y) + carry.
320func umulStep(z, x, y, carry uint64) (hi, lo uint64) {
321 hi, lo = bits.Mul64(x, y)
322 lo, carry = bits.Add64(lo, carry, 0)
323 hi += carry
324 lo, carry = bits.Add64(lo, z, 0)
325 hi += carry
326 return hi, lo
327}
328
329// umulHop computes (hi * 2^64 + lo) = z + (x * y)
330func umulHop(z, x, y uint64) (hi, lo uint64) {
331 hi, lo = bits.Mul64(x, y)
332 lo, carry := bits.Add64(lo, z, 0)
333 hi += carry
334 return hi, lo
335}
336
337// udivremBy1 divides u by single normalized word d and produces both quotient and remainder.
338// The quotient is stored in provided quot.
339func udivremBy1(quot, u []uint64, d uint64) (rem uint64) {
340 reciprocal := reciprocal2by1(d)
341 rem = u[len(u)-1] // Set the top word as remainder.
342 for j := len(u) - 2; j >= 0; j-- {
343 quot[j], rem = udivrem2by1(rem, u[j], d, reciprocal)
344 }
345 return rem
346}
347
348// udivremKnuth implements the division of u by normalized multiple word d from the Knuth's division algorithm.
349// The quotient is stored in provided quot - len(u)-len(d) words.
350// Updates u to contain the remainder - len(d) words.
351func udivremKnuth(quot, u, d []uint64) {
352 dLen := len(d)
353 dh := d[dLen-1]
354 dl := d[dLen-2]
355 reciprocal := reciprocal2by1(dh)
356
357 for j := len(u) - dLen - 1; j >= 0; j-- {
358 u2 := u[j+dLen]
359 u1 := u[j+dLen-1]
360 u0 := u[j+dLen-2]
361
362 var qhat, rhat uint64
363 if u2 >= dh { // Division overflows.
364 qhat = MAX_UINT64
365 // NOTE: Add "qhat one to big" adjustment (not needed for correctness, but helps avoiding "add back" case).
366 } else {
367 qhat, rhat = udivrem2by1(u2, u1, dh, reciprocal)
368 ph, pl := bits.Mul64(qhat, dl)
369 if ph > rhat || (ph == rhat && pl > u0) {
370 qhat--
371 // NOTE: Add "qhat one to big" adjustment (not needed for correctness, but helps avoiding "add back" case).
372 }
373 }
374
375 // Multiply and subtract.
376 borrow := subMulTo(u[j:], d, qhat)
377 u[j+dLen] = u2 - borrow
378 if u2 < borrow { // Too much subtracted, add back.
379 qhat--
380 u[j+dLen] += addTo(u[j:], d)
381 }
382
383 quot[j] = qhat // Store quotient digit.
384 }
385}
386
387// isBitSet returns true if bit n-th is set, where n = 0 is LSB.
388// The n must be <= 255.
389func (z *Uint) isBitSet(n uint) bool {
390 return (z[n/64] & (1 << (n % 64))) != 0
391}
392
393func (z *Uint) IsOverflow() bool {
394 return z.isBitSet(255)
395}
396
397// addTo computes x += y.
398// Requires len(x) >= len(y).
399func addTo(x, y []uint64) uint64 {
400 var carry uint64
401 for i := 0; i < len(y); i++ {
402 x[i], carry = bits.Add64(x[i], y[i], carry)
403 }
404 return carry
405}
406
407// subMulTo computes x -= y * multiplier.
408// Requires len(x) >= len(y).
409func subMulTo(x, y []uint64, multiplier uint64) uint64 {
410 var borrow uint64
411 for i := 0; i < len(y); i++ {
412 s, carry1 := bits.Sub64(x[i], borrow, 0)
413 ph, pl := bits.Mul64(y[i], multiplier)
414 t, carry2 := bits.Sub64(s, pl, 0)
415 x[i] = t
416 borrow = ph + carry1 + carry2
417 }
418 return borrow
419}
420
421// reciprocal2by1 computes <^d, ^0> / d.
422func reciprocal2by1(d uint64) uint64 {
423 reciprocal, _ := bits.Div64(^d, MAX_UINT64, d)
424 return reciprocal
425}
426
427// udivrem2by1 divides <uh, ul> / d and produces both quotient and remainder.
428// It uses the provided d's reciprocal.
429// Implementation ported from https://github.com/chfast/intx and is based on
430// "Improved division by invariant integers", Algorithm 4.
431func udivrem2by1(uh, ul, d, reciprocal uint64) (quot, rem uint64) {
432 qh, ql := bits.Mul64(reciprocal, uh)
433 ql, carry := bits.Add64(ql, ul, 0)
434 qh, _ = bits.Add64(qh, uh, carry)
435 qh++
436
437 r := ul - qh*d
438
439 if r > ql {
440 qh--
441 r += d
442 }
443
444 if r >= d {
445 qh++
446 r -= d
447 }
448
449 return qh, r
450}