sqrt_price_math.gno
11.11 Kb · 329 lines
1package gnsmath
2
3import (
4 i256 "gno.land/p/gnoswap/int256"
5 u256 "gno.land/p/gnoswap/uint256"
6)
7
8var (
9 q96 = u256.Zero().Lsh(u256.One(), 96) // 2^96
10 max160 = u256.Zero().Sub(u256.Zero().Lsh(u256.One(), 160), u256.One()) // 2^160 - 1
11 maxInt256 = u256.Zero().Sub(u256.Zero().Lsh(u256.One(), 255), u256.One()) // 2^255 - 1
12
13 MIN_SQRT_RATIO = u256.MustFromDecimal("4295128739")
14 MAX_SQRT_RATIO = u256.MustFromDecimal("1461446703485210103287273052203988822378723970342")
15)
16
17// getNextPriceAmount0Add calculates the next sqrt price when adding token0 liquidity,
18// rounding up to ensure conservative pricing for the protocol.
19// This internal function handles the case where token0 is being added to the pool.
20func getNextPriceAmount0Add(
21 currentSqrtPriceX96, liquidity, amountToAdd *u256.Uint,
22) *u256.Uint {
23 // liquidityShifted = liquidity << 96
24 liquidityShifted := u256.Zero().Lsh(liquidity, Q96_RESOLUTION)
25 // amountTimesSqrtPrice = amount * sqrtPrice
26 amountTimesSqrtPrice := u256.Zero().Mul(amountToAdd, currentSqrtPriceX96)
27
28 // Overflow check: Ensure (amountTimesSqrtPrice / amountToAdd) == currentSqrtPriceX96
29 quotientCheck := u256.Zero().Div(amountTimesSqrtPrice, amountToAdd)
30 if quotientCheck.Eq(currentSqrtPriceX96) {
31 // denominator = liquidityShifted + amountTimesSqrtPrice
32 denominator := u256.Zero().Add(liquidityShifted, amountTimesSqrtPrice)
33 // only take this path when denominator >= liquidityShifted
34 if denominator.Gte(liquidityShifted) {
35 return u256.MulDivRoundingUp(liquidityShifted, currentSqrtPriceX96, denominator)
36 }
37 }
38
39 // fallback: liquidityShifted / ((liquidityShifted / sqrtPrice) + amount)
40 divValue := u256.Zero().Div(liquidityShifted, currentSqrtPriceX96)
41 denominator := u256.Zero().Add(divValue, amountToAdd)
42 return u256.DivRoundingUp(liquidityShifted, denominator)
43}
44
45// getNextPriceAmount0Remove calculates the next sqrt price when removing token0 liquidity,
46// rounding up to ensure conservative pricing for the protocol.
47// This internal function handles the case where token0 is being removed from the pool.
48// Panics if validation checks fail (invalid pool sqrt price calculation).
49func getNextPriceAmount0Remove(
50 currentSqrtPriceX96, liquidity, amountToRemove *u256.Uint,
51) *u256.Uint {
52 // liquidityShifted = liquidity << 96
53 liquidityShifted := u256.Zero().Lsh(liquidity, Q96_RESOLUTION)
54 // amountTimesSqrtPrice = amountToRemove * currentSqrtPriceX96
55 amountTimesSqrtPrice := u256.Zero().Mul(amountToRemove, currentSqrtPriceX96)
56
57 // Validation checks
58 quotientCheck := u256.Zero().Div(amountTimesSqrtPrice, amountToRemove)
59 if !quotientCheck.Eq(currentSqrtPriceX96) || !liquidityShifted.Gt(amountTimesSqrtPrice) {
60 panic(errInvalidPoolSqrtPrice)
61 }
62
63 denominator := u256.Zero().Sub(liquidityShifted, amountTimesSqrtPrice)
64 return u256.MulDivRoundingUp(liquidityShifted, currentSqrtPriceX96, denominator)
65}
66
67// getNextSqrtPriceFromAmount0RoundingUp calculates the next sqrt price based on token0 amount,
68// always rounding up to ensure conservative pricing in both exact output and exact input cases.
69// The add parameter determines whether liquidity is being added (true) or removed (false).
70func getNextSqrtPriceFromAmount0RoundingUp(
71 sqrtPX96 *u256.Uint,
72 liquidity *u256.Uint,
73 amount *u256.Uint,
74 add bool,
75) *u256.Uint {
76 // Shortcut: if no amount, return original price
77 if amount.IsZero() {
78 return sqrtPX96
79 }
80
81 if add {
82 return getNextPriceAmount0Add(sqrtPX96, liquidity, amount)
83 }
84 return getNextPriceAmount0Remove(sqrtPX96, liquidity, amount)
85}
86
87// getNextPriceAmount1Add calculates the next sqrt price when adding token1,
88// preserving rounding-down logic for the final result.
89// This internal function handles the case where token1 is being added to the pool.
90func getNextPriceAmount1Add(
91 sqrtPX96, liquidity, amount *u256.Uint,
92) *u256.Uint {
93 var quotient *u256.Uint
94
95 if amount.Lte(max160) {
96 // Use local variables to avoid allocation conflicts
97 shifted := u256.Zero().Lsh(amount, Q96_RESOLUTION)
98 quotient = u256.Zero().Div(shifted, liquidity)
99 } else {
100 quotient = u256.MulDiv(amount, q96, liquidity)
101 }
102
103 result, overflow := u256.Zero().AddOverflow(sqrtPX96, quotient)
104 if overflow || result.Gt(max160) {
105 panic(errSqrtPriceOverflow)
106 }
107
108 return result
109}
110
111// getNextPriceAmount1Remove calculates the next sqrt price when removing token1,
112// preserving rounding-down logic for the final result.
113// This internal function handles the case where token1 is being removed from the pool.
114// Panics if sqrt price would exceed quotient.
115func getNextPriceAmount1Remove(
116 sqrtPX96, liquidity, amount *u256.Uint,
117) *u256.Uint {
118 var quotient *u256.Uint
119
120 if amount.Lte(max160) {
121 shifted := u256.Zero().Lsh(amount, Q96_RESOLUTION)
122 quotient = u256.DivRoundingUp(shifted, liquidity)
123 } else {
124 quotient = u256.MulDivRoundingUp(amount, q96, liquidity)
125 }
126
127 if !sqrtPX96.Gt(quotient) {
128 panic(errSqrtPriceExceedsQuotient)
129 }
130
131 return u256.Zero().Sub(sqrtPX96, quotient)
132}
133
134// getNextSqrtPriceFromAmount1RoundingDown calculates the next sqrt price based on token1 amount,
135// always rounding down to ensure conservative pricing in both exact output and exact input cases.
136// The add parameter determines whether liquidity is being added (true) or removed (false).
137func getNextSqrtPriceFromAmount1RoundingDown(
138 sqrtPX96,
139 liquidity,
140 amount *u256.Uint,
141 add bool,
142) *u256.Uint {
143 // Shortcut: if no amount, return original price
144 if amount.IsZero() {
145 return sqrtPX96
146 }
147
148 if add {
149 return getNextPriceAmount1Add(sqrtPX96, liquidity, amount)
150 }
151 return getNextPriceAmount1Remove(sqrtPX96, liquidity, amount)
152}
153
154// getNextSqrtPriceFromInput calculates the next sqrt price after adding tokens to the pool,
155// rounding up for conservative pricing in both swap directions.
156// The zeroForOne parameter indicates swap direction (token0 for token1 when true).
157// Panics if sqrtPX96 or liquidity is zero.
158func getNextSqrtPriceFromInput(
159 sqrtPX96, liquidity, amountIn *u256.Uint,
160 zeroForOne bool,
161) *u256.Uint {
162 if sqrtPX96.IsZero() {
163 panic(errSqrtPriceZero)
164 }
165
166 if liquidity.IsZero() {
167 panic(errLiquidityZero)
168 }
169
170 if zeroForOne {
171 return getNextSqrtPriceFromAmount0RoundingUp(sqrtPX96, liquidity, amountIn, true)
172 }
173
174 return getNextSqrtPriceFromAmount1RoundingDown(sqrtPX96, liquidity, amountIn, true)
175}
176
177// getNextSqrtPriceFromOutput calculates the next sqrt price after removing tokens from the pool,
178// using different rounding directions based on swap direction.
179// The zeroForOne parameter indicates swap direction (token0 for token1 when true).
180// Panics if sqrtPX96 or liquidity is zero.
181func getNextSqrtPriceFromOutput(
182 sqrtPX96, liquidity, amountOut *u256.Uint,
183 zeroForOne bool,
184) *u256.Uint {
185 if sqrtPX96.IsZero() {
186 panic(errSqrtPriceZero)
187 }
188
189 if liquidity.IsZero() {
190 panic(errLiquidityZero)
191 }
192
193 if zeroForOne {
194 return getNextSqrtPriceFromAmount1RoundingDown(sqrtPX96, liquidity, amountOut, false)
195 }
196
197 return getNextSqrtPriceFromAmount0RoundingUp(sqrtPX96, liquidity, amountOut, false)
198}
199
200// getAmount0DeltaHelper calculates the absolute token0 amount difference between two price ranges,
201// automatically swapping inputs to ensure correct ordering. The roundUp parameter controls
202// rounding direction for the final result to ensure conservative AMM calculations.
203// Panics if sqrtRatioAX96 is zero.
204func getAmount0DeltaHelper(
205 sqrtRatioAX96, sqrtRatioBX96, liquidity *u256.Uint,
206 roundUp bool,
207) *u256.Uint {
208 if sqrtRatioAX96.Gt(sqrtRatioBX96) {
209 sqrtRatioAX96, sqrtRatioBX96 = sqrtRatioBX96, sqrtRatioAX96
210 }
211
212 // Use local variables for thread safety
213 numerator := u256.Zero().Lsh(liquidity, Q96_RESOLUTION)
214 difference := u256.Zero().Sub(sqrtRatioBX96, sqrtRatioAX96)
215
216 if sqrtRatioAX96.IsZero() {
217 panic(errSqrtRatioAX96Zero)
218 }
219
220 if roundUp {
221 intermediate := u256.MulDivRoundingUp(numerator, difference, sqrtRatioBX96)
222 return u256.DivRoundingUp(intermediate, sqrtRatioAX96)
223 }
224
225 intermediate := u256.MulDiv(numerator, difference, sqrtRatioBX96)
226 return u256.Zero().Div(intermediate, sqrtRatioAX96)
227}
228
229// getAmount1DeltaHelper calculates the absolute token1 amount difference between two price ranges,
230// automatically swapping inputs to ensure correct ordering. The roundUp parameter controls
231// rounding direction for the final result to ensure conservative AMM calculations.
232func getAmount1DeltaHelper(
233 sqrtRatioAX96, sqrtRatioBX96, liquidity *u256.Uint,
234 roundUp bool,
235) *u256.Uint {
236 if sqrtRatioAX96.Gt(sqrtRatioBX96) {
237 sqrtRatioAX96, sqrtRatioBX96 = sqrtRatioBX96, sqrtRatioAX96
238 }
239
240 // amount1 = liquidity * (sqrtB - sqrtA) / 2^96
241 // Use local variable for thread safety
242 difference := u256.Zero().Sub(sqrtRatioBX96, sqrtRatioAX96)
243
244 if roundUp {
245 return u256.MulDivRoundingUp(liquidity, difference, q96)
246 }
247
248 return u256.MulDiv(liquidity, difference, q96)
249}
250
251// GetAmount0Delta calculates the token0 amount difference within a price range, returning
252// a signed int256 value that is negative when liquidity is negative. Rounds down for
253// negative liquidity and up for positive liquidity.
254//
255// Parameters:
256// - sqrtRatioAX96: first sqrt price in Q96 format
257// - sqrtRatioBX96: second sqrt price in Q96 format
258// - liquidity: signed liquidity value
259//
260// Returns the token0 amount difference as a signed int256 value.
261//
262// Panics if any input is nil or if the result overflows int256.
263func GetAmount0Delta(
264 sqrtRatioAX96, sqrtRatioBX96 *u256.Uint,
265 liquidity *i256.Int,
266) *i256.Int {
267 if sqrtRatioAX96 == nil || sqrtRatioBX96 == nil || liquidity == nil {
268 panic(errGetAmount0DeltaNilInput)
269 }
270
271 if liquidity.IsNeg() {
272 u := getAmount0DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), false)
273 if u.Gt(maxInt256) {
274 // if u > (2**255 - 1), cannot cast to int256
275 panic(errAmount0DeltaOverflow)
276 }
277
278 // Convert to i256 and negate properly
279 return i256.Zero().Neg(i256.FromUint256(u))
280 }
281
282 u := getAmount0DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), true)
283 if u.Gt(maxInt256) {
284 // if u > (2**255 - 1), cannot cast to int256
285 panic(errAmount0DeltaOverflow)
286 }
287
288 return i256.FromUint256(u)
289}
290
291// GetAmount1Delta calculates the token1 amount difference within a price range, returning
292// a signed int256 value that is negative when liquidity is negative. Rounds down for
293// negative liquidity and up for positive liquidity.
294//
295// Parameters:
296// - sqrtRatioAX96: first sqrt price in Q96 format
297// - sqrtRatioBX96: second sqrt price in Q96 format
298// - liquidity: signed liquidity value
299//
300// Returns the token1 amount difference as a signed int256 value.
301//
302// Panics if any input is nil or if the result overflows int256.
303func GetAmount1Delta(
304 sqrtRatioAX96, sqrtRatioBX96 *u256.Uint,
305 liquidity *i256.Int,
306) *i256.Int {
307 if sqrtRatioAX96 == nil || sqrtRatioBX96 == nil || liquidity == nil {
308 panic(errGetAmount1DeltaNilInput)
309 }
310
311 if liquidity.IsNeg() {
312 u := getAmount1DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), false)
313 if u.Gt(maxInt256) {
314 // if u > (2**255 - 1), cannot cast to int256
315 panic(errAmount1DeltaOverflow)
316 }
317
318 // Convert to i256 and negate properly
319 return i256.Zero().Neg(i256.FromUint256(u))
320 }
321
322 u := getAmount1DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), true)
323 if u.Gt(maxInt256) {
324 // if u > (2**255 - 1), cannot cast to int256
325 panic(errAmount1DeltaOverflow)
326 }
327
328 return i256.FromUint256(u)
329}