Add per worker cache for certificate and private key
Originally certificate and private key are converted to DER format and then passed to `set_der_cert` and `set_der_priv_key`. The major drawback of this is the input of `set_der_cert` cannot be cached. The input is DER bytes. Inside `set_der_cert`, a new `X509 *` is always created from it. This means redundant work and causes a memory waste. By switching to `set_cert` and `set_priv_key` APIs, we can pass in an opaque `X509 *` pointer. Since `X509 *` is reference counted, we can easily add a cache for it, thus it can be reused by multiple connections. In benchmark, it saves ~20KB memory per connection.
This commit is contained in:
parent
e00b45beb5
commit
e3e8df6aff
2 changed files with 86 additions and 44 deletions
|
@ -1,4 +1,5 @@
|
||||||
local http = require("resty.http")
|
local http = require("resty.http")
|
||||||
|
local lrucache = require("resty.lrucache")
|
||||||
local ssl = require("ngx.ssl")
|
local ssl = require("ngx.ssl")
|
||||||
local ocsp = require("ngx.ocsp")
|
local ocsp = require("ngx.ocsp")
|
||||||
local ngx = ngx
|
local ngx = ngx
|
||||||
|
@ -19,29 +20,39 @@ local certificate_data = ngx.shared.certificate_data
|
||||||
local certificate_servers = ngx.shared.certificate_servers
|
local certificate_servers = ngx.shared.certificate_servers
|
||||||
local ocsp_response_cache = ngx.shared.ocsp_response_cache
|
local ocsp_response_cache = ngx.shared.ocsp_response_cache
|
||||||
|
|
||||||
local function get_der_cert_and_priv_key(pem_cert_key)
|
local CACHE_SIZE = 1000
|
||||||
local der_cert, der_cert_err = ssl.cert_pem_to_der(pem_cert_key)
|
local cache
|
||||||
if not der_cert then
|
do
|
||||||
return nil, nil, "failed to convert certificate chain from PEM to DER: " .. der_cert_err
|
local err
|
||||||
|
cache, err = lrucache.new(CACHE_SIZE)
|
||||||
|
if not cache then
|
||||||
|
return error("failed to create the certificate cache: " .. (err or "unknown"))
|
||||||
end
|
end
|
||||||
|
|
||||||
local der_priv_key, dev_priv_key_err = ssl.priv_key_pem_to_der(pem_cert_key)
|
|
||||||
if not der_priv_key then
|
|
||||||
return nil, nil, "failed to convert private key from PEM to DER: " .. dev_priv_key_err
|
|
||||||
end
|
|
||||||
|
|
||||||
return der_cert, der_priv_key, nil
|
|
||||||
end
|
end
|
||||||
|
|
||||||
local function set_der_cert_and_key(der_cert, der_priv_key)
|
local function get_cert_and_priv_key(pem_cert_key)
|
||||||
local set_cert_ok, set_cert_err = ssl.set_der_cert(der_cert)
|
local cert, cert_err = ssl.parse_pem_cert(pem_cert_key)
|
||||||
if not set_cert_ok then
|
if not cert then
|
||||||
return "failed to set DER cert: " .. set_cert_err
|
return nil, nil, "failed to parse PEM certificate chain: " .. cert_err
|
||||||
end
|
end
|
||||||
|
|
||||||
local set_priv_key_ok, set_priv_key_err = ssl.set_der_priv_key(der_priv_key)
|
local priv_key, priv_key_err = ssl.parse_pem_priv_key(pem_cert_key)
|
||||||
|
if not priv_key then
|
||||||
|
return nil, nil, "failed to parse PEM private key: " .. priv_key_err
|
||||||
|
end
|
||||||
|
|
||||||
|
return cert, priv_key, nil
|
||||||
|
end
|
||||||
|
|
||||||
|
local function set_cert_and_key(cert, priv_key)
|
||||||
|
local set_cert_ok, set_cert_err = ssl.set_cert(cert)
|
||||||
|
if not set_cert_ok then
|
||||||
|
return "failed to set cert: " .. set_cert_err
|
||||||
|
end
|
||||||
|
|
||||||
|
local set_priv_key_ok, set_priv_key_err = ssl.set_priv_key(priv_key)
|
||||||
if not set_priv_key_ok then
|
if not set_priv_key_ok then
|
||||||
return "failed to set DER private key: " .. set_priv_key_err
|
return "failed to set private key: " .. set_priv_key_err
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -221,6 +232,10 @@ function _M.configured_for_current_request()
|
||||||
return ngx.ctx.cert_configured_for_current_request
|
return ngx.ctx.cert_configured_for_current_request
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function _M.flush_cache()
|
||||||
|
cache:flush_all()
|
||||||
|
end
|
||||||
|
|
||||||
function _M.call()
|
function _M.call()
|
||||||
local hostname, hostname_err = ssl.server_name()
|
local hostname, hostname_err = ssl.server_name()
|
||||||
if hostname_err then
|
if hostname_err then
|
||||||
|
@ -232,35 +247,47 @@ function _M.call()
|
||||||
hostname = DEFAULT_CERT_HOSTNAME
|
hostname = DEFAULT_CERT_HOSTNAME
|
||||||
end
|
end
|
||||||
|
|
||||||
local pem_cert
|
local cert, priv_key, get_err
|
||||||
local pem_cert_uid = get_pem_cert_uid(hostname)
|
local pem_cert_uid = get_pem_cert_uid(hostname)
|
||||||
if not pem_cert_uid then
|
if not pem_cert_uid then
|
||||||
pem_cert_uid = get_pem_cert_uid(DEFAULT_CERT_HOSTNAME)
|
pem_cert_uid = get_pem_cert_uid(DEFAULT_CERT_HOSTNAME)
|
||||||
end
|
end
|
||||||
if pem_cert_uid then
|
if not pem_cert_uid then
|
||||||
pem_cert = certificate_data:get(pem_cert_uid)
|
|
||||||
end
|
|
||||||
if not pem_cert then
|
|
||||||
ngx.log(ngx.ERR, "certificate not found, falling back to fake certificate for hostname: "
|
ngx.log(ngx.ERR, "certificate not found, falling back to fake certificate for hostname: "
|
||||||
.. tostring(hostname))
|
.. tostring(hostname))
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local cached_entry = cache:get(pem_cert_uid)
|
||||||
|
if cached_entry then
|
||||||
|
cert = cached_entry.cert
|
||||||
|
priv_key = cached_entry.priv_key
|
||||||
|
else
|
||||||
|
local pem_cert = certificate_data:get(pem_cert_uid)
|
||||||
|
if not pem_cert then
|
||||||
|
ngx.log(ngx.ERR, "certificate not found, falling back to fake certificate for hostname: "
|
||||||
|
.. tostring(hostname))
|
||||||
|
return
|
||||||
|
end
|
||||||
|
|
||||||
|
cert, priv_key, get_err = get_cert_and_priv_key(pem_cert)
|
||||||
|
if get_err then
|
||||||
|
ngx.log(ngx.ERR, get_err)
|
||||||
|
return ngx.exit(ngx.ERROR)
|
||||||
|
end
|
||||||
|
|
||||||
|
cache:set(pem_cert_uid, { cert = cert, priv_key = priv_key })
|
||||||
|
end
|
||||||
|
|
||||||
local clear_ok, clear_err = ssl.clear_certs()
|
local clear_ok, clear_err = ssl.clear_certs()
|
||||||
if not clear_ok then
|
if not clear_ok then
|
||||||
ngx.log(ngx.ERR, "failed to clear existing (fallback) certificates: " .. clear_err)
|
ngx.log(ngx.ERR, "failed to clear existing (fallback) certificates: " .. clear_err)
|
||||||
return ngx.exit(ngx.ERROR)
|
return ngx.exit(ngx.ERROR)
|
||||||
end
|
end
|
||||||
|
|
||||||
local der_cert, der_priv_key, der_err = get_der_cert_and_priv_key(pem_cert)
|
local set_err = set_cert_and_key(cert, priv_key)
|
||||||
if der_err then
|
if set_err then
|
||||||
ngx.log(ngx.ERR, der_err)
|
ngx.log(ngx.ERR, set_err)
|
||||||
return ngx.exit(ngx.ERROR)
|
|
||||||
end
|
|
||||||
|
|
||||||
local set_der_err = set_der_cert_and_key(der_cert, der_priv_key)
|
|
||||||
if set_der_err then
|
|
||||||
ngx.log(ngx.ERR, set_der_err)
|
|
||||||
return ngx.exit(ngx.ERROR)
|
return ngx.exit(ngx.ERROR)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -16,22 +16,22 @@ local DEFAULT_UUID = "00000000-0000-0000-0000-000000000000"
|
||||||
|
|
||||||
local function assert_certificate_is_set(cert)
|
local function assert_certificate_is_set(cert)
|
||||||
spy.on(ngx, "log")
|
spy.on(ngx, "log")
|
||||||
spy.on(ssl, "set_der_cert")
|
spy.on(ssl, "set_cert")
|
||||||
spy.on(ssl, "set_der_priv_key")
|
spy.on(ssl, "set_priv_key")
|
||||||
|
|
||||||
assert.has_no.errors(certificate.call)
|
assert.has_no.errors(certificate.call)
|
||||||
assert.spy(ngx.log).was_not_called_with(ngx.ERR, _)
|
assert.spy(ngx.log).was_not_called_with(ngx.ERR, _)
|
||||||
assert.spy(ssl.set_der_cert).was_called_with(ssl.cert_pem_to_der(cert))
|
assert.spy(ssl.set_cert).was_called_with("cert")
|
||||||
assert.spy(ssl.set_der_priv_key).was_called_with(ssl.priv_key_pem_to_der(cert))
|
assert.spy(ssl.set_priv_key).was_called_with("priv_key")
|
||||||
end
|
end
|
||||||
|
|
||||||
local function refute_certificate_is_set()
|
local function refute_certificate_is_set()
|
||||||
spy.on(ssl, "set_der_cert")
|
spy.on(ssl, "set_cert")
|
||||||
spy.on(ssl, "set_der_priv_key")
|
spy.on(ssl, "set_priv_key")
|
||||||
|
|
||||||
assert.has_no.errors(certificate.call)
|
assert.has_no.errors(certificate.call)
|
||||||
assert.spy(ssl.set_der_cert).was_not_called()
|
assert.spy(ssl.set_cert).was_not_called()
|
||||||
assert.spy(ssl.set_der_priv_key).was_not_called()
|
assert.spy(ssl.set_priv_key).was_not_called()
|
||||||
end
|
end
|
||||||
|
|
||||||
local function set_certificate(hostname, certificate, uuid)
|
local function set_certificate(hostname, certificate, uuid)
|
||||||
|
@ -52,8 +52,22 @@ describe("Certificate", function()
|
||||||
before_each(function()
|
before_each(function()
|
||||||
ssl.server_name = function() return "hostname", nil end
|
ssl.server_name = function() return "hostname", nil end
|
||||||
ssl.clear_certs = function() return true, "" end
|
ssl.clear_certs = function() return true, "" end
|
||||||
ssl.set_der_cert = function(cert) return true, "" end
|
ssl.parse_pem_cert = function(cert)
|
||||||
ssl.set_der_priv_key = function(priv_key) return true, "" end
|
if cert == "invalid" then
|
||||||
|
return nil, "bad format"
|
||||||
|
else
|
||||||
|
return "cert", nil
|
||||||
|
end
|
||||||
|
end
|
||||||
|
ssl.parse_pem_priv_key = function(priv_key)
|
||||||
|
if priv_key == "invalid" then
|
||||||
|
return nil, "bad format"
|
||||||
|
else
|
||||||
|
return "priv_key", nil
|
||||||
|
end
|
||||||
|
end
|
||||||
|
ssl.set_cert = function(cert) return true, "" end
|
||||||
|
ssl.set_priv_key = function(priv_key) return true, "" end
|
||||||
|
|
||||||
ngx.exit = function(status) end
|
ngx.exit = function(status) end
|
||||||
|
|
||||||
|
@ -65,6 +79,7 @@ describe("Certificate", function()
|
||||||
ngx = unmocked_ngx
|
ngx = unmocked_ngx
|
||||||
ngx.shared.certificate_data:flush_all()
|
ngx.shared.certificate_data:flush_all()
|
||||||
ngx.shared.certificate_servers:flush_all()
|
ngx.shared.certificate_servers:flush_all()
|
||||||
|
certificate.flush_cache()
|
||||||
end)
|
end)
|
||||||
|
|
||||||
it("sets certificate and key when hostname is found in dictionary", function()
|
it("sets certificate and key when hostname is found in dictionary", function()
|
||||||
|
@ -101,12 +116,12 @@ describe("Certificate", function()
|
||||||
end)
|
end)
|
||||||
|
|
||||||
it("logs error message when certificate in dictionary is invalid", function()
|
it("logs error message when certificate in dictionary is invalid", function()
|
||||||
set_certificate("hostname", "something invalid", UUID)
|
set_certificate("hostname", "invalid", UUID)
|
||||||
|
|
||||||
spy.on(ngx, "log")
|
spy.on(ngx, "log")
|
||||||
|
|
||||||
refute_certificate_is_set()
|
refute_certificate_is_set()
|
||||||
assert.spy(ngx.log).was_called_with(ngx.ERR, "failed to convert certificate chain from PEM to DER: PEM_read_bio_X509_AUX() failed")
|
assert.spy(ngx.log).was_called_with(ngx.ERR, "failed to parse PEM certificate chain: bad format")
|
||||||
end)
|
end)
|
||||||
|
|
||||||
it("uses default certificate when there's none found for given hostname", function()
|
it("uses default certificate when there's none found for given hostname", function()
|
||||||
|
@ -126,7 +141,7 @@ describe("Certificate", function()
|
||||||
spy.on(ngx, "log")
|
spy.on(ngx, "log")
|
||||||
|
|
||||||
refute_certificate_is_set()
|
refute_certificate_is_set()
|
||||||
assert.spy(ngx.log).was_called_with(ngx.ERR, "failed to convert certificate chain from PEM to DER: PEM_read_bio_X509_AUX() failed")
|
assert.spy(ngx.log).was_called_with(ngx.ERR, "failed to parse PEM certificate chain: bad format")
|
||||||
end)
|
end)
|
||||||
|
|
||||||
describe("OCSP stapling", function()
|
describe("OCSP stapling", function()
|
||||||
|
|
Loading…
Reference in a new issue