/*
 * Decompiled with CFR 0.152.
 */
package org.bouncycastle.crypto.kems;

import java.math.BigInteger;
import org.bouncycastle.crypto.Digest;
import org.bouncycastle.crypto.EncapsulatedSecretExtractor;
import org.bouncycastle.crypto.kems.SAKKEKEMSGenerator;
import org.bouncycastle.crypto.params.SAKKEPrivateKeyParameters;
import org.bouncycastle.crypto.params.SAKKEPublicKeyParameters;
import org.bouncycastle.math.ec.ECCurve;
import org.bouncycastle.math.ec.ECPoint;
import org.bouncycastle.util.Arrays;
import org.bouncycastle.util.BigIntegers;

public class SAKKEKEMExtractor
implements EncapsulatedSecretExtractor {
    private final ECCurve curve;
    private final BigInteger p;
    private final BigInteger q;
    private final ECPoint P;
    private final ECPoint Z_S;
    private final ECPoint K_bs;
    private final int n;
    private final BigInteger identifier;
    private final Digest digest;

    public SAKKEKEMExtractor(SAKKEPrivateKeyParameters privateKey) {
        SAKKEPublicKeyParameters publicKey = privateKey.getPublicParams();
        this.curve = publicKey.getCurve();
        this.q = publicKey.getQ();
        this.P = publicKey.getPoint();
        this.p = publicKey.getPrime();
        this.Z_S = publicKey.getZ();
        this.identifier = publicKey.getIdentifier();
        this.K_bs = this.P.multiply(this.identifier.add(privateKey.getMasterSecret()).modInverse(this.q)).normalize();
        this.n = publicKey.getN();
        this.digest = publicKey.getDigest();
    }

    @Override
    public byte[] extractSecret(byte[] encapsulation) {
        ECPoint R_bS = this.curve.decodePoint(Arrays.copyOfRange(encapsulation, 0, 257));
        BigInteger H = BigIntegers.fromUnsignedByteArray(encapsulation, 257, 16);
        BigInteger w = SAKKEKEMExtractor.computePairing(R_bS, this.K_bs, this.p, this.q);
        BigInteger twoToN = BigInteger.ONE.shiftLeft(this.n);
        BigInteger mask = SAKKEKEMSGenerator.hashToIntegerRange(w.toByteArray(), twoToN, this.digest);
        BigInteger ssv = H.xor(mask).mod(this.p);
        BigInteger b = this.identifier;
        BigInteger r = SAKKEKEMSGenerator.hashToIntegerRange(Arrays.concatenate(ssv.toByteArray(), b.toByteArray()), this.q, this.digest);
        ECPoint bP = this.P.multiply(b).normalize();
        ECPoint Test2 = bP.add(this.Z_S).multiply(r).normalize();
        if (!R_bS.equals(Test2)) {
            throw new IllegalStateException("Validation of R_bS failed");
        }
        return BigIntegers.asUnsignedByteArray(this.n / 8, ssv);
    }

    @Override
    public int getEncapsulationLength() {
        return 273;
    }

    static BigInteger computePairing(ECPoint R, ECPoint Q, BigInteger p, BigInteger q) {
        BigInteger[] v = new BigInteger[]{BigInteger.ONE, BigInteger.ZERO};
        ECPoint C = R;
        BigInteger qMinusOne = q.subtract(BigInteger.ONE);
        int numBits = qMinusOne.bitLength();
        BigInteger Qx = Q.getAffineXCoord().toBigInteger();
        BigInteger Qy = Q.getAffineYCoord().toBigInteger();
        BigInteger Rx = R.getAffineXCoord().toBigInteger();
        BigInteger Ry = R.getAffineYCoord().toBigInteger();
        BigInteger three = BigInteger.valueOf(3L);
        BigInteger two = BigInteger.valueOf(2L);
        for (int i = numBits - 2; i >= 0; --i) {
            BigInteger Cx = C.getAffineXCoord().toBigInteger();
            BigInteger Cy = C.getAffineYCoord().toBigInteger();
            BigInteger l = three.multiply(Cx.multiply(Cx).subtract(BigInteger.ONE)).multiply(Cy.multiply(two).modInverse(p)).mod(p);
            v = SAKKEKEMExtractor.fp2PointSquare(v[0], v[1], p);
            v = SAKKEKEMExtractor.fp2Multiply(v[0], v[1], l.multiply(Qx.add(Cx)).subtract(Cy), Qy, p);
            C = C.twice().normalize();
            if (!qMinusOne.testBit(i)) continue;
            Cx = C.getAffineXCoord().toBigInteger();
            Cy = C.getAffineYCoord().toBigInteger();
            l = Cy.subtract(Ry).multiply(Cx.subtract(Rx).modInverse(p)).mod(p);
            v = SAKKEKEMExtractor.fp2Multiply(v[0], v[1], l.multiply(Qx.add(Cx)).subtract(Cy), Qy, p);
            C = C.add(R).normalize();
        }
        v = SAKKEKEMExtractor.fp2PointSquare(v[0], v[1], p);
        v = SAKKEKEMExtractor.fp2PointSquare(v[0], v[1], p);
        return v[1].multiply(v[0].modInverse(p)).mod(p);
    }

    static BigInteger[] fp2Multiply(BigInteger x_real, BigInteger x_imag, BigInteger y_real, BigInteger y_imag, BigInteger p) {
        return new BigInteger[]{x_real.multiply(y_real).subtract(x_imag.multiply(y_imag)).mod(p), x_real.multiply(y_imag).add(x_imag.multiply(y_real)).mod(p)};
    }

    static BigInteger[] fp2PointSquare(BigInteger currentX, BigInteger currentY, BigInteger p) {
        BigInteger xPlusY = currentX.add(currentY).mod(p);
        BigInteger xMinusY = currentX.subtract(currentY).mod(p);
        BigInteger newX = xPlusY.multiply(xMinusY).mod(p);
        BigInteger newY = currentX.multiply(currentY).multiply(BigInteger.valueOf(2L)).mod(p);
        return new BigInteger[]{newX, newY};
    }
}

