欢迎您访问 最编程 本站为您分享编程语言代码,编程技术文章!
您现在的位置是: 首页

mTLS: Netty 单向/双向 TLS 演示完整代码

最编程 2024-03-02 22:43:46
...

NettyHelper.java: 主要用是创建EventLoopGroup和判断是否支持Epoll。

package org.example.netty;

import io.netty.channel.EventLoopGroup;
import io.netty.channel.epoll.Epoll;
import io.netty.channel.epoll.EpollEventLoopGroup;
import io.netty.channel.epoll.EpollSocketChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http2.Http2SecurityUtil;
import io.netty.handler.ssl.*;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import io.netty.util.concurrent.DefaultThreadFactory;

import javax.net.ssl.SSLException;
import java.security.cert.CertificateException;
import java.util.concurrent.ThreadFactory;


public class NettyHelper {
    static final String NETTY_EPOLL_ENABLE_KEY = "netty.epoll.enable";

    static final String OS_NAME_KEY = "os.name";

    static final String OS_LINUX_PREFIX = "linux";

    public static EventLoopGroup eventLoopGroup(int threads, String threadFactoryName) {
        ThreadFactory threadFactory = new DefaultThreadFactory(threadFactoryName, true);
        return shouldEpoll() ? new EpollEventLoopGroup(threads, threadFactory) :
                new NioEventLoopGroup(threads, threadFactory);
    }


    public static boolean shouldEpoll() {
        if (Boolean.parseBoolean(System.getProperty(NETTY_EPOLL_ENABLE_KEY, "false"))) {
            String osName = System.getProperty(OS_NAME_KEY);
            return osName.toLowerCase().contains(OS_LINUX_PREFIX) && Epoll.isAvailable();
        }

        return false;
    }


    public static Class<? extends SocketChannel> socketChannelClass() {
        return shouldEpoll() ? EpollSocketChannel.class : NioSocketChannel.class;
    }
}

SslContexts: 创建SslContext对象的工具类

package org.example.netty.tls;

import io.netty.handler.codec.http2.Http2SecurityUtil;
import io.netty.handler.ssl.*;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import lombok.extern.slf4j.Slf4j;

import javax.net.ssl.SSLException;
import java.io.*;
import java.net.MalformedURLException;
import java.security.Provider;
import java.security.Security;
import java.security.cert.CertificateException;

@Slf4j
public class SslContexts {


    public static SslContext createTlsClientSslContext() throws SSLException {
        SslProvider provider = findSslProvider();
        return SslContextBuilder.forClient()
                .sslProvider(provider)
                .trustManager(InsecureTrustManagerFactory.INSTANCE)
                .protocols("TLSv1.3", "TLSv1.2")
                .build();
    }


    /**
     * 创建server SslContext
     * 会自动创建一个临时自签名的证书 -- Generates a temporary self-signed certificate
     *
     * @return
     * @throws CertificateException
     * @throws SSLException
     */
    public static SslContext createTlsServerSslContext() throws CertificateException, SSLException {
        SslProvider provider = findSslProvider();
        SelfSignedCertificate cert = new SelfSignedCertificate();
        return SslContextBuilder.forServer(cert.certificate(), cert.privateKey())
                .sslProvider(provider)
                .protocols("TLSv1.3", "TLSv1.2")
                .build();
    }


    public static SslContext createServerSslContext(File keyCertChainFile, File keyFile) {
        return createServerSslContext(keyCertChainFile, keyFile, null, null);
    }

    public static SslContext createServerSslContext(File keyCertChainFile, File keyFile, File trustCertCollection) {
        return createServerSslContext(keyCertChainFile, keyFile, null, trustCertCollection);
    }

    public static SslContext createServerSslContext(File keyCertChainFile, File keyFile, String keyPassword) {
        return createServerSslContext(keyCertChainFile, keyFile, keyPassword, null);
    }

    public static SslContext createServerSslContext(File keyCertChainFile, File keyFile, String keyPassword, File trustCertCollection) {
        return createServerSslContext(keyCertChainFile, keyFile, keyPassword, trustCertCollection, true);
    }

    public static SslContext createServerSslContext(File keyCertChainFile, File keyFile, String keyPassword, File trustCertCollection, boolean requireClientAuth) {
        try (InputStream keyCertChainInputStream = openInputStream(keyCertChainFile);
             InputStream keyInputStream = openInputStream(keyFile);
             InputStream trustCertCollectionInputStream = openInputStream(trustCertCollection);) {
            return createServerSslContext(keyCertChainInputStream, keyInputStream, keyPassword, trustCertCollectionInputStream, requireClientAuth);
        } catch (IOException e) {
            throw new IllegalArgumentException("Could not find certificate file or the certificate is invalid.", e);
        }
    }




    public static SslContext createServerSslContext(InputStream keyCertChainInputStream, InputStream keyInputStream) {
        return createServerSslContext(keyCertChainInputStream, keyInputStream, null, null);
    }

    public static SslContext createServerSslContext(InputStream keyCertChainInputStream, InputStream keyInputStream, InputStream trustCertCollection) {
        return createServerSslContext(keyCertChainInputStream, keyInputStream, null, trustCertCollection);
    }

    public static SslContext createServerSslContext(InputStream keyCertChainInputStream, InputStream keyInputStream, String keyPassword) {
        return createServerSslContext(keyCertChainInputStream, keyInputStream, keyPassword, null);
    }

    public static SslContext createServerSslContext(InputStream keyCertChainInputStream, InputStream keyInputStream, String keyPassword, InputStream trustCertCollection) {
        return createServerSslContext(keyCertChainInputStream, keyInputStream, keyPassword, trustCertCollection, true);
    }

    public static SslContext createServerSslContext(InputStream keyCertChainInputStream, InputStream keyInputStream, String keyPassword, InputStream trustCertCollection, boolean requireClientAuth) {
        SslContextBuilder builder;
        if (keyPassword != null) {
            builder = SslContextBuilder.forServer(keyCertChainInputStream, keyInputStream, keyPassword);
        } else {
            builder = SslContextBuilder.forServer(keyCertChainInputStream, keyInputStream);
        }
        if (trustCertCollection != null) {
            builder.trustManager(trustCertCollection)
            ;
        }
        if (requireClientAuth) {
            builder.clientAuth(ClientAuth.REQUIRE);
        }

        try {
            SslProvider provider = findSslProvider();
            return builder
                    .sslProvider(provider)
                    .protocols("TLSv1.3", "TLSv1.2")
                    .build();
        } catch (SSLException e) {
            throw new IllegalStateException("Build SslSession failed.", e);
        }
    }


    public static SslContext createClientSslContext() {
        try {
            SslProvider provider = findSslProvider();
            return SslContextBuilder.forClient()
                    .sslProvider(provider)
                    .trustManager(InsecureTrustManagerFactory.INSTANCE)
                    .protocols("TLSv1.3", "TLSv1.2")
                    .build();
        } catch (SSLException e) {
            throw new IllegalStateException("Build SslSession failed.", e);
        }
    }

    public static SslContext createClientSslContext(File trustCertCollection) {
        return createClientSslContext(null, null, null, trustCertCollection);
    }

    public static SslContext createClientSslContext(File keyCertChainInputStream, File keyInputStream, File trustCertCollection) {
        return createClientSslContext(keyCertChainInputStream, keyInputStream, null, trustCertCollection);
    }

    public static SslContext createClientSslContext(File keyCertChainFile, File keyFile, String keyPassword, File trustCertCollectionFile) {
        try (InputStream keyCertChainInputStream = openInputStream(keyCertChainFile);
             InputStream keyInputStream = openInputStream(keyFile);
             InputStream trustCertCollectionInputStream = openInputStream(trustCertCollectionFile);) {
            return createClientSslContext(keyCertChainInputStream, keyInputStream, keyPassword, trustCertCollectionInputStream);
        } catch (IOException e) {
            throw new IllegalArgumentException("Could not find certificate file or the certificate is invalid.", e);
        }
    }

    public static SslContext createClientSslContext(InputStream trustCertCollection) {
        return createClientSslContext(null, null, null, trustCertCollection);
    }

    public static SslContext createClientSslContext(InputStream keyCertChainInputStream, InputStream keyInputStream, InputStream trustCertCollection) {
        return createClientSslContext(keyCertChainInputStream, keyInputStream, null, trustCertCollection);
    }

    public static SslContext createClientSslContext(InputStream keyCertChainInputStream, InputStream keyInputStream, String keyPassword, InputStream trustCertCollection) {
        SslContextBuilder builder = SslContextBuilder.forClient();
        if (trustCertCollection != null) {
            builder.trustManager(trustCertCollection)
                    .clientAuth(ClientAuth.REQUIRE);
        }
        if (keyCertChainInputStream != null || keyInputStream != null) {
            if (keyPassword != null) {
                builder.keyManager(keyCertChainInputStream, keyInputStream, keyPassword);
            } else {
                builder.keyManager(keyCertChainInputStream, keyInputStream);
            }
        }

        try {
            SslProvider provider = findSslProvider();
            return builder
                    .sslProvider(provider)
                    .protocols("TLSv1.3", "TLSv1.2")
                    .build();
        } catch (SSLException e) {
            throw new IllegalStateException("Build SslSession failed.", e);
        }
    }


    /**
     * 创建 https server SslContext
     * 会自动创建一个临时自签名的证书 -- Generates a temporary self-signed certificate
     *
     * @return
     * @throws CertificateException
     * @throws SSLException
     */
    public static SslContext createHttpsServerSslContext() throws CertificateException, SSLException {
        SslProvider provider = findSslProvider();
        SelfSignedCertificate cert = new SelfSignedCertificate();
        return SslContextBuilder.forServer(cert.certificate(), cert.privateKey())
                .sslProvider(provider)
                .ciphers(Http2SecurityUtil.CIPHERS, SupportedCipherSuiteFilter.INSTANCE)
                .protocols("TLSv1.3", "TLSv1.2")
                .applicationProtocolConfig(
                        new ApplicationProtocolConfig(ApplicationProtocolConfig.Protocol.ALPN, ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE,
                                ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT
                                , ApplicationProtocolNames.HTTP_2, ApplicationProtocolNames.HTTP_1_1
                        ))
                .build();
    }

    public static InputStream openInputStream(File file) throws IOException {
        return file == null ? null : file.toURI().toURL().openStream();
//        return file == null ? null : new FileInputStream(file);
    }

//
//    public static SslContext buildServerSslContext() {
//        SslContextBuilder sslClientContextBuilder;
//        InputStream serverKeyCertChainPathStream = null;
//        InputStream serverPrivateKeyPathStream = null;
//        InputStream serverTrustCertStream = null;
//        try {
//            serverKeyCertChainPathStream = sslConfig.getServerKeyCertChainPathStream();
//            serverPrivateKeyPathStream = sslConfig.getServerPrivateKeyPathStream();
//            serverTrustCertStream = sslConfig.getServerTrustCertCollectionPathStream();
//            String password = sslConfig.getServerKeyPassword();
//            if (password != null) {
//                sslClientContextBuilder = SslContextBuilder.forServer(serverKeyCertChainPathStream,
//                        serverPrivateKeyPathStream, password);
//            } else {
//                sslClientContextBuilder = SslContextBuilder.forServer(serverKeyCertChainPathStream,
//                        serverPrivateKeyPathStream);
//            }
//
//            if (serverTrustCertStream != null) {
//                sslClientContextBuilder.trustManager(serverTrustCertStream);
//                sslClientContextBuilder.clientAuth(ClientAuth.REQUIRE);
//            }
//        } catch (Exception e) {
//            throw new IllegalArgumentException("Could not find certificate file or the certificate is invalid.", e);
//        } finally {
//            safeCloseStream(serverTrustCertStream);
//            safeCloseStream(serverKeyCertChainPathStream);
//            safeCloseStream(serverPrivateKeyPathStream);
//        }
//        try {
//            return sslClientContextBuilder.sslProvider(findSslProvider()).build();
//        } catch (SSLException e) {
//            throw new IllegalStateException("Build SslSession failed.", e);
//        }
//    }
//
//    public static SslContext buildClientSslContext(URL url) {
//
//        SslContextBuilder builder = SslContextBuilder.forClient();
//        InputStream clientTrustCertCollectionPath = null;
//        InputStream clientCertChainFilePath = null;
//        InputStream clientPrivateKeyFilePath = null;
//        try {
//            clientTrustCertCollectionPath = sslConfig.getClientTrustCertCollectionPathStream();
//            if (clientTrustCertCollectionPath != null) {
//                builder.trustManager(clientTrustCertCollectionPath);
//            }
//
//            clientCertChainFilePath = sslConfig.getClientKeyCertChainPathStream();
//            clientPrivateKeyFilePath = sslConfig.getClientPrivateKeyPathStream();
//            if (clientCertChainFilePath != null && clientP