diff --git a/images/custom-error-pages/rootfs/main.go b/images/custom-error-pages/rootfs/main.go index 7d3d73029..3a0c04b60 100644 --- a/images/custom-error-pages/rootfs/main.go +++ b/images/custom-error-pages/rootfs/main.go @@ -97,6 +97,41 @@ func main() { http.ListenAndServe(fmt.Sprintf(":8080"), nil) } +// lookupAvailableExtensionsForCode returns a list of available extensions for a given code in the error files directory +func lookupAvailableExtensionsForCode(errFilesPath string, code int) ([]string, error) { + files, err := os.ReadDir(errFilesPath) + if err != nil { + return []string{}, fmt.Errorf("unable to read error files directory: %v", err) + } + availableFormats := []string{} + for _, f := range files { + if !f.IsDir() && strings.HasPrefix(f.Name(), fmt.Sprintf("%v", code)) { + availableFormats = append(availableFormats, f.Name()) + } + } + return availableFormats, nil +} + +// findAcceptedFormat returns the extension for a given Accept header in the list of available extensions +func findAcceptedFormat(acceptHeader string, availExts []string) (string, string, error) { + formats := strings.Split(acceptHeader, ",") + for _, format := range formats { + mimetype := strings.TrimSpace(format) + exts, err := mime.ExtensionsByType(mimetype) + if err != nil { + continue + } + for _, ext := range exts { + for _, aext := range availExts { + if ext == aext { + return ext, mimetype, nil + } + } + } + } + return "", "", fmt.Errorf("no available handler found for '%s' in '%s'", formats, availExts) +} + func errorHandler(path, defaultFormat string) func(http.ResponseWriter, *http.Request) { defaultExts, err := mime.ExtensionsByType(defaultFormat) if err != nil || len(defaultExts) == 0 { @@ -126,17 +161,6 @@ func errorHandler(path, defaultFormat string) func(http.ResponseWriter, *http.Re log.Printf("format not specified. Using %v", format) } - cext, err := mime.ExtensionsByType(format) - if err != nil { - log.Printf("unexpected error reading media type extension: %v. Using %v", err, ext) - format = defaultFormat - } else if len(cext) == 0 { - log.Printf("couldn't get media type extension. Using %v", ext) - } else { - ext = cext[0] - } - w.Header().Set(ContentType, format) - errCode := r.Header.Get(CodeHeader) code, err := strconv.Atoi(errCode) if err != nil { @@ -145,6 +169,21 @@ func errorHandler(path, defaultFormat string) func(http.ResponseWriter, *http.Re } w.WriteHeader(code) + exts, err := lookupAvailableExtensionsForCode(path, code) + if err != nil { + panic("unexpected error reading available formats") + } + + aext, aformat, err := findAcceptedFormat(format, exts) + if err != nil { + log.Printf("unexpected error reading media type extension: %v. Using %v", err, ext) + format = defaultFormat + } else { + ext = aext + format = aformat + } + w.Header().Set(ContentType, format) + if !strings.HasPrefix(ext, ".") { ext = "." + ext }