mirror of
				https://github.com/go-gitea/gitea.git
				synced 2025-10-26 04:17:08 +00:00 
			
		
		
		
	Refactor auth package (#17962)
This commit is contained in:
		
							
								
								
									
										397
									
								
								models/auth/source.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										397
									
								
								models/auth/source.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,397 @@ | ||||
| // Copyright 2014 The Gogs Authors. All rights reserved. | ||||
| // Copyright 2019 The Gitea Authors. All rights reserved. | ||||
| // Use of this source code is governed by a MIT-style | ||||
| // license that can be found in the LICENSE file. | ||||
|  | ||||
| package auth | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
|  | ||||
| 	"code.gitea.io/gitea/models/db" | ||||
| 	"code.gitea.io/gitea/modules/log" | ||||
| 	"code.gitea.io/gitea/modules/timeutil" | ||||
|  | ||||
| 	"xorm.io/xorm" | ||||
| 	"xorm.io/xorm/convert" | ||||
| ) | ||||
|  | ||||
| // Type represents an login type. | ||||
| type Type int | ||||
|  | ||||
| // Note: new type must append to the end of list to maintain compatibility. | ||||
| const ( | ||||
| 	NoType Type = iota | ||||
| 	Plain       // 1 | ||||
| 	LDAP        // 2 | ||||
| 	SMTP        // 3 | ||||
| 	PAM         // 4 | ||||
| 	DLDAP       // 5 | ||||
| 	OAuth2      // 6 | ||||
| 	SSPI        // 7 | ||||
| ) | ||||
|  | ||||
| // String returns the string name of the LoginType | ||||
| func (typ Type) String() string { | ||||
| 	return Names[typ] | ||||
| } | ||||
|  | ||||
| // Int returns the int value of the LoginType | ||||
| func (typ Type) Int() int { | ||||
| 	return int(typ) | ||||
| } | ||||
|  | ||||
| // Names contains the name of LoginType values. | ||||
| var Names = map[Type]string{ | ||||
| 	LDAP:   "LDAP (via BindDN)", | ||||
| 	DLDAP:  "LDAP (simple auth)", // Via direct bind | ||||
| 	SMTP:   "SMTP", | ||||
| 	PAM:    "PAM", | ||||
| 	OAuth2: "OAuth2", | ||||
| 	SSPI:   "SPNEGO with SSPI", | ||||
| } | ||||
|  | ||||
| // Config represents login config as far as the db is concerned | ||||
| type Config interface { | ||||
| 	convert.Conversion | ||||
| } | ||||
|  | ||||
| // SkipVerifiable configurations provide a IsSkipVerify to check if SkipVerify is set | ||||
| type SkipVerifiable interface { | ||||
| 	IsSkipVerify() bool | ||||
| } | ||||
|  | ||||
| // HasTLSer configurations provide a HasTLS to check if TLS can be enabled | ||||
| type HasTLSer interface { | ||||
| 	HasTLS() bool | ||||
| } | ||||
|  | ||||
| // UseTLSer configurations provide a HasTLS to check if TLS is enabled | ||||
| type UseTLSer interface { | ||||
| 	UseTLS() bool | ||||
| } | ||||
|  | ||||
| // SSHKeyProvider configurations provide ProvidesSSHKeys to check if they provide SSHKeys | ||||
| type SSHKeyProvider interface { | ||||
| 	ProvidesSSHKeys() bool | ||||
| } | ||||
|  | ||||
| // RegisterableSource configurations provide RegisterSource which needs to be run on creation | ||||
| type RegisterableSource interface { | ||||
| 	RegisterSource() error | ||||
| 	UnregisterSource() error | ||||
| } | ||||
|  | ||||
| var registeredConfigs = map[Type]func() Config{} | ||||
|  | ||||
| // RegisterTypeConfig register a config for a provided type | ||||
| func RegisterTypeConfig(typ Type, exemplar Config) { | ||||
| 	if reflect.TypeOf(exemplar).Kind() == reflect.Ptr { | ||||
| 		// Pointer: | ||||
| 		registeredConfigs[typ] = func() Config { | ||||
| 			return reflect.New(reflect.ValueOf(exemplar).Elem().Type()).Interface().(Config) | ||||
| 		} | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// Not a Pointer | ||||
| 	registeredConfigs[typ] = func() Config { | ||||
| 		return reflect.New(reflect.TypeOf(exemplar)).Elem().Interface().(Config) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // SourceSettable configurations can have their authSource set on them | ||||
| type SourceSettable interface { | ||||
| 	SetAuthSource(*Source) | ||||
| } | ||||
|  | ||||
| // Source represents an external way for authorizing users. | ||||
| type Source struct { | ||||
| 	ID            int64 `xorm:"pk autoincr"` | ||||
| 	Type          Type | ||||
| 	Name          string             `xorm:"UNIQUE"` | ||||
| 	IsActive      bool               `xorm:"INDEX NOT NULL DEFAULT false"` | ||||
| 	IsSyncEnabled bool               `xorm:"INDEX NOT NULL DEFAULT false"` | ||||
| 	Cfg           convert.Conversion `xorm:"TEXT"` | ||||
|  | ||||
| 	CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"` | ||||
| 	UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"` | ||||
| } | ||||
|  | ||||
| // TableName xorm will read the table name from this method | ||||
| func (Source) TableName() string { | ||||
| 	return "login_source" | ||||
| } | ||||
|  | ||||
| func init() { | ||||
| 	db.RegisterModel(new(Source)) | ||||
| } | ||||
|  | ||||
| // BeforeSet is invoked from XORM before setting the value of a field of this object. | ||||
| func (source *Source) BeforeSet(colName string, val xorm.Cell) { | ||||
| 	if colName == "type" { | ||||
| 		typ := Type(db.Cell2Int64(val)) | ||||
| 		constructor, ok := registeredConfigs[typ] | ||||
| 		if !ok { | ||||
| 			return | ||||
| 		} | ||||
| 		source.Cfg = constructor() | ||||
| 		if settable, ok := source.Cfg.(SourceSettable); ok { | ||||
| 			settable.SetAuthSource(source) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // TypeName return name of this login source type. | ||||
| func (source *Source) TypeName() string { | ||||
| 	return Names[source.Type] | ||||
| } | ||||
|  | ||||
| // IsLDAP returns true of this source is of the LDAP type. | ||||
| func (source *Source) IsLDAP() bool { | ||||
| 	return source.Type == LDAP | ||||
| } | ||||
|  | ||||
| // IsDLDAP returns true of this source is of the DLDAP type. | ||||
| func (source *Source) IsDLDAP() bool { | ||||
| 	return source.Type == DLDAP | ||||
| } | ||||
|  | ||||
| // IsSMTP returns true of this source is of the SMTP type. | ||||
| func (source *Source) IsSMTP() bool { | ||||
| 	return source.Type == SMTP | ||||
| } | ||||
|  | ||||
| // IsPAM returns true of this source is of the PAM type. | ||||
| func (source *Source) IsPAM() bool { | ||||
| 	return source.Type == PAM | ||||
| } | ||||
|  | ||||
| // IsOAuth2 returns true of this source is of the OAuth2 type. | ||||
| func (source *Source) IsOAuth2() bool { | ||||
| 	return source.Type == OAuth2 | ||||
| } | ||||
|  | ||||
| // IsSSPI returns true of this source is of the SSPI type. | ||||
| func (source *Source) IsSSPI() bool { | ||||
| 	return source.Type == SSPI | ||||
| } | ||||
|  | ||||
| // HasTLS returns true of this source supports TLS. | ||||
| func (source *Source) HasTLS() bool { | ||||
| 	hasTLSer, ok := source.Cfg.(HasTLSer) | ||||
| 	return ok && hasTLSer.HasTLS() | ||||
| } | ||||
|  | ||||
| // UseTLS returns true of this source is configured to use TLS. | ||||
| func (source *Source) UseTLS() bool { | ||||
| 	useTLSer, ok := source.Cfg.(UseTLSer) | ||||
| 	return ok && useTLSer.UseTLS() | ||||
| } | ||||
|  | ||||
| // SkipVerify returns true if this source is configured to skip SSL | ||||
| // verification. | ||||
| func (source *Source) SkipVerify() bool { | ||||
| 	skipVerifiable, ok := source.Cfg.(SkipVerifiable) | ||||
| 	return ok && skipVerifiable.IsSkipVerify() | ||||
| } | ||||
|  | ||||
| // CreateSource inserts a AuthSource in the DB if not already | ||||
| // existing with the given name. | ||||
| func CreateSource(source *Source) error { | ||||
| 	has, err := db.GetEngine(db.DefaultContext).Where("name=?", source.Name).Exist(new(Source)) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} else if has { | ||||
| 		return ErrSourceAlreadyExist{source.Name} | ||||
| 	} | ||||
| 	// Synchronization is only available with LDAP for now | ||||
| 	if !source.IsLDAP() { | ||||
| 		source.IsSyncEnabled = false | ||||
| 	} | ||||
|  | ||||
| 	_, err = db.GetEngine(db.DefaultContext).Insert(source) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if !source.IsActive { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	if settable, ok := source.Cfg.(SourceSettable); ok { | ||||
| 		settable.SetAuthSource(source) | ||||
| 	} | ||||
|  | ||||
| 	registerableSource, ok := source.Cfg.(RegisterableSource) | ||||
| 	if !ok { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	err = registerableSource.RegisterSource() | ||||
| 	if err != nil { | ||||
| 		// remove the AuthSource in case of errors while registering configuration | ||||
| 		if _, err := db.GetEngine(db.DefaultContext).Delete(source); err != nil { | ||||
| 			log.Error("CreateSource: Error while wrapOpenIDConnectInitializeError: %v", err) | ||||
| 		} | ||||
| 	} | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| // Sources returns a slice of all login sources found in DB. | ||||
| func Sources() ([]*Source, error) { | ||||
| 	auths := make([]*Source, 0, 6) | ||||
| 	return auths, db.GetEngine(db.DefaultContext).Find(&auths) | ||||
| } | ||||
|  | ||||
| // SourcesByType returns all sources of the specified type | ||||
| func SourcesByType(loginType Type) ([]*Source, error) { | ||||
| 	sources := make([]*Source, 0, 1) | ||||
| 	if err := db.GetEngine(db.DefaultContext).Where("type = ?", loginType).Find(&sources); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return sources, nil | ||||
| } | ||||
|  | ||||
| // AllActiveSources returns all active sources | ||||
| func AllActiveSources() ([]*Source, error) { | ||||
| 	sources := make([]*Source, 0, 5) | ||||
| 	if err := db.GetEngine(db.DefaultContext).Where("is_active = ?", true).Find(&sources); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return sources, nil | ||||
| } | ||||
|  | ||||
| // ActiveSources returns all active sources of the specified type | ||||
| func ActiveSources(tp Type) ([]*Source, error) { | ||||
| 	sources := make([]*Source, 0, 1) | ||||
| 	if err := db.GetEngine(db.DefaultContext).Where("is_active = ? and type = ?", true, tp).Find(&sources); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return sources, nil | ||||
| } | ||||
|  | ||||
| // IsSSPIEnabled returns true if there is at least one activated login | ||||
| // source of type LoginSSPI | ||||
| func IsSSPIEnabled() bool { | ||||
| 	if !db.HasEngine { | ||||
| 		return false | ||||
| 	} | ||||
| 	sources, err := ActiveSources(SSPI) | ||||
| 	if err != nil { | ||||
| 		log.Error("ActiveSources: %v", err) | ||||
| 		return false | ||||
| 	} | ||||
| 	return len(sources) > 0 | ||||
| } | ||||
|  | ||||
| // GetSourceByID returns login source by given ID. | ||||
| func GetSourceByID(id int64) (*Source, error) { | ||||
| 	source := new(Source) | ||||
| 	if id == 0 { | ||||
| 		source.Cfg = registeredConfigs[NoType]() | ||||
| 		// Set this source to active | ||||
| 		// FIXME: allow disabling of db based password authentication in future | ||||
| 		source.IsActive = true | ||||
| 		return source, nil | ||||
| 	} | ||||
|  | ||||
| 	has, err := db.GetEngine(db.DefaultContext).ID(id).Get(source) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} else if !has { | ||||
| 		return nil, ErrSourceNotExist{id} | ||||
| 	} | ||||
| 	return source, nil | ||||
| } | ||||
|  | ||||
| // UpdateSource updates a Source record in DB. | ||||
| func UpdateSource(source *Source) error { | ||||
| 	var originalSource *Source | ||||
| 	if source.IsOAuth2() { | ||||
| 		// keep track of the original values so we can restore in case of errors while registering OAuth2 providers | ||||
| 		var err error | ||||
| 		if originalSource, err = GetSourceByID(source.ID); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	_, err := db.GetEngine(db.DefaultContext).ID(source.ID).AllCols().Update(source) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if !source.IsActive { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	if settable, ok := source.Cfg.(SourceSettable); ok { | ||||
| 		settable.SetAuthSource(source) | ||||
| 	} | ||||
|  | ||||
| 	registerableSource, ok := source.Cfg.(RegisterableSource) | ||||
| 	if !ok { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	err = registerableSource.RegisterSource() | ||||
| 	if err != nil { | ||||
| 		// restore original values since we cannot update the provider it self | ||||
| 		if _, err := db.GetEngine(db.DefaultContext).ID(source.ID).AllCols().Update(originalSource); err != nil { | ||||
| 			log.Error("UpdateSource: Error while wrapOpenIDConnectInitializeError: %v", err) | ||||
| 		} | ||||
| 	} | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| // CountSources returns number of login sources. | ||||
| func CountSources() int64 { | ||||
| 	count, _ := db.GetEngine(db.DefaultContext).Count(new(Source)) | ||||
| 	return count | ||||
| } | ||||
|  | ||||
| // ErrSourceNotExist represents a "SourceNotExist" kind of error. | ||||
| type ErrSourceNotExist struct { | ||||
| 	ID int64 | ||||
| } | ||||
|  | ||||
| // IsErrSourceNotExist checks if an error is a ErrSourceNotExist. | ||||
| func IsErrSourceNotExist(err error) bool { | ||||
| 	_, ok := err.(ErrSourceNotExist) | ||||
| 	return ok | ||||
| } | ||||
|  | ||||
| func (err ErrSourceNotExist) Error() string { | ||||
| 	return fmt.Sprintf("login source does not exist [id: %d]", err.ID) | ||||
| } | ||||
|  | ||||
| // ErrSourceAlreadyExist represents a "SourceAlreadyExist" kind of error. | ||||
| type ErrSourceAlreadyExist struct { | ||||
| 	Name string | ||||
| } | ||||
|  | ||||
| // IsErrSourceAlreadyExist checks if an error is a ErrSourceAlreadyExist. | ||||
| func IsErrSourceAlreadyExist(err error) bool { | ||||
| 	_, ok := err.(ErrSourceAlreadyExist) | ||||
| 	return ok | ||||
| } | ||||
|  | ||||
| func (err ErrSourceAlreadyExist) Error() string { | ||||
| 	return fmt.Sprintf("login source already exists [name: %s]", err.Name) | ||||
| } | ||||
|  | ||||
| // ErrSourceInUse represents a "SourceInUse" kind of error. | ||||
| type ErrSourceInUse struct { | ||||
| 	ID int64 | ||||
| } | ||||
|  | ||||
| // IsErrSourceInUse checks if an error is a ErrSourceInUse. | ||||
| func IsErrSourceInUse(err error) bool { | ||||
| 	_, ok := err.(ErrSourceInUse) | ||||
| 	return ok | ||||
| } | ||||
|  | ||||
| func (err ErrSourceInUse) Error() string { | ||||
| 	return fmt.Sprintf("login source is still used by some users [id: %d]", err.ID) | ||||
| } | ||||
		Reference in New Issue
	
	Block a user
	 Lunny Xiao
					Lunny Xiao