mirror of
				https://github.com/go-gitea/gitea.git
				synced 2025-10-26 12:27:06 +00:00 
			
		
		
		
	upgrade version of lib/pq to v1.1.0 (#6640)
Adds SCRAM-SHA-256 authentication
This commit is contained in:
		
							
								
								
									
										259
									
								
								vendor/github.com/lib/pq/conn.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										259
									
								
								vendor/github.com/lib/pq/conn.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -2,7 +2,9 @@ package pq | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"context" | ||||
| 	"crypto/md5" | ||||
| 	"crypto/sha256" | ||||
| 	"database/sql" | ||||
| 	"database/sql/driver" | ||||
| 	"encoding/binary" | ||||
| @@ -20,6 +22,7 @@ import ( | ||||
| 	"unicode" | ||||
|  | ||||
| 	"github.com/lib/pq/oid" | ||||
| 	"github.com/lib/pq/scram" | ||||
| ) | ||||
|  | ||||
| // Common error types | ||||
| @@ -89,13 +92,24 @@ type Dialer interface { | ||||
| 	DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) | ||||
| } | ||||
|  | ||||
| type defaultDialer struct{} | ||||
|  | ||||
| func (d defaultDialer) Dial(ntw, addr string) (net.Conn, error) { | ||||
| 	return net.Dial(ntw, addr) | ||||
| type DialerContext interface { | ||||
| 	DialContext(ctx context.Context, network, address string) (net.Conn, error) | ||||
| } | ||||
| func (d defaultDialer) DialTimeout(ntw, addr string, timeout time.Duration) (net.Conn, error) { | ||||
| 	return net.DialTimeout(ntw, addr, timeout) | ||||
|  | ||||
| type defaultDialer struct { | ||||
| 	d net.Dialer | ||||
| } | ||||
|  | ||||
| func (d defaultDialer) Dial(network, address string) (net.Conn, error) { | ||||
| 	return d.d.Dial(network, address) | ||||
| } | ||||
| func (d defaultDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { | ||||
| 	ctx, cancel := context.WithTimeout(context.Background(), timeout) | ||||
| 	defer cancel() | ||||
| 	return d.DialContext(ctx, network, address) | ||||
| } | ||||
| func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { | ||||
| 	return d.d.DialContext(ctx, network, address) | ||||
| } | ||||
|  | ||||
| type conn struct { | ||||
| @@ -244,90 +258,35 @@ func (cn *conn) writeBuf(b byte) *writeBuf { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Open opens a new connection to the database. name is a connection string. | ||||
| // Open opens a new connection to the database. dsn is a connection string. | ||||
| // Most users should only use it through database/sql package from the standard | ||||
| // library. | ||||
| func Open(name string) (_ driver.Conn, err error) { | ||||
| 	return DialOpen(defaultDialer{}, name) | ||||
| func Open(dsn string) (_ driver.Conn, err error) { | ||||
| 	return DialOpen(defaultDialer{}, dsn) | ||||
| } | ||||
|  | ||||
| // DialOpen opens a new connection to the database using a dialer. | ||||
| func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { | ||||
| func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) { | ||||
| 	c, err := NewConnector(dsn) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	c.dialer = d | ||||
| 	return c.open(context.Background()) | ||||
| } | ||||
|  | ||||
| func (c *Connector) open(ctx context.Context) (cn *conn, err error) { | ||||
| 	// Handle any panics during connection initialization.  Note that we | ||||
| 	// specifically do *not* want to use errRecover(), as that would turn any | ||||
| 	// connection errors into ErrBadConns, hiding the real error message from | ||||
| 	// the user. | ||||
| 	defer errRecoverNoErrBadConn(&err) | ||||
|  | ||||
| 	o := make(values) | ||||
| 	o := c.opts | ||||
|  | ||||
| 	// A number of defaults are applied here, in this order: | ||||
| 	// | ||||
| 	// * Very low precedence defaults applied in every situation | ||||
| 	// * Environment variables | ||||
| 	// * Explicitly passed connection information | ||||
| 	o["host"] = "localhost" | ||||
| 	o["port"] = "5432" | ||||
| 	// N.B.: Extra float digits should be set to 3, but that breaks | ||||
| 	// Postgres 8.4 and older, where the max is 2. | ||||
| 	o["extra_float_digits"] = "2" | ||||
| 	for k, v := range parseEnviron(os.Environ()) { | ||||
| 		o[k] = v | ||||
| 	} | ||||
|  | ||||
| 	if strings.HasPrefix(name, "postgres://") || strings.HasPrefix(name, "postgresql://") { | ||||
| 		name, err = ParseURL(name) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if err := parseOpts(name, o); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	// Use the "fallback" application name if necessary | ||||
| 	if fallback, ok := o["fallback_application_name"]; ok { | ||||
| 		if _, ok := o["application_name"]; !ok { | ||||
| 			o["application_name"] = fallback | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// We can't work with any client_encoding other than UTF-8 currently. | ||||
| 	// However, we have historically allowed the user to set it to UTF-8 | ||||
| 	// explicitly, and there's no reason to break such programs, so allow that. | ||||
| 	// Note that the "options" setting could also set client_encoding, but | ||||
| 	// parsing its value is not worth it.  Instead, we always explicitly send | ||||
| 	// client_encoding as a separate run-time parameter, which should override | ||||
| 	// anything set in options. | ||||
| 	if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) { | ||||
| 		return nil, errors.New("client_encoding must be absent or 'UTF8'") | ||||
| 	} | ||||
| 	o["client_encoding"] = "UTF8" | ||||
| 	// DateStyle needs a similar treatment. | ||||
| 	if datestyle, ok := o["datestyle"]; ok { | ||||
| 		if datestyle != "ISO, MDY" { | ||||
| 			panic(fmt.Sprintf("setting datestyle must be absent or %v; got %v", | ||||
| 				"ISO, MDY", datestyle)) | ||||
| 		} | ||||
| 	} else { | ||||
| 		o["datestyle"] = "ISO, MDY" | ||||
| 	} | ||||
|  | ||||
| 	// If a user is not provided by any other means, the last | ||||
| 	// resort is to use the current operating system provided user | ||||
| 	// name. | ||||
| 	if _, ok := o["user"]; !ok { | ||||
| 		u, err := userCurrent() | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		o["user"] = u | ||||
| 	} | ||||
|  | ||||
| 	cn := &conn{ | ||||
| 	cn = &conn{ | ||||
| 		opts:   o, | ||||
| 		dialer: d, | ||||
| 		dialer: c.dialer, | ||||
| 	} | ||||
| 	err = cn.handleDriverSettings(o) | ||||
| 	if err != nil { | ||||
| @@ -335,7 +294,7 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { | ||||
| 	} | ||||
| 	cn.handlePgpass(o) | ||||
|  | ||||
| 	cn.c, err = dial(d, o) | ||||
| 	cn.c, err = dial(ctx, c.dialer, o) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @@ -364,10 +323,10 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { | ||||
| 	return cn, err | ||||
| } | ||||
|  | ||||
| func dial(d Dialer, o values) (net.Conn, error) { | ||||
| 	ntw, addr := network(o) | ||||
| func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) { | ||||
| 	network, address := network(o) | ||||
| 	// SSL is not necessary or supported over UNIX domain sockets | ||||
| 	if ntw == "unix" { | ||||
| 	if network == "unix" { | ||||
| 		o["sslmode"] = "disable" | ||||
| 	} | ||||
|  | ||||
| @@ -378,19 +337,30 @@ func dial(d Dialer, o values) (net.Conn, error) { | ||||
| 			return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err) | ||||
| 		} | ||||
| 		duration := time.Duration(seconds) * time.Second | ||||
|  | ||||
| 		// connect_timeout should apply to the entire connection establishment | ||||
| 		// procedure, so we both use a timeout for the TCP connection | ||||
| 		// establishment and set a deadline for doing the initial handshake. | ||||
| 		// The deadline is then reset after startup() is done. | ||||
| 		deadline := time.Now().Add(duration) | ||||
| 		conn, err := d.DialTimeout(ntw, addr, duration) | ||||
| 		var conn net.Conn | ||||
| 		if dctx, ok := d.(DialerContext); ok { | ||||
| 			ctx, cancel := context.WithTimeout(ctx, duration) | ||||
| 			defer cancel() | ||||
| 			conn, err = dctx.DialContext(ctx, network, address) | ||||
| 		} else { | ||||
| 			conn, err = d.DialTimeout(network, address, duration) | ||||
| 		} | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		err = conn.SetDeadline(deadline) | ||||
| 		return conn, err | ||||
| 	} | ||||
| 	return d.Dial(ntw, addr) | ||||
| 	if dctx, ok := d.(DialerContext); ok { | ||||
| 		return dctx.DialContext(ctx, network, address) | ||||
| 	} | ||||
| 	return d.Dial(network, address) | ||||
| } | ||||
|  | ||||
| func network(o values) (string, string) { | ||||
| @@ -704,7 +674,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { | ||||
| 			// res might be non-nil here if we received a previous | ||||
| 			// CommandComplete, but that's fine; just overwrite it | ||||
| 			res = &rows{cn: cn} | ||||
| 			res.colNames, res.colFmts, res.colTyps = parsePortalRowDescribe(r) | ||||
| 			res.rowsHeader = parsePortalRowDescribe(r) | ||||
|  | ||||
| 			// To work around a bug in QueryRow in Go 1.2 and earlier, wait | ||||
| 			// until the first DataRow has been received. | ||||
| @@ -861,17 +831,15 @@ func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { | ||||
| 		cn.readParseResponse() | ||||
| 		cn.readBindResponse() | ||||
| 		rows := &rows{cn: cn} | ||||
| 		rows.colNames, rows.colFmts, rows.colTyps = cn.readPortalDescribeResponse() | ||||
| 		rows.rowsHeader = cn.readPortalDescribeResponse() | ||||
| 		cn.postExecuteWorkaround() | ||||
| 		return rows, nil | ||||
| 	} | ||||
| 	st := cn.prepareTo(query, "") | ||||
| 	st.exec(args) | ||||
| 	return &rows{ | ||||
| 		cn:       cn, | ||||
| 		colNames: st.colNames, | ||||
| 		colTyps:  st.colTyps, | ||||
| 		colFmts:  st.colFmts, | ||||
| 		cn:         cn, | ||||
| 		rowsHeader: st.rowsHeader, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| @@ -992,7 +960,6 @@ func (cn *conn) recv() (t byte, r *readBuf) { | ||||
| 		if err != nil { | ||||
| 			panic(err) | ||||
| 		} | ||||
|  | ||||
| 		switch t { | ||||
| 		case 'E': | ||||
| 			panic(parseError(r)) | ||||
| @@ -1163,6 +1130,55 @@ func (cn *conn) auth(r *readBuf, o values) { | ||||
| 		if r.int32() != 0 { | ||||
| 			errorf("unexpected authentication response: %q", t) | ||||
| 		} | ||||
| 	case 10: | ||||
| 		sc := scram.NewClient(sha256.New, o["user"], o["password"]) | ||||
| 		sc.Step(nil) | ||||
| 		if sc.Err() != nil { | ||||
| 			errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) | ||||
| 		} | ||||
| 		scOut := sc.Out() | ||||
|  | ||||
| 		w := cn.writeBuf('p') | ||||
| 		w.string("SCRAM-SHA-256") | ||||
| 		w.int32(len(scOut)) | ||||
| 		w.bytes(scOut) | ||||
| 		cn.send(w) | ||||
|  | ||||
| 		t, r := cn.recv() | ||||
| 		if t != 'R' { | ||||
| 			errorf("unexpected password response: %q", t) | ||||
| 		} | ||||
|  | ||||
| 		if r.int32() != 11 { | ||||
| 			errorf("unexpected authentication response: %q", t) | ||||
| 		} | ||||
|  | ||||
| 		nextStep := r.next(len(*r)) | ||||
| 		sc.Step(nextStep) | ||||
| 		if sc.Err() != nil { | ||||
| 			errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) | ||||
| 		} | ||||
|  | ||||
| 		scOut = sc.Out() | ||||
| 		w = cn.writeBuf('p') | ||||
| 		w.bytes(scOut) | ||||
| 		cn.send(w) | ||||
|  | ||||
| 		t, r = cn.recv() | ||||
| 		if t != 'R' { | ||||
| 			errorf("unexpected password response: %q", t) | ||||
| 		} | ||||
|  | ||||
| 		if r.int32() != 12 { | ||||
| 			errorf("unexpected authentication response: %q", t) | ||||
| 		} | ||||
|  | ||||
| 		nextStep = r.next(len(*r)) | ||||
| 		sc.Step(nextStep) | ||||
| 		if sc.Err() != nil { | ||||
| 			errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) | ||||
| 		} | ||||
|  | ||||
| 	default: | ||||
| 		errorf("unknown authentication response: %d", code) | ||||
| 	} | ||||
| @@ -1180,12 +1196,10 @@ var colFmtDataAllBinary = []byte{0, 1, 0, 1} | ||||
| var colFmtDataAllText = []byte{0, 0} | ||||
|  | ||||
| type stmt struct { | ||||
| 	cn         *conn | ||||
| 	name       string | ||||
| 	colNames   []string | ||||
| 	colFmts    []format | ||||
| 	cn   *conn | ||||
| 	name string | ||||
| 	rowsHeader | ||||
| 	colFmtData []byte | ||||
| 	colTyps    []fieldDesc | ||||
| 	paramTyps  []oid.Oid | ||||
| 	closed     bool | ||||
| } | ||||
| @@ -1231,10 +1245,8 @@ func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { | ||||
|  | ||||
| 	st.exec(v) | ||||
| 	return &rows{ | ||||
| 		cn:       st.cn, | ||||
| 		colNames: st.colNames, | ||||
| 		colTyps:  st.colTyps, | ||||
| 		colFmts:  st.colFmts, | ||||
| 		cn:         st.cn, | ||||
| 		rowsHeader: st.rowsHeader, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| @@ -1344,16 +1356,22 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { | ||||
| 	return driver.RowsAffected(n), commandTag | ||||
| } | ||||
|  | ||||
| type rows struct { | ||||
| 	cn       *conn | ||||
| 	finish   func() | ||||
| type rowsHeader struct { | ||||
| 	colNames []string | ||||
| 	colTyps  []fieldDesc | ||||
| 	colFmts  []format | ||||
| 	done     bool | ||||
| 	rb       readBuf | ||||
| 	result   driver.Result | ||||
| 	tag      string | ||||
| } | ||||
|  | ||||
| type rows struct { | ||||
| 	cn     *conn | ||||
| 	finish func() | ||||
| 	rowsHeader | ||||
| 	done   bool | ||||
| 	rb     readBuf | ||||
| 	result driver.Result | ||||
| 	tag    string | ||||
|  | ||||
| 	next *rowsHeader | ||||
| } | ||||
|  | ||||
| func (rs *rows) Close() error { | ||||
| @@ -1440,7 +1458,8 @@ func (rs *rows) Next(dest []driver.Value) (err error) { | ||||
| 			} | ||||
| 			return | ||||
| 		case 'T': | ||||
| 			rs.colNames, rs.colFmts, rs.colTyps = parsePortalRowDescribe(&rs.rb) | ||||
| 			next := parsePortalRowDescribe(&rs.rb) | ||||
| 			rs.next = &next | ||||
| 			return io.EOF | ||||
| 		default: | ||||
| 			errorf("unexpected message after execute: %q", t) | ||||
| @@ -1449,10 +1468,16 @@ func (rs *rows) Next(dest []driver.Value) (err error) { | ||||
| } | ||||
|  | ||||
| func (rs *rows) HasNextResultSet() bool { | ||||
| 	return !rs.done | ||||
| 	hasNext := rs.next != nil && !rs.done | ||||
| 	return hasNext | ||||
| } | ||||
|  | ||||
| func (rs *rows) NextResultSet() error { | ||||
| 	if rs.next == nil { | ||||
| 		return io.EOF | ||||
| 	} | ||||
| 	rs.rowsHeader = *rs.next | ||||
| 	rs.next = nil | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| @@ -1630,13 +1655,13 @@ func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames [ | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (cn *conn) readPortalDescribeResponse() (colNames []string, colFmts []format, colTyps []fieldDesc) { | ||||
| func (cn *conn) readPortalDescribeResponse() rowsHeader { | ||||
| 	t, r := cn.recv1() | ||||
| 	switch t { | ||||
| 	case 'T': | ||||
| 		return parsePortalRowDescribe(r) | ||||
| 	case 'n': | ||||
| 		return nil, nil, nil | ||||
| 		return rowsHeader{} | ||||
| 	case 'E': | ||||
| 		err := parseError(r) | ||||
| 		cn.readReadyForQuery() | ||||
| @@ -1742,11 +1767,11 @@ func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDe | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, colTyps []fieldDesc) { | ||||
| func parsePortalRowDescribe(r *readBuf) rowsHeader { | ||||
| 	n := r.int16() | ||||
| 	colNames = make([]string, n) | ||||
| 	colFmts = make([]format, n) | ||||
| 	colTyps = make([]fieldDesc, n) | ||||
| 	colNames := make([]string, n) | ||||
| 	colFmts := make([]format, n) | ||||
| 	colTyps := make([]fieldDesc, n) | ||||
| 	for i := range colNames { | ||||
| 		colNames[i] = r.string() | ||||
| 		r.next(6) | ||||
| @@ -1755,7 +1780,11 @@ func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, co | ||||
| 		colTyps[i].Mod = r.int32() | ||||
| 		colFmts[i] = format(r.int16()) | ||||
| 	} | ||||
| 	return | ||||
| 	return rowsHeader{ | ||||
| 		colNames: colNames, | ||||
| 		colFmts:  colFmts, | ||||
| 		colTyps:  colTyps, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // parseEnviron tries to mimic some of libpq's environment handling | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 techknowlogick
					techknowlogick