/*
 * Decompiled with CFR 0.152.
 */
package de.rub.nds.tlsattacker.attacks.pkcs1;

import de.rub.nds.modifiablevariable.util.ArrayConverter;
import de.rub.nds.tlsattacker.attacks.config.BleichenbacherCommandConfig;
import de.rub.nds.tlsattacker.attacks.pkcs1.Pkcs1Vector;
import de.rub.nds.tlsattacker.core.constants.ProtocolVersion;
import de.rub.nds.tlsattacker.core.exceptions.ConfigurationException;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.security.interfaces.RSAPublicKey;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.NoSuchPaddingException;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class Pkcs1VectorGenerator {
    private static final Logger LOGGER = LogManager.getLogger();

    public static List<Pkcs1Vector> generatePkcs1Vectors(RSAPublicKey publicKey, BleichenbacherCommandConfig.Type type, ProtocolVersion protocolVersion) {
        List<Pkcs1Vector> encryptedVectors = Pkcs1VectorGenerator.generatePlainPkcs1Vectors(publicKey.getModulus().bitLength(), type, protocolVersion);
        try {
            Cipher rsa = Cipher.getInstance("RSA/NONE/NoPadding");
            rsa.init(1, publicKey);
            for (Pkcs1Vector vector : encryptedVectors) {
                byte[] encrypted = rsa.doFinal(vector.getPlainValue());
                vector.setEncryptedValue(encrypted);
            }
            return encryptedVectors;
        }
        catch (InvalidKeyException | NoSuchAlgorithmException | BadPaddingException | IllegalBlockSizeException | NoSuchPaddingException ex) {
            throw new ConfigurationException("The different PKCS#1 attack vectors could not be generated.", ex);
        }
    }

    public static Pkcs1Vector generateCorrectPkcs1Vector(RSAPublicKey publicKey, ProtocolVersion protocolVersion) {
        Pkcs1Vector encryptedVector = Pkcs1VectorGenerator.getPlainCorrect(publicKey.getModulus().bitLength(), protocolVersion);
        try {
            Cipher rsa = Cipher.getInstance("RSA/NONE/NoPadding");
            rsa.init(1, publicKey);
            byte[] encrypted = rsa.doFinal(encryptedVector.getPlainValue());
            encryptedVector.setEncryptedValue(encrypted);
            return encryptedVector;
        }
        catch (InvalidKeyException | NoSuchAlgorithmException | BadPaddingException | IllegalBlockSizeException | NoSuchPaddingException ex) {
            throw new ConfigurationException("The PKCS#1 attack vectors could not be generated.", ex);
        }
    }

    public static List<Pkcs1Vector> generatePlainPkcs1Vectors(int publicKeyBitLength, BleichenbacherCommandConfig.Type type, ProtocolVersion protocolVersion) {
        byte[] keyBytes = new byte[48];
        Arrays.fill(keyBytes, (byte)42);
        keyBytes[0] = protocolVersion.getMajor();
        keyBytes[1] = protocolVersion.getMinor();
        int publicKeyByteLength = publicKeyBitLength / 8;
        LinkedList<Pkcs1Vector> pkcs1Vectors = new LinkedList<Pkcs1Vector>();
        pkcs1Vectors.add(new Pkcs1Vector("Correctly formatted PKCS#1 PMS message", Pkcs1VectorGenerator.getPaddedKey(publicKeyByteLength, keyBytes)));
        pkcs1Vectors.add(new Pkcs1Vector("Wrong first byte (0x00 set to 0x17)", Pkcs1VectorGenerator.getEK_WrongFirstByte(publicKeyByteLength, keyBytes)));
        pkcs1Vectors.add(new Pkcs1Vector("Wrong second byte (0x02 set to 0x17)", Pkcs1VectorGenerator.getEK_WrongSecondByte(publicKeyByteLength, keyBytes)));
        pkcs1Vectors.add(new Pkcs1Vector("Invalid TLS version in PMS", Pkcs1VectorGenerator.getEK_WrongTlsVersion(publicKeyByteLength, keyBytes)));
        pkcs1Vectors.add(new Pkcs1Vector("Correctly formatted PKCS#1 PMS message, but 1 byte shorter", Pkcs1VectorGenerator.getPaddedKey(publicKeyByteLength - 1, keyBytes)));
        pkcs1Vectors.add(new Pkcs1Vector("No 0x00 in message", Pkcs1VectorGenerator.getEK_NoNullByte(publicKeyByteLength, keyBytes)));
        pkcs1Vectors.add(new Pkcs1Vector("0x00 in PKCS#1 padding (first 8 bytes after 0x00 0x02)", Pkcs1VectorGenerator.getEK_NullByteInPkcsPadding(publicKeyByteLength, keyBytes)));
        pkcs1Vectors.add(new Pkcs1Vector("0x00 in some padding byte", Pkcs1VectorGenerator.getEK_NullByteInPadding(publicKeyByteLength, keyBytes)));
        pkcs1Vectors.add(new Pkcs1Vector("0x00 on the last position  (|PMS| = 0)", Pkcs1VectorGenerator.getEK_SymmetricKeyOfSize(publicKeyByteLength, keyBytes, 0)));
        pkcs1Vectors.add(new Pkcs1Vector("0x00 on the next to last position (|PMS| = 1)", Pkcs1VectorGenerator.getEK_SymmetricKeyOfSize(publicKeyByteLength, keyBytes, 1)));
        pkcs1Vectors.add(new Pkcs1Vector("Correctly formatted PKCS#1 message, (|PMS| = 47)", Pkcs1VectorGenerator.getPaddedKey(publicKeyByteLength, Arrays.copyOf(keyBytes, 47))));
        pkcs1Vectors.add(new Pkcs1Vector("Correctly formatted PKCS#1 message, (|PMS| = 49)", Pkcs1VectorGenerator.getPaddedKey(publicKeyByteLength, Arrays.copyOf(keyBytes, 49))));
        if (type == BleichenbacherCommandConfig.Type.FULL) {
            List<Pkcs1Vector> additionalVectors = Pkcs1VectorGenerator.getEK_DifferentPositionsOf0x00(publicKeyByteLength, keyBytes);
            for (Pkcs1Vector vector : additionalVectors) {
                pkcs1Vectors.add(vector);
            }
        }
        return pkcs1Vectors;
    }

    private static Pkcs1Vector getPlainCorrect(int publicKeyBitLength, ProtocolVersion protocolVersion) {
        byte[] keyBytes = new byte[48];
        Arrays.fill(keyBytes, (byte)42);
        keyBytes[0] = protocolVersion.getMajor();
        keyBytes[1] = protocolVersion.getMinor();
        int publicKeyByteLength = publicKeyBitLength / 8;
        return new Pkcs1Vector("Correctly formatted PKCS#1 PMS message", Pkcs1VectorGenerator.getPaddedKey(publicKeyByteLength, keyBytes));
    }

    private static byte[] getPaddedKey(int rsaKeyLength, byte[] symmetricKey) {
        byte[] key = new byte[rsaKeyLength];
        Arrays.fill(key, (byte)42);
        key[0] = 0;
        key[1] = 2;
        key[rsaKeyLength - symmetricKey.length - 1] = 0;
        System.arraycopy(symmetricKey, 0, key, rsaKeyLength - symmetricKey.length, symmetricKey.length);
        LOGGER.debug("Generated a PKCS1 padded message a correct key length, but invalid protocol version: {}", (Object)ArrayConverter.bytesToHexString((byte[])key));
        return key;
    }

    private static byte[] getEK_WrongTlsVersion(int rsaKeyLength, byte[] symmetricKey) {
        byte[] key = Pkcs1VectorGenerator.getPaddedKey(rsaKeyLength, symmetricKey);
        key[rsaKeyLength - symmetricKey.length] = 66;
        key[rsaKeyLength - symmetricKey.length + 1] = 66;
        LOGGER.debug("Generated a PKCS1 padded message with a wrong TLS version bytes: {}", (Object)ArrayConverter.bytesToHexString((byte[])key));
        return key;
    }

    private static byte[] getEK_WrongFirstByte(int rsaKeyLength, byte[] symmetricKey) {
        byte[] key = Pkcs1VectorGenerator.getPaddedKey(rsaKeyLength, symmetricKey);
        key[0] = 23;
        LOGGER.debug("Generated a PKCS1 padded message with a wrong first byte: {}", (Object)ArrayConverter.bytesToHexString((byte[])key));
        return key;
    }

    private static byte[] getEK_WrongSecondByte(int rsaKeyLength, byte[] symmetricKey) {
        byte[] key = Pkcs1VectorGenerator.getPaddedKey(rsaKeyLength, symmetricKey);
        key[1] = 23;
        LOGGER.debug("Generated a PKCS1 padded message with a wrong second byte: {}", (Object)ArrayConverter.bytesToHexString((byte[])key));
        return key;
    }

    private static byte[] getEK_NoNullByte(int rsaKeyLength, byte[] symmetricKey) {
        byte[] key = Pkcs1VectorGenerator.getPaddedKey(rsaKeyLength, symmetricKey);
        for (int i = 3; i < key.length; ++i) {
            if (key[i] != 0) continue;
            key[i] = 1;
        }
        LOGGER.debug("Generated a PKCS1 padded message with no separating byte: {}", (Object)ArrayConverter.bytesToHexString((byte[])key));
        return key;
    }

    private static byte[] getEK_NullByteInPkcsPadding(int rsaKeyLength, byte[] symmetricKey) {
        byte[] key = Pkcs1VectorGenerator.getPaddedKey(rsaKeyLength, symmetricKey);
        key[3] = 0;
        LOGGER.debug("Generated a PKCS1 padded message with a 0x00 byte in the PKCS1 padding: {}", (Object)ArrayConverter.bytesToHexString((byte[])key));
        return key;
    }

    private static byte[] getEK_NullByteInPadding(int rsaKeyLength, byte[] symmetricKey) {
        byte[] key = Pkcs1VectorGenerator.getPaddedKey(rsaKeyLength, symmetricKey);
        key[11] = 0;
        LOGGER.debug("Generated a PKCS1 padded message with a 0x00 byte in padding: {}", (Object)ArrayConverter.bytesToHexString((byte[])key));
        return key;
    }

    private static byte[] getEK_SymmetricKeyOfSize(int rsaKeyLength, byte[] symmetricKey, int size) {
        byte[] key = Pkcs1VectorGenerator.getPaddedKey(rsaKeyLength, symmetricKey);
        for (int i = 3; i < key.length; ++i) {
            if (key[i] != 0) continue;
            key[i] = 1;
        }
        key[rsaKeyLength - size - 1] = 0;
        LOGGER.debug("Generated a PKCS1 padded symmetric key of size {}: {}", (Object)size, (Object)ArrayConverter.bytesToHexString((byte[])key));
        return key;
    }

    private static List<Pkcs1Vector> getEK_DifferentPositionsOf0x00(int rsaKeyLength, byte[] symmetricKey) {
        LinkedList<Pkcs1Vector> vectors = new LinkedList<Pkcs1Vector>();
        for (int i = 2; i < rsaKeyLength; ++i) {
            byte[] key = Pkcs1VectorGenerator.getPaddedKey(rsaKeyLength, symmetricKey);
            for (int j = 3; j < key.length; ++j) {
                if (key[j] != 0) continue;
                key[j] = 1;
            }
            key[i] = 0;
            vectors.add(new Pkcs1Vector("0x00 on a wrong position (" + i + ")", key));
        }
        LOGGER.debug("Generated PKCS1 vectors with different invalid 0x00 positions");
        return vectors;
    }

    private Pkcs1VectorGenerator() {
    }
}

