package cat.gencat.ctti.canigo.arch.security.saml.validation;

import cat.gencat.ctti.canigo.arch.security.saml.validation.util.SAMLAttributes;
import cat.gencat.ctti.canigo.arch.security.saml.validation.util.SAMLObjectPrint;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.security.Provider;
import java.security.Security;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.TreeSet;
import org.apache.commons.io.IOUtils;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.bouncycastle.util.encoders.Base64;
import org.joda.time.DateTime;
import org.opensaml.Configuration;
import org.opensaml.DefaultBootstrap;
import org.opensaml.common.SAMLObject;
import org.opensaml.saml2.core.Assertion;
import org.opensaml.saml2.core.Attribute;
import org.opensaml.saml2.core.AttributeStatement;
import org.opensaml.saml2.core.Audience;
import org.opensaml.saml2.core.AudienceRestriction;
import org.opensaml.saml2.core.Conditions;
import org.opensaml.saml2.core.EncryptedAssertion;
import org.opensaml.saml2.core.Response;
import org.opensaml.saml2.encryption.Decrypter;
import org.opensaml.security.SAMLSignatureProfileValidator;
import org.opensaml.xml.encryption.DecryptionException;
import org.opensaml.xml.encryption.InlineEncryptedKeyResolver;
import org.opensaml.xml.parse.BasicParserPool;
import org.opensaml.xml.schema.XSAny;
import org.opensaml.xml.schema.XSString;
import org.opensaml.xml.security.credential.Credential;
import org.opensaml.xml.security.keyinfo.KeyInfoCredentialResolver;
import org.opensaml.xml.security.keyinfo.StaticKeyInfoCredentialResolver;
import org.opensaml.xml.signature.Signature;
import org.opensaml.xml.signature.SignatureValidator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.w3c.dom.Element;

/* loaded from: input_file:cat/gencat/ctti/canigo/arch/security/saml/validation/SAMLValidator.class */
public class SAMLValidator {
    private static final Logger logger = LoggerFactory.getLogger(SAMLValidator.class);
    private Integer extraValidityMinutes;
    private Assertion validatedAssertion;

    public static void init() {
        addRequiredProviders();
        printAlgorithms();
        initOpenSAML();
    }

    public Integer getExtraValidityMinutes() {
        return this.extraValidityMinutes;
    }

    public void setExtraValidityMinutes(Integer num) {
        this.extraValidityMinutes = num;
    }

    public Assertion getValidatedAssertion() {
        return this.validatedAssertion;
    }

    public void setValidatedAssertion(Assertion assertion) {
        this.validatedAssertion = assertion;
    }

    private static void addRequiredProviders() {
        Security.addProvider(new BouncyCastleProvider());
    }

    private static void printAlgorithms() {
        TreeSet treeSet = new TreeSet();
        for (Provider provider : Security.getProviders()) {
            provider.getServices().stream().filter(service -> {
                return "Signature".equals(service.getType());
            }).forEach(service2 -> {
                treeSet.add(service2.getAlgorithm());
            });
        }
        StringBuilder sb = new StringBuilder("Available signature algorithms:");
        treeSet.forEach(str -> {
            sb.append('\n').append(str);
        });
        logger.info("algorithms: {}", sb);
    }

    private static void initOpenSAML() {
        try {
            DefaultBootstrap.bootstrap();
        } catch (Exception e) {
            logger.error("Error: ", e);
        }
    }

    public Response getSamlResponseBase64(String str) {
        return getSamlResponse(new String(Base64.decode(str), StandardCharsets.UTF_8));
    }

    public Response getSamlResponse(String str) {
        if (logger.isDebugEnabled()) {
            logger.debug("samlResponse {}", str);
        }
        return decodeMessage(IOUtils.toInputStream(str, StandardCharsets.UTF_8));
    }

    public Assertion getAssertion(Response response) {
        if (response != null && response.getAssertions() != null && !response.getAssertions().isEmpty()) {
            return (Assertion) response.getAssertions().get(0);
        }
        logger.error("Response contains no assertions");
        return null;
    }

    public Assertion getSamlAssertionBase64(String str) {
        return getSamlAssertion(new String(Base64.decode(str), StandardCharsets.UTF_8));
    }

    public Assertion getSamlAssertion(String str) {
        if (logger.isDebugEnabled()) {
            logger.debug("samlAssertion {}", str);
        }
        return decodeMessage(IOUtils.toInputStream(str, StandardCharsets.UTF_8));
    }

    public Signature getSamlSignatureBase64(String str) {
        return getSamlSignature(new String(Base64.decode(str), StandardCharsets.UTF_8));
    }

    public Signature getSamlSignature(String str) {
        if (logger.isDebugEnabled()) {
            logger.debug("samlSignature {}", str);
        }
        return decodeMessage(IOUtils.toInputStream(str, StandardCharsets.UTF_8));
    }

    public EncryptedAssertion getSamlEncryptedAssertionBase64(String str) {
        return getSamlEncryptedAssertion(new String(Base64.decode(str), StandardCharsets.UTF_8));
    }

    public EncryptedAssertion getSamlEncryptedAssertion(String str) {
        if (logger.isDebugEnabled()) {
            logger.debug("samlEncryptedAssertion {}", str);
        }
        return decodeMessage(IOUtils.toInputStream(str, StandardCharsets.UTF_8));
    }

    public SAMLValidatorResult validate(Response response, Credential credential, Credential credential2, String str) {
        SAMLValidatorResult sAMLValidatorResult = null;
        setValidatedAssertion(null);
        try {
            Iterator<Assertion> it = getResponseAssertions(response, credential).iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                Assertion next = it.next();
                sAMLValidatorResult = validate(next, credential2, str);
                if (sAMLValidatorResult.isOk()) {
                    setValidatedAssertion(next);
                    break;
                }
            }
        } catch (Exception e) {
            sAMLValidatorResult = new SAMLValidatorResult();
            sAMLValidatorResult.setOk(false);
            sAMLValidatorResult.setErrMsg(e.getMessage());
            logger.error(e.getMessage(), e);
        }
        if (sAMLValidatorResult == null) {
            throw new SAMLValidatorException("No assertions found in SAML response");
        }
        return sAMLValidatorResult;
    }

    public SAMLValidatorResult validate(Assertion assertion, Credential credential, String str) {
        SAMLValidatorResult sAMLValidatorResult = null;
        try {
            sAMLValidatorResult = getAssertionData(assertion);
            if (assertion.getSignature() != null) {
                new SAMLSignatureProfileValidator().validate(assertion.getSignature());
                new SignatureValidator(credential).validate(assertion.getSignature());
            }
            checkConditions(sAMLValidatorResult, str);
            sAMLValidatorResult.setOk(true);
        } catch (Exception e) {
            if (sAMLValidatorResult == null) {
                sAMLValidatorResult = new SAMLValidatorResult();
            }
            sAMLValidatorResult.setOk(false);
            sAMLValidatorResult.setErrMsg(e.getMessage());
            logger.error(e.getMessage(), e);
        }
        return sAMLValidatorResult;
    }

    private List<Assertion> getResponseAssertions(Response response, Credential credential) throws DecryptionException {
        ArrayList arrayList = new ArrayList();
        for (Assertion assertion : response.getAssertions()) {
            if (logger.isDebugEnabled()) {
                logger.debug(SAMLObjectPrint.assertion(assertion));
            }
            arrayList.add(assertion);
        }
        for (EncryptedAssertion encryptedAssertion : response.getEncryptedAssertions()) {
            if (logger.isDebugEnabled()) {
                logger.debug(encryptedAssertion.toString());
            }
            Assertion decryptAssertion = decryptAssertion(encryptedAssertion, credential);
            if (logger.isDebugEnabled()) {
                logger.debug(SAMLObjectPrint.assertion(decryptAssertion));
            }
            arrayList.add(decryptAssertion);
        }
        return arrayList;
    }

    public Assertion decryptAssertion(EncryptedAssertion encryptedAssertion, Credential credential) throws DecryptionException {
        ArrayList arrayList = new ArrayList();
        arrayList.add(credential);
        Decrypter decrypter = new Decrypter((KeyInfoCredentialResolver) null, new StaticKeyInfoCredentialResolver(arrayList), new InlineEncryptedKeyResolver());
        decrypter.setRootInNewDocument(true);
        return decrypter.decrypt(encryptedAssertion);
    }

    private void checkConditions(SAMLValidatorResult sAMLValidatorResult, String str) throws SAMLValidatorException {
        if (!str.equals(sAMLValidatorResult.getAudience())) {
            throw new SAMLValidatorException(String.format("Audience %s of the assertion is not the intended %s entityId.", sAMLValidatorResult.getAudience(), str));
        }
        DateTime dateTime = new DateTime();
        DateTime notOnOrAfter = sAMLValidatorResult.getNotOnOrAfter();
        if (this.extraValidityMinutes != null) {
            notOnOrAfter = notOnOrAfter.plusMinutes(this.extraValidityMinutes.intValue());
        }
        if (dateTime.isAfter(notOnOrAfter) || dateTime.equals(notOnOrAfter)) {
            throw new SAMLValidatorException(String.format("Assertion is not valid on or after %s (%d), now is %s (%d), nooa-now=%d", notOnOrAfter.toString(), Long.valueOf(notOnOrAfter.getMillis()), dateTime.toString(), Long.valueOf(dateTime.getMillis()), Long.valueOf(notOnOrAfter.getMillis() - dateTime.getMillis())));
        }
        DateTime notBefore = sAMLValidatorResult.getNotBefore();
        if (dateTime.isBefore(notBefore)) {
            throw new SAMLValidatorException(String.format("Assertion is not valid before %s (%d), now is %s (%d), now-nbf=%d", notBefore.toString(), Long.valueOf(notBefore.getMillis()), dateTime.toString(), Long.valueOf(dateTime.getMillis()), Long.valueOf(dateTime.getMillis() - notBefore.getMillis())));
        }
    }

    protected SAMLObject decodeMessage(InputStream inputStream) {
        try {
            Element documentElement = new BasicParserPool().parse(inputStream).getDocumentElement();
            return Configuration.getUnmarshallerFactory().getUnmarshaller(documentElement).unmarshall(documentElement);
        } catch (Exception e) {
            logger.error(e.getMessage(), e);
            return null;
        }
    }

    public SAMLValidatorResult getAssertionData(Assertion assertion) {
        SAMLValidatorResult sAMLValidatorResult = new SAMLValidatorResult();
        sAMLValidatorResult.setAttributes(new HashMap());
        sAMLValidatorResult.setSubject(assertion.getSubject().getNameID().getValue());
        sAMLValidatorResult.setSubjectFormat(assertion.getSubject().getNameID().getFormat());
        sAMLValidatorResult.setIssuer(assertion.getIssuer().getValue());
        sAMLValidatorResult.setAudience(((Audience) ((AudienceRestriction) assertion.getConditions().getAudienceRestrictions().get(0)).getAudiences().get(0)).getAudienceURI());
        Conditions conditions = assertion.getConditions();
        sAMLValidatorResult.setNotBefore(conditions.getNotBefore());
        sAMLValidatorResult.setNotOnOrAfter(conditions.getNotOnOrAfter());
        for (AttributeStatement attributeStatement : assertion.getStatements()) {
            if (attributeStatement instanceof AttributeStatement) {
                for (Attribute attribute : attributeStatement.getAttributes()) {
                    sAMLValidatorResult.getAttributes().put(SAMLAttributes.getAttributeIdentifier(attribute), getAttributeValues(attribute));
                }
            }
        }
        return sAMLValidatorResult;
    }

    private List<String> getAttributeValues(Attribute attribute) {
        LinkedList linkedList = new LinkedList();
        for (XSString xSString : attribute.getAttributeValues()) {
            if (xSString instanceof XSAny) {
                linkedList.add(((XSAny) xSString).getTextContent());
            } else if (xSString instanceof XSString) {
                linkedList.add(xSString.getValue());
            }
        }
        return linkedList;
    }
}
