package auth_connectors import ( "context" "crypto/tls" "encoding/json" "errors" "fmt" "net" "oc-auth/conf" "strings" "sync" "time" "github.com/coocood/freecache" "github.com/go-ldap/ldap/v3" "github.com/i-core/rlog" "go.uber.org/zap" ) var ( // errInvalidCredentials is an error that happens when a user's password is invalid. errInvalidCredentials = fmt.Errorf("invalid credentials") // errConnectionTimeout is an error that happens when no one LDAP endpoint responds. errConnectionTimeout = fmt.Errorf("connection timeout") // errMissedUsername is an error that happens errMissedUsername = errors.New("username is missed") // errUnknownUsername is an error that happens errUnknownUsername = errors.New("unknown username") ) type conn interface { Bind(bindDN, password string) error SearchUser(user string, attrs ...string) ([]map[string]interface{}, error) SearchUserRoles(user string, attrs ...string) ([]map[string]interface{}, error) Close() error } type connector interface { Connect(ctx context.Context, addr string) (conn, error) } // Config is a LDAP configuration. type Config struct { Endpoints []string `envconfig:"endpoints" required:"true" desc:"a LDAP's server URLs as \"
:\""` BindDN string `envconfig:"binddn" desc:"a LDAP bind DN"` BindPass string `envconfig:"bindpw" json:"-" desc:"a LDAP bind password"` BaseDN string `envconfig:"basedn" required:"true" desc:"a LDAP base DN for searching users"` AttrClaims map[string]string `envconfig:"attr_claims" default:"name:name,sn:family_name,givenName:given_name,mail:email" desc:"a mapping of LDAP attributes to OpenID connect claims"` RoleBaseDN string `envconfig:"role_basedn" required:"true" desc:"a LDAP base DN for searching roles"` RoleAttr string `envconfig:"role_attr" default:"description" desc:"a LDAP group's attribute that contains a role's name"` RoleClaim string `envconfig:"role_claim" default:"https://github.com/i-core/werther/claims/roles" desc:"a name of an OpenID Connect claim that contains user roles"` CacheSize int `envconfig:"cache_size" default:"512" desc:"a user info cache's size in KiB"` CacheTTL time.Duration `envconfig:"cache_ttl" default:"30m" desc:"a user info cache TTL"` IsTLS bool `envconfig:"is_tls" default:"false" desc:"should LDAP connection be established via TLS"` FlatRoleClaims bool `envconfig:"flat_role_claims" desc:"add roles claim as single list"` } // New creates a new LDAP client. func New() *Client { cnf := Config{ Endpoints: strings.Split(conf.GetConfig().LDAPEndpoints, ","), BindDN: conf.GetConfig().LDAPBindDN, BindPass: conf.GetConfig().LDAPBindPW, BaseDN: conf.GetConfig().LDAPBaseDN, RoleBaseDN: conf.GetConfig().LDAPRoleBaseDN, } return &Client{ Config: cnf, connector: &ldapConnector{BaseDN: cnf.BaseDN, RoleBaseDN: cnf.RoleBaseDN, IsTLS: cnf.IsTLS}, cache: freecache.NewCache(cnf.CacheSize * 1024), } } type Client struct { Config connector connector cache *freecache.Cache } func (cli *Client) Authenticate(ctx context.Context, username, password string) (bool, error) { if username == "" || password == "" { return false, nil } var cancel context.CancelFunc ctx, cancel = context.WithCancel(ctx) cn, ok := <-cli.connect(ctx) cancel() if !ok { return false, errConnectionTimeout } defer cn.Close() // Find a user DN by his or her username. details, err := cli.findBasicUserDetails(cn, username, []string{"dn"}) if err != nil { return false, err } if details == nil { return false, nil } if err := cn.Bind(details["dn"].(string), password); err != nil { if err == errInvalidCredentials { return false, nil } return false, err } // Clear the claims' cache because of possible re-authentication. We don't want stale claims after re-login. if ok := cli.cache.Del([]byte(username)); ok { log := rlog.FromContext(ctx) log.Debug("Cleared user's OIDC claims in the cache") } return true, nil } // Claim is the FindOIDCClaims result struct type LDAPClaim struct { Code string // the root claim name Name string // the claim name Value interface{} // the value } // FindOIDCClaims finds all OIDC claims for a user. func (cli *Client) FindOIDCClaims(ctx context.Context, username string) ([]LDAPClaim, error) { if username == "" { return nil, errMissedUsername } log := rlog.FromContext(ctx).Sugar() // Retrieving from LDAP is slow. So, we try to get claims for the given username from the cache. switch cdata, err := cli.cache.Get([]byte(username)); err { case nil: var claims []LDAPClaim if err = json.Unmarshal(cdata, &claims); err != nil { log.Info("Failed to unmarshal user's OIDC claims", zap.Error(err), "data", cdata) return nil, err } log.Debugw("Retrieved user's OIDC claims from the cache", "claims", claims) return claims, nil case freecache.ErrNotFound: log.Debug("User's OIDC claims is not found in the cache") default: log.Infow("Failed to retrieve user's OIDC claims from the cache", zap.Error(err)) } // Try to make multiple TCP connections to the LDAP server for getting claims. // Accept the first one, and cancel others. var cancel context.CancelFunc ctx, cancel = context.WithCancel(ctx) cn, ok := <-cli.connect(ctx) cancel() if !ok { return nil, errConnectionTimeout } defer cn.Close() // We need to find LDAP attribute's names for all required claims. attrs := []string{"dn"} for k := range cli.AttrClaims { attrs = append(attrs, k) } // Find the attributes in the LDAP server. details, err := cli.findBasicUserDetails(cn, username, attrs) if err != nil { return nil, err } if details == nil { return nil, errUnknownUsername } log.Infow("Retrieved user's info from LDAP", "details", details) // Transform the retrieved attributes to corresponding claims. claims := make([]LDAPClaim, 0, len(details)) for attr, v := range details { if claim, ok := cli.AttrClaims[attr]; ok { claims = append(claims, LDAPClaim{claim, claim, v}) } } // User's roles is stored in LDAP as groups. We find all groups in a role's DN // that include the user as a member. entries, err := cn.SearchUserRoles(fmt.Sprintf("%s", details["dn"]), "dn", cli.RoleAttr) if err != nil { return nil, err } roles := make(map[string]interface{}) for _, entry := range entries { roleDN, ok := entry["dn"].(string) if !ok || roleDN == "" { log.Infow("No required LDAP attribute for a role", "ldapAttribute", "dn", "entry", entry) continue } if entry[cli.RoleAttr] == nil { log.Infow("No required LDAP attribute for a role", "ldapAttribute", cli.RoleAttr, "roleDN", roleDN) continue } // Ensure that a role's DN is inside of the role's base DN. // It's sufficient to compare the DN's suffix with the base DN. n, k := len(roleDN), len(cli.RoleBaseDN) if n < k || !strings.EqualFold(roleDN[n-k:], cli.RoleBaseDN) { panic("You should never see that") } // The DN without the role's base DN must contain a CN and OU // where the CN is for uniqueness only, and the OU is an application id. path := strings.Split(roleDN[:n-k-1], ",") if len(path) != 2 { log.Infow("A role's DN without the role's base DN must contain two nodes only", "roleBaseDN", cli.RoleBaseDN, "roleDN", roleDN) continue } appID := path[1][len("OU="):] var appRoles []interface{} if v := roles[appID]; v != nil { appRoles = v.([]interface{}) } appRoles = append(appRoles, entry[cli.RoleAttr]) roles[appID] = appRoles } claims = append(claims, LDAPClaim{cli.RoleClaim, cli.RoleClaim, roles}) if cli.FlatRoleClaims { for appID, appRoles := range roles { claims = append(claims, LDAPClaim{cli.RoleClaim, cli.RoleClaim + "/" + appID, appRoles}) } } // Save the claims in the cache for future queries. cdata, err := json.Marshal(claims) if err != nil { log.Infow("Failed to marshal user's OIDC claims for caching", zap.Error(err), "claims", claims) } if err = cli.cache.Set([]byte(username), cdata, int(cli.CacheTTL.Seconds())); err != nil { log.Infow("Failed to store user's OIDC claims into the cache", zap.Error(err), "claims", claims) } return claims, nil } func (cli *Client) connect(ctx context.Context) <-chan conn { var ( wg sync.WaitGroup ch = make(chan conn) ) wg.Add(len(cli.Endpoints)) for _, addr := range cli.Endpoints { go func(addr string) { defer wg.Done() cn, err := cli.connector.Connect(ctx, addr) if err != nil { fmt.Println("Failed to create a LDAP connection", "address", addr) return } select { case <-ctx.Done(): cn.Close() fmt.Println("a LDAP connection is cancelled", "address", addr) return case ch <- cn: } }(addr) } go func() { wg.Wait() close(ch) }() return ch } // findBasicUserDetails finds user's LDAP attributes that were specified. It returns nil if no such user. func (cli *Client) findBasicUserDetails(cn conn, username string, attrs []string) (map[string]interface{}, error) { if cli.BindDN != "" { // We need to login to a LDAP server with a service account for retrieving user data. if err := cn.Bind(cli.BindDN, cli.BindPass); err != nil { return nil, errors.New(err.Error() + " : failed to login to a LDAP woth a service account") } } entries, err := cn.SearchUser(username, attrs...) if err != nil { return nil, err } if len(entries) != 1 { // We didn't find the user. return nil, nil } var ( entry = entries[0] details = make(map[string]interface{}) ) for _, attr := range attrs { if v, ok := entry[attr]; ok { details[attr] = v } } return details, nil } type ldapConnector struct { BaseDN string RoleBaseDN string IsTLS bool } func (c *ldapConnector) Connect(ctx context.Context, addr string) (conn, error) { d := net.Dialer{Timeout: ldap.DefaultTimeout} tcpcn, err := d.DialContext(ctx, "tcp", addr) if err != nil { return nil, err } if c.IsTLS { tlscn, err := tls.DialWithDialer(&d, "tcp", addr, nil) if err != nil { return nil, err } tcpcn = tlscn } ldapcn := ldap.NewConn(tcpcn, c.IsTLS) ldapcn.Start() return &ldapConn{Conn: ldapcn, BaseDN: c.BaseDN, RoleBaseDN: c.RoleBaseDN}, nil } type ldapConn struct { *ldap.Conn BaseDN string RoleBaseDN string } func (c *ldapConn) Bind(bindDN, password string) error { err := c.Conn.Bind(bindDN, password) if ldapErr, ok := err.(*ldap.Error); ok && ldapErr.ResultCode == ldap.LDAPResultInvalidCredentials { return errInvalidCredentials } return err } func (c *ldapConn) SearchUser(user string, attrs ...string) ([]map[string]interface{}, error) { query := fmt.Sprintf( "(&(|(objectClass=organizationalPerson)(objectClass=inetOrgPerson))"+ "(|(uid=%[1]s)(mail=%[1]s)(userPrincipalName=%[1]s)(sAMAccountName=%[1]s)))", user) return c.searchEntries(c.BaseDN, query, attrs) } func (c *ldapConn) SearchUserRoles(user string, attrs ...string) ([]map[string]interface{}, error) { query := fmt.Sprintf("(|"+ "(&(|(objectClass=group)(objectClass=groupOfNames))(member=%[1]s))"+ "(&(objectClass=groupOfUniqueNames)(uniqueMember=%[1]s))"+ ")", user) return c.searchEntries(c.RoleBaseDN, query, attrs) } // searchEntries executes a LDAP query, and returns a result as entries where each entry is mapping of LDAP attributes. func (c *ldapConn) searchEntries(baseDN, query string, attrs []string) ([]map[string]interface{}, error) { req := ldap.NewSearchRequest(baseDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, query, attrs, nil) res, err := c.Search(req) if err != nil { return nil, err } var entries []map[string]interface{} for _, v := range res.Entries { entry := map[string]interface{}{"dn": v.DN} for _, attr := range v.Attributes { // We need the first value only for the named attribute. entry[attr.Name] = attr.Values[0] } entries = append(entries, entry) } return entries, nil }