// Copyright 2020-2025 Consensys Software Inc.
// Licensed under the Apache License, Version 2.0. See the LICENSE file for details.

// Code generated by consensys/gnark-crypto DO NOT EDIT

package kzg

import (
	"errors"
	"hash"
	"math/big"
	"sync"

	"github.com/consensys/gnark-crypto/ecc"
	"github.com/consensys/gnark-crypto/ecc/bw6-761"
	"github.com/consensys/gnark-crypto/ecc/bw6-761/fr"
	"github.com/consensys/gnark-crypto/fiat-shamir"

	"github.com/consensys/gnark-crypto/internal/parallel"
)

var (
	ErrInvalidNbDigests              = errors.New("number of digests is not the same as the number of polynomials")
	ErrZeroNbDigests                 = errors.New("number of digests is zero")
	ErrInvalidPolynomialSize         = errors.New("invalid polynomial size (larger than SRS or == 0)")
	ErrVerifyOpeningProof            = errors.New("can't verify opening proof")
	ErrVerifyBatchOpeningSinglePoint = errors.New("can't verify batch opening proof at single point")
	ErrMinSRSSize                    = errors.New("minimum srs size is 2")
)

// Digest commitment of a polynomial.
type Digest = bw6761.G1Affine

// ProvingKey used to create or open commitments
type ProvingKey struct {
	G1 []bw6761.G1Affine // [G₁ [α]G₁ , [α²]G₁, ... ]
}

// VerifyingKey used to verify opening proofs
type VerifyingKey struct {
	G2    [2]bw6761.G2Affine // [G₂, [α]G₂ ]
	G1    bw6761.G1Affine
	Lines [2][2][len(bw6761.LoopCounter) - 1]bw6761.LineEvaluationAff // precomputed pairing lines corresponding to G₂, [α]G₂
}

// SRS must be computed through MPC and comprises the ProvingKey and the VerifyingKey
type SRS struct {
	Pk ProvingKey
	Vk VerifyingKey
}

// TODO @Tabaie get rid of this and use the polynomial package
// eval returns p(point) where p is interpreted as a polynomial
// ∑_{i<len(p)}p[i]Xⁱ
func eval(p []fr.Element, point fr.Element) fr.Element {
	var res fr.Element
	n := len(p)
	res.Set(&p[n-1])
	for i := n - 2; i >= 0; i-- {
		res.Mul(&res, &point).Add(&res, &p[i])
	}
	return res
}

// NewSRS returns a new SRS using alpha as randomness source
//
// In production, a SRS generated through MPC should be used.
//
// Set Alpha = -1 to generate quickly a balanced, valid SRS (useful for benchmarking).
//
// implements io.ReaderFrom and io.WriterTo
func NewSRS(size uint64, bAlpha *big.Int) (*SRS, error) {

	if size < 2 {
		return nil, ErrMinSRSSize
	}
	var srs SRS
	srs.Pk.G1 = make([]bw6761.G1Affine, size)

	var alpha fr.Element
	alpha.SetBigInt(bAlpha)

	var bMOne big.Int
	bMOne.SetInt64(-1)

	_, _, gen1Aff, gen2Aff := bw6761.Generators()

	// in this case, the SRS is <αⁱ[G₁]> whera α is of order 4
	// so no need to run the batch scalar multiplication. We cannot use alpha=1
	// because it tampers the benchmarks (the SRS is not balanced).
	if bAlpha.Cmp(&bMOne) == 0 {

		t, err := fr.Generator(4)
		if err != nil {
			return &srs, nil
		}
		var bt big.Int
		t.BigInt(&bt)

		var g [4]bw6761.G1Affine
		g[0] = gen1Aff
		for i := 1; i < 4; i++ {
			g[i].ScalarMultiplication(&g[i-1], &bt)
		}
		parallel.Execute(int(size), func(start, end int) {
			for i := start; i < int(end); i++ {
				srs.Pk.G1[i] = g[i%4]
			}
		})
		srs.Vk.G1 = gen1Aff
		srs.Vk.G2[0] = gen2Aff
		srs.Vk.G2[1].ScalarMultiplication(&srs.Vk.G2[0], &bt)
		srs.Vk.Lines[0] = bw6761.PrecomputeLines(srs.Vk.G2[0])
		srs.Vk.Lines[1] = bw6761.PrecomputeLines(srs.Vk.G2[1])
		return &srs, nil
	}
	srs.Pk.G1[0] = gen1Aff
	srs.Vk.G1 = gen1Aff
	srs.Vk.G2[0] = gen2Aff
	srs.Vk.G2[1].ScalarMultiplication(&gen2Aff, bAlpha)
	srs.Vk.Lines[0] = bw6761.PrecomputeLines(srs.Vk.G2[0])
	srs.Vk.Lines[1] = bw6761.PrecomputeLines(srs.Vk.G2[1])

	alphas := make([]fr.Element, size-1)
	alphas[0] = alpha
	for i := 1; i < len(alphas); i++ {
		alphas[i].Mul(&alphas[i-1], &alpha)
	}
	g1s := bw6761.BatchScalarMultiplicationG1(&gen1Aff, alphas)
	copy(srs.Pk.G1[1:], g1s)

	return &srs, nil
}

// OpeningProof KZG proof for opening at a single point.
//
// implements io.ReaderFrom and io.WriterTo
type OpeningProof struct {
	// H quotient polynomial (f - f(z))/(x-z)
	H bw6761.G1Affine

	// ClaimedValue purported value
	ClaimedValue fr.Element
}

// BatchOpeningProof opening proof for many polynomials at the same point
//
// implements io.ReaderFrom and io.WriterTo
type BatchOpeningProof struct {
	// H quotient polynomial Sum_i gamma**i*(f - f(z))/(x-z)
	H bw6761.G1Affine

	// ClaimedValues purported values
	ClaimedValues []fr.Element
}

// Commit commits to a polynomial using a multi exponentiation with the SRS.
// It is assumed that the polynomial is in canonical form, in Montgomery form.
func Commit(p []fr.Element, pk ProvingKey, nbTasks ...int) (Digest, error) {

	if len(p) == 0 || len(p) > len(pk.G1) {
		return Digest{}, ErrInvalidPolynomialSize
	}

	var res bw6761.G1Affine

	config := ecc.MultiExpConfig{}
	if len(nbTasks) > 0 {
		config.NbTasks = nbTasks[0]
	}
	if _, err := res.MultiExp(pk.G1[:len(p)], p, config); err != nil {
		return Digest{}, err
	}

	return res, nil
}

// Open computes an opening proof of polynomial p at given point.
// fft.Domain Cardinality must be larger than p.Degree()
func Open(p []fr.Element, point fr.Element, pk ProvingKey) (OpeningProof, error) {
	if len(p) == 0 || len(p) > len(pk.G1) {
		return OpeningProof{}, ErrInvalidPolynomialSize
	}

	// build the proof
	res := OpeningProof{
		ClaimedValue: eval(p, point),
	}

	// compute H
	// h reuses memory from _p
	_p := make([]fr.Element, len(p))
	copy(_p, p)
	h := dividePolyByXminusA(_p, res.ClaimedValue, point)

	// commit to H
	hCommit, err := Commit(h, pk)
	if err != nil {
		return OpeningProof{}, err
	}
	res.H.Set(&hCommit)

	return res, nil
}

// Verify verifies a KZG opening proof at a single point
func Verify(commitment *Digest, proof *OpeningProof, point fr.Element, vk VerifyingKey) error {

	// [f(a)]G₁ + [-a]([H(α)]G₁) = [f(a) - a*H(α)]G₁
	var totalG1 bw6761.G1Jac
	var pointNeg fr.Element
	var cmInt, pointInt big.Int
	proof.ClaimedValue.BigInt(&cmInt)
	pointNeg.Neg(&point).BigInt(&pointInt)
	totalG1.JointScalarMultiplication(&vk.G1, &proof.H, &cmInt, &pointInt)

	// [f(a) - a*H(α)]G₁ + [-f(α)]G₁  = [f(a) - f(α) - a*H(α)]G₁
	var commitmentJac bw6761.G1Jac
	commitmentJac.FromAffine(commitment)
	totalG1.SubAssign(&commitmentJac)

	// e([f(α)-f(a)+aH(α)]G₁], G₂).e([-H(α)]G₁, [α]G₂) == 1
	var totalG1Aff bw6761.G1Affine
	totalG1Aff.FromJacobian(&totalG1)
	check, err := bw6761.PairingCheckFixedQ(
		[]bw6761.G1Affine{totalG1Aff, proof.H},
		vk.Lines[:],
	)

	if err != nil {
		return err
	}
	if !check {
		return ErrVerifyOpeningProof
	}
	return nil
}

// BatchOpenSinglePoint creates a batch opening proof at point of a list of polynomials.
// It's an interactive protocol, made non-interactive using Fiat Shamir.
//
// * point is the point at which the polynomials are opened.
// * digests is the list of committed polynomials to open, need to derive the challenge using Fiat Shamir.
// * polynomials is the list of polynomials to open, they are supposed to be of the same size.
// * dataTranscript extra data that might be needed to derive the challenge used for folding
func BatchOpenSinglePoint(polynomials [][]fr.Element, digests []Digest, point fr.Element, hf hash.Hash, pk ProvingKey, dataTranscript ...[]byte) (BatchOpeningProof, error) {

	// check for invalid sizes
	nbDigests := len(digests)
	if nbDigests != len(polynomials) {
		return BatchOpeningProof{}, ErrInvalidNbDigests
	}

	// TODO ensure the polynomials are of the same size
	largestPoly := -1
	for _, p := range polynomials {
		if len(p) == 0 || len(p) > len(pk.G1) {
			return BatchOpeningProof{}, ErrInvalidPolynomialSize
		}
		if len(p) > largestPoly {
			largestPoly = len(p)
		}
	}

	var res BatchOpeningProof

	// compute the purported values
	res.ClaimedValues = make([]fr.Element, len(polynomials))
	var wg sync.WaitGroup
	wg.Add(len(polynomials))
	for i := 0; i < len(polynomials); i++ {
		go func(_i int) {
			res.ClaimedValues[_i] = eval(polynomials[_i], point)
			wg.Done()
		}(i)
	}

	// wait for polynomial evaluations to be completed (res.ClaimedValues)
	wg.Wait()

	// derive the challenge γ, binded to the point and the commitments
	gamma, err := deriveGamma(point, digests, res.ClaimedValues, hf, dataTranscript...)
	if err != nil {
		return BatchOpeningProof{}, err
	}

	// ∑ᵢγⁱf(a)
	var foldedEvaluations fr.Element
	chSumGammai := make(chan struct{}, 1)
	go func() {
		foldedEvaluations = res.ClaimedValues[nbDigests-1]
		for i := nbDigests - 2; i >= 0; i-- {
			foldedEvaluations.Mul(&foldedEvaluations, &gamma).
				Add(&foldedEvaluations, &res.ClaimedValues[i])
		}
		close(chSumGammai)
	}()

	// compute ∑ᵢγⁱfᵢ
	// note: if we are willing to parallelize that, we could clone the poly and scale them by
	// gamma n in parallel, before reducing into foldedPolynomials
	foldedPolynomials := make([]fr.Element, largestPoly)
	copy(foldedPolynomials, polynomials[0])
	gammas := make([]fr.Element, len(polynomials))
	gammas[0] = gamma
	for i := 1; i < len(polynomials); i++ {
		gammas[i].Mul(&gammas[i-1], &gamma)
	}

	for i := 1; i < len(polynomials); i++ {
		i := i
		parallel.Execute(len(polynomials[i]), func(start, end int) {
			var pj fr.Element
			for j := start; j < end; j++ {
				pj.Mul(&polynomials[i][j], &gammas[i-1])
				foldedPolynomials[j].Add(&foldedPolynomials[j], &pj)
			}
		})
	}

	// compute H
	<-chSumGammai
	h := dividePolyByXminusA(foldedPolynomials, foldedEvaluations, point)
	foldedPolynomials = nil // same memory as h

	res.H, err = Commit(h, pk)
	if err != nil {
		return BatchOpeningProof{}, err
	}

	return res, nil
}

// FoldProof fold the digests and the proofs in batchOpeningProof using Fiat Shamir
// to obtain an opening proof at a single point.
//
// * digests list of digests on which batchOpeningProof is based
// * batchOpeningProof opening proof of digests
// * transcript extra data needed to derive the challenge used for folding.
// * returns the folded version of batchOpeningProof, Digest, the folded version of digests
func FoldProof(digests []Digest, batchOpeningProof *BatchOpeningProof, point fr.Element, hf hash.Hash, dataTranscript ...[]byte) (OpeningProof, Digest, error) {

	nbDigests := len(digests)

	// check consistency between numbers of claims vs number of digests
	if nbDigests != len(batchOpeningProof.ClaimedValues) {
		return OpeningProof{}, Digest{}, ErrInvalidNbDigests
	}

	// derive the challenge γ, binded to the point and the commitments
	gamma, err := deriveGamma(point, digests, batchOpeningProof.ClaimedValues, hf, dataTranscript...)
	if err != nil {
		return OpeningProof{}, Digest{}, ErrInvalidNbDigests
	}

	// fold the claimed values and digests
	// gammai = [1,γ,γ²,..,γⁿ⁻¹]
	gammai := make([]fr.Element, nbDigests)
	gammai[0].SetOne()
	if nbDigests > 1 {
		gammai[1] = gamma
	}
	for i := 2; i < nbDigests; i++ {
		gammai[i].Mul(&gammai[i-1], &gamma)
	}

	foldedDigests, foldedEvaluations, err := fold(digests, batchOpeningProof.ClaimedValues, gammai)
	if err != nil {
		return OpeningProof{}, Digest{}, err
	}

	// create the folded opening proof
	var res OpeningProof
	res.ClaimedValue.Set(&foldedEvaluations)
	res.H.Set(&batchOpeningProof.H)

	return res, foldedDigests, nil
}

// BatchVerifySinglePoint verifies a batched opening proof at a single point of a list of polynomials.
//
// * digests list of digests on which opening proof is done
// * batchOpeningProof proof of correct opening on the digests
// * dataTranscript extra data that might be needed to derive the challenge used for the folding
func BatchVerifySinglePoint(digests []Digest, batchOpeningProof *BatchOpeningProof, point fr.Element, hf hash.Hash, vk VerifyingKey, dataTranscript ...[]byte) error {

	// fold the proof
	foldedProof, foldedDigest, err := FoldProof(digests, batchOpeningProof, point, hf, dataTranscript...)
	if err != nil {
		return err
	}

	// verify the foldedProof against the foldedDigest
	err = Verify(&foldedDigest, &foldedProof, point, vk)
	return err

}

// BatchVerifyMultiPoints batch verifies a list of opening proofs at different points.
// The purpose of the batching is to have only one pairing for verifying several proofs.
//
// * digests list of committed polynomials
// * proofs list of opening proofs, one for each digest
// * points the list of points at which the opening are done
func BatchVerifyMultiPoints(digests []Digest, proofs []OpeningProof, points []fr.Element, vk VerifyingKey) error {

	// check consistency nb proogs vs nb digests
	if len(digests) != len(proofs) || len(digests) != len(points) {
		return ErrInvalidNbDigests
	}

	// len(digests) should be nonzero because of randomNumbers
	if len(digests) == 0 {
		return ErrZeroNbDigests
	}

	// if only one digest, call Verify
	if len(digests) == 1 {
		return Verify(&digests[0], &proofs[0], points[0], vk)
	}

	// sample random numbers λᵢ for sampling
	randomNumbers := make([]fr.Element, len(digests))
	randomNumbers[0].SetOne()
	for i := 1; i < len(randomNumbers); i++ {
		if _, err := randomNumbers[i].SetRandom(); err != nil {
			return err
		}
	}

	// fold the committed quotients compute ∑ᵢλᵢ[Hᵢ(α)]G₁
	var foldedQuotients bw6761.G1Affine
	quotients := make([]bw6761.G1Affine, len(proofs))
	for i := 0; i < len(randomNumbers); i++ {
		quotients[i].Set(&proofs[i].H)
	}
	config := ecc.MultiExpConfig{}
	if _, err := foldedQuotients.MultiExp(quotients, randomNumbers, config); err != nil {
		return err
	}

	// fold digests and evals
	evals := make([]fr.Element, len(digests))
	for i := 0; i < len(randomNumbers); i++ {
		evals[i].Set(&proofs[i].ClaimedValue)
	}

	// fold the digests: ∑ᵢλᵢ[f_i(α)]G₁
	// fold the evals  : ∑ᵢλᵢfᵢ(aᵢ)
	foldedDigests, foldedEvals, err := fold(digests, evals, randomNumbers)
	if err != nil {
		return err
	}

	// compute commitment to folded Eval  [∑ᵢλᵢfᵢ(aᵢ)]G₁
	var foldedEvalsCommit bw6761.G1Affine
	var foldedEvalsBigInt big.Int
	foldedEvals.BigInt(&foldedEvalsBigInt)
	foldedEvalsCommit.ScalarMultiplication(&vk.G1, &foldedEvalsBigInt)

	// compute foldedDigests = ∑ᵢλᵢ[fᵢ(α)]G₁ - [∑ᵢλᵢfᵢ(aᵢ)]G₁
	foldedDigests.Sub(&foldedDigests, &foldedEvalsCommit)

	// combien the points and the quotients using γᵢ
	// ∑ᵢλᵢ[p_i]([Hᵢ(α)]G₁)
	var foldedPointsQuotients bw6761.G1Affine
	for i := 0; i < len(randomNumbers); i++ {
		randomNumbers[i].Mul(&randomNumbers[i], &points[i])
	}
	_, err = foldedPointsQuotients.MultiExp(quotients, randomNumbers, config)
	if err != nil {
		return err
	}

	// ∑ᵢλᵢ[f_i(α)]G₁ - [∑ᵢλᵢfᵢ(aᵢ)]G₁ + ∑ᵢλᵢ[p_i]([Hᵢ(α)]G₁)
	// = [∑ᵢλᵢf_i(α) - ∑ᵢλᵢfᵢ(aᵢ) + ∑ᵢλᵢpᵢHᵢ(α)]G₁
	foldedDigests.Add(&foldedDigests, &foldedPointsQuotients)

	// -∑ᵢλᵢ[Qᵢ(α)]G₁
	foldedQuotients.Neg(&foldedQuotients)

	// pairing check
	// e([∑ᵢλᵢ(fᵢ(α) - fᵢ(pᵢ) + pᵢHᵢ(α))]G₁, G₂).e([-∑ᵢλᵢ[Hᵢ(α)]G₁), [α]G₂)
	check, err := bw6761.PairingCheckFixedQ(
		[]bw6761.G1Affine{foldedDigests, foldedQuotients},
		vk.Lines[:],
	)
	if err != nil {
		return err
	}
	if !check {
		return ErrVerifyOpeningProof
	}
	return nil

}

// fold folds digests and evaluations using the list of factors as random numbers.
//
// * digests list of digests to fold
// * evaluations list of evaluations to fold
// * factors list of multiplicative factors used for the folding (in Montgomery form)
//
// * Returns ∑ᵢcᵢdᵢ, ∑ᵢcᵢf(aᵢ)
func fold(di []Digest, fai []fr.Element, ci []fr.Element) (Digest, fr.Element, error) {

	// length inconsistency between digests and evaluations should have been done before calling this function
	nbDigests := len(di)

	// fold the claimed values ∑ᵢcᵢf(aᵢ)
	var foldedEvaluations, tmp fr.Element
	for i := 0; i < nbDigests; i++ {
		tmp.Mul(&fai[i], &ci[i])
		foldedEvaluations.Add(&foldedEvaluations, &tmp)
	}

	// fold the digests ∑ᵢ[cᵢ]([fᵢ(α)]G₁)
	var foldedDigests Digest
	_, err := foldedDigests.MultiExp(di, ci, ecc.MultiExpConfig{})
	if err != nil {
		return foldedDigests, foldedEvaluations, err
	}

	// folding done
	return foldedDigests, foldedEvaluations, nil

}

// deriveGamma derives a challenge using Fiat Shamir to fold proofs.
func deriveGamma(point fr.Element, digests []Digest, claimedValues []fr.Element, hf hash.Hash, dataTranscript ...[]byte) (fr.Element, error) {

	// derive the challenge gamma, binded to the point and the commitments
	fs := fiatshamir.NewTranscript(hf, "gamma")
	if err := fs.Bind("gamma", point.Marshal()); err != nil {
		return fr.Element{}, err
	}
	for i := range digests {
		if err := fs.Bind("gamma", digests[i].Marshal()); err != nil {
			return fr.Element{}, err
		}
	}
	for i := range claimedValues {
		if err := fs.Bind("gamma", claimedValues[i].Marshal()); err != nil {
			return fr.Element{}, err
		}
	}

	for i := 0; i < len(dataTranscript); i++ {
		if err := fs.Bind("gamma", dataTranscript[i]); err != nil {
			return fr.Element{}, err
		}
	}

	gammaByte, err := fs.ComputeChallenge("gamma")
	if err != nil {
		return fr.Element{}, err
	}
	var gamma fr.Element
	gamma.SetBytes(gammaByte)

	return gamma, nil
}

// dividePolyByXminusA computes (f-f(a))/(x-a), in canonical basis, in regular form
// f memory is re-used for the result
func dividePolyByXminusA(f []fr.Element, fa, a fr.Element) []fr.Element {

	// first we compute f-f(a)
	f[0].Sub(&f[0], &fa)

	// now we use synthetic division to divide by x-a
	var t fr.Element
	for i := len(f) - 2; i >= 0; i-- {
		t.Mul(&f[i+1], &a)

		f[i].Add(&f[i], &t)
	}

	// the result is of degree deg(f)-1
	return f[1:]
}
