fix(custom-error-pages): support multiple accept header values

Signed-off-by: Zadkiel Aharonian <zadkiel.aharonian@gmail.com>
This commit is contained in:
Zadkiel Aharonian 2022-01-19 17:16:05 +00:00 committed by GitHub
parent da51393cac
commit e212c52f25
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -97,6 +97,41 @@ func main() {
http.ListenAndServe(fmt.Sprintf(":8080"), nil) http.ListenAndServe(fmt.Sprintf(":8080"), nil)
} }
// lookupAvailableExtensionsForCode returns a list of available extensions for a given code in the error files directory
func lookupAvailableExtensionsForCode(errFilesPath string, code int) ([]string, error) {
files, err := os.ReadDir(errFilesPath)
if err != nil {
return []string{}, fmt.Errorf("unable to read error files directory: %v", err)
}
availableFormats := []string{}
for _, f := range files {
if !f.IsDir() && strings.HasPrefix(f.Name(), fmt.Sprintf("%v", code)) {
availableFormats = append(availableFormats, f.Name())
}
}
return availableFormats, nil
}
// findAcceptedFormat returns the extension for a given Accept header in the list of available extensions
func findAcceptedFormat(acceptHeader string, availExts []string) (string, string, error) {
formats := strings.Split(acceptHeader, ",")
for _, format := range formats {
mimetype := strings.TrimSpace(format)
exts, err := mime.ExtensionsByType(mimetype)
if err != nil {
continue
}
for _, ext := range exts {
for _, aext := range availExts {
if ext == aext {
return ext, mimetype, nil
}
}
}
}
return "", "", fmt.Errorf("no available handler found for '%s' in '%s'", formats, availExts)
}
func errorHandler(path, defaultFormat string) func(http.ResponseWriter, *http.Request) { func errorHandler(path, defaultFormat string) func(http.ResponseWriter, *http.Request) {
defaultExts, err := mime.ExtensionsByType(defaultFormat) defaultExts, err := mime.ExtensionsByType(defaultFormat)
if err != nil || len(defaultExts) == 0 { if err != nil || len(defaultExts) == 0 {
@ -126,17 +161,6 @@ func errorHandler(path, defaultFormat string) func(http.ResponseWriter, *http.Re
log.Printf("format not specified. Using %v", format) log.Printf("format not specified. Using %v", format)
} }
cext, err := mime.ExtensionsByType(format)
if err != nil {
log.Printf("unexpected error reading media type extension: %v. Using %v", err, ext)
format = defaultFormat
} else if len(cext) == 0 {
log.Printf("couldn't get media type extension. Using %v", ext)
} else {
ext = cext[0]
}
w.Header().Set(ContentType, format)
errCode := r.Header.Get(CodeHeader) errCode := r.Header.Get(CodeHeader)
code, err := strconv.Atoi(errCode) code, err := strconv.Atoi(errCode)
if err != nil { if err != nil {
@ -145,6 +169,21 @@ func errorHandler(path, defaultFormat string) func(http.ResponseWriter, *http.Re
} }
w.WriteHeader(code) w.WriteHeader(code)
exts, err := lookupAvailableExtensionsForCode(path, code)
if err != nil {
panic("unexpected error reading available formats")
}
aext, aformat, err := findAcceptedFormat(format, exts)
if err != nil {
log.Printf("unexpected error reading media type extension: %v. Using %v", err, ext)
format = defaultFormat
} else {
ext = aext
format = aformat
}
w.Header().Set(ContentType, format)
if !strings.HasPrefix(ext, ".") { if !strings.HasPrefix(ext, ".") {
ext = "." + ext ext = "." + ext
} }