soteria/internal/server/auth_support.go

121 lines
3.0 KiB
Go

package server
import (
"context"
"fmt"
"net/http"
"strings"
)
func (s *Server) authorize(r *http.Request) (authIdentity, int, error) {
if !s.cfg.AuthRequired {
return authIdentity{}, http.StatusOK, nil
}
authorization := strings.TrimSpace(r.Header.Get("Authorization"))
if strings.HasPrefix(strings.ToLower(authorization), "bearer ") {
token := strings.TrimSpace(authorization[7:])
for _, expected := range s.cfg.AuthBearerTokens {
if token != "" && token == expected {
return authIdentity{Authenticated: true, User: "service-token", Groups: []string{"service-token"}}, http.StatusOK, nil
}
}
}
identity := authIdentity{
Authenticated: true,
User: firstHeader(r, "X-Auth-Request-User", "X-Forwarded-User"),
Email: firstHeader(r, "X-Auth-Request-Email", "X-Forwarded-Email"),
Groups: normalizeGroups(splitGroups(firstHeader(r, "X-Auth-Request-Groups", "X-Forwarded-Groups"))),
}
if identity.User == "" && identity.Email == "" {
return authIdentity{}, http.StatusUnauthorized, fmt.Errorf("authentication required")
}
if len(s.cfg.AllowedGroups) == 0 {
return identity, http.StatusOK, nil
}
if hasAllowedGroup(identity.Groups, s.cfg.AllowedGroups) {
return identity, http.StatusOK, nil
}
return authIdentity{}, http.StatusForbidden, fmt.Errorf("access requires one of: %s", strings.Join(s.cfg.AllowedGroups, ", "))
}
func requesterFromContext(ctx context.Context) authIdentity {
identity, _ := ctx.Value(authContextKey).(authIdentity)
return identity
}
func currentRequester(ctx context.Context) string {
identity := requesterFromContext(ctx)
if identity.User != "" {
return identity.User
}
if identity.Email != "" {
return identity.Email
}
if identity.Authenticated {
return "authenticated"
}
return "anonymous"
}
func authzReason(status int, err error) string {
if err == nil {
return "unknown"
}
switch status {
case http.StatusUnauthorized:
return "unauthenticated"
case http.StatusForbidden:
return "forbidden_group"
default:
return "error"
}
}
func hasAllowedGroup(actual, allowed []string) bool {
allowedSet := make(map[string]struct{}, len(allowed))
for _, group := range normalizeGroups(allowed) {
allowedSet[group] = struct{}{}
}
for _, group := range normalizeGroups(actual) {
if _, ok := allowedSet[group]; ok {
return true
}
}
return false
}
func normalizeGroups(values []string) []string {
groups := make([]string, 0, len(values))
for _, value := range values {
value = strings.TrimSpace(value)
value = strings.TrimPrefix(value, "/")
if value == "" {
continue
}
groups = append(groups, value)
}
return groups
}
func firstHeader(r *http.Request, names ...string) string {
for _, name := range names {
value := strings.TrimSpace(r.Header.Get(name))
if value != "" {
return value
}
}
return ""
}
func splitGroups(raw string) []string {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil
}
return strings.FieldsFunc(raw, func(r rune) bool {
return r == ',' || r == ';'
})
}