Add per worker cache for certificate and private key

Originally certificate and private key are converted to DER format and
then passed to `set_der_cert` and `set_der_priv_key`. The major drawback
of this is the input of `set_der_cert` cannot be cached. The input is
DER bytes. Inside `set_der_cert`, a new `X509 *` is always created from
it. This means redundant work and causes a memory waste.

By switching to `set_cert` and `set_priv_key` APIs, we can pass in an
opaque `X509 *` pointer. Since `X509 *` is reference counted, we can
easily add a cache for it, thus it can be reused by multiple
connections.

In benchmark, it saves ~20KB memory per connection.
This commit is contained in:
Dayang Shen 2022-02-17 14:22:57 +08:00
parent e00b45beb5
commit e3e8df6aff
2 changed files with 86 additions and 44 deletions

View file

@ -1,4 +1,5 @@
local http = require("resty.http") local http = require("resty.http")
local lrucache = require("resty.lrucache")
local ssl = require("ngx.ssl") local ssl = require("ngx.ssl")
local ocsp = require("ngx.ocsp") local ocsp = require("ngx.ocsp")
local ngx = ngx local ngx = ngx
@ -19,29 +20,39 @@ local certificate_data = ngx.shared.certificate_data
local certificate_servers = ngx.shared.certificate_servers local certificate_servers = ngx.shared.certificate_servers
local ocsp_response_cache = ngx.shared.ocsp_response_cache local ocsp_response_cache = ngx.shared.ocsp_response_cache
local function get_der_cert_and_priv_key(pem_cert_key) local CACHE_SIZE = 1000
local der_cert, der_cert_err = ssl.cert_pem_to_der(pem_cert_key) local cache
if not der_cert then do
return nil, nil, "failed to convert certificate chain from PEM to DER: " .. der_cert_err local err
cache, err = lrucache.new(CACHE_SIZE)
if not cache then
return error("failed to create the certificate cache: " .. (err or "unknown"))
end end
local der_priv_key, dev_priv_key_err = ssl.priv_key_pem_to_der(pem_cert_key)
if not der_priv_key then
return nil, nil, "failed to convert private key from PEM to DER: " .. dev_priv_key_err
end
return der_cert, der_priv_key, nil
end end
local function set_der_cert_and_key(der_cert, der_priv_key) local function get_cert_and_priv_key(pem_cert_key)
local set_cert_ok, set_cert_err = ssl.set_der_cert(der_cert) local cert, cert_err = ssl.parse_pem_cert(pem_cert_key)
if not set_cert_ok then if not cert then
return "failed to set DER cert: " .. set_cert_err return nil, nil, "failed to parse PEM certificate chain: " .. cert_err
end end
local set_priv_key_ok, set_priv_key_err = ssl.set_der_priv_key(der_priv_key) local priv_key, priv_key_err = ssl.parse_pem_priv_key(pem_cert_key)
if not priv_key then
return nil, nil, "failed to parse PEM private key: " .. priv_key_err
end
return cert, priv_key, nil
end
local function set_cert_and_key(cert, priv_key)
local set_cert_ok, set_cert_err = ssl.set_cert(cert)
if not set_cert_ok then
return "failed to set cert: " .. set_cert_err
end
local set_priv_key_ok, set_priv_key_err = ssl.set_priv_key(priv_key)
if not set_priv_key_ok then if not set_priv_key_ok then
return "failed to set DER private key: " .. set_priv_key_err return "failed to set private key: " .. set_priv_key_err
end end
end end
@ -221,6 +232,10 @@ function _M.configured_for_current_request()
return ngx.ctx.cert_configured_for_current_request return ngx.ctx.cert_configured_for_current_request
end end
function _M.flush_cache()
cache:flush_all()
end
function _M.call() function _M.call()
local hostname, hostname_err = ssl.server_name() local hostname, hostname_err = ssl.server_name()
if hostname_err then if hostname_err then
@ -232,35 +247,47 @@ function _M.call()
hostname = DEFAULT_CERT_HOSTNAME hostname = DEFAULT_CERT_HOSTNAME
end end
local pem_cert local cert, priv_key, get_err
local pem_cert_uid = get_pem_cert_uid(hostname) local pem_cert_uid = get_pem_cert_uid(hostname)
if not pem_cert_uid then if not pem_cert_uid then
pem_cert_uid = get_pem_cert_uid(DEFAULT_CERT_HOSTNAME) pem_cert_uid = get_pem_cert_uid(DEFAULT_CERT_HOSTNAME)
end end
if pem_cert_uid then if not pem_cert_uid then
pem_cert = certificate_data:get(pem_cert_uid)
end
if not pem_cert then
ngx.log(ngx.ERR, "certificate not found, falling back to fake certificate for hostname: " ngx.log(ngx.ERR, "certificate not found, falling back to fake certificate for hostname: "
.. tostring(hostname)) .. tostring(hostname))
return return
end end
local cached_entry = cache:get(pem_cert_uid)
if cached_entry then
cert = cached_entry.cert
priv_key = cached_entry.priv_key
else
local pem_cert = certificate_data:get(pem_cert_uid)
if not pem_cert then
ngx.log(ngx.ERR, "certificate not found, falling back to fake certificate for hostname: "
.. tostring(hostname))
return
end
cert, priv_key, get_err = get_cert_and_priv_key(pem_cert)
if get_err then
ngx.log(ngx.ERR, get_err)
return ngx.exit(ngx.ERROR)
end
cache:set(pem_cert_uid, { cert = cert, priv_key = priv_key })
end
local clear_ok, clear_err = ssl.clear_certs() local clear_ok, clear_err = ssl.clear_certs()
if not clear_ok then if not clear_ok then
ngx.log(ngx.ERR, "failed to clear existing (fallback) certificates: " .. clear_err) ngx.log(ngx.ERR, "failed to clear existing (fallback) certificates: " .. clear_err)
return ngx.exit(ngx.ERROR) return ngx.exit(ngx.ERROR)
end end
local der_cert, der_priv_key, der_err = get_der_cert_and_priv_key(pem_cert) local set_err = set_cert_and_key(cert, priv_key)
if der_err then if set_err then
ngx.log(ngx.ERR, der_err) ngx.log(ngx.ERR, set_err)
return ngx.exit(ngx.ERROR)
end
local set_der_err = set_der_cert_and_key(der_cert, der_priv_key)
if set_der_err then
ngx.log(ngx.ERR, set_der_err)
return ngx.exit(ngx.ERROR) return ngx.exit(ngx.ERROR)
end end

View file

@ -16,22 +16,22 @@ local DEFAULT_UUID = "00000000-0000-0000-0000-000000000000"
local function assert_certificate_is_set(cert) local function assert_certificate_is_set(cert)
spy.on(ngx, "log") spy.on(ngx, "log")
spy.on(ssl, "set_der_cert") spy.on(ssl, "set_cert")
spy.on(ssl, "set_der_priv_key") spy.on(ssl, "set_priv_key")
assert.has_no.errors(certificate.call) assert.has_no.errors(certificate.call)
assert.spy(ngx.log).was_not_called_with(ngx.ERR, _) assert.spy(ngx.log).was_not_called_with(ngx.ERR, _)
assert.spy(ssl.set_der_cert).was_called_with(ssl.cert_pem_to_der(cert)) assert.spy(ssl.set_cert).was_called_with("cert")
assert.spy(ssl.set_der_priv_key).was_called_with(ssl.priv_key_pem_to_der(cert)) assert.spy(ssl.set_priv_key).was_called_with("priv_key")
end end
local function refute_certificate_is_set() local function refute_certificate_is_set()
spy.on(ssl, "set_der_cert") spy.on(ssl, "set_cert")
spy.on(ssl, "set_der_priv_key") spy.on(ssl, "set_priv_key")
assert.has_no.errors(certificate.call) assert.has_no.errors(certificate.call)
assert.spy(ssl.set_der_cert).was_not_called() assert.spy(ssl.set_cert).was_not_called()
assert.spy(ssl.set_der_priv_key).was_not_called() assert.spy(ssl.set_priv_key).was_not_called()
end end
local function set_certificate(hostname, certificate, uuid) local function set_certificate(hostname, certificate, uuid)
@ -52,8 +52,22 @@ describe("Certificate", function()
before_each(function() before_each(function()
ssl.server_name = function() return "hostname", nil end ssl.server_name = function() return "hostname", nil end
ssl.clear_certs = function() return true, "" end ssl.clear_certs = function() return true, "" end
ssl.set_der_cert = function(cert) return true, "" end ssl.parse_pem_cert = function(cert)
ssl.set_der_priv_key = function(priv_key) return true, "" end if cert == "invalid" then
return nil, "bad format"
else
return "cert", nil
end
end
ssl.parse_pem_priv_key = function(priv_key)
if priv_key == "invalid" then
return nil, "bad format"
else
return "priv_key", nil
end
end
ssl.set_cert = function(cert) return true, "" end
ssl.set_priv_key = function(priv_key) return true, "" end
ngx.exit = function(status) end ngx.exit = function(status) end
@ -65,6 +79,7 @@ describe("Certificate", function()
ngx = unmocked_ngx ngx = unmocked_ngx
ngx.shared.certificate_data:flush_all() ngx.shared.certificate_data:flush_all()
ngx.shared.certificate_servers:flush_all() ngx.shared.certificate_servers:flush_all()
certificate.flush_cache()
end) end)
it("sets certificate and key when hostname is found in dictionary", function() it("sets certificate and key when hostname is found in dictionary", function()
@ -101,12 +116,12 @@ describe("Certificate", function()
end) end)
it("logs error message when certificate in dictionary is invalid", function() it("logs error message when certificate in dictionary is invalid", function()
set_certificate("hostname", "something invalid", UUID) set_certificate("hostname", "invalid", UUID)
spy.on(ngx, "log") spy.on(ngx, "log")
refute_certificate_is_set() refute_certificate_is_set()
assert.spy(ngx.log).was_called_with(ngx.ERR, "failed to convert certificate chain from PEM to DER: PEM_read_bio_X509_AUX() failed") assert.spy(ngx.log).was_called_with(ngx.ERR, "failed to parse PEM certificate chain: bad format")
end) end)
it("uses default certificate when there's none found for given hostname", function() it("uses default certificate when there's none found for given hostname", function()
@ -126,7 +141,7 @@ describe("Certificate", function()
spy.on(ngx, "log") spy.on(ngx, "log")
refute_certificate_is_set() refute_certificate_is_set()
assert.spy(ngx.log).was_called_with(ngx.ERR, "failed to convert certificate chain from PEM to DER: PEM_read_bio_X509_AUX() failed") assert.spy(ngx.log).was_called_with(ngx.ERR, "failed to parse PEM certificate chain: bad format")
end) end)
describe("OCSP stapling", function() describe("OCSP stapling", function()