Search Apps Documentation Source Content File Folder Download Copy Actions Download

arithmetic.gno

10.12 Kb · 450 lines
  1// arithmetic provides arithmetic operations for Uint objects.
  2// This includes basic binary operations such as addition, subtraction, multiplication, division, and modulo operations
  3// as well as overflow checks, and negation. These functions are essential for numeric
  4// calculations using 256-bit unsigned integers.
  5package uint256
  6
  7import "math/bits"
  8
  9// Add sets z to the sum x+y and returns z.
 10func (z *Uint) Add(x, y *Uint) *Uint {
 11	var carry uint64
 12	z[0], carry = bits.Add64(x[0], y[0], 0)
 13	z[1], carry = bits.Add64(x[1], y[1], carry)
 14	z[2], carry = bits.Add64(x[2], y[2], carry)
 15	z[3], _ = bits.Add64(x[3], y[3], carry)
 16	return z
 17}
 18
 19// AddOverflow sets z to the sum x+y and returns z and true if overflow occurred.
 20func (z *Uint) AddOverflow(x, y *Uint) (*Uint, bool) {
 21	var carry uint64
 22	z[0], carry = bits.Add64(x[0], y[0], 0)
 23	z[1], carry = bits.Add64(x[1], y[1], carry)
 24	z[2], carry = bits.Add64(x[2], y[2], carry)
 25	z[3], carry = bits.Add64(x[3], y[3], carry)
 26	return z, carry != 0
 27}
 28
 29// Sub sets z to the difference x-y and returns z.
 30func (z *Uint) Sub(x, y *Uint) *Uint {
 31	var carry uint64
 32	z[0], carry = bits.Sub64(x[0], y[0], 0)
 33	z[1], carry = bits.Sub64(x[1], y[1], carry)
 34	z[2], carry = bits.Sub64(x[2], y[2], carry)
 35	z[3], _ = bits.Sub64(x[3], y[3], carry)
 36	return z
 37}
 38
 39// SubOverflow sets z to the difference x-y and returns z and true if underflow occurred.
 40func (z *Uint) SubOverflow(x, y *Uint) (*Uint, bool) {
 41	var carry uint64
 42	z[0], carry = bits.Sub64(x[0], y[0], 0)
 43	z[1], carry = bits.Sub64(x[1], y[1], carry)
 44	z[2], carry = bits.Sub64(x[2], y[2], carry)
 45	z[3], carry = bits.Sub64(x[3], y[3], carry)
 46	return z, carry != 0
 47}
 48
 49// Neg returns -x mod 2^256.
 50func (z *Uint) Neg(x *Uint) *Uint {
 51	return z.Sub(Zero(), x)
 52}
 53
 54// Mul sets z to the product x*y and returns z.
 55func (z *Uint) Mul(x, y *Uint) *Uint {
 56	var (
 57		res              Uint
 58		carry            uint64
 59		res1, res2, res3 uint64
 60	)
 61
 62	carry, res[0] = bits.Mul64(x[0], y[0])
 63	carry, res1 = umulHop(carry, x[1], y[0])
 64	carry, res2 = umulHop(carry, x[2], y[0])
 65	res3 = x[3]*y[0] + carry
 66
 67	carry, res[1] = umulHop(res1, x[0], y[1])
 68	carry, res2 = umulStep(res2, x[1], y[1], carry)
 69	res3 = res3 + x[2]*y[1] + carry
 70
 71	carry, res[2] = umulHop(res2, x[0], y[2])
 72	res3 = res3 + x[1]*y[2] + carry
 73
 74	res[3] = res3 + x[0]*y[3]
 75
 76	return z.Set(&res)
 77}
 78
 79// MulOverflow sets z to the product x*y and returns z and true if overflow occurred.
 80func (z *Uint) MulOverflow(x, y *Uint) (*Uint, bool) {
 81	p := umul(x, y)
 82	copy(z[:], p[:4])
 83	return z, (p[4] | p[5] | p[6] | p[7]) != 0
 84}
 85
 86// Div sets z to the quotient x/y and returns z.
 87// It panics if y == 0.
 88func (z *Uint) Div(x, y *Uint) *Uint {
 89	if y.IsZero() {
 90		panic("division by zero")
 91	}
 92	if y.Gt(x) {
 93		return z.Clear()
 94	}
 95	if x.Eq(y) {
 96		return z.SetOne()
 97	}
 98	// Shortcut some cases
 99	if x.IsUint64() {
100		return z.SetUint64(x.Uint64() / y.Uint64())
101	}
102
103	// At this point, we know
104	// x/y ; x > y > 0
105
106	var quot Uint
107	udivrem(quot[:], x[:], y)
108	return z.Set(&quot)
109}
110
111// Mod sets z to the modulus x%y and returns z.
112// It panics if y == 0.
113func (z *Uint) Mod(x, y *Uint) *Uint {
114	if y.IsZero() {
115		panic("modulo by zero")
116	}
117	if x.IsZero() {
118		return z.Clear()
119	}
120	switch x.Cmp(y) {
121	case -1:
122		// x < y
123		copy(z[:], x[:])
124		return z
125	case 0:
126		// x == y
127		return z.Clear() // They are equal
128	}
129
130	// At this point:
131	// x != 0
132	// y != 0
133	// x > y
134
135	// Shortcut trivial case
136	if x.IsUint64() {
137		return z.SetUint64(x.Uint64() % y.Uint64())
138	}
139
140	var quot Uint
141	*z = udivrem(quot[:], x[:], y)
142	return z
143}
144
145// MulMod sets z to (x * y) mod m and returns z.
146// It panics if m == 0.
147func (z *Uint) MulMod(x, y, m *Uint) *Uint {
148	if m.IsZero() {
149		panic("modulo by zero")
150	}
151	if x.IsZero() || y.IsZero() {
152		return z.Clear()
153	}
154	p := umul(x, y)
155
156	if m[3] != 0 {
157		mu := Reciprocal(m)
158		r := reduce4(p, m, mu)
159		return z.Set(&r)
160	}
161
162	var (
163		pl Uint
164		ph Uint
165	)
166
167	pl[0], pl[1], pl[2], pl[3] = p[0], p[1], p[2], p[3]
168	ph[0], ph[1], ph[2], ph[3] = p[4], p[5], p[6], p[7]
169
170	// If the multiplication is within 256 bits use Mod().
171	if ph.IsZero() {
172		return z.Mod(&pl, m)
173	}
174
175	var quot [8]uint64
176	rem := udivrem(quot[:], p[:], m)
177	return z.Set(&rem)
178}
179
180// DivMod sets z to the quotient x/y and m to the modulus x%y, returning the pair (z, m).
181// It panics if y == 0.
182func (z *Uint) DivMod(x, y, m *Uint) (*Uint, *Uint) {
183	if y.IsZero() {
184		panic("division by zero")
185	}
186
187	switch x.Cmp(y) {
188	case -1:
189		// x < y
190		return z.Clear(), m.Set(x)
191	case 0:
192		// x == y
193		return z.SetOne(), m.Clear()
194	}
195
196	// At this point:
197	// x != 0
198	// y != 0
199	// x > y
200
201	// Shortcut trivial case
202	if x.IsUint64() {
203		x0, y0 := x.Uint64(), y.Uint64()
204		return z.SetUint64(x0 / y0), m.SetUint64(x0 % y0)
205	}
206
207	var quot Uint
208	*m = udivrem(quot[:], x[:], y)
209	*z = quot
210	return z, m
211}
212
213// udivrem divides u by d and produces both quotient and remainder.
214// The quotient is stored in provided quot - len(u)-len(d)+1 words.
215// It loosely follows the Knuth's division algorithm (sometimes referenced as "schoolbook" division) using 64-bit words.
216// See Knuth, Volume 2, section 4.3.1, Algorithm D.
217func udivrem(quot, u []uint64, d *Uint) (rem Uint) {
218	var dLen int
219	for i := len(d) - 1; i >= 0; i-- {
220		if d[i] != 0 {
221			dLen = i + 1
222			break
223		}
224	}
225
226	shift := uint(bits.LeadingZeros64(d[dLen-1]))
227
228	var dnStorage Uint
229	dn := dnStorage[:dLen]
230	for i := dLen - 1; i > 0; i-- {
231		dn[i] = (d[i] << shift) | (d[i-1] >> (64 - shift))
232	}
233	dn[0] = d[0] << shift
234
235	var uLen int
236	for i := len(u) - 1; i >= 0; i-- {
237		if u[i] != 0 {
238			uLen = i + 1
239			break
240		}
241	}
242
243	if uLen < dLen {
244		copy(rem[:], u)
245		return rem
246	}
247
248	var unStorage [9]uint64
249	un := unStorage[:uLen+1]
250	un[uLen] = u[uLen-1] >> (64 - shift)
251	for i := uLen - 1; i > 0; i-- {
252		un[i] = (u[i] << shift) | (u[i-1] >> (64 - shift))
253	}
254	un[0] = u[0] << shift
255
256	if dLen == 1 {
257		r := udivremBy1(quot, un, dn[0])
258		rem.SetUint64(r >> shift)
259		return rem
260	}
261
262	udivremKnuth(quot, un, dn)
263
264	for i := 0; i < dLen-1; i++ {
265		rem[i] = (un[i] >> shift) | (un[i+1] << (64 - shift))
266	}
267	rem[dLen-1] = un[dLen-1] >> shift
268
269	return rem
270}
271
272// umul computes full 256 x 256 -> 512 multiplication.
273func umul(x, y *Uint) [8]uint64 {
274	var res [8]uint64
275
276	topX := highestNonZeroWord(x)
277	topY := highestNonZeroWord(y)
278
279	if topX < 0 || topY < 0 {
280		return res
281	}
282
283	lenX := topX + 1
284	lenY := topY + 1
285
286	for i := 0; i < lenX; i++ {
287		xi := x[i]
288		if xi == 0 {
289			continue
290		}
291		var carry uint64
292		k := i
293		for j := 0; j < lenY; j++ {
294			hi, lo := bits.Mul64(xi, y[j])
295			lo, c := bits.Add64(lo, res[k], 0)
296			hi += c
297			lo, c = bits.Add64(lo, carry, 0)
298			hi += c
299			res[k] = lo
300			carry = hi
301			k++
302		}
303		res[i+lenY] = carry
304	}
305
306	return res
307}
308
309// highestNonZeroWord returns the highest index with non-zero value or -1 if the Uint is zero.
310func highestNonZeroWord(u *Uint) int {
311	for i := 3; i >= 0; i-- {
312		if u[i] != 0 {
313			return i
314		}
315	}
316	return -1
317}
318
319// umulStep computes (hi * 2^64 + lo) = z + (x * y) + carry.
320func umulStep(z, x, y, carry uint64) (hi, lo uint64) {
321	hi, lo = bits.Mul64(x, y)
322	lo, carry = bits.Add64(lo, carry, 0)
323	hi += carry
324	lo, carry = bits.Add64(lo, z, 0)
325	hi += carry
326	return hi, lo
327}
328
329// umulHop computes (hi * 2^64 + lo) = z + (x * y)
330func umulHop(z, x, y uint64) (hi, lo uint64) {
331	hi, lo = bits.Mul64(x, y)
332	lo, carry := bits.Add64(lo, z, 0)
333	hi += carry
334	return hi, lo
335}
336
337// udivremBy1 divides u by single normalized word d and produces both quotient and remainder.
338// The quotient is stored in provided quot.
339func udivremBy1(quot, u []uint64, d uint64) (rem uint64) {
340	reciprocal := reciprocal2by1(d)
341	rem = u[len(u)-1] // Set the top word as remainder.
342	for j := len(u) - 2; j >= 0; j-- {
343		quot[j], rem = udivrem2by1(rem, u[j], d, reciprocal)
344	}
345	return rem
346}
347
348// udivremKnuth implements the division of u by normalized multiple word d from the Knuth's division algorithm.
349// The quotient is stored in provided quot - len(u)-len(d) words.
350// Updates u to contain the remainder - len(d) words.
351func udivremKnuth(quot, u, d []uint64) {
352	dLen := len(d)
353	dh := d[dLen-1]
354	dl := d[dLen-2]
355	reciprocal := reciprocal2by1(dh)
356
357	for j := len(u) - dLen - 1; j >= 0; j-- {
358		u2 := u[j+dLen]
359		u1 := u[j+dLen-1]
360		u0 := u[j+dLen-2]
361
362		var qhat, rhat uint64
363		if u2 >= dh { // Division overflows.
364			qhat = MAX_UINT64
365			// NOTE: Add "qhat one to big" adjustment (not needed for correctness, but helps avoiding "add back" case).
366		} else {
367			qhat, rhat = udivrem2by1(u2, u1, dh, reciprocal)
368			ph, pl := bits.Mul64(qhat, dl)
369			if ph > rhat || (ph == rhat && pl > u0) {
370				qhat--
371				// NOTE: Add "qhat one to big" adjustment (not needed for correctness, but helps avoiding "add back" case).
372			}
373		}
374
375		// Multiply and subtract.
376		borrow := subMulTo(u[j:], d, qhat)
377		u[j+dLen] = u2 - borrow
378		if u2 < borrow { // Too much subtracted, add back.
379			qhat--
380			u[j+dLen] += addTo(u[j:], d)
381		}
382
383		quot[j] = qhat // Store quotient digit.
384	}
385}
386
387// isBitSet returns true if bit n-th is set, where n = 0 is LSB.
388// The n must be <= 255.
389func (z *Uint) isBitSet(n uint) bool {
390	return (z[n/64] & (1 << (n % 64))) != 0
391}
392
393func (z *Uint) IsOverflow() bool {
394	return z.isBitSet(255)
395}
396
397// addTo computes x += y.
398// Requires len(x) >= len(y).
399func addTo(x, y []uint64) uint64 {
400	var carry uint64
401	for i := 0; i < len(y); i++ {
402		x[i], carry = bits.Add64(x[i], y[i], carry)
403	}
404	return carry
405}
406
407// subMulTo computes x -= y * multiplier.
408// Requires len(x) >= len(y).
409func subMulTo(x, y []uint64, multiplier uint64) uint64 {
410	var borrow uint64
411	for i := 0; i < len(y); i++ {
412		s, carry1 := bits.Sub64(x[i], borrow, 0)
413		ph, pl := bits.Mul64(y[i], multiplier)
414		t, carry2 := bits.Sub64(s, pl, 0)
415		x[i] = t
416		borrow = ph + carry1 + carry2
417	}
418	return borrow
419}
420
421// reciprocal2by1 computes <^d, ^0> / d.
422func reciprocal2by1(d uint64) uint64 {
423	reciprocal, _ := bits.Div64(^d, MAX_UINT64, d)
424	return reciprocal
425}
426
427// udivrem2by1 divides <uh, ul> / d and produces both quotient and remainder.
428// It uses the provided d's reciprocal.
429// Implementation ported from https://github.com/chfast/intx and is based on
430// "Improved division by invariant integers", Algorithm 4.
431func udivrem2by1(uh, ul, d, reciprocal uint64) (quot, rem uint64) {
432	qh, ql := bits.Mul64(reciprocal, uh)
433	ql, carry := bits.Add64(ql, ul, 0)
434	qh, _ = bits.Add64(qh, uh, carry)
435	qh++
436
437	r := ul - qh*d
438
439	if r > ql {
440		qh--
441		r += d
442	}
443
444	if r >= d {
445		qh++
446		r -= d
447	}
448
449	return qh, r
450}