Update godeps

This commit is contained in:
Manuel de Brito Fontes 2016-09-21 20:00:42 -03:00
parent a965f44f84
commit 73e22a50d2
453 changed files with 84778 additions and 70308 deletions

1153
Godeps/Godeps.json generated

File diff suppressed because it is too large Load diff

View file

@ -27,6 +27,7 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"runtime"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -34,11 +35,20 @@ import (
"golang.org/x/net/context" "golang.org/x/net/context"
"golang.org/x/net/context/ctxhttp" "golang.org/x/net/context/ctxhttp"
"google.golang.org/cloud/internal" "cloud.google.com/go/internal"
) )
// metadataIP is the documented metadata server IP address. const (
const metadataIP = "169.254.169.254" // metadataIP is the documented metadata server IP address.
metadataIP = "169.254.169.254"
// metadataHostEnv is the environment variable specifying the
// GCE metadata hostname. If empty, the default value of
// metadataIP ("169.254.169.254") is used instead.
// This is variable name is not defined by any spec, as far as
// I know; it was made up for the Go package.
metadataHostEnv = "GCE_METADATA_HOST"
)
type cachedValue struct { type cachedValue struct {
k string k string
@ -110,7 +120,7 @@ func getETag(client *http.Client, suffix string) (value, etag string, err error)
// deployments. To enable spoofing of the metadata service, the environment // deployments. To enable spoofing of the metadata service, the environment
// variable GCE_METADATA_HOST is first inspected to decide where metadata // variable GCE_METADATA_HOST is first inspected to decide where metadata
// requests shall go. // requests shall go.
host := os.Getenv("GCE_METADATA_HOST") host := os.Getenv(metadataHostEnv)
if host == "" { if host == "" {
// Using 169.254.169.254 instead of "metadata" here because Go // Using 169.254.169.254 instead of "metadata" here because Go
// binaries built with the "netgo" tag and without cgo won't // binaries built with the "netgo" tag and without cgo won't
@ -163,32 +173,34 @@ func (c *cachedValue) get() (v string, err error) {
return return
} }
var onGCE struct { var (
sync.Mutex onGCEOnce sync.Once
set bool onGCE bool
v bool )
}
// OnGCE reports whether this process is running on Google Compute Engine. // OnGCE reports whether this process is running on Google Compute Engine.
func OnGCE() bool { func OnGCE() bool {
defer onGCE.Unlock() onGCEOnce.Do(initOnGCE)
onGCE.Lock() return onGCE
if onGCE.set { }
return onGCE.v
} func initOnGCE() {
onGCE.set = true onGCE = testOnGCE()
onGCE.v = testOnGCE()
return onGCE.v
} }
func testOnGCE() bool { func testOnGCE() bool {
// The user explicitly said they're on GCE, so trust them.
if os.Getenv(metadataHostEnv) != "" {
return true
}
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
resc := make(chan bool, 2) resc := make(chan bool, 2)
// Try two strategies in parallel. // Try two strategies in parallel.
// See https://github.com/GoogleCloudPlatform/gcloud-golang/issues/194 // See https://github.com/GoogleCloudPlatform/google-cloud-go/issues/194
go func() { go func() {
res, err := ctxhttp.Get(ctx, metaClient, "http://"+metadataIP) res, err := ctxhttp.Get(ctx, metaClient, "http://"+metadataIP)
if err != nil { if err != nil {
@ -208,9 +220,53 @@ func testOnGCE() bool {
resc <- strsContains(addrs, metadataIP) resc <- strsContains(addrs, metadataIP)
}() }()
tryHarder := systemInfoSuggestsGCE()
if tryHarder {
res := <-resc
if res {
// The first strategy succeeded, so let's use it.
return true
}
// Wait for either the DNS or metadata server probe to
// contradict the other one and say we are running on
// GCE. Give it a lot of time to do so, since the system
// info already suggests we're running on a GCE BIOS.
timer := time.NewTimer(5 * time.Second)
defer timer.Stop()
select {
case res = <-resc:
return res
case <-timer.C:
// Too slow. Who knows what this system is.
return false
}
}
// There's no hint from the system info that we're running on
// GCE, so use the first probe's result as truth, whether it's
// true or false. The goal here is to optimize for speed for
// users who are NOT running on GCE. We can't assume that
// either a DNS lookup or an HTTP request to a blackholed IP
// address is fast. Worst case this should return when the
// metaClient's Transport.ResponseHeaderTimeout or
// Transport.Dial.Timeout fires (in two seconds).
return <-resc return <-resc
} }
// systemInfoSuggestsGCE reports whether the local system (without
// doing network requests) suggests that we're running on GCE. If this
// returns true, testOnGCE tries a bit harder to reach its metadata
// server.
func systemInfoSuggestsGCE() bool {
if runtime.GOOS != "linux" {
// We don't have any non-Linux clues available, at least yet.
return false
}
slurp, _ := ioutil.ReadFile("/sys/class/dmi/id/product_name")
name := strings.TrimSpace(string(slurp))
return name == "Google" || name == "Google Compute Engine"
}
// Subscribe subscribes to a value from the metadata service. // Subscribe subscribes to a value from the metadata service.
// The suffix is appended to "http://${GCE_METADATA_HOST}/computeMetadata/v1/". // The suffix is appended to "http://${GCE_METADATA_HOST}/computeMetadata/v1/".
// The suffix may contain query parameters. // The suffix may contain query parameters.

64
vendor/cloud.google.com/go/internal/cloud.go generated vendored Normal file
View file

@ -0,0 +1,64 @@
// Copyright 2014 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package internal provides support for the cloud packages.
//
// Users should not import this package directly.
package internal
import (
"fmt"
"net/http"
)
const userAgent = "gcloud-golang/0.1"
// Transport is an http.RoundTripper that appends Google Cloud client's
// user-agent to the original request's user-agent header.
type Transport struct {
// TODO(bradfitz): delete internal.Transport. It's too wrappy for what it does.
// Do User-Agent some other way.
// Base is the actual http.RoundTripper
// requests will use. It must not be nil.
Base http.RoundTripper
}
// RoundTrip appends a user-agent to the existing user-agent
// header and delegates the request to the base http.RoundTripper.
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
req = cloneRequest(req)
ua := req.Header.Get("User-Agent")
if ua == "" {
ua = userAgent
} else {
ua = fmt.Sprintf("%s %s", ua, userAgent)
}
req.Header.Set("User-Agent", ua)
return t.Base.RoundTrip(req)
}
// cloneRequest returns a clone of the provided *http.Request.
// The clone is a shallow copy of the struct and its Header map.
func cloneRequest(r *http.Request) *http.Request {
// shallow copy of the struct
r2 := new(http.Request)
*r2 = *r
// deep copy of the Header
r2.Header = make(http.Header)
for k, s := range r.Header {
r2.Header[k] = s
}
return r2
}

View file

@ -45,6 +45,8 @@ type simpleBalancer struct {
// pinAddr is the currently pinned address; set to the empty string on // pinAddr is the currently pinned address; set to the empty string on
// intialization and shutdown. // intialization and shutdown.
pinAddr string pinAddr string
closed bool
} }
func newSimpleBalancer(eps []string) *simpleBalancer { func newSimpleBalancer(eps []string) *simpleBalancer {
@ -74,15 +76,25 @@ func (b *simpleBalancer) ConnectNotify() <-chan struct{} {
func (b *simpleBalancer) Up(addr grpc.Address) func(error) { func (b *simpleBalancer) Up(addr grpc.Address) func(error) {
b.mu.Lock() b.mu.Lock()
defer b.mu.Unlock()
// gRPC might call Up after it called Close. We add this check
// to "fix" it up at application layer. Or our simplerBalancer
// might panic since b.upc is closed.
if b.closed {
return func(err error) {}
}
if len(b.upEps) == 0 { if len(b.upEps) == 0 {
// notify waiting Get()s and pin first connected address // notify waiting Get()s and pin first connected address
close(b.upc) close(b.upc)
b.pinAddr = addr.Addr b.pinAddr = addr.Addr
} }
b.upEps[addr.Addr] = struct{}{} b.upEps[addr.Addr] = struct{}{}
b.mu.Unlock()
// notify client that a connection is up // notify client that a connection is up
b.readyOnce.Do(func() { close(b.readyc) }) b.readyOnce.Do(func() { close(b.readyc) })
return func(err error) { return func(err error) {
b.mu.Lock() b.mu.Lock()
delete(b.upEps, addr.Addr) delete(b.upEps, addr.Addr)
@ -128,13 +140,19 @@ func (b *simpleBalancer) Notify() <-chan []grpc.Address { return b.notifyCh }
func (b *simpleBalancer) Close() error { func (b *simpleBalancer) Close() error {
b.mu.Lock() b.mu.Lock()
defer b.mu.Unlock()
// In case gRPC calls close twice. TODO: remove the checking
// when we are sure that gRPC wont call close twice.
if b.closed {
return nil
}
b.closed = true
close(b.notifyCh) close(b.notifyCh)
// terminate all waiting Get()s // terminate all waiting Get()s
b.pinAddr = "" b.pinAddr = ""
if len(b.upEps) == 0 { if len(b.upEps) == 0 {
close(b.upc) close(b.upc)
} }
b.mu.Unlock()
return nil return nil
} }

View file

@ -669,6 +669,10 @@ func (w *watchGrpcStream) resumeWatchers(wc pb.Watch_WatchClient) error {
w.mu.RUnlock() w.mu.RUnlock()
for _, ws := range streams { for _, ws := range streams {
// drain recvc so no old WatchResponses (e.g., Created messages)
// are processed while resuming
ws.drain()
// pause serveStream // pause serveStream
ws.resumec <- -1 ws.resumec <- -1
@ -701,6 +705,17 @@ func (w *watchGrpcStream) resumeWatchers(wc pb.Watch_WatchClient) error {
return nil return nil
} }
// drain removes all buffered WatchResponses from the stream's receive channel.
func (ws *watcherStream) drain() {
for {
select {
case <-ws.recvc:
default:
return
}
}
}
// toPB converts an internal watch request structure to its protobuf messagefunc (wr *watchRequest) // toPB converts an internal watch request structure to its protobuf messagefunc (wr *watchRequest)
func (wr *watchRequest) toPB() *pb.WatchRequest { func (wr *watchRequest) toPB() *pb.WatchRequest {
req := &pb.WatchCreateRequest{ req := &pb.WatchCreateRequest{

22
vendor/github.com/coreos/etcd/pkg/fileutil/dir_unix.go generated vendored Normal file
View file

@ -0,0 +1,22 @@
// Copyright 2016 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build !windows
package fileutil
import "os"
// OpenDir opens a directory for syncing.
func OpenDir(path string) (*os.File, error) { return os.Open(path) }

View file

@ -0,0 +1,46 @@
// Copyright 2016 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build windows
package fileutil
import (
"os"
"syscall"
)
// OpenDir opens a directory in windows with write access for syncing.
func OpenDir(path string) (*os.File, error) {
fd, err := openDir(path)
if err != nil {
return nil, err
}
return os.NewFile(uintptr(fd), path), nil
}
func openDir(path string) (fd syscall.Handle, err error) {
if len(path) == 0 {
return syscall.InvalidHandle, syscall.ERROR_FILE_NOT_FOUND
}
pathp, err := syscall.UTF16PtrFromString(path)
if err != nil {
return syscall.InvalidHandle, err
}
access := uint32(syscall.GENERIC_READ | syscall.GENERIC_WRITE)
sharemode := uint32(syscall.FILE_SHARE_READ | syscall.FILE_SHARE_WRITE)
createmode := uint32(syscall.OPEN_EXISTING)
fl := uint32(syscall.FILE_FLAG_BACKUP_SEMANTICS)
return syscall.CreateFile(pathp, access, sharemode, nil, createmode, fl, 0)
}

View file

@ -96,3 +96,26 @@ func Exist(name string) bool {
_, err := os.Stat(name) _, err := os.Stat(name)
return err == nil return err == nil
} }
// ZeroToEnd zeros a file starting from SEEK_CUR to its SEEK_END. May temporarily
// shorten the length of the file.
func ZeroToEnd(f *os.File) error {
// TODO: support FALLOC_FL_ZERO_RANGE
off, err := f.Seek(0, os.SEEK_CUR)
if err != nil {
return err
}
lenf, lerr := f.Seek(0, os.SEEK_END)
if lerr != nil {
return lerr
}
if err = f.Truncate(off); err != nil {
return err
}
// make sure blocks remain allocated
if err = Preallocate(f, lenf, true); err != nil {
return err
}
_, err = f.Seek(off, os.SEEK_SET)
return err
}

View file

@ -4,18 +4,13 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
"log"
"net/http" "net/http"
"net/url" "net/url"
"path" "path"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/coreos/pkg/capnslog"
)
var (
log = capnslog.NewPackageLogger("github.com/coreos/go-oidc", "http")
) )
func WriteError(w http.ResponseWriter, code int, msg string) { func WriteError(w http.ResponseWriter, code int, msg string) {
@ -26,7 +21,9 @@ func WriteError(w http.ResponseWriter, code int, msg string) {
} }
b, err := json.Marshal(e) b, err := json.Marshal(e)
if err != nil { if err != nil {
log.Errorf("Failed marshaling %#v to JSON: %v", e, err) log.Printf("go-oidc: failed to marshal %#v: %v", e, err)
code = http.StatusInternalServerError
b = []byte(`{"error":"server_error"}`)
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code) w.WriteHeader(code)

View file

@ -1,14 +0,0 @@
package http
import (
"net/http"
)
type LoggingMiddleware struct {
Next http.Handler
}
func (l *LoggingMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log.Infof("HTTP %s %v", r.Method, r.URL)
l.Next.ServeHTTP(w, r)
}

View file

@ -3,9 +3,9 @@ package key
import ( import (
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"encoding/base64" "encoding/hex"
"encoding/json" "encoding/json"
"math/big" "io"
"time" "time"
"github.com/coreos/go-oidc/jose" "github.com/coreos/go-oidc/jose"
@ -139,15 +139,15 @@ func GeneratePrivateKey() (*PrivateKey, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
keyID := make([]byte, 20)
if _, err := io.ReadFull(rand.Reader, keyID); err != nil {
return nil, err
}
k := PrivateKey{ k := PrivateKey{
KeyID: base64BigInt(pk.PublicKey.N), KeyID: hex.EncodeToString(keyID),
PrivateKey: pk, PrivateKey: pk,
} }
return &k, nil return &k, nil
} }
func base64BigInt(b *big.Int) string {
return base64.URLEncoding.EncodeToString(b.Bytes())
}

View file

@ -2,16 +2,14 @@ package key
import ( import (
"errors" "errors"
"log"
"time" "time"
"github.com/coreos/pkg/capnslog"
ptime "github.com/coreos/pkg/timeutil" ptime "github.com/coreos/pkg/timeutil"
"github.com/jonboulle/clockwork" "github.com/jonboulle/clockwork"
) )
var ( var (
log = capnslog.NewPackageLogger("github.com/coreos/go-oidc", "key")
ErrorPrivateKeysExpired = errors.New("private keys have expired") ErrorPrivateKeysExpired = errors.New("private keys have expired")
) )
@ -67,7 +65,6 @@ func (r *PrivateKeyRotator) privateKeySet() (*PrivateKeySet, error) {
func (r *PrivateKeyRotator) nextRotation() (time.Duration, error) { func (r *PrivateKeyRotator) nextRotation() (time.Duration, error) {
pks, err := r.privateKeySet() pks, err := r.privateKeySet()
if err == ErrorNoKeys { if err == ErrorNoKeys {
log.Infof("No keys in private key set; must rotate immediately")
return 0, nil return 0, nil
} }
if err != nil { if err != nil {
@ -94,17 +91,15 @@ func (r *PrivateKeyRotator) Run() chan struct{} {
attempt := func() { attempt := func() {
k, err := r.generateKey() k, err := r.generateKey()
if err != nil { if err != nil {
log.Errorf("Failed generating signing key: %v", err) log.Printf("go-oidc: failed generating signing key: %v", err)
return return
} }
exp := r.expiresAt() exp := r.expiresAt()
if err := rotatePrivateKeys(r.repo, k, r.keep, exp); err != nil { if err := rotatePrivateKeys(r.repo, k, r.keep, exp); err != nil {
log.Errorf("Failed key rotation: %v", err) log.Printf("go-oidc: key rotation failed: %v", err)
return return
} }
log.Infof("Rotated signing keys: id=%s expiresAt=%s", k.ID(), exp)
} }
stop := make(chan struct{}) stop := make(chan struct{})
@ -118,11 +113,10 @@ func (r *PrivateKeyRotator) Run() chan struct{} {
break break
} }
sleep = ptime.ExpBackoff(sleep, time.Minute) sleep = ptime.ExpBackoff(sleep, time.Minute)
log.Errorf("error getting nextRotation, retrying in %v: %v", sleep, err) log.Printf("go-oidc: error getting nextRotation, retrying in %v: %v", sleep, err)
time.Sleep(sleep) time.Sleep(sleep)
} }
log.Infof("will rotate keys in %v", nextRotation)
select { select {
case <-r.clock.After(nextRotation): case <-r.clock.After(nextRotation):
attempt() attempt()

View file

@ -2,6 +2,7 @@ package key
import ( import (
"errors" "errors"
"log"
"time" "time"
"github.com/jonboulle/clockwork" "github.com/jonboulle/clockwork"
@ -38,15 +39,14 @@ func (s *KeySetSyncer) Run() chan struct{} {
next = timeutil.ExpBackoff(next, time.Minute) next = timeutil.ExpBackoff(next, time.Minute)
} }
if exp == 0 { if exp == 0 {
log.Errorf("Synced to already expired key set, retrying in %v: %v", next, err) log.Printf("Synced to already expired key set, retrying in %v: %v", next, err)
} else { } else {
log.Errorf("Failed syncing key set, retrying in %v: %v", next, err) log.Printf("Failed syncing key set, retrying in %v: %v", next, err)
} }
} else { } else {
failing = false failing = false
next = exp / 2 next = exp / 2
log.Infof("Synced key set, checking again in %v", next)
} }
select { select {

View file

@ -332,16 +332,16 @@ func parseTokenResponse(resp *http.Response) (result TokenResponse, err error) {
result.Scope = vals.Get("scope") result.Scope = vals.Get("scope")
} else { } else {
var r struct { var r struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
TokenType string `json:"token_type"` TokenType string `json:"token_type"`
IDToken string `json:"id_token"` IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
Scope string `json:"scope"` Scope string `json:"scope"`
State string `json:"state"` State string `json:"state"`
ExpiresIn int `json:"expires_in"` ExpiresIn json.Number `json:"expires_in"` // Azure AD returns string
Expires int `json:"expires"` Expires int `json:"expires"`
Error string `json:"error"` Error string `json:"error"`
Desc string `json:"error_description"` Desc string `json:"error_description"`
} }
if err = json.Unmarshal(body, &r); err != nil { if err = json.Unmarshal(body, &r); err != nil {
return return
@ -355,10 +355,10 @@ func parseTokenResponse(resp *http.Response) (result TokenResponse, err error) {
result.IDToken = r.IDToken result.IDToken = r.IDToken
result.RefreshToken = r.RefreshToken result.RefreshToken = r.RefreshToken
result.Scope = r.Scope result.Scope = r.Scope
if r.ExpiresIn == 0 { if expiresIn, err := r.ExpiresIn.Int64(); err != nil {
result.Expires = r.Expires result.Expires = r.Expires
} else { } else {
result.Expires = r.ExpiresIn result.Expires = int(expiresIn)
} }
} }
return return

View file

@ -4,13 +4,13 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"log"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/coreos/pkg/capnslog"
"github.com/coreos/pkg/timeutil" "github.com/coreos/pkg/timeutil"
"github.com/jonboulle/clockwork" "github.com/jonboulle/clockwork"
@ -18,10 +18,6 @@ import (
"github.com/coreos/go-oidc/oauth2" "github.com/coreos/go-oidc/oauth2"
) )
var (
log = capnslog.NewPackageLogger("github.com/coreos/go-oidc", "http")
)
const ( const (
// Subject Identifier types defined by the OIDC spec. Specifies if the provider // Subject Identifier types defined by the OIDC spec. Specifies if the provider
// should provide the same sub claim value to all clients (public) or a unique // should provide the same sub claim value to all clients (public) or a unique
@ -69,6 +65,8 @@ type ProviderConfig struct {
UserInfoEndpoint *url.URL UserInfoEndpoint *url.URL
KeysEndpoint *url.URL // Required KeysEndpoint *url.URL // Required
RegistrationEndpoint *url.URL RegistrationEndpoint *url.URL
EndSessionEndpoint *url.URL
CheckSessionIFrame *url.URL
// Servers MAY choose not to advertise some supported scope values even when this // Servers MAY choose not to advertise some supported scope values even when this
// parameter is used, although those defined in OpenID Core SHOULD be listed, if supported. // parameter is used, although those defined in OpenID Core SHOULD be listed, if supported.
@ -170,6 +168,8 @@ type encodableProviderConfig struct {
UserInfoEndpoint string `json:"userinfo_endpoint,omitempty"` UserInfoEndpoint string `json:"userinfo_endpoint,omitempty"`
KeysEndpoint string `json:"jwks_uri"` KeysEndpoint string `json:"jwks_uri"`
RegistrationEndpoint string `json:"registration_endpoint,omitempty"` RegistrationEndpoint string `json:"registration_endpoint,omitempty"`
EndSessionEndpoint string `json:"end_session_endpoint,omitempty"`
CheckSessionIFrame string `json:"check_session_iframe,omitempty"`
// Use 'omitempty' for all slices as per OIDC spec: // Use 'omitempty' for all slices as per OIDC spec:
// "Claims that return multiple values are represented as JSON arrays. // "Claims that return multiple values are represented as JSON arrays.
@ -219,6 +219,8 @@ func (cfg ProviderConfig) toEncodableStruct() encodableProviderConfig {
UserInfoEndpoint: uriToString(cfg.UserInfoEndpoint), UserInfoEndpoint: uriToString(cfg.UserInfoEndpoint),
KeysEndpoint: uriToString(cfg.KeysEndpoint), KeysEndpoint: uriToString(cfg.KeysEndpoint),
RegistrationEndpoint: uriToString(cfg.RegistrationEndpoint), RegistrationEndpoint: uriToString(cfg.RegistrationEndpoint),
EndSessionEndpoint: uriToString(cfg.EndSessionEndpoint),
CheckSessionIFrame: uriToString(cfg.CheckSessionIFrame),
ScopesSupported: cfg.ScopesSupported, ScopesSupported: cfg.ScopesSupported,
ResponseTypesSupported: cfg.ResponseTypesSupported, ResponseTypesSupported: cfg.ResponseTypesSupported,
ResponseModesSupported: cfg.ResponseModesSupported, ResponseModesSupported: cfg.ResponseModesSupported,
@ -260,6 +262,8 @@ func (e encodableProviderConfig) toStruct() (ProviderConfig, error) {
UserInfoEndpoint: p.parseURI(e.UserInfoEndpoint, "userinfo_endpoint"), UserInfoEndpoint: p.parseURI(e.UserInfoEndpoint, "userinfo_endpoint"),
KeysEndpoint: p.parseURI(e.KeysEndpoint, "jwks_uri"), KeysEndpoint: p.parseURI(e.KeysEndpoint, "jwks_uri"),
RegistrationEndpoint: p.parseURI(e.RegistrationEndpoint, "registration_endpoint"), RegistrationEndpoint: p.parseURI(e.RegistrationEndpoint, "registration_endpoint"),
EndSessionEndpoint: p.parseURI(e.EndSessionEndpoint, "end_session_endpoint"),
CheckSessionIFrame: p.parseURI(e.CheckSessionIFrame, "check_session_iframe"),
ScopesSupported: e.ScopesSupported, ScopesSupported: e.ScopesSupported,
ResponseTypesSupported: e.ResponseTypesSupported, ResponseTypesSupported: e.ResponseTypesSupported,
ResponseModesSupported: e.ResponseModesSupported, ResponseModesSupported: e.ResponseModesSupported,
@ -364,6 +368,8 @@ func (p ProviderConfig) Valid() error {
{p.UserInfoEndpoint, "userinfo_endpoint", false}, {p.UserInfoEndpoint, "userinfo_endpoint", false},
{p.KeysEndpoint, "jwks_uri", true}, {p.KeysEndpoint, "jwks_uri", true},
{p.RegistrationEndpoint, "registration_endpoint", false}, {p.RegistrationEndpoint, "registration_endpoint", false},
{p.EndSessionEndpoint, "end_session_endpoint", false},
{p.CheckSessionIFrame, "check_session_iframe", false},
{p.ServiceDocs, "service_documentation", false}, {p.ServiceDocs, "service_documentation", false},
{p.Policy, "op_policy_uri", false}, {p.Policy, "op_policy_uri", false},
{p.TermsOfService, "op_tos_uri", false}, {p.TermsOfService, "op_tos_uri", false},
@ -537,8 +543,6 @@ func (s *ProviderConfigSyncer) sync() (time.Duration, error) {
s.initialSyncDone = true s.initialSyncDone = true
} }
log.Infof("Updating provider config: config=%#v", cfg)
return nextSyncAfter(cfg.ExpiresAt, s.clock), nil return nextSyncAfter(cfg.ExpiresAt, s.clock), nil
} }
@ -561,10 +565,9 @@ func (n *pcsStepNext) step(fn pcsStepFunc) (next pcsStepper) {
ttl, err := fn() ttl, err := fn()
if err == nil { if err == nil {
next = &pcsStepNext{aft: ttl} next = &pcsStepNext{aft: ttl}
log.Debugf("Synced provider config, next attempt in %v", next.after())
} else { } else {
next = &pcsStepRetry{aft: time.Second} next = &pcsStepRetry{aft: time.Second}
log.Errorf("Provider config sync failed, retrying in %v: %v", next.after(), err) log.Printf("go-oidc: provider config sync falied, retyring in %v: %v", next.after(), err)
} }
return return
} }
@ -581,10 +584,9 @@ func (r *pcsStepRetry) step(fn pcsStepFunc) (next pcsStepper) {
ttl, err := fn() ttl, err := fn()
if err == nil { if err == nil {
next = &pcsStepNext{aft: ttl} next = &pcsStepNext{aft: ttl}
log.Infof("Provider config sync no longer failing")
} else { } else {
next = &pcsStepRetry{aft: timeutil.ExpBackoff(r.aft, time.Minute)} next = &pcsStepRetry{aft: timeutil.ExpBackoff(r.aft, time.Minute)}
log.Errorf("Provider config sync still failing, retrying in %v: %v", next.after(), err) log.Printf("go-oidc: provider config sync falied, retyring in %v: %v", next.after(), err)
} }
return return
} }

View file

@ -161,11 +161,18 @@ func NewJWTVerifier(issuer, clientID string, syncFunc func() error, keysFunc fun
} }
func (v *JWTVerifier) Verify(jwt jose.JWT) error { func (v *JWTVerifier) Verify(jwt jose.JWT) error {
// Verify claims before verifying the signature. This is an optimization to throw out
// tokens we know are invalid without undergoing an expensive signature check and
// possibly a re-sync event.
if err := VerifyClaims(jwt, v.issuer, v.clientID); err != nil {
return fmt.Errorf("oidc: JWT claims invalid: %v", err)
}
ok, err := VerifySignature(jwt, v.keysFunc()) ok, err := VerifySignature(jwt, v.keysFunc())
if ok { if err != nil {
goto SignatureVerified
} else if err != nil {
return fmt.Errorf("oidc: JWT signature verification failed: %v", err) return fmt.Errorf("oidc: JWT signature verification failed: %v", err)
} else if ok {
return nil
} }
if err = v.syncFunc(); err != nil { if err = v.syncFunc(); err != nil {
@ -179,10 +186,5 @@ func (v *JWTVerifier) Verify(jwt jose.JWT) error {
return errors.New("oidc: unable to verify JWT signature: no matching keys") return errors.New("oidc: unable to verify JWT signature: no matching keys")
} }
SignatureVerified:
if err := VerifyClaims(jwt, v.issuer, v.clientID); err != nil {
return fmt.Errorf("oidc: JWT claims invalid: %v", err)
}
return nil return nil
} }

View file

@ -1,8 +1,7 @@
language: go language: go
go: go:
- 1.3 - tip
- 1.4
install: install:
- export GOPATH="$HOME/gopath" - export GOPATH="$HOME/gopath"

View file

@ -1,6 +1,7 @@
# OAuth2 for Go # OAuth2 for Go
[![Build Status](https://travis-ci.org/golang/oauth2.svg?branch=master)](https://travis-ci.org/golang/oauth2) [![Build Status](https://travis-ci.org/golang/oauth2.svg?branch=master)](https://travis-ci.org/golang/oauth2)
[![GoDoc](https://godoc.org/golang.org/x/oauth2?status.svg)](https://godoc.org/golang.org/x/oauth2)
oauth2 package contains a client implementation for OAuth 2.0 spec. oauth2 package contains a client implementation for OAuth 2.0 spec.

View file

@ -1,8 +1,8 @@
// Copyright 2014 The oauth2 Authors. All rights reserved. // Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build appengine appenginevm // +build appengine
// App Engine hooks. // App Engine hooks.

View file

@ -1,4 +1,4 @@
// Copyright 2014 The oauth2 Authors. All rights reserved. // Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
@ -14,6 +14,9 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
// Set at init time by appenginevm_hook.go. If true, we are on App Engine Managed VMs.
var appengineVM bool
// Set at init time by appengine_hook.go. If nil, we're not on App Engine. // Set at init time by appengine_hook.go. If nil, we're not on App Engine.
var appengineTokenFunc func(c context.Context, scopes ...string) (token string, expiry time.Time, err error) var appengineTokenFunc func(c context.Context, scopes ...string) (token string, expiry time.Time, err error)

View file

@ -1,8 +1,8 @@
// Copyright 2015 The oauth2 Authors. All rights reserved. // Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build appengine appenginevm // +build appengine
package google package google

14
vendor/golang.org/x/oauth2/google/appenginevm_hook.go generated vendored Normal file
View file

@ -0,0 +1,14 @@
// Copyright 2015 The oauth2 Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build appenginevm
package google
import "google.golang.org/appengine"
func init() {
appengineVM = true
appengineTokenFunc = appengine.AccessToken
}

View file

@ -1,4 +1,4 @@
// Copyright 2015 The oauth2 Authors. All rights reserved. // Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
@ -14,10 +14,10 @@ import (
"path/filepath" "path/filepath"
"runtime" "runtime"
"cloud.google.com/go/compute/metadata"
"golang.org/x/net/context" "golang.org/x/net/context"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"golang.org/x/oauth2/jwt" "golang.org/x/oauth2/jwt"
"google.golang.org/cloud/compute/metadata"
) )
// DefaultClient returns an HTTP Client that uses the // DefaultClient returns an HTTP Client that uses the
@ -50,7 +50,8 @@ func DefaultClient(ctx context.Context, scope ...string) (*http.Client, error) {
// On Windows, this is %APPDATA%/gcloud/application_default_credentials.json. // On Windows, this is %APPDATA%/gcloud/application_default_credentials.json.
// On other systems, $HOME/.config/gcloud/application_default_credentials.json. // On other systems, $HOME/.config/gcloud/application_default_credentials.json.
// 3. On Google App Engine it uses the appengine.AccessToken function. // 3. On Google App Engine it uses the appengine.AccessToken function.
// 4. On Google Compute Engine, it fetches credentials from the metadata server. // 4. On Google Compute Engine and Google App Engine Managed VMs, it fetches
// credentials from the metadata server.
// (In this final case any provided scopes are ignored.) // (In this final case any provided scopes are ignored.)
// //
// For more details, see: // For more details, see:
@ -84,7 +85,7 @@ func DefaultTokenSource(ctx context.Context, scope ...string) (oauth2.TokenSourc
} }
// Third, if we're on Google App Engine use those credentials. // Third, if we're on Google App Engine use those credentials.
if appengineTokenFunc != nil { if appengineTokenFunc != nil && !appengineVM {
return AppEngineTokenSource(ctx, scope...), nil return AppEngineTokenSource(ctx, scope...), nil
} }

View file

@ -1,4 +1,4 @@
// Copyright 2014 The oauth2 Authors. All rights reserved. // Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
@ -21,9 +21,9 @@ import (
"strings" "strings"
"time" "time"
"cloud.google.com/go/compute/metadata"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"golang.org/x/oauth2/jwt" "golang.org/x/oauth2/jwt"
"google.golang.org/cloud/compute/metadata"
) )
// Endpoint is Google's OAuth 2.0 endpoint. // Endpoint is Google's OAuth 2.0 endpoint.
@ -37,9 +37,10 @@ const JWTTokenURL = "https://accounts.google.com/o/oauth2/token"
// ConfigFromJSON uses a Google Developers Console client_credentials.json // ConfigFromJSON uses a Google Developers Console client_credentials.json
// file to construct a config. // file to construct a config.
// client_credentials.json can be downloadable from https://console.developers.google.com, // client_credentials.json can be downloaded from
// under "APIs & Auth" > "Credentials". Download the Web application credentials in the // https://console.developers.google.com, under "Credentials". Download the Web
// JSON format and provide the contents of the file as jsonKey. // application credentials in the JSON format and provide the contents of the
// file as jsonKey.
func ConfigFromJSON(jsonKey []byte, scope ...string) (*oauth2.Config, error) { func ConfigFromJSON(jsonKey []byte, scope ...string) (*oauth2.Config, error) {
type cred struct { type cred struct {
ClientID string `json:"client_id"` ClientID string `json:"client_id"`
@ -81,22 +82,29 @@ func ConfigFromJSON(jsonKey []byte, scope ...string) (*oauth2.Config, error) {
// JWTConfigFromJSON uses a Google Developers service account JSON key file to read // JWTConfigFromJSON uses a Google Developers service account JSON key file to read
// the credentials that authorize and authenticate the requests. // the credentials that authorize and authenticate the requests.
// Create a service account on "Credentials" page under "APIs & Auth" for your // Create a service account on "Credentials" for your project at
// project at https://console.developers.google.com to download a JSON key file. // https://console.developers.google.com to download a JSON key file.
func JWTConfigFromJSON(jsonKey []byte, scope ...string) (*jwt.Config, error) { func JWTConfigFromJSON(jsonKey []byte, scope ...string) (*jwt.Config, error) {
var key struct { var key struct {
Email string `json:"client_email"` Email string `json:"client_email"`
PrivateKey string `json:"private_key"` PrivateKey string `json:"private_key"`
PrivateKeyID string `json:"private_key_id"`
TokenURL string `json:"token_uri"`
} }
if err := json.Unmarshal(jsonKey, &key); err != nil { if err := json.Unmarshal(jsonKey, &key); err != nil {
return nil, err return nil, err
} }
return &jwt.Config{ config := &jwt.Config{
Email: key.Email, Email: key.Email,
PrivateKey: []byte(key.PrivateKey), PrivateKey: []byte(key.PrivateKey),
Scopes: scope, PrivateKeyID: key.PrivateKeyID,
TokenURL: JWTTokenURL, Scopes: scope,
}, nil TokenURL: key.TokenURL,
}
if config.TokenURL == "" {
config.TokenURL = JWTTokenURL
}
return config, nil
} }
// ComputeTokenSource returns a token source that fetches access tokens // ComputeTokenSource returns a token source that fetches access tokens

74
vendor/golang.org/x/oauth2/google/jwt.go generated vendored Normal file
View file

@ -0,0 +1,74 @@
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package google
import (
"crypto/rsa"
"fmt"
"time"
"golang.org/x/oauth2"
"golang.org/x/oauth2/internal"
"golang.org/x/oauth2/jws"
)
// JWTAccessTokenSourceFromJSON uses a Google Developers service account JSON
// key file to read the credentials that authorize and authenticate the
// requests, and returns a TokenSource that does not use any OAuth2 flow but
// instead creates a JWT and sends that as the access token.
// The audience is typically a URL that specifies the scope of the credentials.
//
// Note that this is not a standard OAuth flow, but rather an
// optimization supported by a few Google services.
// Unless you know otherwise, you should use JWTConfigFromJSON instead.
func JWTAccessTokenSourceFromJSON(jsonKey []byte, audience string) (oauth2.TokenSource, error) {
cfg, err := JWTConfigFromJSON(jsonKey)
if err != nil {
return nil, fmt.Errorf("google: could not parse JSON key: %v", err)
}
pk, err := internal.ParseKey(cfg.PrivateKey)
if err != nil {
return nil, fmt.Errorf("google: could not parse key: %v", err)
}
ts := &jwtAccessTokenSource{
email: cfg.Email,
audience: audience,
pk: pk,
pkID: cfg.PrivateKeyID,
}
tok, err := ts.Token()
if err != nil {
return nil, err
}
return oauth2.ReuseTokenSource(tok, ts), nil
}
type jwtAccessTokenSource struct {
email, audience string
pk *rsa.PrivateKey
pkID string
}
func (ts *jwtAccessTokenSource) Token() (*oauth2.Token, error) {
iat := time.Now()
exp := iat.Add(time.Hour)
cs := &jws.ClaimSet{
Iss: ts.email,
Sub: ts.email,
Aud: ts.audience,
Iat: iat.Unix(),
Exp: exp.Unix(),
}
hdr := &jws.Header{
Algorithm: "RS256",
Typ: "JWT",
KeyID: string(ts.pkID),
}
msg, err := jws.Encode(hdr, cs, ts.pk)
if err != nil {
return nil, fmt.Errorf("google: could not encode JWT: %v", err)
}
return &oauth2.Token{AccessToken: msg, TokenType: "Bearer", Expiry: exp}, nil
}

View file

@ -1,4 +1,4 @@
// Copyright 2015 The oauth2 Authors. All rights reserved. // Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.

View file

@ -1,4 +1,4 @@
// Copyright 2014 The oauth2 Authors. All rights reserved. // Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.

View file

@ -1,4 +1,4 @@
// Copyright 2014 The oauth2 Authors. All rights reserved. // Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
@ -91,24 +91,36 @@ func (e *expirationTime) UnmarshalJSON(b []byte) error {
var brokenAuthHeaderProviders = []string{ var brokenAuthHeaderProviders = []string{
"https://accounts.google.com/", "https://accounts.google.com/",
"https://www.googleapis.com/",
"https://github.com/",
"https://api.instagram.com/",
"https://www.douban.com/",
"https://api.dropbox.com/", "https://api.dropbox.com/",
"https://api.soundcloud.com/", "https://api.dropboxapi.com/",
"https://www.linkedin.com/", "https://api.instagram.com/",
"https://api.twitch.tv/", "https://api.netatmo.net/",
"https://oauth.vk.com/",
"https://api.odnoklassniki.ru/", "https://api.odnoklassniki.ru/",
"https://connect.stripe.com/",
"https://api.pushbullet.com/", "https://api.pushbullet.com/",
"https://api.soundcloud.com/",
"https://api.twitch.tv/",
"https://app.box.com/",
"https://connect.stripe.com/",
"https://login.microsoftonline.com/",
"https://login.salesforce.com/",
"https://oauth.sandbox.trainingpeaks.com/", "https://oauth.sandbox.trainingpeaks.com/",
"https://oauth.trainingpeaks.com/", "https://oauth.trainingpeaks.com/",
"https://www.strava.com/oauth/", "https://oauth.vk.com/",
"https://app.box.com/", "https://openapi.baidu.com/",
"https://slack.com/",
"https://test-sandbox.auth.corp.google.com", "https://test-sandbox.auth.corp.google.com",
"https://test.salesforce.com/",
"https://user.gini.net/", "https://user.gini.net/",
"https://www.douban.com/",
"https://www.googleapis.com/",
"https://www.linkedin.com/",
"https://www.strava.com/oauth/",
"https://www.wunderlist.com/oauth/",
"https://api.patreon.com/",
}
func RegisterBrokenAuthHeaderProvider(tokenURL string) {
brokenAuthHeaderProviders = append(brokenAuthHeaderProviders, tokenURL)
} }
// providerAuthHeaderWorks reports whether the OAuth2 server identified by the tokenURL // providerAuthHeaderWorks reports whether the OAuth2 server identified by the tokenURL
@ -134,23 +146,23 @@ func providerAuthHeaderWorks(tokenURL string) bool {
return true return true
} }
func RetrieveToken(ctx context.Context, ClientID, ClientSecret, TokenURL string, v url.Values) (*Token, error) { func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values) (*Token, error) {
hc, err := ContextClient(ctx) hc, err := ContextClient(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
v.Set("client_id", ClientID) v.Set("client_id", clientID)
bustedAuth := !providerAuthHeaderWorks(TokenURL) bustedAuth := !providerAuthHeaderWorks(tokenURL)
if bustedAuth && ClientSecret != "" { if bustedAuth && clientSecret != "" {
v.Set("client_secret", ClientSecret) v.Set("client_secret", clientSecret)
} }
req, err := http.NewRequest("POST", TokenURL, strings.NewReader(v.Encode())) req, err := http.NewRequest("POST", tokenURL, strings.NewReader(v.Encode()))
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
if !bustedAuth { if !bustedAuth {
req.SetBasicAuth(ClientID, ClientSecret) req.SetBasicAuth(clientID, clientSecret)
} }
r, err := hc.Do(req) r, err := hc.Do(req)
if err != nil { if err != nil {

View file

@ -1,4 +1,4 @@
// Copyright 2014 The oauth2 Authors. All rights reserved. // Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
@ -33,6 +33,11 @@ func RegisterContextClientFunc(fn ContextClientFunc) {
} }
func ContextClient(ctx context.Context) (*http.Client, error) { func ContextClient(ctx context.Context) (*http.Client, error) {
if ctx != nil {
if hc, ok := ctx.Value(HTTPClient).(*http.Client); ok {
return hc, nil
}
}
for _, fn := range contextClientFuncs { for _, fn := range contextClientFuncs {
c, err := fn(ctx) c, err := fn(ctx)
if err != nil { if err != nil {
@ -42,9 +47,6 @@ func ContextClient(ctx context.Context) (*http.Client, error) {
return c, nil return c, nil
} }
} }
if hc, ok := ctx.Value(HTTPClient).(*http.Client); ok {
return hc, nil
}
return http.DefaultClient, nil return http.DefaultClient, nil
} }

110
vendor/golang.org/x/oauth2/jws/jws.go generated vendored
View file

@ -1,9 +1,17 @@
// Copyright 2014 The oauth2 Authors. All rights reserved. // Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Package jws provides encoding and decoding utilities for // Package jws provides a partial implementation
// signed JWS messages. // of JSON Web Signature encoding and decoding.
// It exists to support the golang.org/x/oauth2 package.
//
// See RFC 7515.
//
// Deprecated: this package is not intended for public use and might be
// removed in the future. It exists for internal use only.
// Please switch to another JWS package or copy this package into your own
// source tree.
package jws package jws
import ( import (
@ -27,8 +35,8 @@ type ClaimSet struct {
Iss string `json:"iss"` // email address of the client_id of the application making the access token request Iss string `json:"iss"` // email address of the client_id of the application making the access token request
Scope string `json:"scope,omitempty"` // space-delimited list of the permissions the application requests Scope string `json:"scope,omitempty"` // space-delimited list of the permissions the application requests
Aud string `json:"aud"` // descriptor of the intended target of the assertion (Optional). Aud string `json:"aud"` // descriptor of the intended target of the assertion (Optional).
Exp int64 `json:"exp"` // the expiration time of the assertion Exp int64 `json:"exp"` // the expiration time of the assertion (seconds since Unix epoch)
Iat int64 `json:"iat"` // the time the assertion was issued. Iat int64 `json:"iat"` // the time the assertion was issued (seconds since Unix epoch)
Typ string `json:"typ,omitempty"` // token type (Optional). Typ string `json:"typ,omitempty"` // token type (Optional).
// Email for which the application is requesting delegated access (Optional). // Email for which the application is requesting delegated access (Optional).
@ -41,23 +49,22 @@ type ClaimSet struct {
// See http://tools.ietf.org/html/draft-jones-json-web-token-10#section-4.3 // See http://tools.ietf.org/html/draft-jones-json-web-token-10#section-4.3
// This array is marshalled using custom code (see (c *ClaimSet) encode()). // This array is marshalled using custom code (see (c *ClaimSet) encode()).
PrivateClaims map[string]interface{} `json:"-"` PrivateClaims map[string]interface{} `json:"-"`
exp time.Time
iat time.Time
} }
func (c *ClaimSet) encode() (string, error) { func (c *ClaimSet) encode() (string, error) {
if c.exp.IsZero() || c.iat.IsZero() { // Reverting time back for machines whose time is not perfectly in sync.
// Reverting time back for machines whose time is not perfectly in sync. // If client machine's time is in the future according
// If client machine's time is in the future according // to Google servers, an access token will not be issued.
// to Google servers, an access token will not be issued. now := time.Now().Add(-10 * time.Second)
now := time.Now().Add(-10 * time.Second) if c.Iat == 0 {
c.iat = now c.Iat = now.Unix()
c.exp = now.Add(time.Hour) }
if c.Exp == 0 {
c.Exp = now.Add(time.Hour).Unix()
}
if c.Exp < c.Iat {
return "", fmt.Errorf("jws: invalid Exp = %v; must be later than Iat = %v", c.Exp, c.Iat)
} }
c.Exp = c.exp.Unix()
c.Iat = c.iat.Unix()
b, err := json.Marshal(c) b, err := json.Marshal(c)
if err != nil { if err != nil {
@ -65,7 +72,7 @@ func (c *ClaimSet) encode() (string, error) {
} }
if len(c.PrivateClaims) == 0 { if len(c.PrivateClaims) == 0 {
return base64Encode(b), nil return base64.RawURLEncoding.EncodeToString(b), nil
} }
// Marshal private claim set and then append it to b. // Marshal private claim set and then append it to b.
@ -83,7 +90,7 @@ func (c *ClaimSet) encode() (string, error) {
} }
b[len(b)-1] = ',' // Replace closing curly brace with a comma. b[len(b)-1] = ',' // Replace closing curly brace with a comma.
b = append(b, prv[1:]...) // Append private claims. b = append(b, prv[1:]...) // Append private claims.
return base64Encode(b), nil return base64.RawURLEncoding.EncodeToString(b), nil
} }
// Header represents the header for the signed JWS payloads. // Header represents the header for the signed JWS payloads.
@ -93,6 +100,9 @@ type Header struct {
// Represents the token type. // Represents the token type.
Typ string `json:"typ"` Typ string `json:"typ"`
// The optional hint of which key is being used.
KeyID string `json:"kid,omitempty"`
} }
func (h *Header) encode() (string, error) { func (h *Header) encode() (string, error) {
@ -100,7 +110,7 @@ func (h *Header) encode() (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
return base64Encode(b), nil return base64.RawURLEncoding.EncodeToString(b), nil
} }
// Decode decodes a claim set from a JWS payload. // Decode decodes a claim set from a JWS payload.
@ -111,7 +121,7 @@ func Decode(payload string) (*ClaimSet, error) {
// TODO(jbd): Provide more context about the error. // TODO(jbd): Provide more context about the error.
return nil, errors.New("jws: invalid token received") return nil, errors.New("jws: invalid token received")
} }
decoded, err := base64Decode(s[1]) decoded, err := base64.RawURLEncoding.DecodeString(s[1])
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -120,8 +130,11 @@ func Decode(payload string) (*ClaimSet, error) {
return c, err return c, err
} }
// Encode encodes a signed JWS with provided header and claim set. // Signer returns a signature for the given data.
func Encode(header *Header, c *ClaimSet, signature *rsa.PrivateKey) (string, error) { type Signer func(data []byte) (sig []byte, err error)
// EncodeWithSigner encodes a header and claim set with the provided signer.
func EncodeWithSigner(header *Header, c *ClaimSet, sg Signer) (string, error) {
head, err := header.encode() head, err := header.encode()
if err != nil { if err != nil {
return "", err return "", err
@ -131,30 +144,39 @@ func Encode(header *Header, c *ClaimSet, signature *rsa.PrivateKey) (string, err
return "", err return "", err
} }
ss := fmt.Sprintf("%s.%s", head, cs) ss := fmt.Sprintf("%s.%s", head, cs)
h := sha256.New() sig, err := sg([]byte(ss))
h.Write([]byte(ss))
b, err := rsa.SignPKCS1v15(rand.Reader, signature, crypto.SHA256, h.Sum(nil))
if err != nil { if err != nil {
return "", err return "", err
} }
sig := base64Encode(b) return fmt.Sprintf("%s.%s", ss, base64.RawURLEncoding.EncodeToString(sig)), nil
return fmt.Sprintf("%s.%s", ss, sig), nil
} }
// base64Encode returns and Base64url encoded version of the input string with any // Encode encodes a signed JWS with provided header and claim set.
// trailing "=" stripped. // This invokes EncodeWithSigner using crypto/rsa.SignPKCS1v15 with the given RSA private key.
func base64Encode(b []byte) string { func Encode(header *Header, c *ClaimSet, key *rsa.PrivateKey) (string, error) {
return strings.TrimRight(base64.URLEncoding.EncodeToString(b), "=") sg := func(data []byte) (sig []byte, err error) {
} h := sha256.New()
h.Write(data)
// base64Decode decodes the Base64url encoded string return rsa.SignPKCS1v15(rand.Reader, key, crypto.SHA256, h.Sum(nil))
func base64Decode(s string) ([]byte, error) {
// add back missing padding
switch len(s) % 4 {
case 2:
s += "=="
case 3:
s += "="
} }
return base64.URLEncoding.DecodeString(s) return EncodeWithSigner(header, c, sg)
}
// Verify tests whether the provided JWT token's signature was produced by the private key
// associated with the supplied public key.
func Verify(token string, key *rsa.PublicKey) error {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return errors.New("jws: invalid token received, token must have 3 parts")
}
signedContent := parts[0] + "." + parts[1]
signatureString, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return err
}
h := sha256.New()
h.Write([]byte(signedContent))
return rsa.VerifyPKCS1v15(key, crypto.SHA256, h.Sum(nil), []byte(signatureString))
} }

View file

@ -1,4 +1,4 @@
// Copyright 2014 The oauth2 Authors. All rights reserved. // Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
@ -46,6 +46,10 @@ type Config struct {
// //
PrivateKey []byte PrivateKey []byte
// PrivateKeyID contains an optional hint indicating which key is being
// used.
PrivateKeyID string
// Subject is the optional user to impersonate. // Subject is the optional user to impersonate.
Subject string Subject string
@ -54,6 +58,9 @@ type Config struct {
// TokenURL is the endpoint required to complete the 2-legged JWT flow. // TokenURL is the endpoint required to complete the 2-legged JWT flow.
TokenURL string TokenURL string
// Expires optionally specifies how long the token is valid for.
Expires time.Duration
} }
// TokenSource returns a JWT TokenSource using the configuration // TokenSource returns a JWT TokenSource using the configuration
@ -95,6 +102,9 @@ func (js jwtSource) Token() (*oauth2.Token, error) {
// to be compatible with legacy OAuth 2.0 providers. // to be compatible with legacy OAuth 2.0 providers.
claimSet.Prn = subject claimSet.Prn = subject
} }
if t := js.conf.Expires; t > 0 {
claimSet.Exp = time.Now().Add(t).Unix()
}
payload, err := jws.Encode(defaultHeader, claimSet, pk) payload, err := jws.Encode(defaultHeader, claimSet, pk)
if err != nil { if err != nil {
return nil, err return nil, err

20
vendor/golang.org/x/oauth2/oauth2.go generated vendored
View file

@ -1,4 +1,4 @@
// Copyright 2014 The oauth2 Authors. All rights reserved. // Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
@ -21,10 +21,26 @@ import (
// NoContext is the default context you should supply if not using // NoContext is the default context you should supply if not using
// your own context.Context (see https://golang.org/x/net/context). // your own context.Context (see https://golang.org/x/net/context).
//
// Deprecated: Use context.Background() or context.TODO() instead.
var NoContext = context.TODO() var NoContext = context.TODO()
// RegisterBrokenAuthHeaderProvider registers an OAuth2 server
// identified by the tokenURL prefix as an OAuth2 implementation
// which doesn't support the HTTP Basic authentication
// scheme to authenticate with the authorization server.
// Once a server is registered, credentials (client_id and client_secret)
// will be passed as query parameters rather than being present
// in the Authorization header.
// See https://code.google.com/p/goauth2/issues/detail?id=31 for background.
func RegisterBrokenAuthHeaderProvider(tokenURL string) {
internal.RegisterBrokenAuthHeaderProvider(tokenURL)
}
// Config describes a typical 3-legged OAuth2 flow, with both the // Config describes a typical 3-legged OAuth2 flow, with both the
// client application information and the server's endpoint URLs. // client application information and the server's endpoint URLs.
// For the client credentials 2-legged OAuth2 flow, see the clientcredentials
// package (https://golang.org/x/oauth2/clientcredentials).
type Config struct { type Config struct {
// ClientID is the application's ID. // ClientID is the application's ID.
ClientID string ClientID string
@ -283,7 +299,7 @@ func NewClient(ctx context.Context, src TokenSource) *http.Client {
if src == nil { if src == nil {
c, err := internal.ContextClient(ctx) c, err := internal.ContextClient(ctx)
if err != nil { if err != nil {
return &http.Client{Transport: internal.ErrorTransport{err}} return &http.Client{Transport: internal.ErrorTransport{Err: err}}
} }
return c return c
} }

27
vendor/golang.org/x/oauth2/token.go generated vendored
View file

@ -1,4 +1,4 @@
// Copyright 2014 The oauth2 Authors. All rights reserved. // Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
@ -7,6 +7,7 @@ package oauth2
import ( import (
"net/http" "net/http"
"net/url" "net/url"
"strconv"
"strings" "strings"
"time" "time"
@ -92,14 +93,28 @@ func (t *Token) WithExtra(extra interface{}) *Token {
// Extra fields are key-value pairs returned by the server as a // Extra fields are key-value pairs returned by the server as a
// part of the token retrieval response. // part of the token retrieval response.
func (t *Token) Extra(key string) interface{} { func (t *Token) Extra(key string) interface{} {
if vals, ok := t.raw.(url.Values); ok {
// TODO(jbd): Cast numeric values to int64 or float64.
return vals.Get(key)
}
if raw, ok := t.raw.(map[string]interface{}); ok { if raw, ok := t.raw.(map[string]interface{}); ok {
return raw[key] return raw[key]
} }
return nil
vals, ok := t.raw.(url.Values)
if !ok {
return nil
}
v := vals.Get(key)
switch s := strings.TrimSpace(v); strings.Count(s, ".") {
case 0: // Contains no "."; try to parse as int
if i, err := strconv.ParseInt(s, 10, 64); err == nil {
return i
}
case 1: // Contains a single "."; try to parse as float
if f, err := strconv.ParseFloat(s, 64); err == nil {
return f
}
}
return v
} }
// expired reports whether the token is expired. // expired reports whether the token is expired.

View file

@ -1,4 +1,4 @@
// Copyright 2014 The oauth2 Authors. All rights reserved. // Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -878,23 +878,22 @@ func (c *ProjectsZonesGetServerconfigCall) Context(ctx context.Context) *Project
} }
func (c *ProjectsZonesGetServerconfigCall) doRequest(alt string) (*http.Response, error) { func (c *ProjectsZonesGetServerconfigCall) doRequest(alt string) (*http.Response, error) {
reqHeaders := make(http.Header)
reqHeaders.Set("User-Agent", c.s.userAgent())
if c.ifNoneMatch_ != "" {
reqHeaders.Set("If-None-Match", c.ifNoneMatch_)
}
var body io.Reader = nil var body io.Reader = nil
c.urlParams_.Set("alt", alt) c.urlParams_.Set("alt", alt)
urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/serverconfig") urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/serverconfig")
urls += "?" + c.urlParams_.Encode() urls += "?" + c.urlParams_.Encode()
req, _ := http.NewRequest("GET", urls, body) req, _ := http.NewRequest("GET", urls, body)
req.Header = reqHeaders
googleapi.Expand(req.URL, map[string]string{ googleapi.Expand(req.URL, map[string]string{
"projectId": c.projectId, "projectId": c.projectId,
"zone": c.zone, "zone": c.zone,
}) })
req.Header.Set("User-Agent", c.s.userAgent()) return gensupport.SendRequest(c.ctx_, c.s.client, req)
if c.ifNoneMatch_ != "" {
req.Header.Set("If-None-Match", c.ifNoneMatch_)
}
if c.ctx_ != nil {
return ctxhttp.Do(c.ctx_, c.s.client, req)
}
return c.s.client.Do(req)
} }
// Do executes the "container.projects.zones.getServerconfig" call. // Do executes the "container.projects.zones.getServerconfig" call.
@ -929,7 +928,8 @@ func (c *ProjectsZonesGetServerconfigCall) Do(opts ...googleapi.CallOption) (*Se
HTTPStatusCode: res.StatusCode, HTTPStatusCode: res.StatusCode,
}, },
} }
if err := json.NewDecoder(res.Body).Decode(&ret); err != nil { target := &ret
if err := json.NewDecoder(res.Body).Decode(target); err != nil {
return nil, err return nil, err
} }
return ret, nil return ret, nil
@ -1011,26 +1011,24 @@ func (c *ProjectsZonesClustersCreateCall) Context(ctx context.Context) *Projects
} }
func (c *ProjectsZonesClustersCreateCall) doRequest(alt string) (*http.Response, error) { func (c *ProjectsZonesClustersCreateCall) doRequest(alt string) (*http.Response, error) {
reqHeaders := make(http.Header)
reqHeaders.Set("User-Agent", c.s.userAgent())
var body io.Reader = nil var body io.Reader = nil
body, err := googleapi.WithoutDataWrapper.JSONReader(c.createclusterrequest) body, err := googleapi.WithoutDataWrapper.JSONReader(c.createclusterrequest)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ctype := "application/json" reqHeaders.Set("Content-Type", "application/json")
c.urlParams_.Set("alt", alt) c.urlParams_.Set("alt", alt)
urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/clusters") urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/clusters")
urls += "?" + c.urlParams_.Encode() urls += "?" + c.urlParams_.Encode()
req, _ := http.NewRequest("POST", urls, body) req, _ := http.NewRequest("POST", urls, body)
req.Header = reqHeaders
googleapi.Expand(req.URL, map[string]string{ googleapi.Expand(req.URL, map[string]string{
"projectId": c.projectId, "projectId": c.projectId,
"zone": c.zone, "zone": c.zone,
}) })
req.Header.Set("Content-Type", ctype) return gensupport.SendRequest(c.ctx_, c.s.client, req)
req.Header.Set("User-Agent", c.s.userAgent())
if c.ctx_ != nil {
return ctxhttp.Do(c.ctx_, c.s.client, req)
}
return c.s.client.Do(req)
} }
// Do executes the "container.projects.zones.clusters.create" call. // Do executes the "container.projects.zones.clusters.create" call.
@ -1065,7 +1063,8 @@ func (c *ProjectsZonesClustersCreateCall) Do(opts ...googleapi.CallOption) (*Ope
HTTPStatusCode: res.StatusCode, HTTPStatusCode: res.StatusCode,
}, },
} }
if err := json.NewDecoder(res.Body).Decode(&ret); err != nil { target := &ret
if err := json.NewDecoder(res.Body).Decode(target); err != nil {
return nil, err return nil, err
} }
return ret, nil return ret, nil
@ -1147,21 +1146,20 @@ func (c *ProjectsZonesClustersDeleteCall) Context(ctx context.Context) *Projects
} }
func (c *ProjectsZonesClustersDeleteCall) doRequest(alt string) (*http.Response, error) { func (c *ProjectsZonesClustersDeleteCall) doRequest(alt string) (*http.Response, error) {
reqHeaders := make(http.Header)
reqHeaders.Set("User-Agent", c.s.userAgent())
var body io.Reader = nil var body io.Reader = nil
c.urlParams_.Set("alt", alt) c.urlParams_.Set("alt", alt)
urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/clusters/{clusterId}") urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/clusters/{clusterId}")
urls += "?" + c.urlParams_.Encode() urls += "?" + c.urlParams_.Encode()
req, _ := http.NewRequest("DELETE", urls, body) req, _ := http.NewRequest("DELETE", urls, body)
req.Header = reqHeaders
googleapi.Expand(req.URL, map[string]string{ googleapi.Expand(req.URL, map[string]string{
"projectId": c.projectId, "projectId": c.projectId,
"zone": c.zone, "zone": c.zone,
"clusterId": c.clusterId, "clusterId": c.clusterId,
}) })
req.Header.Set("User-Agent", c.s.userAgent()) return gensupport.SendRequest(c.ctx_, c.s.client, req)
if c.ctx_ != nil {
return ctxhttp.Do(c.ctx_, c.s.client, req)
}
return c.s.client.Do(req)
} }
// Do executes the "container.projects.zones.clusters.delete" call. // Do executes the "container.projects.zones.clusters.delete" call.
@ -1196,7 +1194,8 @@ func (c *ProjectsZonesClustersDeleteCall) Do(opts ...googleapi.CallOption) (*Ope
HTTPStatusCode: res.StatusCode, HTTPStatusCode: res.StatusCode,
}, },
} }
if err := json.NewDecoder(res.Body).Decode(&ret); err != nil { target := &ret
if err := json.NewDecoder(res.Body).Decode(target); err != nil {
return nil, err return nil, err
} }
return ret, nil return ret, nil
@ -1288,24 +1287,23 @@ func (c *ProjectsZonesClustersGetCall) Context(ctx context.Context) *ProjectsZon
} }
func (c *ProjectsZonesClustersGetCall) doRequest(alt string) (*http.Response, error) { func (c *ProjectsZonesClustersGetCall) doRequest(alt string) (*http.Response, error) {
reqHeaders := make(http.Header)
reqHeaders.Set("User-Agent", c.s.userAgent())
if c.ifNoneMatch_ != "" {
reqHeaders.Set("If-None-Match", c.ifNoneMatch_)
}
var body io.Reader = nil var body io.Reader = nil
c.urlParams_.Set("alt", alt) c.urlParams_.Set("alt", alt)
urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/clusters/{clusterId}") urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/clusters/{clusterId}")
urls += "?" + c.urlParams_.Encode() urls += "?" + c.urlParams_.Encode()
req, _ := http.NewRequest("GET", urls, body) req, _ := http.NewRequest("GET", urls, body)
req.Header = reqHeaders
googleapi.Expand(req.URL, map[string]string{ googleapi.Expand(req.URL, map[string]string{
"projectId": c.projectId, "projectId": c.projectId,
"zone": c.zone, "zone": c.zone,
"clusterId": c.clusterId, "clusterId": c.clusterId,
}) })
req.Header.Set("User-Agent", c.s.userAgent()) return gensupport.SendRequest(c.ctx_, c.s.client, req)
if c.ifNoneMatch_ != "" {
req.Header.Set("If-None-Match", c.ifNoneMatch_)
}
if c.ctx_ != nil {
return ctxhttp.Do(c.ctx_, c.s.client, req)
}
return c.s.client.Do(req)
} }
// Do executes the "container.projects.zones.clusters.get" call. // Do executes the "container.projects.zones.clusters.get" call.
@ -1340,7 +1338,8 @@ func (c *ProjectsZonesClustersGetCall) Do(opts ...googleapi.CallOption) (*Cluste
HTTPStatusCode: res.StatusCode, HTTPStatusCode: res.StatusCode,
}, },
} }
if err := json.NewDecoder(res.Body).Decode(&ret); err != nil { target := &ret
if err := json.NewDecoder(res.Body).Decode(target); err != nil {
return nil, err return nil, err
} }
return ret, nil return ret, nil
@ -1431,23 +1430,22 @@ func (c *ProjectsZonesClustersListCall) Context(ctx context.Context) *ProjectsZo
} }
func (c *ProjectsZonesClustersListCall) doRequest(alt string) (*http.Response, error) { func (c *ProjectsZonesClustersListCall) doRequest(alt string) (*http.Response, error) {
reqHeaders := make(http.Header)
reqHeaders.Set("User-Agent", c.s.userAgent())
if c.ifNoneMatch_ != "" {
reqHeaders.Set("If-None-Match", c.ifNoneMatch_)
}
var body io.Reader = nil var body io.Reader = nil
c.urlParams_.Set("alt", alt) c.urlParams_.Set("alt", alt)
urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/clusters") urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/clusters")
urls += "?" + c.urlParams_.Encode() urls += "?" + c.urlParams_.Encode()
req, _ := http.NewRequest("GET", urls, body) req, _ := http.NewRequest("GET", urls, body)
req.Header = reqHeaders
googleapi.Expand(req.URL, map[string]string{ googleapi.Expand(req.URL, map[string]string{
"projectId": c.projectId, "projectId": c.projectId,
"zone": c.zone, "zone": c.zone,
}) })
req.Header.Set("User-Agent", c.s.userAgent()) return gensupport.SendRequest(c.ctx_, c.s.client, req)
if c.ifNoneMatch_ != "" {
req.Header.Set("If-None-Match", c.ifNoneMatch_)
}
if c.ctx_ != nil {
return ctxhttp.Do(c.ctx_, c.s.client, req)
}
return c.s.client.Do(req)
} }
// Do executes the "container.projects.zones.clusters.list" call. // Do executes the "container.projects.zones.clusters.list" call.
@ -1482,7 +1480,8 @@ func (c *ProjectsZonesClustersListCall) Do(opts ...googleapi.CallOption) (*ListC
HTTPStatusCode: res.StatusCode, HTTPStatusCode: res.StatusCode,
}, },
} }
if err := json.NewDecoder(res.Body).Decode(&ret); err != nil { target := &ret
if err := json.NewDecoder(res.Body).Decode(target); err != nil {
return nil, err return nil, err
} }
return ret, nil return ret, nil
@ -1558,27 +1557,25 @@ func (c *ProjectsZonesClustersUpdateCall) Context(ctx context.Context) *Projects
} }
func (c *ProjectsZonesClustersUpdateCall) doRequest(alt string) (*http.Response, error) { func (c *ProjectsZonesClustersUpdateCall) doRequest(alt string) (*http.Response, error) {
reqHeaders := make(http.Header)
reqHeaders.Set("User-Agent", c.s.userAgent())
var body io.Reader = nil var body io.Reader = nil
body, err := googleapi.WithoutDataWrapper.JSONReader(c.updateclusterrequest) body, err := googleapi.WithoutDataWrapper.JSONReader(c.updateclusterrequest)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ctype := "application/json" reqHeaders.Set("Content-Type", "application/json")
c.urlParams_.Set("alt", alt) c.urlParams_.Set("alt", alt)
urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/clusters/{clusterId}") urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/clusters/{clusterId}")
urls += "?" + c.urlParams_.Encode() urls += "?" + c.urlParams_.Encode()
req, _ := http.NewRequest("PUT", urls, body) req, _ := http.NewRequest("PUT", urls, body)
req.Header = reqHeaders
googleapi.Expand(req.URL, map[string]string{ googleapi.Expand(req.URL, map[string]string{
"projectId": c.projectId, "projectId": c.projectId,
"zone": c.zone, "zone": c.zone,
"clusterId": c.clusterId, "clusterId": c.clusterId,
}) })
req.Header.Set("Content-Type", ctype) return gensupport.SendRequest(c.ctx_, c.s.client, req)
req.Header.Set("User-Agent", c.s.userAgent())
if c.ctx_ != nil {
return ctxhttp.Do(c.ctx_, c.s.client, req)
}
return c.s.client.Do(req)
} }
// Do executes the "container.projects.zones.clusters.update" call. // Do executes the "container.projects.zones.clusters.update" call.
@ -1613,7 +1610,8 @@ func (c *ProjectsZonesClustersUpdateCall) Do(opts ...googleapi.CallOption) (*Ope
HTTPStatusCode: res.StatusCode, HTTPStatusCode: res.StatusCode,
}, },
} }
if err := json.NewDecoder(res.Body).Decode(&ret); err != nil { target := &ret
if err := json.NewDecoder(res.Body).Decode(target); err != nil {
return nil, err return nil, err
} }
return ret, nil return ret, nil
@ -1699,27 +1697,25 @@ func (c *ProjectsZonesClustersNodePoolsCreateCall) Context(ctx context.Context)
} }
func (c *ProjectsZonesClustersNodePoolsCreateCall) doRequest(alt string) (*http.Response, error) { func (c *ProjectsZonesClustersNodePoolsCreateCall) doRequest(alt string) (*http.Response, error) {
reqHeaders := make(http.Header)
reqHeaders.Set("User-Agent", c.s.userAgent())
var body io.Reader = nil var body io.Reader = nil
body, err := googleapi.WithoutDataWrapper.JSONReader(c.createnodepoolrequest) body, err := googleapi.WithoutDataWrapper.JSONReader(c.createnodepoolrequest)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ctype := "application/json" reqHeaders.Set("Content-Type", "application/json")
c.urlParams_.Set("alt", alt) c.urlParams_.Set("alt", alt)
urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/clusters/{clusterId}/nodePools") urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/clusters/{clusterId}/nodePools")
urls += "?" + c.urlParams_.Encode() urls += "?" + c.urlParams_.Encode()
req, _ := http.NewRequest("POST", urls, body) req, _ := http.NewRequest("POST", urls, body)
req.Header = reqHeaders
googleapi.Expand(req.URL, map[string]string{ googleapi.Expand(req.URL, map[string]string{
"projectId": c.projectId, "projectId": c.projectId,
"zone": c.zone, "zone": c.zone,
"clusterId": c.clusterId, "clusterId": c.clusterId,
}) })
req.Header.Set("Content-Type", ctype) return gensupport.SendRequest(c.ctx_, c.s.client, req)
req.Header.Set("User-Agent", c.s.userAgent())
if c.ctx_ != nil {
return ctxhttp.Do(c.ctx_, c.s.client, req)
}
return c.s.client.Do(req)
} }
// Do executes the "container.projects.zones.clusters.nodePools.create" call. // Do executes the "container.projects.zones.clusters.nodePools.create" call.
@ -1754,7 +1750,8 @@ func (c *ProjectsZonesClustersNodePoolsCreateCall) Do(opts ...googleapi.CallOpti
HTTPStatusCode: res.StatusCode, HTTPStatusCode: res.StatusCode,
}, },
} }
if err := json.NewDecoder(res.Body).Decode(&ret); err != nil { target := &ret
if err := json.NewDecoder(res.Body).Decode(target); err != nil {
return nil, err return nil, err
} }
return ret, nil return ret, nil
@ -1840,22 +1837,21 @@ func (c *ProjectsZonesClustersNodePoolsDeleteCall) Context(ctx context.Context)
} }
func (c *ProjectsZonesClustersNodePoolsDeleteCall) doRequest(alt string) (*http.Response, error) { func (c *ProjectsZonesClustersNodePoolsDeleteCall) doRequest(alt string) (*http.Response, error) {
reqHeaders := make(http.Header)
reqHeaders.Set("User-Agent", c.s.userAgent())
var body io.Reader = nil var body io.Reader = nil
c.urlParams_.Set("alt", alt) c.urlParams_.Set("alt", alt)
urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/clusters/{clusterId}/nodePools/{nodePoolId}") urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/clusters/{clusterId}/nodePools/{nodePoolId}")
urls += "?" + c.urlParams_.Encode() urls += "?" + c.urlParams_.Encode()
req, _ := http.NewRequest("DELETE", urls, body) req, _ := http.NewRequest("DELETE", urls, body)
req.Header = reqHeaders
googleapi.Expand(req.URL, map[string]string{ googleapi.Expand(req.URL, map[string]string{
"projectId": c.projectId, "projectId": c.projectId,
"zone": c.zone, "zone": c.zone,
"clusterId": c.clusterId, "clusterId": c.clusterId,
"nodePoolId": c.nodePoolId, "nodePoolId": c.nodePoolId,
}) })
req.Header.Set("User-Agent", c.s.userAgent()) return gensupport.SendRequest(c.ctx_, c.s.client, req)
if c.ctx_ != nil {
return ctxhttp.Do(c.ctx_, c.s.client, req)
}
return c.s.client.Do(req)
} }
// Do executes the "container.projects.zones.clusters.nodePools.delete" call. // Do executes the "container.projects.zones.clusters.nodePools.delete" call.
@ -1890,7 +1886,8 @@ func (c *ProjectsZonesClustersNodePoolsDeleteCall) Do(opts ...googleapi.CallOpti
HTTPStatusCode: res.StatusCode, HTTPStatusCode: res.StatusCode,
}, },
} }
if err := json.NewDecoder(res.Body).Decode(&ret); err != nil { target := &ret
if err := json.NewDecoder(res.Body).Decode(target); err != nil {
return nil, err return nil, err
} }
return ret, nil return ret, nil
@ -1991,25 +1988,24 @@ func (c *ProjectsZonesClustersNodePoolsGetCall) Context(ctx context.Context) *Pr
} }
func (c *ProjectsZonesClustersNodePoolsGetCall) doRequest(alt string) (*http.Response, error) { func (c *ProjectsZonesClustersNodePoolsGetCall) doRequest(alt string) (*http.Response, error) {
reqHeaders := make(http.Header)
reqHeaders.Set("User-Agent", c.s.userAgent())
if c.ifNoneMatch_ != "" {
reqHeaders.Set("If-None-Match", c.ifNoneMatch_)
}
var body io.Reader = nil var body io.Reader = nil
c.urlParams_.Set("alt", alt) c.urlParams_.Set("alt", alt)
urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/clusters/{clusterId}/nodePools/{nodePoolId}") urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/clusters/{clusterId}/nodePools/{nodePoolId}")
urls += "?" + c.urlParams_.Encode() urls += "?" + c.urlParams_.Encode()
req, _ := http.NewRequest("GET", urls, body) req, _ := http.NewRequest("GET", urls, body)
req.Header = reqHeaders
googleapi.Expand(req.URL, map[string]string{ googleapi.Expand(req.URL, map[string]string{
"projectId": c.projectId, "projectId": c.projectId,
"zone": c.zone, "zone": c.zone,
"clusterId": c.clusterId, "clusterId": c.clusterId,
"nodePoolId": c.nodePoolId, "nodePoolId": c.nodePoolId,
}) })
req.Header.Set("User-Agent", c.s.userAgent()) return gensupport.SendRequest(c.ctx_, c.s.client, req)
if c.ifNoneMatch_ != "" {
req.Header.Set("If-None-Match", c.ifNoneMatch_)
}
if c.ctx_ != nil {
return ctxhttp.Do(c.ctx_, c.s.client, req)
}
return c.s.client.Do(req)
} }
// Do executes the "container.projects.zones.clusters.nodePools.get" call. // Do executes the "container.projects.zones.clusters.nodePools.get" call.
@ -2044,7 +2040,8 @@ func (c *ProjectsZonesClustersNodePoolsGetCall) Do(opts ...googleapi.CallOption)
HTTPStatusCode: res.StatusCode, HTTPStatusCode: res.StatusCode,
}, },
} }
if err := json.NewDecoder(res.Body).Decode(&ret); err != nil { target := &ret
if err := json.NewDecoder(res.Body).Decode(target); err != nil {
return nil, err return nil, err
} }
return ret, nil return ret, nil
@ -2143,24 +2140,23 @@ func (c *ProjectsZonesClustersNodePoolsListCall) Context(ctx context.Context) *P
} }
func (c *ProjectsZonesClustersNodePoolsListCall) doRequest(alt string) (*http.Response, error) { func (c *ProjectsZonesClustersNodePoolsListCall) doRequest(alt string) (*http.Response, error) {
reqHeaders := make(http.Header)
reqHeaders.Set("User-Agent", c.s.userAgent())
if c.ifNoneMatch_ != "" {
reqHeaders.Set("If-None-Match", c.ifNoneMatch_)
}
var body io.Reader = nil var body io.Reader = nil
c.urlParams_.Set("alt", alt) c.urlParams_.Set("alt", alt)
urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/clusters/{clusterId}/nodePools") urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/clusters/{clusterId}/nodePools")
urls += "?" + c.urlParams_.Encode() urls += "?" + c.urlParams_.Encode()
req, _ := http.NewRequest("GET", urls, body) req, _ := http.NewRequest("GET", urls, body)
req.Header = reqHeaders
googleapi.Expand(req.URL, map[string]string{ googleapi.Expand(req.URL, map[string]string{
"projectId": c.projectId, "projectId": c.projectId,
"zone": c.zone, "zone": c.zone,
"clusterId": c.clusterId, "clusterId": c.clusterId,
}) })
req.Header.Set("User-Agent", c.s.userAgent()) return gensupport.SendRequest(c.ctx_, c.s.client, req)
if c.ifNoneMatch_ != "" {
req.Header.Set("If-None-Match", c.ifNoneMatch_)
}
if c.ctx_ != nil {
return ctxhttp.Do(c.ctx_, c.s.client, req)
}
return c.s.client.Do(req)
} }
// Do executes the "container.projects.zones.clusters.nodePools.list" call. // Do executes the "container.projects.zones.clusters.nodePools.list" call.
@ -2195,7 +2191,8 @@ func (c *ProjectsZonesClustersNodePoolsListCall) Do(opts ...googleapi.CallOption
HTTPStatusCode: res.StatusCode, HTTPStatusCode: res.StatusCode,
}, },
} }
if err := json.NewDecoder(res.Body).Decode(&ret); err != nil { target := &ret
if err := json.NewDecoder(res.Body).Decode(target); err != nil {
return nil, err return nil, err
} }
return ret, nil return ret, nil
@ -2287,24 +2284,23 @@ func (c *ProjectsZonesOperationsGetCall) Context(ctx context.Context) *ProjectsZ
} }
func (c *ProjectsZonesOperationsGetCall) doRequest(alt string) (*http.Response, error) { func (c *ProjectsZonesOperationsGetCall) doRequest(alt string) (*http.Response, error) {
reqHeaders := make(http.Header)
reqHeaders.Set("User-Agent", c.s.userAgent())
if c.ifNoneMatch_ != "" {
reqHeaders.Set("If-None-Match", c.ifNoneMatch_)
}
var body io.Reader = nil var body io.Reader = nil
c.urlParams_.Set("alt", alt) c.urlParams_.Set("alt", alt)
urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/operations/{operationId}") urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/operations/{operationId}")
urls += "?" + c.urlParams_.Encode() urls += "?" + c.urlParams_.Encode()
req, _ := http.NewRequest("GET", urls, body) req, _ := http.NewRequest("GET", urls, body)
req.Header = reqHeaders
googleapi.Expand(req.URL, map[string]string{ googleapi.Expand(req.URL, map[string]string{
"projectId": c.projectId, "projectId": c.projectId,
"zone": c.zone, "zone": c.zone,
"operationId": c.operationId, "operationId": c.operationId,
}) })
req.Header.Set("User-Agent", c.s.userAgent()) return gensupport.SendRequest(c.ctx_, c.s.client, req)
if c.ifNoneMatch_ != "" {
req.Header.Set("If-None-Match", c.ifNoneMatch_)
}
if c.ctx_ != nil {
return ctxhttp.Do(c.ctx_, c.s.client, req)
}
return c.s.client.Do(req)
} }
// Do executes the "container.projects.zones.operations.get" call. // Do executes the "container.projects.zones.operations.get" call.
@ -2339,7 +2335,8 @@ func (c *ProjectsZonesOperationsGetCall) Do(opts ...googleapi.CallOption) (*Oper
HTTPStatusCode: res.StatusCode, HTTPStatusCode: res.StatusCode,
}, },
} }
if err := json.NewDecoder(res.Body).Decode(&ret); err != nil { target := &ret
if err := json.NewDecoder(res.Body).Decode(target); err != nil {
return nil, err return nil, err
} }
return ret, nil return ret, nil
@ -2430,23 +2427,22 @@ func (c *ProjectsZonesOperationsListCall) Context(ctx context.Context) *Projects
} }
func (c *ProjectsZonesOperationsListCall) doRequest(alt string) (*http.Response, error) { func (c *ProjectsZonesOperationsListCall) doRequest(alt string) (*http.Response, error) {
reqHeaders := make(http.Header)
reqHeaders.Set("User-Agent", c.s.userAgent())
if c.ifNoneMatch_ != "" {
reqHeaders.Set("If-None-Match", c.ifNoneMatch_)
}
var body io.Reader = nil var body io.Reader = nil
c.urlParams_.Set("alt", alt) c.urlParams_.Set("alt", alt)
urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/operations") urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/operations")
urls += "?" + c.urlParams_.Encode() urls += "?" + c.urlParams_.Encode()
req, _ := http.NewRequest("GET", urls, body) req, _ := http.NewRequest("GET", urls, body)
req.Header = reqHeaders
googleapi.Expand(req.URL, map[string]string{ googleapi.Expand(req.URL, map[string]string{
"projectId": c.projectId, "projectId": c.projectId,
"zone": c.zone, "zone": c.zone,
}) })
req.Header.Set("User-Agent", c.s.userAgent()) return gensupport.SendRequest(c.ctx_, c.s.client, req)
if c.ifNoneMatch_ != "" {
req.Header.Set("If-None-Match", c.ifNoneMatch_)
}
if c.ctx_ != nil {
return ctxhttp.Do(c.ctx_, c.s.client, req)
}
return c.s.client.Do(req)
} }
// Do executes the "container.projects.zones.operations.list" call. // Do executes the "container.projects.zones.operations.list" call.
@ -2481,7 +2477,8 @@ func (c *ProjectsZonesOperationsListCall) Do(opts ...googleapi.CallOption) (*Lis
HTTPStatusCode: res.StatusCode, HTTPStatusCode: res.StatusCode,
}, },
} }
if err := json.NewDecoder(res.Body).Decode(&ret); err != nil { target := &ret
if err := json.NewDecoder(res.Body).Decode(target); err != nil {
return nil, err return nil, err
} }
return ret, nil return ret, nil

View file

@ -12,7 +12,6 @@ import (
"time" "time"
"golang.org/x/net/context" "golang.org/x/net/context"
"golang.org/x/net/context/ctxhttp"
) )
const ( const (
@ -80,7 +79,7 @@ func (rx *ResumableUpload) doUploadRequest(ctx context.Context, data io.Reader,
req.Header.Set("Content-Range", contentRange) req.Header.Set("Content-Range", contentRange)
req.Header.Set("Content-Type", rx.MediaType) req.Header.Set("Content-Type", rx.MediaType)
req.Header.Set("User-Agent", rx.UserAgent) req.Header.Set("User-Agent", rx.UserAgent)
return ctxhttp.Do(ctx, rx.Client, req) return SendRequest(ctx, rx.Client, req)
} }
@ -135,6 +134,8 @@ func contextDone(ctx context.Context) bool {
// It retries using the provided back off strategy until cancelled or the // It retries using the provided back off strategy until cancelled or the
// strategy indicates to stop retrying. // strategy indicates to stop retrying.
// It is called from the auto-generated API code and is not visible to the user. // It is called from the auto-generated API code and is not visible to the user.
// Before sending an HTTP request, Upload calls any registered hook functions,
// and calls the returned functions after the request returns (see send.go).
// rx is private to the auto-generated API code. // rx is private to the auto-generated API code.
// Exactly one of resp or err will be nil. If resp is non-nil, the caller must call resp.Body.Close. // Exactly one of resp or err will be nil. If resp is non-nil, the caller must call resp.Body.Close.
func (rx *ResumableUpload) Upload(ctx context.Context) (resp *http.Response, err error) { func (rx *ResumableUpload) Upload(ctx context.Context) (resp *http.Response, err error) {

55
vendor/google.golang.org/api/gensupport/send.go generated vendored Normal file
View file

@ -0,0 +1,55 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gensupport
import (
"net/http"
"golang.org/x/net/context"
"golang.org/x/net/context/ctxhttp"
)
// Hook is the type of a function that is called once before each HTTP request
// that is sent by a generated API. It returns a function that is called after
// the request returns.
// Hooks are not called if the context is nil.
type Hook func(ctx context.Context, req *http.Request) func(resp *http.Response)
var hooks []Hook
// RegisterHook registers a Hook to be called before each HTTP request by a
// generated API. Hooks are called in the order they are registered. Each
// hook can return a function; if it is non-nil, it is called after the HTTP
// request returns. These functions are called in the reverse order.
// RegisterHook should not be called concurrently with itself or SendRequest.
func RegisterHook(h Hook) {
hooks = append(hooks, h)
}
// SendRequest sends a single HTTP request using the given client.
// If ctx is non-nil, it calls all hooks, then sends the request with
// ctxhttp.Do, then calls any functions returned by the hooks in reverse order.
func SendRequest(ctx context.Context, client *http.Client, req *http.Request) (*http.Response, error) {
if ctx == nil {
return client.Do(req)
}
// Call hooks in order of registration, store returned funcs.
post := make([]func(resp *http.Response), len(hooks))
for i, h := range hooks {
fn := h(ctx, req)
post[i] = fn
}
// Send request.
resp, err := ctxhttp.Do(ctx, client, req)
// Call returned funcs in reverse order.
for i := len(post) - 1; i >= 0; i-- {
if fn := post[i]; fn != nil {
fn(resp)
}
}
return resp, err
}

View file

@ -149,12 +149,12 @@ func IsNotModified(err error) bool {
// CheckMediaResponse returns an error (of type *Error) if the response // CheckMediaResponse returns an error (of type *Error) if the response
// status code is not 2xx. Unlike CheckResponse it does not assume the // status code is not 2xx. Unlike CheckResponse it does not assume the
// body is a JSON error document. // body is a JSON error document.
// It is the caller's responsibility to close res.Body.
func CheckMediaResponse(res *http.Response) error { func CheckMediaResponse(res *http.Response) error {
if res.StatusCode >= 200 && res.StatusCode <= 299 { if res.StatusCode >= 200 && res.StatusCode <= 299 {
return nil return nil
} }
slurp, _ := ioutil.ReadAll(io.LimitReader(res.Body, 1<<20)) slurp, _ := ioutil.ReadAll(io.LimitReader(res.Body, 1<<20))
res.Body.Close()
return &Error{ return &Error{
Code: res.StatusCode, Code: res.StatusCode,
Body: string(slurp), Body: string(slurp),
@ -278,41 +278,15 @@ func ResolveRelative(basestr, relstr string) string {
return us return us
} }
// has4860Fix is whether this Go environment contains the fix for
// http://golang.org/issue/4860
var has4860Fix bool
// init initializes has4860Fix by checking the behavior of the net/http package.
func init() {
r := http.Request{
URL: &url.URL{
Scheme: "http",
Opaque: "//opaque",
},
}
b := &bytes.Buffer{}
r.Write(b)
has4860Fix = bytes.HasPrefix(b.Bytes(), []byte("GET http"))
}
// SetOpaque sets u.Opaque from u.Path such that HTTP requests to it
// don't alter any hex-escaped characters in u.Path.
func SetOpaque(u *url.URL) {
u.Opaque = "//" + u.Host + u.Path
if !has4860Fix {
u.Opaque = u.Scheme + ":" + u.Opaque
}
}
// Expand subsitutes any {encoded} strings in the URL passed in using // Expand subsitutes any {encoded} strings in the URL passed in using
// the map supplied. // the map supplied.
// //
// This calls SetOpaque to avoid encoding of the parameters in the URL path. // This calls SetOpaque to avoid encoding of the parameters in the URL path.
func Expand(u *url.URL, expansions map[string]string) { func Expand(u *url.URL, expansions map[string]string) {
expanded, err := uritemplates.Expand(u.Path, expansions) escaped, unescaped, err := uritemplates.Expand(u.Path, expansions)
if err == nil { if err == nil {
u.Path = expanded u.Path = unescaped
SetOpaque(u) u.RawPath = escaped
} }
} }

View file

@ -34,11 +34,37 @@ func pctEncode(src []byte) []byte {
return dst return dst
} }
func escape(s string, allowReserved bool) string { // pairWriter is a convenience struct which allows escaped and unescaped
// versions of the template to be written in parallel.
type pairWriter struct {
escaped, unescaped bytes.Buffer
}
// Write writes the provided string directly without any escaping.
func (w *pairWriter) Write(s string) {
w.escaped.WriteString(s)
w.unescaped.WriteString(s)
}
// Escape writes the provided string, escaping the string for the
// escaped output.
func (w *pairWriter) Escape(s string, allowReserved bool) {
w.unescaped.WriteString(s)
if allowReserved { if allowReserved {
return string(reserved.ReplaceAllFunc([]byte(s), pctEncode)) w.escaped.Write(reserved.ReplaceAllFunc([]byte(s), pctEncode))
} else {
w.escaped.Write(unreserved.ReplaceAllFunc([]byte(s), pctEncode))
} }
return string(unreserved.ReplaceAllFunc([]byte(s), pctEncode)) }
// Escaped returns the escaped string.
func (w *pairWriter) Escaped() string {
return w.escaped.String()
}
// Unescaped returns the unescaped string.
func (w *pairWriter) Unescaped() string {
return w.unescaped.String()
} }
// A uriTemplate is a parsed representation of a URI template. // A uriTemplate is a parsed representation of a URI template.
@ -170,18 +196,20 @@ func parseTerm(term string) (result templateTerm, err error) {
return result, err return result, err
} }
// Expand expands a URI template with a set of values to produce a string. // Expand expands a URI template with a set of values to produce the
func (t *uriTemplate) Expand(values map[string]string) string { // resultant URI. Two forms of the result are returned: one with all the
var buf bytes.Buffer // elements escaped, and one with the elements unescaped.
func (t *uriTemplate) Expand(values map[string]string) (escaped, unescaped string) {
var w pairWriter
for _, p := range t.parts { for _, p := range t.parts {
p.expand(&buf, values) p.expand(&w, values)
} }
return buf.String() return w.Escaped(), w.Unescaped()
} }
func (tp *templatePart) expand(buf *bytes.Buffer, values map[string]string) { func (tp *templatePart) expand(w *pairWriter, values map[string]string) {
if len(tp.raw) > 0 { if len(tp.raw) > 0 {
buf.WriteString(tp.raw) w.Write(tp.raw)
return return
} }
var first = true var first = true
@ -191,30 +219,30 @@ func (tp *templatePart) expand(buf *bytes.Buffer, values map[string]string) {
continue continue
} }
if first { if first {
buf.WriteString(tp.first) w.Write(tp.first)
first = false first = false
} else { } else {
buf.WriteString(tp.sep) w.Write(tp.sep)
} }
tp.expandString(buf, term, value) tp.expandString(w, term, value)
} }
} }
func (tp *templatePart) expandName(buf *bytes.Buffer, name string, empty bool) { func (tp *templatePart) expandName(w *pairWriter, name string, empty bool) {
if tp.named { if tp.named {
buf.WriteString(name) w.Write(name)
if empty { if empty {
buf.WriteString(tp.ifemp) w.Write(tp.ifemp)
} else { } else {
buf.WriteString("=") w.Write("=")
} }
} }
} }
func (tp *templatePart) expandString(buf *bytes.Buffer, t templateTerm, s string) { func (tp *templatePart) expandString(w *pairWriter, t templateTerm, s string) {
if len(s) > t.truncate && t.truncate > 0 { if len(s) > t.truncate && t.truncate > 0 {
s = s[:t.truncate] s = s[:t.truncate]
} }
tp.expandName(buf, t.name, len(s) == 0) tp.expandName(w, t.name, len(s) == 0)
buf.WriteString(escape(s, tp.allowReserved)) w.Escape(s, tp.allowReserved)
} }

View file

@ -4,10 +4,14 @@
package uritemplates package uritemplates
func Expand(path string, values map[string]string) (string, error) { // Expand parses then expands a URI template with a set of values to produce
// the resultant URI. Two forms of the result are returned: one with all the
// elements escaped, and one with the elements unescaped.
func Expand(path string, values map[string]string) (escaped, unescaped string, err error) {
template, err := parse(path) template, err := parse(path)
if err != nil { if err != nil {
return "", err return "", "", err
} }
return template.Expand(values), nil escaped, unescaped = template.Expand(values)
return escaped, unescaped, nil
} }

View file

@ -1,128 +0,0 @@
// Copyright 2014 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package internal provides support for the cloud packages.
//
// Users should not import this package directly.
package internal
import (
"fmt"
"net/http"
"sync"
"golang.org/x/net/context"
)
type contextKey struct{}
func WithContext(parent context.Context, projID string, c *http.Client) context.Context {
if c == nil {
panic("nil *http.Client passed to WithContext")
}
if projID == "" {
panic("empty project ID passed to WithContext")
}
return context.WithValue(parent, contextKey{}, &cloudContext{
ProjectID: projID,
HTTPClient: c,
})
}
const userAgent = "gcloud-golang/0.1"
type cloudContext struct {
ProjectID string
HTTPClient *http.Client
mu sync.Mutex // guards svc
svc map[string]interface{} // e.g. "storage" => *rawStorage.Service
}
// Service returns the result of the fill function if it's never been
// called before for the given name (which is assumed to be an API
// service name, like "datastore"). If it has already been cached, the fill
// func is not run.
// It's safe for concurrent use by multiple goroutines.
func Service(ctx context.Context, name string, fill func(*http.Client) interface{}) interface{} {
return cc(ctx).service(name, fill)
}
func (c *cloudContext) service(name string, fill func(*http.Client) interface{}) interface{} {
c.mu.Lock()
defer c.mu.Unlock()
if c.svc == nil {
c.svc = make(map[string]interface{})
} else if v, ok := c.svc[name]; ok {
return v
}
v := fill(c.HTTPClient)
c.svc[name] = v
return v
}
// Transport is an http.RoundTripper that appends
// Google Cloud client's user-agent to the original
// request's user-agent header.
type Transport struct {
// Base is the actual http.RoundTripper
// requests will use. It must not be nil.
Base http.RoundTripper
}
// RoundTrip appends a user-agent to the existing user-agent
// header and delegates the request to the base http.RoundTripper.
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
req = cloneRequest(req)
ua := req.Header.Get("User-Agent")
if ua == "" {
ua = userAgent
} else {
ua = fmt.Sprintf("%s %s", ua, userAgent)
}
req.Header.Set("User-Agent", ua)
return t.Base.RoundTrip(req)
}
// cloneRequest returns a clone of the provided *http.Request.
// The clone is a shallow copy of the struct and its Header map.
func cloneRequest(r *http.Request) *http.Request {
// shallow copy of the struct
r2 := new(http.Request)
*r2 = *r
// deep copy of the Header
r2.Header = make(http.Header)
for k, s := range r.Header {
r2.Header[k] = s
}
return r2
}
func ProjID(ctx context.Context) string {
return cc(ctx).ProjectID
}
func HTTPClient(ctx context.Context) *http.Client {
return cc(ctx).HTTPClient
}
// cc returns the internal *cloudContext (cc) state for a context.Context.
// It panics if the user did it wrong.
func cc(ctx context.Context) *cloudContext {
if c, ok := ctx.Value(contextKey{}).(*cloudContext); ok {
return c
}
panic("invalid context.Context type; it should be created with cloud.NewContext")
}

View file

@ -1,17 +1,21 @@
language: go language: go
go: go:
- 1.5.3 - 1.5.4
- 1.6 - 1.6.3
go_import_path: google.golang.org/grpc
before_install: before_install:
- go get golang.org/x/tools/cmd/goimports
- go get github.com/golang/lint/golint
- go get github.com/axw/gocov/gocov - go get github.com/axw/gocov/gocov
- go get github.com/mattn/goveralls - go get github.com/mattn/goveralls
- go get golang.org/x/tools/cmd/cover - go get golang.org/x/tools/cmd/cover
install:
- mkdir -p "$GOPATH/src/google.golang.org"
- mv "$TRAVIS_BUILD_DIR" "$GOPATH/src/google.golang.org/grpc"
script: script:
- '! gofmt -s -d -l . 2>&1 | read'
- '! goimports -l . | read'
- '! golint ./... | grep -vE "(_string|\.pb)\.go:"'
- '! go tool vet -all . 2>&1 | grep -vE "constant [0-9]+ not a string in call to Errorf"'
- make test testrace - make test testrace

View file

@ -36,6 +36,7 @@ package grpc
import ( import (
"bytes" "bytes"
"io" "io"
"math"
"time" "time"
"golang.org/x/net/context" "golang.org/x/net/context"
@ -51,13 +52,20 @@ import (
func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error { func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error {
// Try to acquire header metadata from the server if there is any. // Try to acquire header metadata from the server if there is any.
var err error var err error
defer func() {
if err != nil {
if _, ok := err.(transport.ConnectionError); !ok {
t.CloseStream(stream, err)
}
}
}()
c.headerMD, err = stream.Header() c.headerMD, err = stream.Header()
if err != nil { if err != nil {
return err return err
} }
p := &parser{r: stream} p := &parser{r: stream}
for { for {
if err = recv(p, dopts.codec, stream, dopts.dc, reply); err != nil { if err = recv(p, dopts.codec, stream, dopts.dc, reply, math.MaxInt32); err != nil {
if err == io.EOF { if err == io.EOF {
break break
} }
@ -76,6 +84,7 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd
} }
defer func() { defer func() {
if err != nil { if err != nil {
// If err is connection error, t will be closed, no need to close stream here.
if _, ok := err.(transport.ConnectionError); !ok { if _, ok := err.(transport.ConnectionError); !ok {
t.CloseStream(stream, err) t.CloseStream(stream, err)
} }
@ -90,7 +99,10 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd
return nil, transport.StreamErrorf(codes.Internal, "grpc: %v", err) return nil, transport.StreamErrorf(codes.Internal, "grpc: %v", err)
} }
err = t.Write(stream, outBuf, opts) err = t.Write(stream, outBuf, opts)
if err != nil { // t.NewStream(...) could lead to an early rejection of the RPC (e.g., the service/method
// does not exist.) so that t.Write could get io.EOF from wait(...). Leave the following
// recvResponse to get the final status.
if err != nil && err != io.EOF {
return nil, err return nil, err
} }
// Sent successfully. // Sent successfully.
@ -158,9 +170,9 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
if _, ok := err.(*rpcError); ok { if _, ok := err.(*rpcError); ok {
return err return err
} }
if err == errConnClosing { if err == errConnClosing || err == errConnUnavailable {
if c.failFast { if c.failFast {
return Errorf(codes.Unavailable, "%v", errConnClosing) return Errorf(codes.Unavailable, "%v", err)
} }
continue continue
} }
@ -176,7 +188,10 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
put() put()
put = nil put = nil
} }
if _, ok := err.(transport.ConnectionError); ok { // Retry a non-failfast RPC when
// i) there is a connection error; or
// ii) the server started to drain before this RPC was initiated.
if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
if c.failFast { if c.failFast {
return toRPCErr(err) return toRPCErr(err)
} }
@ -184,20 +199,18 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
} }
return toRPCErr(err) return toRPCErr(err)
} }
// Receive the response
err = recvResponse(cc.dopts, t, &c, stream, reply) err = recvResponse(cc.dopts, t, &c, stream, reply)
if err != nil { if err != nil {
if put != nil { if put != nil {
put() put()
put = nil put = nil
} }
if _, ok := err.(transport.ConnectionError); ok { if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
if c.failFast { if c.failFast {
return toRPCErr(err) return toRPCErr(err)
} }
continue continue
} }
t.CloseStream(stream, err)
return toRPCErr(err) return toRPCErr(err)
} }
if c.traceInfo.tr != nil { if c.traceInfo.tr != nil {

View file

@ -43,7 +43,6 @@ import (
"golang.org/x/net/context" "golang.org/x/net/context"
"golang.org/x/net/trace" "golang.org/x/net/trace"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
"google.golang.org/grpc/transport" "google.golang.org/grpc/transport"
@ -68,13 +67,15 @@ var (
// errCredentialsConflict indicates that grpc.WithTransportCredentials() // errCredentialsConflict indicates that grpc.WithTransportCredentials()
// and grpc.WithInsecure() are both called for a connection. // and grpc.WithInsecure() are both called for a connection.
errCredentialsConflict = errors.New("grpc: transport credentials are set for an insecure connection (grpc.WithTransportCredentials() and grpc.WithInsecure() are both called)") errCredentialsConflict = errors.New("grpc: transport credentials are set for an insecure connection (grpc.WithTransportCredentials() and grpc.WithInsecure() are both called)")
// errNetworkIP indicates that the connection is down due to some network I/O error. // errNetworkIO indicates that the connection is down due to some network I/O error.
errNetworkIO = errors.New("grpc: failed with network I/O error") errNetworkIO = errors.New("grpc: failed with network I/O error")
// errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs. // errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs.
errConnDrain = errors.New("grpc: the connection is drained") errConnDrain = errors.New("grpc: the connection is drained")
// errConnClosing indicates that the connection is closing. // errConnClosing indicates that the connection is closing.
errConnClosing = errors.New("grpc: the connection is closing") errConnClosing = errors.New("grpc: the connection is closing")
errNoAddr = errors.New("grpc: there is no address available to dial") // errConnUnavailable indicates that the connection is unavailable.
errConnUnavailable = errors.New("grpc: the connection is unavailable")
errNoAddr = errors.New("grpc: there is no address available to dial")
// minimum time to give a connection to complete // minimum time to give a connection to complete
minConnectTimeout = 20 * time.Second minConnectTimeout = 20 * time.Second
) )
@ -196,9 +197,14 @@ func WithTimeout(d time.Duration) DialOption {
} }
// WithDialer returns a DialOption that specifies a function to use for dialing network addresses. // WithDialer returns a DialOption that specifies a function to use for dialing network addresses.
func WithDialer(f func(addr string, timeout time.Duration) (net.Conn, error)) DialOption { func WithDialer(f func(string, time.Duration) (net.Conn, error)) DialOption {
return func(o *dialOptions) { return func(o *dialOptions) {
o.copts.Dialer = f o.copts.Dialer = func(ctx context.Context, addr string) (net.Conn, error) {
if deadline, ok := ctx.Deadline(); ok {
return f(addr, deadline.Sub(time.Now()))
}
return f(addr, 0)
}
} }
} }
@ -209,12 +215,19 @@ func WithUserAgent(s string) DialOption {
} }
} }
// Dial creates a client connection the given target. // Dial creates a client connection to the given target.
func Dial(target string, opts ...DialOption) (*ClientConn, error) { func Dial(target string, opts ...DialOption) (*ClientConn, error) {
return DialContext(context.Background(), target, opts...)
}
// DialContext creates a client connection to the given target
// using the supplied context.
func DialContext(ctx context.Context, target string, opts ...DialOption) (*ClientConn, error) {
cc := &ClientConn{ cc := &ClientConn{
target: target, target: target,
conns: make(map[Address]*addrConn), conns: make(map[Address]*addrConn),
} }
cc.ctx, cc.cancel = context.WithCancel(ctx)
for _, opt := range opts { for _, opt := range opts {
opt(&cc.dopts) opt(&cc.dopts)
} }
@ -226,31 +239,33 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
if cc.dopts.bs == nil { if cc.dopts.bs == nil {
cc.dopts.bs = DefaultBackoffConfig cc.dopts.bs = DefaultBackoffConfig
} }
if cc.dopts.balancer == nil {
cc.dopts.balancer = RoundRobin(nil)
}
if err := cc.dopts.balancer.Start(target); err != nil {
return nil, err
}
var ( var (
ok bool ok bool
addrs []Address addrs []Address
) )
ch := cc.dopts.balancer.Notify() if cc.dopts.balancer == nil {
if ch == nil { // Connect to target directly if balancer is nil.
// There is no name resolver installed.
addrs = append(addrs, Address{Addr: target}) addrs = append(addrs, Address{Addr: target})
} else { } else {
addrs, ok = <-ch if err := cc.dopts.balancer.Start(target); err != nil {
if !ok || len(addrs) == 0 { return nil, err
return nil, errNoAddr }
ch := cc.dopts.balancer.Notify()
if ch == nil {
// There is no name resolver installed.
addrs = append(addrs, Address{Addr: target})
} else {
addrs, ok = <-ch
if !ok || len(addrs) == 0 {
return nil, errNoAddr
}
} }
} }
waitC := make(chan error, 1) waitC := make(chan error, 1)
go func() { go func() {
for _, a := range addrs { for _, a := range addrs {
if err := cc.newAddrConn(a, false); err != nil { if err := cc.resetAddrConn(a, false, nil); err != nil {
waitC <- err waitC <- err
return return
} }
@ -267,10 +282,15 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
cc.Close() cc.Close()
return nil, err return nil, err
} }
case <-cc.ctx.Done():
cc.Close()
return nil, cc.ctx.Err()
case <-timeoutCh: case <-timeoutCh:
cc.Close() cc.Close()
return nil, ErrClientConnTimeout return nil, ErrClientConnTimeout
} }
// If balancer is nil or balancer.Notify() is nil, ok will be false here.
// The lbWatcher goroutine will not be created.
if ok { if ok {
go cc.lbWatcher() go cc.lbWatcher()
} }
@ -317,6 +337,9 @@ func (s ConnectivityState) String() string {
// ClientConn represents a client connection to an RPC server. // ClientConn represents a client connection to an RPC server.
type ClientConn struct { type ClientConn struct {
ctx context.Context
cancel context.CancelFunc
target string target string
authority string authority string
dopts dialOptions dopts dialOptions
@ -347,11 +370,12 @@ func (cc *ClientConn) lbWatcher() {
} }
if !keep { if !keep {
del = append(del, c) del = append(del, c)
delete(cc.conns, c.addr)
} }
} }
cc.mu.Unlock() cc.mu.Unlock()
for _, a := range add { for _, a := range add {
cc.newAddrConn(a, true) cc.resetAddrConn(a, true, nil)
} }
for _, c := range del { for _, c := range del {
c.tearDown(errConnDrain) c.tearDown(errConnDrain)
@ -359,13 +383,17 @@ func (cc *ClientConn) lbWatcher() {
} }
} }
func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error { // resetAddrConn creates an addrConn for addr and adds it to cc.conns.
// If there is an old addrConn for addr, it will be torn down, using tearDownErr as the reason.
// If tearDownErr is nil, errConnDrain will be used instead.
func (cc *ClientConn) resetAddrConn(addr Address, skipWait bool, tearDownErr error) error {
ac := &addrConn{ ac := &addrConn{
cc: cc, cc: cc,
addr: addr, addr: addr,
dopts: cc.dopts, dopts: cc.dopts,
shutdownChan: make(chan struct{}),
} }
ac.ctx, ac.cancel = context.WithCancel(cc.ctx)
ac.stateCV = sync.NewCond(&ac.mu)
if EnableTracing { if EnableTracing {
ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr) ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr)
} }
@ -383,26 +411,44 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
} }
} }
} }
// Insert ac into ac.cc.conns. This needs to be done before any getTransport(...) is called. // Track ac in cc. This needs to be done before any getTransport(...) is called.
ac.cc.mu.Lock() cc.mu.Lock()
if ac.cc.conns == nil { if cc.conns == nil {
ac.cc.mu.Unlock() cc.mu.Unlock()
return ErrClientConnClosing return ErrClientConnClosing
} }
stale := ac.cc.conns[ac.addr] stale := cc.conns[ac.addr]
ac.cc.conns[ac.addr] = ac cc.conns[ac.addr] = ac
ac.cc.mu.Unlock() cc.mu.Unlock()
if stale != nil { if stale != nil {
// There is an addrConn alive on ac.addr already. This could be due to // There is an addrConn alive on ac.addr already. This could be due to
// i) stale's Close is undergoing; // 1) a buggy Balancer notifies duplicated Addresses;
// ii) a buggy Balancer notifies duplicated Addresses. // 2) goaway was received, a new ac will replace the old ac.
stale.tearDown(errConnDrain) // The old ac should be deleted from cc.conns, but the
// underlying transport should drain rather than close.
if tearDownErr == nil {
// tearDownErr is nil if resetAddrConn is called by
// 1) Dial
// 2) lbWatcher
// In both cases, the stale ac should drain, not close.
stale.tearDown(errConnDrain)
} else {
stale.tearDown(tearDownErr)
}
} }
ac.stateCV = sync.NewCond(&ac.mu)
// skipWait may overwrite the decision in ac.dopts.block. // skipWait may overwrite the decision in ac.dopts.block.
if ac.dopts.block && !skipWait { if ac.dopts.block && !skipWait {
if err := ac.resetTransport(false); err != nil { if err := ac.resetTransport(false); err != nil {
ac.tearDown(err) if err != errConnClosing {
// Tear down ac and delete it from cc.conns.
cc.mu.Lock()
delete(cc.conns, ac.addr)
cc.mu.Unlock()
ac.tearDown(err)
}
if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() {
return e.Origin()
}
return err return err
} }
// Start to monitor the error status of transport. // Start to monitor the error status of transport.
@ -412,7 +458,10 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
go func() { go func() {
if err := ac.resetTransport(false); err != nil { if err := ac.resetTransport(false); err != nil {
grpclog.Printf("Failed to dial %s: %v; please retry.", ac.addr.Addr, err) grpclog.Printf("Failed to dial %s: %v; please retry.", ac.addr.Addr, err)
ac.tearDown(err) if err != errConnClosing {
// Keep this ac in cc.conns, to get the reason it's torn down.
ac.tearDown(err)
}
return return
} }
ac.transportMonitor() ac.transportMonitor()
@ -422,24 +471,48 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
} }
func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) (transport.ClientTransport, func(), error) { func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) (transport.ClientTransport, func(), error) {
addr, put, err := cc.dopts.balancer.Get(ctx, opts) var (
if err != nil { ac *addrConn
return nil, nil, toRPCErr(err) ok bool
} put func()
cc.mu.RLock() )
if cc.conns == nil { if cc.dopts.balancer == nil {
// If balancer is nil, there should be only one addrConn available.
cc.mu.RLock()
if cc.conns == nil {
cc.mu.RUnlock()
return nil, nil, toRPCErr(ErrClientConnClosing)
}
for _, ac = range cc.conns {
// Break after the first iteration to get the first addrConn.
ok = true
break
}
cc.mu.RUnlock()
} else {
var (
addr Address
err error
)
addr, put, err = cc.dopts.balancer.Get(ctx, opts)
if err != nil {
return nil, nil, toRPCErr(err)
}
cc.mu.RLock()
if cc.conns == nil {
cc.mu.RUnlock()
return nil, nil, toRPCErr(ErrClientConnClosing)
}
ac, ok = cc.conns[addr]
cc.mu.RUnlock() cc.mu.RUnlock()
return nil, nil, toRPCErr(ErrClientConnClosing)
} }
ac, ok := cc.conns[addr]
cc.mu.RUnlock()
if !ok { if !ok {
if put != nil { if put != nil {
put() put()
} }
return nil, nil, Errorf(codes.Internal, "grpc: failed to find the transport to send the rpc") return nil, nil, errConnClosing
} }
t, err := ac.wait(ctx, !opts.BlockingWait) t, err := ac.wait(ctx, cc.dopts.balancer != nil, !opts.BlockingWait)
if err != nil { if err != nil {
if put != nil { if put != nil {
put() put()
@ -451,6 +524,8 @@ func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions)
// Close tears down the ClientConn and all underlying connections. // Close tears down the ClientConn and all underlying connections.
func (cc *ClientConn) Close() error { func (cc *ClientConn) Close() error {
cc.cancel()
cc.mu.Lock() cc.mu.Lock()
if cc.conns == nil { if cc.conns == nil {
cc.mu.Unlock() cc.mu.Unlock()
@ -459,7 +534,9 @@ func (cc *ClientConn) Close() error {
conns := cc.conns conns := cc.conns
cc.conns = nil cc.conns = nil
cc.mu.Unlock() cc.mu.Unlock()
cc.dopts.balancer.Close() if cc.dopts.balancer != nil {
cc.dopts.balancer.Close()
}
for _, ac := range conns { for _, ac := range conns {
ac.tearDown(ErrClientConnClosing) ac.tearDown(ErrClientConnClosing)
} }
@ -468,11 +545,13 @@ func (cc *ClientConn) Close() error {
// addrConn is a network connection to a given address. // addrConn is a network connection to a given address.
type addrConn struct { type addrConn struct {
cc *ClientConn ctx context.Context
addr Address cancel context.CancelFunc
dopts dialOptions
shutdownChan chan struct{} cc *ClientConn
events trace.EventLog addr Address
dopts dialOptions
events trace.EventLog
mu sync.Mutex mu sync.Mutex
state ConnectivityState state ConnectivityState
@ -482,6 +561,9 @@ type addrConn struct {
// due to timeout. // due to timeout.
ready chan struct{} ready chan struct{}
transport transport.ClientTransport transport transport.ClientTransport
// The reason this addrConn is torn down.
tearDownErr error
} }
// printf records an event in ac's event log, unless ac has been closed. // printf records an event in ac's event log, unless ac has been closed.
@ -537,8 +619,7 @@ func (ac *addrConn) waitForStateChange(ctx context.Context, sourceState Connecti
} }
func (ac *addrConn) resetTransport(closeTransport bool) error { func (ac *addrConn) resetTransport(closeTransport bool) error {
var retries int for retries := 0; ; retries++ {
for {
ac.mu.Lock() ac.mu.Lock()
ac.printf("connecting") ac.printf("connecting")
if ac.state == Shutdown { if ac.state == Shutdown {
@ -558,13 +639,20 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
t.Close() t.Close()
} }
sleepTime := ac.dopts.bs.backoff(retries) sleepTime := ac.dopts.bs.backoff(retries)
ac.dopts.copts.Timeout = sleepTime timeout := minConnectTimeout
if sleepTime < minConnectTimeout { if timeout < sleepTime {
ac.dopts.copts.Timeout = minConnectTimeout timeout = sleepTime
} }
ctx, cancel := context.WithTimeout(ac.ctx, timeout)
connectTime := time.Now() connectTime := time.Now()
newTransport, err := transport.NewClientTransport(ac.addr.Addr, &ac.dopts.copts) newTransport, err := transport.NewClientTransport(ctx, ac.addr.Addr, ac.dopts.copts)
if err != nil { if err != nil {
cancel()
if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() {
return err
}
grpclog.Printf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, ac.addr)
ac.mu.Lock() ac.mu.Lock()
if ac.state == Shutdown { if ac.state == Shutdown {
// ac.tearDown(...) has been invoked. // ac.tearDown(...) has been invoked.
@ -579,17 +667,12 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
ac.ready = nil ac.ready = nil
} }
ac.mu.Unlock() ac.mu.Unlock()
sleepTime -= time.Since(connectTime)
if sleepTime < 0 {
sleepTime = 0
}
closeTransport = false closeTransport = false
select { select {
case <-time.After(sleepTime): case <-time.After(sleepTime - time.Since(connectTime)):
case <-ac.shutdownChan: case <-ac.ctx.Done():
return ac.ctx.Err()
} }
retries++
grpclog.Printf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, ac.addr)
continue continue
} }
ac.mu.Lock() ac.mu.Lock()
@ -607,7 +690,9 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
close(ac.ready) close(ac.ready)
ac.ready = nil ac.ready = nil
} }
ac.down = ac.cc.dopts.balancer.Up(ac.addr) if ac.cc.dopts.balancer != nil {
ac.down = ac.cc.dopts.balancer.Up(ac.addr)
}
ac.mu.Unlock() ac.mu.Unlock()
return nil return nil
} }
@ -621,14 +706,42 @@ func (ac *addrConn) transportMonitor() {
t := ac.transport t := ac.transport
ac.mu.Unlock() ac.mu.Unlock()
select { select {
// shutdownChan is needed to detect the teardown when // This is needed to detect the teardown when
// the addrConn is idle (i.e., no RPC in flight). // the addrConn is idle (i.e., no RPC in flight).
case <-ac.shutdownChan: case <-ac.ctx.Done():
select {
case <-t.Error():
t.Close()
default:
}
return
case <-t.GoAway():
// If GoAway happens without any network I/O error, ac is closed without shutting down the
// underlying transport (the transport will be closed when all the pending RPCs finished or
// failed.).
// If GoAway and some network I/O error happen concurrently, ac and its underlying transport
// are closed.
// In both cases, a new ac is created.
select {
case <-t.Error():
ac.cc.resetAddrConn(ac.addr, true, errNetworkIO)
default:
ac.cc.resetAddrConn(ac.addr, true, errConnDrain)
}
return return
case <-t.Error(): case <-t.Error():
select {
case <-ac.ctx.Done():
t.Close()
return
case <-t.GoAway():
ac.cc.resetAddrConn(ac.addr, true, errNetworkIO)
return
default:
}
ac.mu.Lock() ac.mu.Lock()
if ac.state == Shutdown { if ac.state == Shutdown {
// ac.tearDown(...) has been invoked. // ac has been shutdown.
ac.mu.Unlock() ac.mu.Unlock()
return return
} }
@ -640,6 +753,10 @@ func (ac *addrConn) transportMonitor() {
ac.printf("transport exiting: %v", err) ac.printf("transport exiting: %v", err)
ac.mu.Unlock() ac.mu.Unlock()
grpclog.Printf("grpc: addrConn.transportMonitor exits due to: %v", err) grpclog.Printf("grpc: addrConn.transportMonitor exits due to: %v", err)
if err != errConnClosing {
// Keep this ac in cc.conns, to get the reason it's torn down.
ac.tearDown(err)
}
return return
} }
} }
@ -647,35 +764,42 @@ func (ac *addrConn) transportMonitor() {
} }
// wait blocks until i) the new transport is up or ii) ctx is done or iii) ac is closed or // wait blocks until i) the new transport is up or ii) ctx is done or iii) ac is closed or
// iv) transport is in TransientFailure and the RPC is fail-fast. // iv) transport is in TransientFailure and there's no balancer/failfast is true.
func (ac *addrConn) wait(ctx context.Context, failFast bool) (transport.ClientTransport, error) { func (ac *addrConn) wait(ctx context.Context, hasBalancer, failfast bool) (transport.ClientTransport, error) {
for { for {
ac.mu.Lock() ac.mu.Lock()
switch { switch {
case ac.state == Shutdown: case ac.state == Shutdown:
if failfast || !hasBalancer {
// RPC is failfast or balancer is nil. This RPC should fail with ac.tearDownErr.
err := ac.tearDownErr
ac.mu.Unlock()
return nil, err
}
ac.mu.Unlock() ac.mu.Unlock()
return nil, errConnClosing return nil, errConnClosing
case ac.state == Ready: case ac.state == Ready:
ct := ac.transport ct := ac.transport
ac.mu.Unlock() ac.mu.Unlock()
return ct, nil return ct, nil
case ac.state == TransientFailure && failFast: case ac.state == TransientFailure:
ac.mu.Unlock() if failfast || hasBalancer {
return nil, Errorf(codes.Unavailable, "grpc: RPC failed fast due to transport failure") ac.mu.Unlock()
default: return nil, errConnUnavailable
ready := ac.ready
if ready == nil {
ready = make(chan struct{})
ac.ready = ready
}
ac.mu.Unlock()
select {
case <-ctx.Done():
return nil, toRPCErr(ctx.Err())
// Wait until the new transport is ready or failed.
case <-ready:
} }
} }
ready := ac.ready
if ready == nil {
ready = make(chan struct{})
ac.ready = ready
}
ac.mu.Unlock()
select {
case <-ctx.Done():
return nil, toRPCErr(ctx.Err())
// Wait until the new transport is ready or failed.
case <-ready:
}
} }
} }
@ -683,24 +807,28 @@ func (ac *addrConn) wait(ctx context.Context, failFast bool) (transport.ClientTr
// TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in // TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in
// some edge cases (e.g., the caller opens and closes many addrConn's in a // some edge cases (e.g., the caller opens and closes many addrConn's in a
// tight loop. // tight loop.
// tearDown doesn't remove ac from ac.cc.conns.
func (ac *addrConn) tearDown(err error) { func (ac *addrConn) tearDown(err error) {
ac.cancel()
ac.mu.Lock() ac.mu.Lock()
defer func() { defer ac.mu.Unlock()
ac.mu.Unlock()
ac.cc.mu.Lock()
if ac.cc.conns != nil {
delete(ac.cc.conns, ac.addr)
}
ac.cc.mu.Unlock()
}()
if ac.state == Shutdown {
return
}
ac.state = Shutdown
if ac.down != nil { if ac.down != nil {
ac.down(downErrorf(false, false, "%v", err)) ac.down(downErrorf(false, false, "%v", err))
ac.down = nil ac.down = nil
} }
if err == errConnDrain && ac.transport != nil {
// GracefulClose(...) may be executed multiple times when
// i) receiving multiple GoAway frames from the server; or
// ii) there are concurrent name resolver/Balancer triggered
// address removal and GoAway.
ac.transport.GracefulClose()
}
if ac.state == Shutdown {
return
}
ac.state = Shutdown
ac.tearDownErr = err
ac.stateCV.Broadcast() ac.stateCV.Broadcast()
if ac.events != nil { if ac.events != nil {
ac.events.Finish() ac.events.Finish()
@ -710,15 +838,8 @@ func (ac *addrConn) tearDown(err error) {
close(ac.ready) close(ac.ready)
ac.ready = nil ac.ready = nil
} }
if ac.transport != nil { if ac.transport != nil && err != errConnDrain {
if err == errConnDrain { ac.transport.Close()
ac.transport.GracefulClose()
} else {
ac.transport.Close()
}
}
if ac.shutdownChan != nil {
close(ac.shutdownChan)
} }
return return
} }

View file

@ -44,7 +44,6 @@ import (
"io/ioutil" "io/ioutil"
"net" "net"
"strings" "strings"
"time"
"golang.org/x/net/context" "golang.org/x/net/context"
) )
@ -93,11 +92,12 @@ type TransportCredentials interface {
// ClientHandshake does the authentication handshake specified by the corresponding // ClientHandshake does the authentication handshake specified by the corresponding
// authentication protocol on rawConn for clients. It returns the authenticated // authentication protocol on rawConn for clients. It returns the authenticated
// connection and the corresponding auth information about the connection. // connection and the corresponding auth information about the connection.
ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, AuthInfo, error) // Implementations must use the provided context to implement timely cancellation.
ClientHandshake(context.Context, string, net.Conn) (net.Conn, AuthInfo, error)
// ServerHandshake does the authentication handshake for servers. It returns // ServerHandshake does the authentication handshake for servers. It returns
// the authenticated connection and the corresponding auth information about // the authenticated connection and the corresponding auth information about
// the connection. // the connection.
ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) ServerHandshake(net.Conn) (net.Conn, AuthInfo, error)
// Info provides the ProtocolInfo of this TransportCredentials. // Info provides the ProtocolInfo of this TransportCredentials.
Info() ProtocolInfo Info() ProtocolInfo
} }
@ -136,42 +136,28 @@ func (c *tlsCreds) RequireTransportSecurity() bool {
return true return true
} }
type timeoutError struct{} func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ AuthInfo, err error) {
func (timeoutError) Error() string { return "credentials: Dial timed out" }
func (timeoutError) Timeout() bool { return true }
func (timeoutError) Temporary() bool { return true }
func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, _ AuthInfo, err error) {
// borrow some code from tls.DialWithDialer
var errChannel chan error
if timeout != 0 {
errChannel = make(chan error, 2)
time.AfterFunc(timeout, func() {
errChannel <- timeoutError{}
})
}
// use local cfg to avoid clobbering ServerName if using multiple endpoints // use local cfg to avoid clobbering ServerName if using multiple endpoints
cfg := *c.config cfg := cloneTLSConfig(c.config)
if c.config.ServerName == "" { if cfg.ServerName == "" {
colonPos := strings.LastIndex(addr, ":") colonPos := strings.LastIndex(addr, ":")
if colonPos == -1 { if colonPos == -1 {
colonPos = len(addr) colonPos = len(addr)
} }
cfg.ServerName = addr[:colonPos] cfg.ServerName = addr[:colonPos]
} }
conn := tls.Client(rawConn, &cfg) conn := tls.Client(rawConn, cfg)
if timeout == 0 { errChannel := make(chan error, 1)
err = conn.Handshake() go func() {
} else { errChannel <- conn.Handshake()
go func() { }()
errChannel <- conn.Handshake() select {
}() case err := <-errChannel:
err = <-errChannel if err != nil {
} return nil, nil, err
if err != nil { }
rawConn.Close() case <-ctx.Done():
return nil, nil, err return nil, nil, ctx.Err()
} }
// TODO(zhaoq): Omit the auth info for client now. It is more for // TODO(zhaoq): Omit the auth info for client now. It is more for
// information than anything else. // information than anything else.
@ -181,7 +167,6 @@ func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.D
func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) { func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) {
conn := tls.Server(rawConn, c.config) conn := tls.Server(rawConn, c.config)
if err := conn.Handshake(); err != nil { if err := conn.Handshake(); err != nil {
rawConn.Close()
return nil, nil, err return nil, nil, err
} }
return conn, TLSInfo{conn.ConnectionState()}, nil return conn, TLSInfo{conn.ConnectionState()}, nil
@ -189,7 +174,7 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
// NewTLS uses c to construct a TransportCredentials based on TLS. // NewTLS uses c to construct a TransportCredentials based on TLS.
func NewTLS(c *tls.Config) TransportCredentials { func NewTLS(c *tls.Config) TransportCredentials {
tc := &tlsCreds{c} tc := &tlsCreds{cloneTLSConfig(c)}
tc.config.NextProtos = alpnProtoStr tc.config.NextProtos = alpnProtoStr
return tc return tc
} }

View file

@ -0,0 +1,76 @@
// +build go1.7
/*
*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package credentials
import (
"crypto/tls"
)
// cloneTLSConfig returns a shallow clone of the exported
// fields of cfg, ignoring the unexported sync.Once, which
// contains a mutex and must not be copied.
//
// If cfg is nil, a new zero tls.Config is returned.
//
// TODO replace this function with official clone function.
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return &tls.Config{
Rand: cfg.Rand,
Time: cfg.Time,
Certificates: cfg.Certificates,
NameToCertificate: cfg.NameToCertificate,
GetCertificate: cfg.GetCertificate,
RootCAs: cfg.RootCAs,
NextProtos: cfg.NextProtos,
ServerName: cfg.ServerName,
ClientAuth: cfg.ClientAuth,
ClientCAs: cfg.ClientCAs,
InsecureSkipVerify: cfg.InsecureSkipVerify,
CipherSuites: cfg.CipherSuites,
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
SessionTicketsDisabled: cfg.SessionTicketsDisabled,
SessionTicketKey: cfg.SessionTicketKey,
ClientSessionCache: cfg.ClientSessionCache,
MinVersion: cfg.MinVersion,
MaxVersion: cfg.MaxVersion,
CurvePreferences: cfg.CurvePreferences,
DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled,
Renegotiation: cfg.Renegotiation,
}
}

View file

@ -0,0 +1,74 @@
// +build !go1.7
/*
*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package credentials
import (
"crypto/tls"
)
// cloneTLSConfig returns a shallow clone of the exported
// fields of cfg, ignoring the unexported sync.Once, which
// contains a mutex and must not be copied.
//
// If cfg is nil, a new zero tls.Config is returned.
//
// TODO replace this function with official clone function.
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return &tls.Config{
Rand: cfg.Rand,
Time: cfg.Time,
Certificates: cfg.Certificates,
NameToCertificate: cfg.NameToCertificate,
GetCertificate: cfg.GetCertificate,
RootCAs: cfg.RootCAs,
NextProtos: cfg.NextProtos,
ServerName: cfg.ServerName,
ClientAuth: cfg.ClientAuth,
ClientCAs: cfg.ClientCAs,
InsecureSkipVerify: cfg.InsecureSkipVerify,
CipherSuites: cfg.CipherSuites,
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
SessionTicketsDisabled: cfg.SessionTicketsDisabled,
SessionTicketKey: cfg.SessionTicketKey,
ClientSessionCache: cfg.ClientSessionCache,
MinVersion: cfg.MinVersion,
MaxVersion: cfg.MaxVersion,
CurvePreferences: cfg.CurvePreferences,
}
}

View file

@ -60,15 +60,21 @@ func encodeKeyValue(k, v string) (string, string) {
// DecodeKeyValue returns the original key and value corresponding to the // DecodeKeyValue returns the original key and value corresponding to the
// encoded data in k, v. // encoded data in k, v.
// If k is a binary header and v contains comma, v is split on comma before decoded,
// and the decoded v will be joined with comma before returned.
func DecodeKeyValue(k, v string) (string, string, error) { func DecodeKeyValue(k, v string) (string, string, error) {
if !strings.HasSuffix(k, binHdrSuffix) { if !strings.HasSuffix(k, binHdrSuffix) {
return k, v, nil return k, v, nil
} }
val, err := base64.StdEncoding.DecodeString(v) vvs := strings.Split(v, ",")
if err != nil { for i, vv := range vvs {
return "", "", err val, err := base64.StdEncoding.DecodeString(vv)
if err != nil {
return "", "", err
}
vvs[i] = string(val)
} }
return k, string(val), nil return k, strings.Join(vvs, ","), nil
} }
// MD is a mapping from metadata keys to values. Users should use the following // MD is a mapping from metadata keys to values. Users should use the following

View file

@ -227,7 +227,7 @@ type parser struct {
// No other error values or types must be returned, which also means // No other error values or types must be returned, which also means
// that the underlying io.Reader must not return an incompatible // that the underlying io.Reader must not return an incompatible
// error. // error.
func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) { func (p *parser) recvMsg(maxMsgSize int) (pf payloadFormat, msg []byte, err error) {
if _, err := io.ReadFull(p.r, p.header[:]); err != nil { if _, err := io.ReadFull(p.r, p.header[:]); err != nil {
return 0, nil, err return 0, nil, err
} }
@ -238,6 +238,9 @@ func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) {
if length == 0 { if length == 0 {
return pf, nil, nil return pf, nil, nil
} }
if length > uint32(maxMsgSize) {
return 0, nil, Errorf(codes.Internal, "grpc: received message length %d exceeding the max size %d", length, maxMsgSize)
}
// TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead // TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead
// of making it for each message: // of making it for each message:
msg = make([]byte, int(length)) msg = make([]byte, int(length))
@ -308,8 +311,8 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er
return nil return nil
} }
func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}) error { func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxMsgSize int) error {
pf, d, err := p.recvMsg() pf, d, err := p.recvMsg(maxMsgSize)
if err != nil { if err != nil {
return err return err
} }
@ -319,11 +322,16 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{
if pf == compressionMade { if pf == compressionMade {
d, err = dc.Do(bytes.NewReader(d)) d, err = dc.Do(bytes.NewReader(d))
if err != nil { if err != nil {
return transport.StreamErrorf(codes.Internal, "grpc: failed to decompress the received message %v", err) return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
} }
} }
if len(d) > maxMsgSize {
// TODO: Revisit the error code. Currently keep it consistent with java
// implementation.
return Errorf(codes.Internal, "grpc: received a message of %d bytes exceeding %d limit", len(d), maxMsgSize)
}
if err := c.Unmarshal(d, m); err != nil { if err := c.Unmarshal(d, m); err != nil {
return transport.StreamErrorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err) return Errorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err)
} }
return nil return nil
} }

View file

@ -89,9 +89,13 @@ type service struct {
type Server struct { type Server struct {
opts options opts options
mu sync.Mutex // guards following mu sync.Mutex // guards following
lis map[net.Listener]bool lis map[net.Listener]bool
conns map[io.Closer]bool conns map[io.Closer]bool
drain bool
// A CondVar to let GracefulStop() blocks until all the pending RPCs are finished
// and all the transport goes away.
cv *sync.Cond
m map[string]*service // service name -> service info m map[string]*service // service name -> service info
events trace.EventLog events trace.EventLog
} }
@ -101,12 +105,15 @@ type options struct {
codec Codec codec Codec
cp Compressor cp Compressor
dc Decompressor dc Decompressor
maxMsgSize int
unaryInt UnaryServerInterceptor unaryInt UnaryServerInterceptor
streamInt StreamServerInterceptor streamInt StreamServerInterceptor
maxConcurrentStreams uint32 maxConcurrentStreams uint32
useHandlerImpl bool // use http.Handler-based server useHandlerImpl bool // use http.Handler-based server
} }
var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size limit
// A ServerOption sets options. // A ServerOption sets options.
type ServerOption func(*options) type ServerOption func(*options)
@ -117,20 +124,28 @@ func CustomCodec(codec Codec) ServerOption {
} }
} }
// RPCCompressor returns a ServerOption that sets a compressor for outbound message. // RPCCompressor returns a ServerOption that sets a compressor for outbound messages.
func RPCCompressor(cp Compressor) ServerOption { func RPCCompressor(cp Compressor) ServerOption {
return func(o *options) { return func(o *options) {
o.cp = cp o.cp = cp
} }
} }
// RPCDecompressor returns a ServerOption that sets a decompressor for inbound message. // RPCDecompressor returns a ServerOption that sets a decompressor for inbound messages.
func RPCDecompressor(dc Decompressor) ServerOption { func RPCDecompressor(dc Decompressor) ServerOption {
return func(o *options) { return func(o *options) {
o.dc = dc o.dc = dc
} }
} }
// MaxMsgSize returns a ServerOption to set the max message size in bytes for inbound mesages.
// If this is not set, gRPC uses the default 4MB.
func MaxMsgSize(m int) ServerOption {
return func(o *options) {
o.maxMsgSize = m
}
}
// MaxConcurrentStreams returns a ServerOption that will apply a limit on the number // MaxConcurrentStreams returns a ServerOption that will apply a limit on the number
// of concurrent streams to each ServerTransport. // of concurrent streams to each ServerTransport.
func MaxConcurrentStreams(n uint32) ServerOption { func MaxConcurrentStreams(n uint32) ServerOption {
@ -173,6 +188,7 @@ func StreamInterceptor(i StreamServerInterceptor) ServerOption {
// started to accept requests yet. // started to accept requests yet.
func NewServer(opt ...ServerOption) *Server { func NewServer(opt ...ServerOption) *Server {
var opts options var opts options
opts.maxMsgSize = defaultMaxMsgSize
for _, o := range opt { for _, o := range opt {
o(&opts) o(&opts)
} }
@ -186,6 +202,7 @@ func NewServer(opt ...ServerOption) *Server {
conns: make(map[io.Closer]bool), conns: make(map[io.Closer]bool),
m: make(map[string]*service), m: make(map[string]*service),
} }
s.cv = sync.NewCond(&s.mu)
if EnableTracing { if EnableTracing {
_, file, line, _ := runtime.Caller(1) _, file, line, _ := runtime.Caller(1)
s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line)) s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line))
@ -264,8 +281,8 @@ type ServiceInfo struct {
// GetServiceInfo returns a map from service names to ServiceInfo. // GetServiceInfo returns a map from service names to ServiceInfo.
// Service names include the package names, in the form of <package>.<service>. // Service names include the package names, in the form of <package>.<service>.
func (s *Server) GetServiceInfo() map[string]*ServiceInfo { func (s *Server) GetServiceInfo() map[string]ServiceInfo {
ret := make(map[string]*ServiceInfo) ret := make(map[string]ServiceInfo)
for n, srv := range s.m { for n, srv := range s.m {
methods := make([]MethodInfo, 0, len(srv.md)+len(srv.sd)) methods := make([]MethodInfo, 0, len(srv.md)+len(srv.sd))
for m := range srv.md { for m := range srv.md {
@ -283,7 +300,7 @@ func (s *Server) GetServiceInfo() map[string]*ServiceInfo {
}) })
} }
ret[n] = &ServiceInfo{ ret[n] = ServiceInfo{
Methods: methods, Methods: methods,
Metadata: srv.mdata, Metadata: srv.mdata,
} }
@ -468,7 +485,7 @@ func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Strea
func (s *Server) addConn(c io.Closer) bool { func (s *Server) addConn(c io.Closer) bool {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if s.conns == nil { if s.conns == nil || s.drain {
return false return false
} }
s.conns[c] = true s.conns[c] = true
@ -480,6 +497,7 @@ func (s *Server) removeConn(c io.Closer) {
defer s.mu.Unlock() defer s.mu.Unlock()
if s.conns != nil { if s.conns != nil {
delete(s.conns, c) delete(s.conns, c)
s.cv.Signal()
} }
} }
@ -520,7 +538,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
} }
p := &parser{r: stream} p := &parser{r: stream}
for { for {
pf, req, err := p.recvMsg() pf, req, err := p.recvMsg(s.opts.maxMsgSize)
if err == io.EOF { if err == io.EOF {
// The entire stream is done (for unary RPC only). // The entire stream is done (for unary RPC only).
return err return err
@ -530,6 +548,10 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
} }
if err != nil { if err != nil {
switch err := err.(type) { switch err := err.(type) {
case *rpcError:
if err := t.WriteStatus(stream, err.code, err.desc); err != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
}
case transport.ConnectionError: case transport.ConnectionError:
// Nothing to do here. // Nothing to do here.
case transport.StreamError: case transport.StreamError:
@ -569,6 +591,12 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
return err return err
} }
} }
if len(req) > s.opts.maxMsgSize {
// TODO: Revisit the error code. Currently keep it consistent with
// java implementation.
statusCode = codes.Internal
statusDesc = fmt.Sprintf("grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxMsgSize)
}
if err := s.opts.codec.Unmarshal(req, v); err != nil { if err := s.opts.codec.Unmarshal(req, v); err != nil {
return err return err
} }
@ -628,13 +656,14 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
stream.SetSendCompress(s.opts.cp.Type()) stream.SetSendCompress(s.opts.cp.Type())
} }
ss := &serverStream{ ss := &serverStream{
t: t, t: t,
s: stream, s: stream,
p: &parser{r: stream}, p: &parser{r: stream},
codec: s.opts.codec, codec: s.opts.codec,
cp: s.opts.cp, cp: s.opts.cp,
dc: s.opts.dc, dc: s.opts.dc,
trInfo: trInfo, maxMsgSize: s.opts.maxMsgSize,
trInfo: trInfo,
} }
if ss.cp != nil { if ss.cp != nil {
ss.cbuf = new(bytes.Buffer) ss.cbuf = new(bytes.Buffer)
@ -766,14 +795,16 @@ func (s *Server) Stop() {
s.mu.Lock() s.mu.Lock()
listeners := s.lis listeners := s.lis
s.lis = nil s.lis = nil
cs := s.conns st := s.conns
s.conns = nil s.conns = nil
// interrupt GracefulStop if Stop and GracefulStop are called concurrently.
s.cv.Signal()
s.mu.Unlock() s.mu.Unlock()
for lis := range listeners { for lis := range listeners {
lis.Close() lis.Close()
} }
for c := range cs { for c := range st {
c.Close() c.Close()
} }
@ -785,6 +816,32 @@ func (s *Server) Stop() {
s.mu.Unlock() s.mu.Unlock()
} }
// GracefulStop stops the gRPC server gracefully. It stops the server to accept new
// connections and RPCs and blocks until all the pending RPCs are finished.
func (s *Server) GracefulStop() {
s.mu.Lock()
defer s.mu.Unlock()
if s.drain == true || s.conns == nil {
return
}
s.drain = true
for lis := range s.lis {
lis.Close()
}
s.lis = nil
for c := range s.conns {
c.(transport.ServerTransport).Drain()
}
for len(s.conns) != 0 {
s.cv.Wait()
}
s.conns = nil
if s.events != nil {
s.events.Finish()
s.events = nil
}
}
func init() { func init() {
internal.TestingCloseConns = func(arg interface{}) { internal.TestingCloseConns = func(arg interface{}) {
arg.(*Server).testingCloseConns() arg.(*Server).testingCloseConns()

View file

@ -37,6 +37,7 @@ import (
"bytes" "bytes"
"errors" "errors"
"io" "io"
"math"
"sync" "sync"
"time" "time"
@ -84,12 +85,9 @@ type ClientStream interface {
// Header returns the header metadata received from the server if there // Header returns the header metadata received from the server if there
// is any. It blocks if the metadata is not ready to read. // is any. It blocks if the metadata is not ready to read.
Header() (metadata.MD, error) Header() (metadata.MD, error)
// Trailer returns the trailer metadata from the server. It must be called // Trailer returns the trailer metadata from the server, if there is any.
// after stream.Recv() returns non-nil error (including io.EOF) for // It must only be called after stream.CloseAndRecv has returned, or
// bi-directional streaming and server streaming or stream.CloseAndRecv() // stream.Recv has returned a non-nil error (including io.EOF).
// returns for client streaming in order to receive trailer metadata if
// present. Otherwise, it could returns an empty MD even though trailer
// is present.
Trailer() metadata.MD Trailer() metadata.MD
// CloseSend closes the send direction of the stream. It closes the stream // CloseSend closes the send direction of the stream. It closes the stream
// when non-nil error is met. // when non-nil error is met.
@ -99,11 +97,10 @@ type ClientStream interface {
// NewClientStream creates a new Stream for the client side. This is called // NewClientStream creates a new Stream for the client side. This is called
// by generated code. // by generated code.
func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) { func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
var ( var (
t transport.ClientTransport t transport.ClientTransport
s *transport.Stream s *transport.Stream
err error
put func() put func()
) )
c := defaultCallInfo c := defaultCallInfo
@ -120,27 +117,24 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
if cc.dopts.cp != nil { if cc.dopts.cp != nil {
callHdr.SendCompress = cc.dopts.cp.Type() callHdr.SendCompress = cc.dopts.cp.Type()
} }
cs := &clientStream{ var trInfo traceInfo
opts: opts, if EnableTracing {
c: c, trInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method)
desc: desc, trInfo.firstLine.client = true
codec: cc.dopts.codec,
cp: cc.dopts.cp,
dc: cc.dopts.dc,
tracing: EnableTracing,
}
if cc.dopts.cp != nil {
callHdr.SendCompress = cc.dopts.cp.Type()
cs.cbuf = new(bytes.Buffer)
}
if cs.tracing {
cs.trInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method)
cs.trInfo.firstLine.client = true
if deadline, ok := ctx.Deadline(); ok { if deadline, ok := ctx.Deadline(); ok {
cs.trInfo.firstLine.deadline = deadline.Sub(time.Now()) trInfo.firstLine.deadline = deadline.Sub(time.Now())
} }
cs.trInfo.tr.LazyLog(&cs.trInfo.firstLine, false) trInfo.tr.LazyLog(&trInfo.firstLine, false)
ctx = trace.NewContext(ctx, cs.trInfo.tr) ctx = trace.NewContext(ctx, trInfo.tr)
defer func() {
if err != nil {
// Need to call tr.finish() if error is returned.
// Because tr will not be returned to caller.
trInfo.tr.LazyPrintf("RPC: [%v]", err)
trInfo.tr.SetError()
trInfo.tr.Finish()
}
}()
} }
gopts := BalancerGetOptions{ gopts := BalancerGetOptions{
BlockingWait: !c.failFast, BlockingWait: !c.failFast,
@ -152,9 +146,9 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
if _, ok := err.(*rpcError); ok { if _, ok := err.(*rpcError); ok {
return nil, err return nil, err
} }
if err == errConnClosing { if err == errConnClosing || err == errConnUnavailable {
if c.failFast { if c.failFast {
return nil, Errorf(codes.Unavailable, "%v", errConnClosing) return nil, Errorf(codes.Unavailable, "%v", err)
} }
continue continue
} }
@ -168,9 +162,8 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
put() put()
put = nil put = nil
} }
if _, ok := err.(transport.ConnectionError); ok { if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
if c.failFast { if c.failFast {
cs.finish(err)
return nil, toRPCErr(err) return nil, toRPCErr(err)
} }
continue continue
@ -179,16 +172,43 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
} }
break break
} }
cs.put = put cs := &clientStream{
cs.t = t opts: opts,
cs.s = s c: c,
cs.p = &parser{r: s} desc: desc,
// Listen on ctx.Done() to detect cancellation when there is no pending codec: cc.dopts.codec,
// I/O operations on this stream. cp: cc.dopts.cp,
dc: cc.dopts.dc,
put: put,
t: t,
s: s,
p: &parser{r: s},
tracing: EnableTracing,
trInfo: trInfo,
}
if cc.dopts.cp != nil {
cs.cbuf = new(bytes.Buffer)
}
// Listen on ctx.Done() to detect cancellation and s.Done() to detect normal termination
// when there is no pending I/O operations on this stream.
go func() { go func() {
select { select {
case <-t.Error(): case <-t.Error():
// Incur transport error, simply exit. // Incur transport error, simply exit.
case <-s.Done():
// TODO: The trace of the RPC is terminated here when there is no pending
// I/O, which is probably not the optimal solution.
if s.StatusCode() == codes.OK {
cs.finish(nil)
} else {
cs.finish(Errorf(s.StatusCode(), "%s", s.StatusDesc()))
}
cs.closeTransportStream(nil)
case <-s.GoAway():
cs.finish(errConnDrain)
cs.closeTransportStream(errConnDrain)
case <-s.Context().Done(): case <-s.Context().Done():
err := s.Context().Err() err := s.Context().Err()
cs.finish(err) cs.finish(err)
@ -251,7 +271,17 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
if err != nil { if err != nil {
cs.finish(err) cs.finish(err)
} }
if err == nil || err == io.EOF { if err == nil {
return
}
if err == io.EOF {
// Specialize the process for server streaming. SendMesg is only called
// once when creating the stream object. io.EOF needs to be skipped when
// the rpc is early finished (before the stream object is created.).
// TODO: It is probably better to move this into the generated code.
if !cs.desc.ClientStreams && cs.desc.ServerStreams {
err = nil
}
return return
} }
if _, ok := err.(transport.ConnectionError); !ok { if _, ok := err.(transport.ConnectionError); !ok {
@ -272,7 +302,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
} }
func (cs *clientStream) RecvMsg(m interface{}) (err error) { func (cs *clientStream) RecvMsg(m interface{}) (err error) {
err = recv(cs.p, cs.codec, cs.s, cs.dc, m) err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32)
defer func() { defer func() {
// err != nil indicates the termination of the stream. // err != nil indicates the termination of the stream.
if err != nil { if err != nil {
@ -291,7 +321,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
return return
} }
// Special handling for client streaming rpc. // Special handling for client streaming rpc.
err = recv(cs.p, cs.codec, cs.s, cs.dc, m) err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32)
cs.closeTransportStream(err) cs.closeTransportStream(err)
if err == nil { if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>")) return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
@ -326,7 +356,7 @@ func (cs *clientStream) CloseSend() (err error) {
} }
}() }()
if err == nil || err == io.EOF { if err == nil || err == io.EOF {
return return nil
} }
if _, ok := err.(transport.ConnectionError); !ok { if _, ok := err.(transport.ConnectionError); !ok {
cs.closeTransportStream(err) cs.closeTransportStream(err)
@ -392,6 +422,7 @@ type serverStream struct {
cp Compressor cp Compressor
dc Decompressor dc Decompressor
cbuf *bytes.Buffer cbuf *bytes.Buffer
maxMsgSize int
statusCode codes.Code statusCode codes.Code
statusDesc string statusDesc string
trInfo *traceInfo trInfo *traceInfo
@ -458,5 +489,5 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
ss.mu.Unlock() ss.mu.Unlock()
} }
}() }()
return recv(ss.p, ss.codec, ss.s, ss.dc, m) return recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize)
} }

View file

@ -72,6 +72,11 @@ type resetStream struct {
func (*resetStream) item() {} func (*resetStream) item() {}
type goAway struct {
}
func (*goAway) item() {}
type flushIO struct { type flushIO struct {
} }

46
vendor/google.golang.org/grpc/transport/go16.go generated vendored Normal file
View file

@ -0,0 +1,46 @@
// +build go1.6,!go1.7
/*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package transport
import (
"net"
"golang.org/x/net/context"
)
// dialContext connects to the address on the named network.
func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
return (&net.Dialer{Cancel: ctx.Done()}).Dial(network, address)
}

46
vendor/google.golang.org/grpc/transport/go17.go generated vendored Normal file
View file

@ -0,0 +1,46 @@
// +build go1.7
/*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package transport
import (
"net"
"golang.org/x/net/context"
)
// dialContext connects to the address on the named network.
func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
return (&net.Dialer{}).DialContext(ctx, network, address)
}

View file

@ -83,7 +83,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr
} }
if v := r.Header.Get("grpc-timeout"); v != "" { if v := r.Header.Get("grpc-timeout"); v != "" {
to, err := timeoutDecode(v) to, err := decodeTimeout(v)
if err != nil { if err != nil {
return nil, StreamErrorf(codes.Internal, "malformed time-out: %v", err) return nil, StreamErrorf(codes.Internal, "malformed time-out: %v", err)
} }
@ -194,7 +194,7 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, statusCode codes.Code,
h := ht.rw.Header() h := ht.rw.Header()
h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode)) h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode))
if statusDesc != "" { if statusDesc != "" {
h.Set("Grpc-Message", statusDesc) h.Set("Grpc-Message", encodeGrpcMessage(statusDesc))
} }
if md := s.Trailer(); len(md) > 0 { if md := s.Trailer(); len(md) > 0 {
for k, vv := range md { for k, vv := range md {
@ -370,6 +370,10 @@ func (ht *serverHandlerTransport) runStream() {
} }
} }
func (ht *serverHandlerTransport) Drain() {
panic("Drain() is not implemented")
}
// mapRecvMsgError returns the non-nil err into the appropriate // mapRecvMsgError returns the non-nil err into the appropriate
// error value as expected by callers of *grpc.parser.recvMsg. // error value as expected by callers of *grpc.parser.recvMsg.
// In particular, in can only be: // In particular, in can only be:

View file

@ -35,6 +35,7 @@ package transport
import ( import (
"bytes" "bytes"
"fmt"
"io" "io"
"math" "math"
"net" "net"
@ -71,6 +72,9 @@ type http2Client struct {
shutdownChan chan struct{} shutdownChan chan struct{}
// errorChan is closed to notify the I/O error to the caller. // errorChan is closed to notify the I/O error to the caller.
errorChan chan struct{} errorChan chan struct{}
// goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor)
// that the server sent GoAway on this transport.
goAway chan struct{}
framer *framer framer *framer
hBuf *bytes.Buffer // the buffer for HPACK encoding hBuf *bytes.Buffer // the buffer for HPACK encoding
@ -97,41 +101,44 @@ type http2Client struct {
maxStreams int maxStreams int
// the per-stream outbound flow control window size set by the peer. // the per-stream outbound flow control window size set by the peer.
streamSendQuota uint32 streamSendQuota uint32
// goAwayID records the Last-Stream-ID in the GoAway frame from the server.
goAwayID uint32
// prevGoAway ID records the Last-Stream-ID in the previous GOAway frame.
prevGoAwayID uint32
}
func dial(fn func(context.Context, string) (net.Conn, error), ctx context.Context, addr string) (net.Conn, error) {
if fn != nil {
return fn(ctx, addr)
}
return dialContext(ctx, "tcp", addr)
} }
// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
// and starts to receive messages on it. Non-nil error returns if construction // and starts to receive messages on it. Non-nil error returns if construction
// fails. // fails.
func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err error) { func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ ClientTransport, err error) {
if opts.Dialer == nil {
// Set the default Dialer.
opts.Dialer = func(addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout("tcp", addr, timeout)
}
}
scheme := "http" scheme := "http"
startT := time.Now() conn, connErr := dial(opts.Dialer, ctx, addr)
timeout := opts.Timeout
conn, connErr := opts.Dialer(addr, timeout)
if connErr != nil { if connErr != nil {
return nil, ConnectionErrorf("transport: %v", connErr) return nil, ConnectionErrorf(true, connErr, "transport: %v", connErr)
} }
var authInfo credentials.AuthInfo // Any further errors will close the underlying connection
if opts.TransportCredentials != nil { defer func(conn net.Conn) {
scheme = "https"
if timeout > 0 {
timeout -= time.Since(startT)
}
conn, authInfo, connErr = opts.TransportCredentials.ClientHandshake(addr, conn, timeout)
}
if connErr != nil {
return nil, ConnectionErrorf("transport: %v", connErr)
}
defer func() {
if err != nil { if err != nil {
conn.Close() conn.Close()
} }
}() }(conn)
var authInfo credentials.AuthInfo
if creds := opts.TransportCredentials; creds != nil {
scheme = "https"
conn, authInfo, connErr = creds.ClientHandshake(ctx, addr, conn)
}
if connErr != nil {
// Credentials handshake error is not a temporary error (unless the error
// was the connection closing).
return nil, ConnectionErrorf(connErr == io.EOF, connErr, "transport: %v", connErr)
}
ua := primaryUA ua := primaryUA
if opts.UserAgent != "" { if opts.UserAgent != "" {
ua = opts.UserAgent + " " + ua ua = opts.UserAgent + " " + ua
@ -147,6 +154,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
writableChan: make(chan int, 1), writableChan: make(chan int, 1),
shutdownChan: make(chan struct{}), shutdownChan: make(chan struct{}),
errorChan: make(chan struct{}), errorChan: make(chan struct{}),
goAway: make(chan struct{}),
framer: newFramer(conn), framer: newFramer(conn),
hBuf: &buf, hBuf: &buf,
hEnc: hpack.NewEncoder(&buf), hEnc: hpack.NewEncoder(&buf),
@ -168,11 +176,11 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
n, err := t.conn.Write(clientPreface) n, err := t.conn.Write(clientPreface)
if err != nil { if err != nil {
t.Close() t.Close()
return nil, ConnectionErrorf("transport: %v", err) return nil, ConnectionErrorf(true, err, "transport: %v", err)
} }
if n != len(clientPreface) { if n != len(clientPreface) {
t.Close() t.Close()
return nil, ConnectionErrorf("transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface)) return nil, ConnectionErrorf(true, err, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface))
} }
if initialWindowSize != defaultWindowSize { if initialWindowSize != defaultWindowSize {
err = t.framer.writeSettings(true, http2.Setting{ err = t.framer.writeSettings(true, http2.Setting{
@ -184,13 +192,13 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
} }
if err != nil { if err != nil {
t.Close() t.Close()
return nil, ConnectionErrorf("transport: %v", err) return nil, ConnectionErrorf(true, err, "transport: %v", err)
} }
// Adjust the connection flow control window if needed. // Adjust the connection flow control window if needed.
if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 { if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 {
if err := t.framer.writeWindowUpdate(true, 0, delta); err != nil { if err := t.framer.writeWindowUpdate(true, 0, delta); err != nil {
t.Close() t.Close()
return nil, ConnectionErrorf("transport: %v", err) return nil, ConnectionErrorf(true, err, "transport: %v", err)
} }
} }
go t.controller() go t.controller()
@ -202,6 +210,8 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
// TODO(zhaoq): Handle uint32 overflow of Stream.id. // TODO(zhaoq): Handle uint32 overflow of Stream.id.
s := &Stream{ s := &Stream{
id: t.nextID, id: t.nextID,
done: make(chan struct{}),
goAway: make(chan struct{}),
method: callHdr.Method, method: callHdr.Method,
sendCompress: callHdr.SendCompress, sendCompress: callHdr.SendCompress,
buf: newRecvBuffer(), buf: newRecvBuffer(),
@ -216,8 +226,9 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
// Make a stream be able to cancel the pending operations by itself. // Make a stream be able to cancel the pending operations by itself.
s.ctx, s.cancel = context.WithCancel(ctx) s.ctx, s.cancel = context.WithCancel(ctx)
s.dec = &recvBufferReader{ s.dec = &recvBufferReader{
ctx: s.ctx, ctx: s.ctx,
recv: s.buf, goAway: s.goAway,
recv: s.buf,
} }
return s return s
} }
@ -271,6 +282,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
t.mu.Unlock() t.mu.Unlock()
return nil, ErrConnClosing return nil, ErrConnClosing
} }
if t.state == draining {
t.mu.Unlock()
return nil, ErrStreamDrain
}
if t.state != reachable { if t.state != reachable {
t.mu.Unlock() t.mu.Unlock()
return nil, ErrConnClosing return nil, ErrConnClosing
@ -278,7 +293,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
checkStreamsQuota := t.streamsQuota != nil checkStreamsQuota := t.streamsQuota != nil
t.mu.Unlock() t.mu.Unlock()
if checkStreamsQuota { if checkStreamsQuota {
sq, err := wait(ctx, t.shutdownChan, t.streamsQuota.acquire()) sq, err := wait(ctx, nil, nil, t.shutdownChan, t.streamsQuota.acquire())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -287,7 +302,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
t.streamsQuota.add(sq - 1) t.streamsQuota.add(sq - 1)
} }
} }
if _, err := wait(ctx, t.shutdownChan, t.writableChan); err != nil { if _, err := wait(ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
// Return the quota back now because there is no stream returned to the caller. // Return the quota back now because there is no stream returned to the caller.
if _, ok := err.(StreamError); ok && checkStreamsQuota { if _, ok := err.(StreamError); ok && checkStreamsQuota {
t.streamsQuota.add(1) t.streamsQuota.add(1)
@ -295,6 +310,15 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
return nil, err return nil, err
} }
t.mu.Lock() t.mu.Lock()
if t.state == draining {
t.mu.Unlock()
if checkStreamsQuota {
t.streamsQuota.add(1)
}
// Need to make t writable again so that the rpc in flight can still proceed.
t.writableChan <- 0
return nil, ErrStreamDrain
}
if t.state != reachable { if t.state != reachable {
t.mu.Unlock() t.mu.Unlock()
return nil, ErrConnClosing return nil, ErrConnClosing
@ -329,7 +353,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress}) t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress})
} }
if timeout > 0 { if timeout > 0 {
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: timeoutEncode(timeout)}) t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)})
} }
for k, v := range authData { for k, v := range authData {
// Capital header names are illegal in HTTP/2. // Capital header names are illegal in HTTP/2.
@ -384,7 +408,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
} }
if err != nil { if err != nil {
t.notifyError(err) t.notifyError(err)
return nil, ConnectionErrorf("transport: %v", err) return nil, ConnectionErrorf(true, err, "transport: %v", err)
} }
} }
t.writableChan <- 0 t.writableChan <- 0
@ -403,22 +427,17 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
if t.streamsQuota != nil { if t.streamsQuota != nil {
updateStreams = true updateStreams = true
} }
if t.state == draining && len(t.activeStreams) == 1 { delete(t.activeStreams, s.id)
if t.state == draining && len(t.activeStreams) == 0 {
// The transport is draining and s is the last live stream on t. // The transport is draining and s is the last live stream on t.
t.mu.Unlock() t.mu.Unlock()
t.Close() t.Close()
return return
} }
delete(t.activeStreams, s.id)
t.mu.Unlock() t.mu.Unlock()
if updateStreams { if updateStreams {
t.streamsQuota.add(1) t.streamsQuota.add(1)
} }
// In case stream sending and receiving are invoked in separate
// goroutines (e.g., bi-directional streaming), the caller needs
// to call cancel on the stream to interrupt the blocking on
// other goroutines.
s.cancel()
s.mu.Lock() s.mu.Lock()
if q := s.fc.resetPendingData(); q > 0 { if q := s.fc.resetPendingData(); q > 0 {
if n := t.fc.onRead(q); n > 0 { if n := t.fc.onRead(q); n > 0 {
@ -445,13 +464,13 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
// accessed any more. // accessed any more.
func (t *http2Client) Close() (err error) { func (t *http2Client) Close() (err error) {
t.mu.Lock() t.mu.Lock()
if t.state == reachable {
close(t.errorChan)
}
if t.state == closing { if t.state == closing {
t.mu.Unlock() t.mu.Unlock()
return return
} }
if t.state == reachable || t.state == draining {
close(t.errorChan)
}
t.state = closing t.state = closing
t.mu.Unlock() t.mu.Unlock()
close(t.shutdownChan) close(t.shutdownChan)
@ -475,10 +494,35 @@ func (t *http2Client) Close() (err error) {
func (t *http2Client) GracefulClose() error { func (t *http2Client) GracefulClose() error {
t.mu.Lock() t.mu.Lock()
if t.state == closing { switch t.state {
case unreachable:
// The server may close the connection concurrently. t is not available for
// any streams. Close it now.
t.mu.Unlock()
t.Close()
return nil
case closing:
t.mu.Unlock() t.mu.Unlock()
return nil return nil
} }
// Notify the streams which were initiated after the server sent GOAWAY.
select {
case <-t.goAway:
n := t.prevGoAwayID
if n == 0 && t.nextID > 1 {
n = t.nextID - 2
}
m := t.goAwayID + 2
if m == 2 {
m = 1
}
for i := m; i <= n; i += 2 {
if s, ok := t.activeStreams[i]; ok {
close(s.goAway)
}
}
default:
}
if t.state == draining { if t.state == draining {
t.mu.Unlock() t.mu.Unlock()
return nil return nil
@ -504,15 +548,15 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
size := http2MaxFrameLen size := http2MaxFrameLen
s.sendQuotaPool.add(0) s.sendQuotaPool.add(0)
// Wait until the stream has some quota to send the data. // Wait until the stream has some quota to send the data.
sq, err := wait(s.ctx, t.shutdownChan, s.sendQuotaPool.acquire()) sq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, s.sendQuotaPool.acquire())
if err != nil { if err != nil {
return err return err
} }
t.sendQuotaPool.add(0) t.sendQuotaPool.add(0)
// Wait until the transport has some quota to send the data. // Wait until the transport has some quota to send the data.
tq, err := wait(s.ctx, t.shutdownChan, t.sendQuotaPool.acquire()) tq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.sendQuotaPool.acquire())
if err != nil { if err != nil {
if _, ok := err.(StreamError); ok { if _, ok := err.(StreamError); ok || err == io.EOF {
t.sendQuotaPool.cancel() t.sendQuotaPool.cancel()
} }
return err return err
@ -544,8 +588,8 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
// Indicate there is a writer who is about to write a data frame. // Indicate there is a writer who is about to write a data frame.
t.framer.adjustNumWriters(1) t.framer.adjustNumWriters(1)
// Got some quota. Try to acquire writing privilege on the transport. // Got some quota. Try to acquire writing privilege on the transport.
if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { if _, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.writableChan); err != nil {
if _, ok := err.(StreamError); ok { if _, ok := err.(StreamError); ok || err == io.EOF {
// Return the connection quota back. // Return the connection quota back.
t.sendQuotaPool.add(len(p)) t.sendQuotaPool.add(len(p))
} }
@ -578,7 +622,7 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
// invoked. // invoked.
if err := t.framer.writeData(forceFlush, s.id, endStream, p); err != nil { if err := t.framer.writeData(forceFlush, s.id, endStream, p); err != nil {
t.notifyError(err) t.notifyError(err)
return ConnectionErrorf("transport: %v", err) return ConnectionErrorf(true, err, "transport: %v", err)
} }
if t.framer.adjustNumWriters(-1) == 0 { if t.framer.adjustNumWriters(-1) == 0 {
t.framer.flushWrite() t.framer.flushWrite()
@ -593,11 +637,7 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
} }
s.mu.Lock() s.mu.Lock()
if s.state != streamDone { if s.state != streamDone {
if s.state == streamReadDone { s.state = streamWriteDone
s.state = streamDone
} else {
s.state = streamWriteDone
}
} }
s.mu.Unlock() s.mu.Unlock()
return nil return nil
@ -630,7 +670,7 @@ func (t *http2Client) updateWindow(s *Stream, n uint32) {
func (t *http2Client) handleData(f *http2.DataFrame) { func (t *http2Client) handleData(f *http2.DataFrame) {
size := len(f.Data()) size := len(f.Data())
if err := t.fc.onData(uint32(size)); err != nil { if err := t.fc.onData(uint32(size)); err != nil {
t.notifyError(ConnectionErrorf("%v", err)) t.notifyError(ConnectionErrorf(true, err, "%v", err))
return return
} }
// Select the right stream to dispatch. // Select the right stream to dispatch.
@ -655,6 +695,7 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
s.state = streamDone s.state = streamDone
s.statusCode = codes.Internal s.statusCode = codes.Internal
s.statusDesc = err.Error() s.statusDesc = err.Error()
close(s.done)
s.mu.Unlock() s.mu.Unlock()
s.write(recvMsg{err: io.EOF}) s.write(recvMsg{err: io.EOF})
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl}) t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl})
@ -672,13 +713,14 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
// the read direction is closed, and set the status appropriately. // the read direction is closed, and set the status appropriately.
if f.FrameHeader.Flags.Has(http2.FlagDataEndStream) { if f.FrameHeader.Flags.Has(http2.FlagDataEndStream) {
s.mu.Lock() s.mu.Lock()
if s.state == streamWriteDone { if s.state == streamDone {
s.state = streamDone s.mu.Unlock()
} else { return
s.state = streamReadDone
} }
s.state = streamDone
s.statusCode = codes.Internal s.statusCode = codes.Internal
s.statusDesc = "server closed the stream without sending trailers" s.statusDesc = "server closed the stream without sending trailers"
close(s.done)
s.mu.Unlock() s.mu.Unlock()
s.write(recvMsg{err: io.EOF}) s.write(recvMsg{err: io.EOF})
} }
@ -704,6 +746,8 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) {
grpclog.Println("transport: http2Client.handleRSTStream found no mapped gRPC status for the received http2 error ", f.ErrCode) grpclog.Println("transport: http2Client.handleRSTStream found no mapped gRPC status for the received http2 error ", f.ErrCode)
s.statusCode = codes.Unknown s.statusCode = codes.Unknown
} }
s.statusDesc = fmt.Sprintf("stream terminated by RST_STREAM with error code: %d", f.ErrCode)
close(s.done)
s.mu.Unlock() s.mu.Unlock()
s.write(recvMsg{err: io.EOF}) s.write(recvMsg{err: io.EOF})
} }
@ -728,7 +772,32 @@ func (t *http2Client) handlePing(f *http2.PingFrame) {
} }
func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
// TODO(zhaoq): GoAwayFrame handler to be implemented t.mu.Lock()
if t.state == reachable || t.state == draining {
if f.LastStreamID > 0 && f.LastStreamID%2 != 1 {
t.mu.Unlock()
t.notifyError(ConnectionErrorf(true, nil, "received illegal http2 GOAWAY frame: stream ID %d is even", f.LastStreamID))
return
}
select {
case <-t.goAway:
id := t.goAwayID
// t.goAway has been closed (i.e.,multiple GoAways).
if id < f.LastStreamID {
t.mu.Unlock()
t.notifyError(ConnectionErrorf(true, nil, "received illegal http2 GOAWAY frame: previously recv GOAWAY frame with LastStramID %d, currently recv %d", id, f.LastStreamID))
return
}
t.prevGoAwayID = id
t.goAwayID = f.LastStreamID
t.mu.Unlock()
return
default:
}
t.goAwayID = f.LastStreamID
close(t.goAway)
}
t.mu.Unlock()
} }
func (t *http2Client) handleWindowUpdate(f *http2.WindowUpdateFrame) { func (t *http2Client) handleWindowUpdate(f *http2.WindowUpdateFrame) {
@ -780,11 +849,11 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
if len(state.mdata) > 0 { if len(state.mdata) > 0 {
s.trailer = state.mdata s.trailer = state.mdata
} }
s.state = streamDone
s.statusCode = state.statusCode s.statusCode = state.statusCode
s.statusDesc = state.statusDesc s.statusDesc = state.statusDesc
close(s.done)
s.state = streamDone
s.mu.Unlock() s.mu.Unlock()
s.write(recvMsg{err: io.EOF}) s.write(recvMsg{err: io.EOF})
} }
@ -937,13 +1006,22 @@ func (t *http2Client) Error() <-chan struct{} {
return t.errorChan return t.errorChan
} }
func (t *http2Client) GoAway() <-chan struct{} {
return t.goAway
}
func (t *http2Client) notifyError(err error) { func (t *http2Client) notifyError(err error) {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock()
// make sure t.errorChan is closed only once. // make sure t.errorChan is closed only once.
if t.state == draining {
t.mu.Unlock()
t.Close()
return
}
if t.state == reachable { if t.state == reachable {
t.state = unreachable t.state = unreachable
close(t.errorChan) close(t.errorChan)
grpclog.Printf("transport: http2Client.notifyError got notified that the client transport was broken %v.", err) grpclog.Printf("transport: http2Client.notifyError got notified that the client transport was broken %v.", err)
} }
t.mu.Unlock()
} }

View file

@ -111,12 +111,12 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthI
Val: uint32(initialWindowSize)}) Val: uint32(initialWindowSize)})
} }
if err := framer.writeSettings(true, settings...); err != nil { if err := framer.writeSettings(true, settings...); err != nil {
return nil, ConnectionErrorf("transport: %v", err) return nil, ConnectionErrorf(true, err, "transport: %v", err)
} }
// Adjust the connection flow control window if needed. // Adjust the connection flow control window if needed.
if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 { if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 {
if err := framer.writeWindowUpdate(true, 0, delta); err != nil { if err := framer.writeWindowUpdate(true, 0, delta); err != nil {
return nil, ConnectionErrorf("transport: %v", err) return nil, ConnectionErrorf(true, err, "transport: %v", err)
} }
} }
var buf bytes.Buffer var buf bytes.Buffer
@ -142,7 +142,7 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthI
} }
// operateHeader takes action on the decoded headers. // operateHeader takes action on the decoded headers.
func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) { func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) (close bool) {
buf := newRecvBuffer() buf := newRecvBuffer()
s := &Stream{ s := &Stream{
id: frame.Header().StreamID, id: frame.Header().StreamID,
@ -205,6 +205,13 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream}) t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream})
return return
} }
if s.id%2 != 1 || s.id <= t.maxStreamID {
t.mu.Unlock()
// illegal gRPC stream id.
grpclog.Println("transport: http2Server.HandleStreams received an illegal stream id: ", s.id)
return true
}
t.maxStreamID = s.id
s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota)) s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota))
t.activeStreams[s.id] = s t.activeStreams[s.id] = s
t.mu.Unlock() t.mu.Unlock()
@ -212,6 +219,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
t.updateWindow(s, uint32(n)) t.updateWindow(s, uint32(n))
} }
handle(s) handle(s)
return
} }
// HandleStreams receives incoming streams using the given handler. This is // HandleStreams receives incoming streams using the given handler. This is
@ -231,6 +239,10 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
} }
frame, err := t.framer.readFrame() frame, err := t.framer.readFrame()
if err == io.EOF || err == io.ErrUnexpectedEOF {
t.Close()
return
}
if err != nil { if err != nil {
grpclog.Printf("transport: http2Server.HandleStreams failed to read frame: %v", err) grpclog.Printf("transport: http2Server.HandleStreams failed to read frame: %v", err)
t.Close() t.Close()
@ -257,20 +269,20 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
t.controlBuf.put(&resetStream{se.StreamID, se.Code}) t.controlBuf.put(&resetStream{se.StreamID, se.Code})
continue continue
} }
if err == io.EOF || err == io.ErrUnexpectedEOF {
t.Close()
return
}
grpclog.Printf("transport: http2Server.HandleStreams failed to read frame: %v", err)
t.Close() t.Close()
return return
} }
switch frame := frame.(type) { switch frame := frame.(type) {
case *http2.MetaHeadersFrame: case *http2.MetaHeadersFrame:
id := frame.Header().StreamID if t.operateHeaders(frame, handle) {
if id%2 != 1 || id <= t.maxStreamID {
// illegal gRPC stream id.
grpclog.Println("transport: http2Server.HandleStreams received an illegal stream id: ", id)
t.Close() t.Close()
break break
} }
t.maxStreamID = id
t.operateHeaders(frame, handle)
case *http2.DataFrame: case *http2.DataFrame:
t.handleData(frame) t.handleData(frame)
case *http2.RSTStreamFrame: case *http2.RSTStreamFrame:
@ -282,7 +294,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
case *http2.WindowUpdateFrame: case *http2.WindowUpdateFrame:
t.handleWindowUpdate(frame) t.handleWindowUpdate(frame)
case *http2.GoAwayFrame: case *http2.GoAwayFrame:
break // TODO: Handle GoAway from the client appropriately.
default: default:
grpclog.Printf("transport: http2Server.HandleStreams found unhandled frame type %v.", frame) grpclog.Printf("transport: http2Server.HandleStreams found unhandled frame type %v.", frame)
} }
@ -364,11 +376,7 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
// Received the end of stream from the client. // Received the end of stream from the client.
s.mu.Lock() s.mu.Lock()
if s.state != streamDone { if s.state != streamDone {
if s.state == streamWriteDone { s.state = streamReadDone
s.state = streamDone
} else {
s.state = streamReadDone
}
} }
s.mu.Unlock() s.mu.Unlock()
s.write(recvMsg{err: io.EOF}) s.write(recvMsg{err: io.EOF})
@ -440,7 +448,7 @@ func (t *http2Server) writeHeaders(s *Stream, b *bytes.Buffer, endStream bool) e
} }
if err != nil { if err != nil {
t.Close() t.Close()
return ConnectionErrorf("transport: %v", err) return ConnectionErrorf(true, err, "transport: %v", err)
} }
} }
return nil return nil
@ -455,7 +463,7 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
} }
s.headerOk = true s.headerOk = true
s.mu.Unlock() s.mu.Unlock()
if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
return err return err
} }
t.hBuf.Reset() t.hBuf.Reset()
@ -495,7 +503,7 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s
headersSent = true headersSent = true
} }
s.mu.Unlock() s.mu.Unlock()
if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
return err return err
} }
t.hBuf.Reset() t.hBuf.Reset()
@ -508,7 +516,7 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s
Name: "grpc-status", Name: "grpc-status",
Value: strconv.Itoa(int(statusCode)), Value: strconv.Itoa(int(statusCode)),
}) })
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: statusDesc}) t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(statusDesc)})
// Attach the trailer metadata. // Attach the trailer metadata.
for k, v := range s.trailer { for k, v := range s.trailer {
// Clients don't tolerate reading restricted headers after some non restricted ones were sent. // Clients don't tolerate reading restricted headers after some non restricted ones were sent.
@ -544,7 +552,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
} }
s.mu.Unlock() s.mu.Unlock()
if writeHeaderFrame { if writeHeaderFrame {
if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
return err return err
} }
t.hBuf.Reset() t.hBuf.Reset()
@ -560,7 +568,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
} }
if err := t.framer.writeHeaders(false, p); err != nil { if err := t.framer.writeHeaders(false, p); err != nil {
t.Close() t.Close()
return ConnectionErrorf("transport: %v", err) return ConnectionErrorf(true, err, "transport: %v", err)
} }
t.writableChan <- 0 t.writableChan <- 0
} }
@ -572,13 +580,13 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
size := http2MaxFrameLen size := http2MaxFrameLen
s.sendQuotaPool.add(0) s.sendQuotaPool.add(0)
// Wait until the stream has some quota to send the data. // Wait until the stream has some quota to send the data.
sq, err := wait(s.ctx, t.shutdownChan, s.sendQuotaPool.acquire()) sq, err := wait(s.ctx, nil, nil, t.shutdownChan, s.sendQuotaPool.acquire())
if err != nil { if err != nil {
return err return err
} }
t.sendQuotaPool.add(0) t.sendQuotaPool.add(0)
// Wait until the transport has some quota to send the data. // Wait until the transport has some quota to send the data.
tq, err := wait(s.ctx, t.shutdownChan, t.sendQuotaPool.acquire()) tq, err := wait(s.ctx, nil, nil, t.shutdownChan, t.sendQuotaPool.acquire())
if err != nil { if err != nil {
if _, ok := err.(StreamError); ok { if _, ok := err.(StreamError); ok {
t.sendQuotaPool.cancel() t.sendQuotaPool.cancel()
@ -604,7 +612,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
t.framer.adjustNumWriters(1) t.framer.adjustNumWriters(1)
// Got some quota. Try to acquire writing privilege on the // Got some quota. Try to acquire writing privilege on the
// transport. // transport.
if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
if _, ok := err.(StreamError); ok { if _, ok := err.(StreamError); ok {
// Return the connection quota back. // Return the connection quota back.
t.sendQuotaPool.add(ps) t.sendQuotaPool.add(ps)
@ -634,7 +642,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
} }
if err := t.framer.writeData(forceFlush, s.id, false, p); err != nil { if err := t.framer.writeData(forceFlush, s.id, false, p); err != nil {
t.Close() t.Close()
return ConnectionErrorf("transport: %v", err) return ConnectionErrorf(true, err, "transport: %v", err)
} }
if t.framer.adjustNumWriters(-1) == 0 { if t.framer.adjustNumWriters(-1) == 0 {
t.framer.flushWrite() t.framer.flushWrite()
@ -679,6 +687,17 @@ func (t *http2Server) controller() {
} }
case *resetStream: case *resetStream:
t.framer.writeRSTStream(true, i.streamID, i.code) t.framer.writeRSTStream(true, i.streamID, i.code)
case *goAway:
t.mu.Lock()
if t.state == closing {
t.mu.Unlock()
// The transport is closing.
return
}
sid := t.maxStreamID
t.state = draining
t.mu.Unlock()
t.framer.writeGoAway(true, sid, http2.ErrCodeNo, nil)
case *flushIO: case *flushIO:
t.framer.flushWrite() t.framer.flushWrite()
case *ping: case *ping:
@ -724,6 +743,9 @@ func (t *http2Server) Close() (err error) {
func (t *http2Server) closeStream(s *Stream) { func (t *http2Server) closeStream(s *Stream) {
t.mu.Lock() t.mu.Lock()
delete(t.activeStreams, s.id) delete(t.activeStreams, s.id)
if t.state == draining && len(t.activeStreams) == 0 {
defer t.Close()
}
t.mu.Unlock() t.mu.Unlock()
// In case stream sending and receiving are invoked in separate // In case stream sending and receiving are invoked in separate
// goroutines (e.g., bi-directional streaming), cancel needs to be // goroutines (e.g., bi-directional streaming), cancel needs to be
@ -746,3 +768,7 @@ func (t *http2Server) closeStream(s *Stream) {
func (t *http2Server) RemoteAddr() net.Addr { func (t *http2Server) RemoteAddr() net.Addr {
return t.conn.RemoteAddr() return t.conn.RemoteAddr()
} }
func (t *http2Server) Drain() {
t.controlBuf.put(&goAway{})
}

View file

@ -35,6 +35,7 @@ package transport
import ( import (
"bufio" "bufio"
"bytes"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -174,11 +175,11 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) {
} }
d.statusCode = codes.Code(code) d.statusCode = codes.Code(code)
case "grpc-message": case "grpc-message":
d.statusDesc = f.Value d.statusDesc = decodeGrpcMessage(f.Value)
case "grpc-timeout": case "grpc-timeout":
d.timeoutSet = true d.timeoutSet = true
var err error var err error
d.timeout, err = timeoutDecode(f.Value) d.timeout, err = decodeTimeout(f.Value)
if err != nil { if err != nil {
d.setErr(StreamErrorf(codes.Internal, "transport: malformed time-out: %v", err)) d.setErr(StreamErrorf(codes.Internal, "transport: malformed time-out: %v", err))
return return
@ -251,7 +252,7 @@ func div(d, r time.Duration) int64 {
} }
// TODO(zhaoq): It is the simplistic and not bandwidth efficient. Improve it. // TODO(zhaoq): It is the simplistic and not bandwidth efficient. Improve it.
func timeoutEncode(t time.Duration) string { func encodeTimeout(t time.Duration) string {
if d := div(t, time.Nanosecond); d <= maxTimeoutValue { if d := div(t, time.Nanosecond); d <= maxTimeoutValue {
return strconv.FormatInt(d, 10) + "n" return strconv.FormatInt(d, 10) + "n"
} }
@ -271,7 +272,7 @@ func timeoutEncode(t time.Duration) string {
return strconv.FormatInt(div(t, time.Hour), 10) + "H" return strconv.FormatInt(div(t, time.Hour), 10) + "H"
} }
func timeoutDecode(s string) (time.Duration, error) { func decodeTimeout(s string) (time.Duration, error) {
size := len(s) size := len(s)
if size < 2 { if size < 2 {
return 0, fmt.Errorf("transport: timeout string is too short: %q", s) return 0, fmt.Errorf("transport: timeout string is too short: %q", s)
@ -288,6 +289,80 @@ func timeoutDecode(s string) (time.Duration, error) {
return d * time.Duration(t), nil return d * time.Duration(t), nil
} }
const (
spaceByte = ' '
tildaByte = '~'
percentByte = '%'
)
// encodeGrpcMessage is used to encode status code in header field
// "grpc-message".
// It checks to see if each individual byte in msg is an
// allowable byte, and then either percent encoding or passing it through.
// When percent encoding, the byte is converted into hexadecimal notation
// with a '%' prepended.
func encodeGrpcMessage(msg string) string {
if msg == "" {
return ""
}
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
c := msg[i]
if !(c >= spaceByte && c < tildaByte && c != percentByte) {
return encodeGrpcMessageUnchecked(msg)
}
}
return msg
}
func encodeGrpcMessageUnchecked(msg string) string {
var buf bytes.Buffer
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
c := msg[i]
if c >= spaceByte && c < tildaByte && c != percentByte {
buf.WriteByte(c)
} else {
buf.WriteString(fmt.Sprintf("%%%02X", c))
}
}
return buf.String()
}
// decodeGrpcMessage decodes the msg encoded by encodeGrpcMessage.
func decodeGrpcMessage(msg string) string {
if msg == "" {
return ""
}
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
if msg[i] == percentByte && i+2 < lenMsg {
return decodeGrpcMessageUnchecked(msg)
}
}
return msg
}
func decodeGrpcMessageUnchecked(msg string) string {
var buf bytes.Buffer
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
c := msg[i]
if c == percentByte && i+2 < lenMsg {
parsed, err := strconv.ParseInt(msg[i+1:i+3], 16, 8)
if err != nil {
buf.WriteByte(c)
} else {
buf.WriteByte(byte(parsed))
i += 2
}
} else {
buf.WriteByte(c)
}
}
return buf.String()
}
type framer struct { type framer struct {
numWriters int32 numWriters int32
reader io.Reader reader io.Reader

51
vendor/google.golang.org/grpc/transport/pre_go16.go generated vendored Normal file
View file

@ -0,0 +1,51 @@
// +build !go1.6
/*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package transport
import (
"net"
"time"
"golang.org/x/net/context"
)
// dialContext connects to the address on the named network.
func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
var dialer net.Dialer
if deadline, ok := ctx.Deadline(); ok {
dialer.Timeout = deadline.Sub(time.Now())
}
return dialer.Dial(network, address)
}

View file

@ -44,7 +44,6 @@ import (
"io" "io"
"net" "net"
"sync" "sync"
"time"
"golang.org/x/net/context" "golang.org/x/net/context"
"golang.org/x/net/trace" "golang.org/x/net/trace"
@ -120,10 +119,11 @@ func (b *recvBuffer) get() <-chan item {
// recvBufferReader implements io.Reader interface to read the data from // recvBufferReader implements io.Reader interface to read the data from
// recvBuffer. // recvBuffer.
type recvBufferReader struct { type recvBufferReader struct {
ctx context.Context ctx context.Context
recv *recvBuffer goAway chan struct{}
last *bytes.Reader // Stores the remaining data in the previous calls. recv *recvBuffer
err error last *bytes.Reader // Stores the remaining data in the previous calls.
err error
} }
// Read reads the next len(p) bytes from last. If last is drained, it tries to // Read reads the next len(p) bytes from last. If last is drained, it tries to
@ -141,6 +141,8 @@ func (r *recvBufferReader) Read(p []byte) (n int, err error) {
select { select {
case <-r.ctx.Done(): case <-r.ctx.Done():
return 0, ContextErr(r.ctx.Err()) return 0, ContextErr(r.ctx.Err())
case <-r.goAway:
return 0, ErrStreamDrain
case i := <-r.recv.get(): case i := <-r.recv.get():
r.recv.load() r.recv.load()
m := i.(*recvMsg) m := i.(*recvMsg)
@ -158,7 +160,7 @@ const (
streamActive streamState = iota streamActive streamState = iota
streamWriteDone // EndStream sent streamWriteDone // EndStream sent
streamReadDone // EndStream received streamReadDone // EndStream received
streamDone // sendDone and recvDone or RSTStreamFrame is sent or received. streamDone // the entire stream is finished.
) )
// Stream represents an RPC in the transport layer. // Stream represents an RPC in the transport layer.
@ -169,6 +171,10 @@ type Stream struct {
// ctx is the associated context of the stream. // ctx is the associated context of the stream.
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
// done is closed when the final status arrives.
done chan struct{}
// goAway is closed when the server sent GoAways signal before this stream was initiated.
goAway chan struct{}
// method records the associated RPC method of the stream. // method records the associated RPC method of the stream.
method string method string
recvCompress string recvCompress string
@ -214,6 +220,18 @@ func (s *Stream) SetSendCompress(str string) {
s.sendCompress = str s.sendCompress = str
} }
// Done returns a chanel which is closed when it receives the final status
// from the server.
func (s *Stream) Done() <-chan struct{} {
return s.done
}
// GoAway returns a channel which is closed when the server sent GoAways signal
// before this stream was initiated.
func (s *Stream) GoAway() <-chan struct{} {
return s.goAway
}
// Header acquires the key-value pairs of header metadata once it // Header acquires the key-value pairs of header metadata once it
// is available. It blocks until i) the metadata is ready or ii) there is no // is available. It blocks until i) the metadata is ready or ii) there is no
// header metadata or iii) the stream is cancelled/expired. // header metadata or iii) the stream is cancelled/expired.
@ -221,6 +239,8 @@ func (s *Stream) Header() (metadata.MD, error) {
select { select {
case <-s.ctx.Done(): case <-s.ctx.Done():
return nil, ContextErr(s.ctx.Err()) return nil, ContextErr(s.ctx.Err())
case <-s.goAway:
return nil, ErrStreamDrain
case <-s.headerChan: case <-s.headerChan:
return s.header.Copy(), nil return s.header.Copy(), nil
} }
@ -335,19 +355,17 @@ type ConnectOptions struct {
// UserAgent is the application user agent. // UserAgent is the application user agent.
UserAgent string UserAgent string
// Dialer specifies how to dial a network address. // Dialer specifies how to dial a network address.
Dialer func(string, time.Duration) (net.Conn, error) Dialer func(context.Context, string) (net.Conn, error)
// PerRPCCredentials stores the PerRPCCredentials required to issue RPCs. // PerRPCCredentials stores the PerRPCCredentials required to issue RPCs.
PerRPCCredentials []credentials.PerRPCCredentials PerRPCCredentials []credentials.PerRPCCredentials
// TransportCredentials stores the Authenticator required to setup a client connection. // TransportCredentials stores the Authenticator required to setup a client connection.
TransportCredentials credentials.TransportCredentials TransportCredentials credentials.TransportCredentials
// Timeout specifies the timeout for dialing a ClientTransport.
Timeout time.Duration
} }
// NewClientTransport establishes the transport with the required ConnectOptions // NewClientTransport establishes the transport with the required ConnectOptions
// and returns it to the caller. // and returns it to the caller.
func NewClientTransport(target string, opts *ConnectOptions) (ClientTransport, error) { func NewClientTransport(ctx context.Context, target string, opts ConnectOptions) (ClientTransport, error) {
return newHTTP2Client(target, opts) return newHTTP2Client(ctx, target, opts)
} }
// Options provides additional hints and information for message // Options provides additional hints and information for message
@ -417,6 +435,11 @@ type ClientTransport interface {
// and create a new one) in error case. It should not return nil // and create a new one) in error case. It should not return nil
// once the transport is initiated. // once the transport is initiated.
Error() <-chan struct{} Error() <-chan struct{}
// GoAway returns a channel that is closed when ClientTranspor
// receives the draining signal from the server (e.g., GOAWAY frame in
// HTTP/2).
GoAway() <-chan struct{}
} }
// ServerTransport is the common interface for all gRPC server-side transport // ServerTransport is the common interface for all gRPC server-side transport
@ -448,6 +471,9 @@ type ServerTransport interface {
// RemoteAddr returns the remote network address. // RemoteAddr returns the remote network address.
RemoteAddr() net.Addr RemoteAddr() net.Addr
// Drain notifies the client this ServerTransport stops accepting new RPCs.
Drain()
} }
// StreamErrorf creates an StreamError with the specified error code and description. // StreamErrorf creates an StreamError with the specified error code and description.
@ -459,9 +485,11 @@ func StreamErrorf(c codes.Code, format string, a ...interface{}) StreamError {
} }
// ConnectionErrorf creates an ConnectionError with the specified error description. // ConnectionErrorf creates an ConnectionError with the specified error description.
func ConnectionErrorf(format string, a ...interface{}) ConnectionError { func ConnectionErrorf(temp bool, e error, format string, a ...interface{}) ConnectionError {
return ConnectionError{ return ConnectionError{
Desc: fmt.Sprintf(format, a...), Desc: fmt.Sprintf(format, a...),
temp: temp,
err: e,
} }
} }
@ -469,14 +497,36 @@ func ConnectionErrorf(format string, a ...interface{}) ConnectionError {
// entire connection and the retry of all the active streams. // entire connection and the retry of all the active streams.
type ConnectionError struct { type ConnectionError struct {
Desc string Desc string
temp bool
err error
} }
func (e ConnectionError) Error() string { func (e ConnectionError) Error() string {
return fmt.Sprintf("connection error: desc = %q", e.Desc) return fmt.Sprintf("connection error: desc = %q", e.Desc)
} }
// ErrConnClosing indicates that the transport is closing. // Temporary indicates if this connection error is temporary or fatal.
var ErrConnClosing = ConnectionError{Desc: "transport is closing"} func (e ConnectionError) Temporary() bool {
return e.temp
}
// Origin returns the original error of this connection error.
func (e ConnectionError) Origin() error {
// Never return nil error here.
// If the original error is nil, return itself.
if e.err == nil {
return e
}
return e.err
}
var (
// ErrConnClosing indicates that the transport is closing.
ErrConnClosing = ConnectionError{Desc: "transport is closing", temp: true}
// ErrStreamDrain indicates that the stream is rejected by the server because
// the server stops accepting new RPCs.
ErrStreamDrain = StreamErrorf(codes.Unavailable, "the server stops accepting new RPCs")
)
// StreamError is an error that only affects one stream within a connection. // StreamError is an error that only affects one stream within a connection.
type StreamError struct { type StreamError struct {
@ -501,12 +551,25 @@ func ContextErr(err error) StreamError {
// wait blocks until it can receive from ctx.Done, closing, or proceed. // wait blocks until it can receive from ctx.Done, closing, or proceed.
// If it receives from ctx.Done, it returns 0, the StreamError for ctx.Err. // If it receives from ctx.Done, it returns 0, the StreamError for ctx.Err.
// If it receives from done, it returns 0, io.EOF if ctx is not done; otherwise
// it return the StreamError for ctx.Err.
// If it receives from goAway, it returns 0, ErrStreamDrain.
// If it receives from closing, it returns 0, ErrConnClosing. // If it receives from closing, it returns 0, ErrConnClosing.
// If it receives from proceed, it returns the received integer, nil. // If it receives from proceed, it returns the received integer, nil.
func wait(ctx context.Context, closing <-chan struct{}, proceed <-chan int) (int, error) { func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <-chan int) (int, error) {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return 0, ContextErr(ctx.Err()) return 0, ContextErr(ctx.Err())
case <-done:
// User cancellation has precedence.
select {
case <-ctx.Done():
return 0, ContextErr(ctx.Err())
default:
}
return 0, io.EOF
case <-goAway:
return 0, ErrStreamDrain
case <-closing: case <-closing:
return 0, ErrConnClosing return 0, ErrConnClosing
case i := <-proceed: case i := <-proceed:

1
vendor/k8s.io/client-go generated vendored
View file

@ -1 +0,0 @@
kubernetes/staging/src/k8s.io/client-go

View file

@ -325,13 +325,13 @@ func NewGenericServerResponse(code int, verb string, qualifiedResource unversion
default: default:
if code >= 500 { if code >= 500 {
reason = unversioned.StatusReasonInternalError reason = unversioned.StatusReasonInternalError
message = "an error on the server has prevented the request from succeeding" message = fmt.Sprintf("an error on the server (%q) has prevented the request from succeeding", serverMessage)
} }
} }
switch { switch {
case !qualifiedResource.IsEmpty() && len(name) > 0: case !qualifiedResource.Empty() && len(name) > 0:
message = fmt.Sprintf("%s (%s %s %s)", message, strings.ToLower(verb), qualifiedResource.String(), name) message = fmt.Sprintf("%s (%s %s %s)", message, strings.ToLower(verb), qualifiedResource.String(), name)
case !qualifiedResource.IsEmpty(): case !qualifiedResource.Empty():
message = fmt.Sprintf("%s (%s %s)", message, strings.ToLower(verb), qualifiedResource.String()) message = fmt.Sprintf("%s (%s %s)", message, strings.ToLower(verb), qualifiedResource.String())
} }
var causes []unversioned.StatusCause var causes []unversioned.StatusCause

View file

@ -30,6 +30,7 @@ import (
"k8s.io/client-go/1.4/pkg/fields" "k8s.io/client-go/1.4/pkg/fields"
"k8s.io/client-go/1.4/pkg/labels" "k8s.io/client-go/1.4/pkg/labels"
"k8s.io/client-go/1.4/pkg/runtime" "k8s.io/client-go/1.4/pkg/runtime"
"k8s.io/client-go/1.4/pkg/selection"
"k8s.io/client-go/1.4/pkg/types" "k8s.io/client-go/1.4/pkg/types"
"k8s.io/client-go/1.4/pkg/util/sets" "k8s.io/client-go/1.4/pkg/util/sets"
@ -222,6 +223,10 @@ func IsServiceIPSet(service *Service) bool {
// this function aims to check if the service's cluster IP is requested or not // this function aims to check if the service's cluster IP is requested or not
func IsServiceIPRequested(service *Service) bool { func IsServiceIPRequested(service *Service) bool {
// ExternalName services are CNAME aliases to external ones. Ignore the IP.
if service.Spec.Type == ServiceTypeExternalName {
return false
}
return service.Spec.ClusterIP == "" return service.Spec.ClusterIP == ""
} }
@ -379,20 +384,20 @@ func NodeSelectorRequirementsAsSelector(nsm []NodeSelectorRequirement) (labels.S
} }
selector := labels.NewSelector() selector := labels.NewSelector()
for _, expr := range nsm { for _, expr := range nsm {
var op labels.Operator var op selection.Operator
switch expr.Operator { switch expr.Operator {
case NodeSelectorOpIn: case NodeSelectorOpIn:
op = labels.InOperator op = selection.In
case NodeSelectorOpNotIn: case NodeSelectorOpNotIn:
op = labels.NotInOperator op = selection.NotIn
case NodeSelectorOpExists: case NodeSelectorOpExists:
op = labels.ExistsOperator op = selection.Exists
case NodeSelectorOpDoesNotExist: case NodeSelectorOpDoesNotExist:
op = labels.DoesNotExistOperator op = selection.DoesNotExist
case NodeSelectorOpGt: case NodeSelectorOpGt:
op = labels.GreaterThanOperator op = selection.GreaterThan
case NodeSelectorOpLt: case NodeSelectorOpLt:
op = labels.LessThanOperator op = selection.LessThan
default: default:
return nil, fmt.Errorf("%q is not a valid node selector operator", expr.Operator) return nil, fmt.Errorf("%q is not a valid node selector operator", expr.Operator)
} }
@ -433,6 +438,20 @@ const (
// PreferAvoidPodsAnnotationKey represents the key of preferAvoidPods data (json serialized) // PreferAvoidPodsAnnotationKey represents the key of preferAvoidPods data (json serialized)
// in the Annotations of a Node. // in the Annotations of a Node.
PreferAvoidPodsAnnotationKey string = "scheduler.alpha.kubernetes.io/preferAvoidPods" PreferAvoidPodsAnnotationKey string = "scheduler.alpha.kubernetes.io/preferAvoidPods"
// SysctlsPodAnnotationKey represents the key of sysctls which are set for the infrastructure
// container of a pod. The annotation value is a comma separated list of sysctl_name=value
// key-value pairs. Only a limited set of whitelisted and isolated sysctls is supported by
// the kubelet. Pods with other sysctls will fail to launch.
SysctlsPodAnnotationKey string = "security.alpha.kubernetes.io/sysctls"
// UnsafeSysctlsPodAnnotationKey represents the key of sysctls which are set for the infrastructure
// container of a pod. The annotation value is a comma separated list of sysctl_name=value
// key-value pairs. Unsafe sysctls must be explicitly enabled for a kubelet. They are properly
// namespaced to a pod or a container, but their isolation is usually unclear or weak. Their use
// is at-your-own-risk. Pods that attempt to set an unsafe sysctl that is not enabled for a kubelet
// will fail to launch.
UnsafeSysctlsPodAnnotationKey string = "security.alpha.kubernetes.io/unsafe-sysctls"
) )
// GetAffinityFromPod gets the json serialized affinity data from Pod.Annotations // GetAffinityFromPod gets the json serialized affinity data from Pod.Annotations
@ -517,3 +536,51 @@ func GetAvoidPodsFromNodeAnnotations(annotations map[string]string) (AvoidPods,
} }
return avoidPods, nil return avoidPods, nil
} }
// SysctlsFromPodAnnotations parses the sysctl annotations into a slice of safe Sysctls
// and a slice of unsafe Sysctls. This is only a convenience wrapper around
// SysctlsFromPodAnnotation.
func SysctlsFromPodAnnotations(a map[string]string) ([]Sysctl, []Sysctl, error) {
safe, err := SysctlsFromPodAnnotation(a[SysctlsPodAnnotationKey])
if err != nil {
return nil, nil, err
}
unsafe, err := SysctlsFromPodAnnotation(a[UnsafeSysctlsPodAnnotationKey])
if err != nil {
return nil, nil, err
}
return safe, unsafe, nil
}
// SysctlsFromPodAnnotation parses an annotation value into a slice of Sysctls.
func SysctlsFromPodAnnotation(annotation string) ([]Sysctl, error) {
if len(annotation) == 0 {
return nil, nil
}
kvs := strings.Split(annotation, ",")
sysctls := make([]Sysctl, len(kvs))
for i, kv := range kvs {
cs := strings.Split(kv, "=")
if len(cs) != 2 {
return nil, fmt.Errorf("sysctl %q not of the format sysctl_name=value", kv)
}
sysctls[i].Name = cs[0]
sysctls[i].Value = cs[1]
}
return sysctls, nil
}
// PodAnnotationsFromSysctls creates an annotation value for a slice of Sysctls.
func PodAnnotationsFromSysctls(sysctls []Sysctl) string {
if len(sysctls) == 0 {
return ""
}
kvs := make([]string, len(sysctls))
for i := range sysctls {
kvs[i] = fmt.Sprintf("%s=%s", sysctls[i].Name, sysctls[i].Value)
}
return strings.Join(kvs, ",")
}

View file

@ -21,6 +21,7 @@ import (
"k8s.io/client-go/1.4/pkg/api/meta" "k8s.io/client-go/1.4/pkg/api/meta"
"k8s.io/client-go/1.4/pkg/api/unversioned" "k8s.io/client-go/1.4/pkg/api/unversioned"
"k8s.io/client-go/1.4/pkg/runtime"
"k8s.io/client-go/1.4/pkg/util/sets" "k8s.io/client-go/1.4/pkg/util/sets"
) )
@ -34,14 +35,21 @@ func RegisterRESTMapper(m meta.RESTMapper) {
RESTMapper = append(RESTMapper.(meta.MultiRESTMapper), m) RESTMapper = append(RESTMapper.(meta.MultiRESTMapper), m)
} }
// Instantiates a DefaultRESTMapper based on types registered in api.Scheme
func NewDefaultRESTMapper(defaultGroupVersions []unversioned.GroupVersion, interfacesFunc meta.VersionInterfacesFunc, func NewDefaultRESTMapper(defaultGroupVersions []unversioned.GroupVersion, interfacesFunc meta.VersionInterfacesFunc,
importPathPrefix string, ignoredKinds, rootScoped sets.String) *meta.DefaultRESTMapper { importPathPrefix string, ignoredKinds, rootScoped sets.String) *meta.DefaultRESTMapper {
return NewDefaultRESTMapperFromScheme(defaultGroupVersions, interfacesFunc, importPathPrefix, ignoredKinds, rootScoped, Scheme)
}
// Instantiates a DefaultRESTMapper based on types registered in the given scheme.
func NewDefaultRESTMapperFromScheme(defaultGroupVersions []unversioned.GroupVersion, interfacesFunc meta.VersionInterfacesFunc,
importPathPrefix string, ignoredKinds, rootScoped sets.String, scheme *runtime.Scheme) *meta.DefaultRESTMapper {
mapper := meta.NewDefaultRESTMapper(defaultGroupVersions, interfacesFunc) mapper := meta.NewDefaultRESTMapper(defaultGroupVersions, interfacesFunc)
// enumerate all supported versions, get the kinds, and register with the mapper how to address // enumerate all supported versions, get the kinds, and register with the mapper how to address
// our resources. // our resources.
for _, gv := range defaultGroupVersions { for _, gv := range defaultGroupVersions {
for kind, oType := range Scheme.KnownTypes(gv) { for kind, oType := range scheme.KnownTypes(gv) {
gvk := gv.WithKind(kind) gvk := gv.WithKind(kind)
// TODO: Remove import path check. // TODO: Remove import path check.
// We check the import path because we currently stuff both "api" and "extensions" objects // We check the import path because we currently stuff both "api" and "extensions" objects

View file

@ -131,3 +131,10 @@ func (meta *ObjectMeta) SetOwnerReferences(references []metatypes.OwnerReference
} }
meta.OwnerReferences = newReferences meta.OwnerReferences = newReferences
} }
func (meta *ObjectMeta) GetClusterName() string {
return meta.ClusterName
}
func (meta *ObjectMeta) SetClusterName(clusterName string) {
meta.ClusterName = clusterName
}

View file

@ -62,6 +62,8 @@ type Object interface {
SetFinalizers(finalizers []string) SetFinalizers(finalizers []string)
GetOwnerReferences() []metatypes.OwnerReference GetOwnerReferences() []metatypes.OwnerReference
SetOwnerReferences([]metatypes.OwnerReference) SetOwnerReferences([]metatypes.OwnerReference)
GetClusterName() string
SetClusterName(clusterName string)
} }
var _ Object = &runtime.Unstructured{} var _ Object = &runtime.Unstructured{}
@ -161,16 +163,16 @@ type RESTMapping struct {
// TODO(caesarxuchao): Add proper multi-group support so that kinds & resources are // TODO(caesarxuchao): Add proper multi-group support so that kinds & resources are
// scoped to groups. See http://issues.k8s.io/12413 and http://issues.k8s.io/10009. // scoped to groups. See http://issues.k8s.io/12413 and http://issues.k8s.io/10009.
type RESTMapper interface { type RESTMapper interface {
// KindFor takes a partial resource and returns back the single match. Returns an error if there are multiple matches // KindFor takes a partial resource and returns the single match. Returns an error if there are multiple matches
KindFor(resource unversioned.GroupVersionResource) (unversioned.GroupVersionKind, error) KindFor(resource unversioned.GroupVersionResource) (unversioned.GroupVersionKind, error)
// KindsFor takes a partial resource and returns back the list of potential kinds in priority order // KindsFor takes a partial resource and returns the list of potential kinds in priority order
KindsFor(resource unversioned.GroupVersionResource) ([]unversioned.GroupVersionKind, error) KindsFor(resource unversioned.GroupVersionResource) ([]unversioned.GroupVersionKind, error)
// ResourceFor takes a partial resource and returns back the single match. Returns an error if there are multiple matches // ResourceFor takes a partial resource and returns the single match. Returns an error if there are multiple matches
ResourceFor(input unversioned.GroupVersionResource) (unversioned.GroupVersionResource, error) ResourceFor(input unversioned.GroupVersionResource) (unversioned.GroupVersionResource, error)
// ResourcesFor takes a partial resource and returns back the list of potential resource in priority order // ResourcesFor takes a partial resource and returns the list of potential resource in priority order
ResourcesFor(input unversioned.GroupVersionResource) ([]unversioned.GroupVersionResource, error) ResourcesFor(input unversioned.GroupVersionResource) ([]unversioned.GroupVersionResource, error)
// RESTMapping identifies a preferred resource mapping for the provided group kind. // RESTMapping identifies a preferred resource mapping for the provided group kind.

View file

@ -183,17 +183,17 @@ func (m *DefaultRESTMapper) ResourceSingularizer(resourceType string) (string, e
if !ok { if !ok {
continue continue
} }
if singular.IsEmpty() { if singular.Empty() {
singular = currSingular singular = currSingular
continue continue
} }
if currSingular.Resource != singular.Resource { if currSingular.Resource != singular.Resource {
return resourceType, fmt.Errorf("multiple possibile singular resources (%v) found for %v", resources, resourceType) return resourceType, fmt.Errorf("multiple possible singular resources (%v) found for %v", resources, resourceType)
} }
} }
if singular.IsEmpty() { if singular.Empty() {
return resourceType, fmt.Errorf("no singular of resource %v has been defined", resourceType) return resourceType, fmt.Errorf("no singular of resource %v has been defined", resourceType)
} }

View file

@ -0,0 +1,31 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package meta
import (
"k8s.io/client-go/1.4/pkg/api/unversioned"
"k8s.io/client-go/1.4/pkg/runtime"
)
// InterfacesForUnstructured returns VersionInterfaces suitable for
// dealing with runtime.Unstructured objects.
func InterfacesForUnstructured(unversioned.GroupVersion) (*VersionInterfaces, error) {
return &VersionInterfaces{
ObjectConvertor: &runtime.UnstructuredObjectConverter{},
MetadataAccessor: NewAccessor(),
}, nil
}

View file

@ -30,7 +30,7 @@ const (
// TODO: to be de!eted after v1.3 is released. PodSpec has a dedicated Subdomain field. // TODO: to be de!eted after v1.3 is released. PodSpec has a dedicated Subdomain field.
// The annotation value is a string specifying the subdomain e.g. "my-web-service" // The annotation value is a string specifying the subdomain e.g. "my-web-service"
// If specified, on the the pod itself, "<hostname>.my-web-service.<namespace>.svc.<cluster domain>" would resolve to // If specified, on the pod itself, "<hostname>.my-web-service.<namespace>.svc.<cluster domain>" would resolve to
// the pod's IP. // the pod's IP.
// If there is a headless service named "my-web-service" in the same namespace as the pod, then, // If there is a headless service named "my-web-service" in the same namespace as the pod, then,
// <hostname>.my-web-service.<namespace>.svc.<cluster domain>" would be resolved by the cluster DNS Server. // <hostname>.my-web-service.<namespace>.svc.<cluster domain>" would be resolved by the cluster DNS Server.

View file

@ -45,12 +45,12 @@ var Unversioned = unversioned.GroupVersion{Group: "", Version: "v1"}
// ParameterCodec handles versioning of objects that are converted to query parameters. // ParameterCodec handles versioning of objects that are converted to query parameters.
var ParameterCodec = runtime.NewParameterCodec(Scheme) var ParameterCodec = runtime.NewParameterCodec(Scheme)
// Kind takes an unqualified kind and returns back a Group qualified GroupKind // Kind takes an unqualified kind and returns a Group qualified GroupKind
func Kind(kind string) unversioned.GroupKind { func Kind(kind string) unversioned.GroupKind {
return SchemeGroupVersion.WithKind(kind).GroupKind() return SchemeGroupVersion.WithKind(kind).GroupKind()
} }
// Resource takes an unqualified resource and returns back a Group qualified GroupResource // Resource takes an unqualified resource and returns a Group qualified GroupResource
func Resource(resource string) unversioned.GroupResource { func Resource(resource string) unversioned.GroupResource {
return SchemeGroupVersion.WithResource(resource).GroupResource() return SchemeGroupVersion.WithResource(resource).GroupResource()
} }
@ -62,7 +62,7 @@ var (
func init() { func init() {
// TODO(lavalamp): move this call to scheme builder above. Can't // TODO(lavalamp): move this call to scheme builder above. Can't
// remove it from here because lots of people inapropriately rely on it // remove it from here because lots of people inappropriately rely on it
// (specifically the unversioned time conversion). Can't have it in // (specifically the unversioned time conversion). Can't have it in
// both places because then it gets double registered. Consequence of // both places because then it gets double registered. Consequence of
// current state is that it only ever gets registered in the main // current state is that it only ever gets registered in the main

View file

@ -40,7 +40,8 @@
"secret": null, "secret": null,
"nfs": null, "nfs": null,
"iscsi": null, "iscsi": null,
"glusterfs": null "glusterfs": null,
"quobyte": null
} }
], ],
"containers": [ "containers": [

Some files were not shown because too many files have changed in this diff Show more