Why Would This Be Needed?

In today's web environment, it’s common for a server to host multiple domains or subdomains. Each domain needs its own SSL/TLS certificate to ensure secure communication. Server Name Indication (SNI) allows a server to present different certificates based on the domain the client is trying to reach, all on the same IP address. This setup is particularly important for services or applications hosting multiple websites under a single server infrastructure, offering efficient certificate management and enhanced security. In Rust, using Axum with rustls makes it possible to handle such configurations smoothly.

Prerequisite

Let's review the crates we will need for this use case:


[dependencies]
axum = "0.7.6"
axum-server = { version = "0.7.1", features = ["tls-rustls"] }
tokio = { version = "=1.40.0", features = ["full"] }
dashmap = { version = "=5.5.3", features = ["inline", "rayon"] }
rustls = "0.23.13"
rustls-pemfile = "2.1.3"
once_cell = "1.19.0"

We will use axum and axum-server, as the title suggests, as our backend framework. For asynchronous operations, we’ll rely on the Tokio library. To handle encryption, we'll use rustls alongside rustls-pemfile, which will assist in generating the CertifiedKey. For managing static global variables, we will utilize once_cell, and for efficient retrieval of the CertifiedKey, we'll use dashmap, known for being one of the fastest concurrent HashMap implementations.

Let's begin

We will first write the tls.rs file.

Structs and Global variables


use std::{fs::File, io::BufReader, sync::Arc};

use dashmap::DashMap;
use once_cell::sync::Lazy;
use rustls::{compress::CompressionCache, crypto::aws_lc_rs::{sign::any_supported_type, Ticketer}, server::{ClientHello, ResolvesServerCert, ServerSessionMemoryCache}, sign::CertifiedKey, ServerConfig};

//Create a global variable to store the certificates
pub static CERT_DB: Lazy>> = Lazy::new(|| {
    let cert_vec = Arc::new(DashMap::::default());
    return cert_vec;
});

//Create a struct to resolve the server certificate
#[derive(Debug)]
struct ResolveServerCert;

//Implement the ResolvesServerCert trait for the ResolveServerCert struct
impl ResolvesServerCert for ResolveServerCert {
    fn resolve(&self, client_hello: ClientHello) -> Option> {
        match client_hello.server_name() {
            Some(sni) => {
                if let Some(cert_key) = CERT_DB.get(&sni.to_string()) {
                    return Some(cert_key.certified_key.clone());
                }
                else {
                    //error!("No certificate found.");
                    None
                }
            }
            None => {
                //error!("No SNI value found.");
                None
            }
        }
    }
}

//A struct for holding Arc
pub struct TlsCollection {
    pub certified_key: Arc
}

Generating CertifiedKey from pem files


//A function to get CertifiedKey out of the certificate and private key pem files
pub async fn get_cert_key(domain: &str) -> Option {  
    let cert_file = format!("{}{}{}", "", domain, "cert.pem");
    let key_file = format!("{}{}{}", "", domain, "key.pem");

    let cert_path = std::path::Path::new(&cert_file);
    let key_path = std::path::Path::new(&key_file);

    if cert_path.exists() && cert_path.is_file() && key_path.exists() && key_path.is_file() {   
        let cert_file = &mut BufReader::new(File::open(cert_file).unwrap());
        let private_key_file = &mut BufReader::new(File::open(key_file).unwrap());

        let certs = rustls_pemfile::certs(cert_file)
            .collect::, _>>()
            .unwrap();

        let private_key = rustls_pemfile::private_key(private_key_file);

        let private_key = private_key.unwrap();

        if private_key.is_some() {
            let pk = private_key.unwrap().clone_key();
            let certified_key = CertifiedKey {
                cert: certs,
                key: any_supported_type(&pk).unwrap(),
                ocsp: None
            };

            return Some(certified_key);
        }
        else {
            None
        }
    }
    else {
        return None;
    }
}

Helper function for creating rustls ServerConfig


pub fn create_server_config() -> ServerConfig {  
    let mut server_config = ServerConfig::builder()
        .with_no_client_auth()
        .with_cert_resolver(Arc::new(ResolveServerCert));

    //Recommended configurations
    server_config.max_early_data_size = 2048;
    server_config.ticketer = Ticketer::new().unwrap();
    server_config.session_storage = ServerSessionMemoryCache::new(10240);
    server_config.alpn_protocols = vec!["h2".into(), "http/1.1".into()];
    server_config.cert_compression_cache = Arc::new(CompressionCache::new(2048));
    server_config.send_half_rtt_data = true;
    server_config.send_tls13_tickets = 4;

    server_config
}

Helper function for initialization during startup


pub async fn init_cert_in_memory(listening_domains: Vec) -> std::io::Result<()>  {
    for domain in listening_domains.iter() {
        if let Some(existing_key) = get_cert_key(&domain).await {
            let cert_db = CERT_DB.clone();
            cert_db.insert(domain.to_owned(), TlsCollection {
                certified_key: Arc::new(existing_key)
            });
        }
    }

    Ok(())    
}

Finally the main.rs


use std::{net::SocketAddr, sync::Arc};

use axum::Router;
use axum_server::tls_rustls::{RustlsAcceptor, RustlsConfig};
use tls::{create_server_config, init_cert_in_memory};

pub mod tls;

#[tokio::main]
async fn main() {
    _ = init_cert_in_memory(vec!["localhost".into()]).await;

    let config = Arc::new(create_server_config());
    let axum_config = RustlsConfig::from_config(config);
    let acceptor = RustlsAcceptor::new(axum_config);

    let router = Router::new();

    let server = axum_server::bind(SocketAddr::from(([0, 0, 0, 0], 443)))
        .acceptor(acceptor)
        .serve(router.into_make_service());

    _ = server.await;
}

Final Thoughts

In conclusion, setting up an Axum server with SNI configuration to handle TLS connections for multiple domains is a powerful way to secure multiple websites using a single server. With the help of rustls for encryption, Tokio for asynchronous operations, and efficient libraries like once_cell and DashMap, you can create a highly performant and secure multi-domain server in Rust. This approach not only streamlines certificate management but also ensures that each domain benefits from strong, independent security measures. Whether you're managing a few subdomains or multiple websites, this setup provides scalability and reliability for modern web applications.