Update godeps

This commit is contained in:
Manuel de Brito Fontes 2016-09-21 20:00:42 -03:00
parent a965f44f84
commit 73e22a50d2
453 changed files with 84778 additions and 70308 deletions

1153
Godeps/Godeps.json generated

File diff suppressed because it is too large Load diff

View file

@ -27,6 +27,7 @@ import (
"net/http"
"net/url"
"os"
"runtime"
"strings"
"sync"
"time"
@ -34,11 +35,20 @@ import (
"golang.org/x/net/context"
"golang.org/x/net/context/ctxhttp"
"google.golang.org/cloud/internal"
"cloud.google.com/go/internal"
)
// metadataIP is the documented metadata server IP address.
const metadataIP = "169.254.169.254"
const (
// metadataIP is the documented metadata server IP address.
metadataIP = "169.254.169.254"
// metadataHostEnv is the environment variable specifying the
// GCE metadata hostname. If empty, the default value of
// metadataIP ("169.254.169.254") is used instead.
// This is variable name is not defined by any spec, as far as
// I know; it was made up for the Go package.
metadataHostEnv = "GCE_METADATA_HOST"
)
type cachedValue struct {
k string
@ -110,7 +120,7 @@ func getETag(client *http.Client, suffix string) (value, etag string, err error)
// deployments. To enable spoofing of the metadata service, the environment
// variable GCE_METADATA_HOST is first inspected to decide where metadata
// requests shall go.
host := os.Getenv("GCE_METADATA_HOST")
host := os.Getenv(metadataHostEnv)
if host == "" {
// Using 169.254.169.254 instead of "metadata" here because Go
// binaries built with the "netgo" tag and without cgo won't
@ -163,32 +173,34 @@ func (c *cachedValue) get() (v string, err error) {
return
}
var onGCE struct {
sync.Mutex
set bool
v bool
}
var (
onGCEOnce sync.Once
onGCE bool
)
// OnGCE reports whether this process is running on Google Compute Engine.
func OnGCE() bool {
defer onGCE.Unlock()
onGCE.Lock()
if onGCE.set {
return onGCE.v
}
onGCE.set = true
onGCE.v = testOnGCE()
return onGCE.v
onGCEOnce.Do(initOnGCE)
return onGCE
}
func initOnGCE() {
onGCE = testOnGCE()
}
func testOnGCE() bool {
// The user explicitly said they're on GCE, so trust them.
if os.Getenv(metadataHostEnv) != "" {
return true
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resc := make(chan bool, 2)
// Try two strategies in parallel.
// See https://github.com/GoogleCloudPlatform/gcloud-golang/issues/194
// See https://github.com/GoogleCloudPlatform/google-cloud-go/issues/194
go func() {
res, err := ctxhttp.Get(ctx, metaClient, "http://"+metadataIP)
if err != nil {
@ -208,9 +220,53 @@ func testOnGCE() bool {
resc <- strsContains(addrs, metadataIP)
}()
tryHarder := systemInfoSuggestsGCE()
if tryHarder {
res := <-resc
if res {
// The first strategy succeeded, so let's use it.
return true
}
// Wait for either the DNS or metadata server probe to
// contradict the other one and say we are running on
// GCE. Give it a lot of time to do so, since the system
// info already suggests we're running on a GCE BIOS.
timer := time.NewTimer(5 * time.Second)
defer timer.Stop()
select {
case res = <-resc:
return res
case <-timer.C:
// Too slow. Who knows what this system is.
return false
}
}
// There's no hint from the system info that we're running on
// GCE, so use the first probe's result as truth, whether it's
// true or false. The goal here is to optimize for speed for
// users who are NOT running on GCE. We can't assume that
// either a DNS lookup or an HTTP request to a blackholed IP
// address is fast. Worst case this should return when the
// metaClient's Transport.ResponseHeaderTimeout or
// Transport.Dial.Timeout fires (in two seconds).
return <-resc
}
// systemInfoSuggestsGCE reports whether the local system (without
// doing network requests) suggests that we're running on GCE. If this
// returns true, testOnGCE tries a bit harder to reach its metadata
// server.
func systemInfoSuggestsGCE() bool {
if runtime.GOOS != "linux" {
// We don't have any non-Linux clues available, at least yet.
return false
}
slurp, _ := ioutil.ReadFile("/sys/class/dmi/id/product_name")
name := strings.TrimSpace(string(slurp))
return name == "Google" || name == "Google Compute Engine"
}
// Subscribe subscribes to a value from the metadata service.
// The suffix is appended to "http://${GCE_METADATA_HOST}/computeMetadata/v1/".
// The suffix may contain query parameters.

64
vendor/cloud.google.com/go/internal/cloud.go generated vendored Normal file
View file

@ -0,0 +1,64 @@
// Copyright 2014 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package internal provides support for the cloud packages.
//
// Users should not import this package directly.
package internal
import (
"fmt"
"net/http"
)
const userAgent = "gcloud-golang/0.1"
// Transport is an http.RoundTripper that appends Google Cloud client's
// user-agent to the original request's user-agent header.
type Transport struct {
// TODO(bradfitz): delete internal.Transport. It's too wrappy for what it does.
// Do User-Agent some other way.
// Base is the actual http.RoundTripper
// requests will use. It must not be nil.
Base http.RoundTripper
}
// RoundTrip appends a user-agent to the existing user-agent
// header and delegates the request to the base http.RoundTripper.
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
req = cloneRequest(req)
ua := req.Header.Get("User-Agent")
if ua == "" {
ua = userAgent
} else {
ua = fmt.Sprintf("%s %s", ua, userAgent)
}
req.Header.Set("User-Agent", ua)
return t.Base.RoundTrip(req)
}
// cloneRequest returns a clone of the provided *http.Request.
// The clone is a shallow copy of the struct and its Header map.
func cloneRequest(r *http.Request) *http.Request {
// shallow copy of the struct
r2 := new(http.Request)
*r2 = *r
// deep copy of the Header
r2.Header = make(http.Header)
for k, s := range r.Header {
r2.Header[k] = s
}
return r2
}

View file

@ -45,6 +45,8 @@ type simpleBalancer struct {
// pinAddr is the currently pinned address; set to the empty string on
// intialization and shutdown.
pinAddr string
closed bool
}
func newSimpleBalancer(eps []string) *simpleBalancer {
@ -74,15 +76,25 @@ func (b *simpleBalancer) ConnectNotify() <-chan struct{} {
func (b *simpleBalancer) Up(addr grpc.Address) func(error) {
b.mu.Lock()
defer b.mu.Unlock()
// gRPC might call Up after it called Close. We add this check
// to "fix" it up at application layer. Or our simplerBalancer
// might panic since b.upc is closed.
if b.closed {
return func(err error) {}
}
if len(b.upEps) == 0 {
// notify waiting Get()s and pin first connected address
close(b.upc)
b.pinAddr = addr.Addr
}
b.upEps[addr.Addr] = struct{}{}
b.mu.Unlock()
// notify client that a connection is up
b.readyOnce.Do(func() { close(b.readyc) })
return func(err error) {
b.mu.Lock()
delete(b.upEps, addr.Addr)
@ -128,13 +140,19 @@ func (b *simpleBalancer) Notify() <-chan []grpc.Address { return b.notifyCh }
func (b *simpleBalancer) Close() error {
b.mu.Lock()
defer b.mu.Unlock()
// In case gRPC calls close twice. TODO: remove the checking
// when we are sure that gRPC wont call close twice.
if b.closed {
return nil
}
b.closed = true
close(b.notifyCh)
// terminate all waiting Get()s
b.pinAddr = ""
if len(b.upEps) == 0 {
close(b.upc)
}
b.mu.Unlock()
return nil
}

View file

@ -669,6 +669,10 @@ func (w *watchGrpcStream) resumeWatchers(wc pb.Watch_WatchClient) error {
w.mu.RUnlock()
for _, ws := range streams {
// drain recvc so no old WatchResponses (e.g., Created messages)
// are processed while resuming
ws.drain()
// pause serveStream
ws.resumec <- -1
@ -701,6 +705,17 @@ func (w *watchGrpcStream) resumeWatchers(wc pb.Watch_WatchClient) error {
return nil
}
// drain removes all buffered WatchResponses from the stream's receive channel.
func (ws *watcherStream) drain() {
for {
select {
case <-ws.recvc:
default:
return
}
}
}
// toPB converts an internal watch request structure to its protobuf messagefunc (wr *watchRequest)
func (wr *watchRequest) toPB() *pb.WatchRequest {
req := &pb.WatchCreateRequest{

22
vendor/github.com/coreos/etcd/pkg/fileutil/dir_unix.go generated vendored Normal file
View file

@ -0,0 +1,22 @@
// Copyright 2016 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build !windows
package fileutil
import "os"
// OpenDir opens a directory for syncing.
func OpenDir(path string) (*os.File, error) { return os.Open(path) }

View file

@ -0,0 +1,46 @@
// Copyright 2016 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build windows
package fileutil
import (
"os"
"syscall"
)
// OpenDir opens a directory in windows with write access for syncing.
func OpenDir(path string) (*os.File, error) {
fd, err := openDir(path)
if err != nil {
return nil, err
}
return os.NewFile(uintptr(fd), path), nil
}
func openDir(path string) (fd syscall.Handle, err error) {
if len(path) == 0 {
return syscall.InvalidHandle, syscall.ERROR_FILE_NOT_FOUND
}
pathp, err := syscall.UTF16PtrFromString(path)
if err != nil {
return syscall.InvalidHandle, err
}
access := uint32(syscall.GENERIC_READ | syscall.GENERIC_WRITE)
sharemode := uint32(syscall.FILE_SHARE_READ | syscall.FILE_SHARE_WRITE)
createmode := uint32(syscall.OPEN_EXISTING)
fl := uint32(syscall.FILE_FLAG_BACKUP_SEMANTICS)
return syscall.CreateFile(pathp, access, sharemode, nil, createmode, fl, 0)
}

View file

@ -96,3 +96,26 @@ func Exist(name string) bool {
_, err := os.Stat(name)
return err == nil
}
// ZeroToEnd zeros a file starting from SEEK_CUR to its SEEK_END. May temporarily
// shorten the length of the file.
func ZeroToEnd(f *os.File) error {
// TODO: support FALLOC_FL_ZERO_RANGE
off, err := f.Seek(0, os.SEEK_CUR)
if err != nil {
return err
}
lenf, lerr := f.Seek(0, os.SEEK_END)
if lerr != nil {
return lerr
}
if err = f.Truncate(off); err != nil {
return err
}
// make sure blocks remain allocated
if err = Preallocate(f, lenf, true); err != nil {
return err
}
_, err = f.Seek(off, os.SEEK_SET)
return err
}

View file

@ -4,18 +4,13 @@ import (
"encoding/base64"
"encoding/json"
"errors"
"log"
"net/http"
"net/url"
"path"
"strconv"
"strings"
"time"
"github.com/coreos/pkg/capnslog"
)
var (
log = capnslog.NewPackageLogger("github.com/coreos/go-oidc", "http")
)
func WriteError(w http.ResponseWriter, code int, msg string) {
@ -26,7 +21,9 @@ func WriteError(w http.ResponseWriter, code int, msg string) {
}
b, err := json.Marshal(e)
if err != nil {
log.Errorf("Failed marshaling %#v to JSON: %v", e, err)
log.Printf("go-oidc: failed to marshal %#v: %v", e, err)
code = http.StatusInternalServerError
b = []byte(`{"error":"server_error"}`)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)

View file

@ -1,14 +0,0 @@
package http
import (
"net/http"
)
type LoggingMiddleware struct {
Next http.Handler
}
func (l *LoggingMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log.Infof("HTTP %s %v", r.Method, r.URL)
l.Next.ServeHTTP(w, r)
}

View file

@ -3,9 +3,9 @@ package key
import (
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/hex"
"encoding/json"
"math/big"
"io"
"time"
"github.com/coreos/go-oidc/jose"
@ -139,15 +139,15 @@ func GeneratePrivateKey() (*PrivateKey, error) {
if err != nil {
return nil, err
}
keyID := make([]byte, 20)
if _, err := io.ReadFull(rand.Reader, keyID); err != nil {
return nil, err
}
k := PrivateKey{
KeyID: base64BigInt(pk.PublicKey.N),
KeyID: hex.EncodeToString(keyID),
PrivateKey: pk,
}
return &k, nil
}
func base64BigInt(b *big.Int) string {
return base64.URLEncoding.EncodeToString(b.Bytes())
}

View file

@ -2,16 +2,14 @@ package key
import (
"errors"
"log"
"time"
"github.com/coreos/pkg/capnslog"
ptime "github.com/coreos/pkg/timeutil"
"github.com/jonboulle/clockwork"
)
var (
log = capnslog.NewPackageLogger("github.com/coreos/go-oidc", "key")
ErrorPrivateKeysExpired = errors.New("private keys have expired")
)
@ -67,7 +65,6 @@ func (r *PrivateKeyRotator) privateKeySet() (*PrivateKeySet, error) {
func (r *PrivateKeyRotator) nextRotation() (time.Duration, error) {
pks, err := r.privateKeySet()
if err == ErrorNoKeys {
log.Infof("No keys in private key set; must rotate immediately")
return 0, nil
}
if err != nil {
@ -94,17 +91,15 @@ func (r *PrivateKeyRotator) Run() chan struct{} {
attempt := func() {
k, err := r.generateKey()
if err != nil {
log.Errorf("Failed generating signing key: %v", err)
log.Printf("go-oidc: failed generating signing key: %v", err)
return
}
exp := r.expiresAt()
if err := rotatePrivateKeys(r.repo, k, r.keep, exp); err != nil {
log.Errorf("Failed key rotation: %v", err)
log.Printf("go-oidc: key rotation failed: %v", err)
return
}
log.Infof("Rotated signing keys: id=%s expiresAt=%s", k.ID(), exp)
}
stop := make(chan struct{})
@ -118,11 +113,10 @@ func (r *PrivateKeyRotator) Run() chan struct{} {
break
}
sleep = ptime.ExpBackoff(sleep, time.Minute)
log.Errorf("error getting nextRotation, retrying in %v: %v", sleep, err)
log.Printf("go-oidc: error getting nextRotation, retrying in %v: %v", sleep, err)
time.Sleep(sleep)
}
log.Infof("will rotate keys in %v", nextRotation)
select {
case <-r.clock.After(nextRotation):
attempt()

View file

@ -2,6 +2,7 @@ package key
import (
"errors"
"log"
"time"
"github.com/jonboulle/clockwork"
@ -38,15 +39,14 @@ func (s *KeySetSyncer) Run() chan struct{} {
next = timeutil.ExpBackoff(next, time.Minute)
}
if exp == 0 {
log.Errorf("Synced to already expired key set, retrying in %v: %v", next, err)
log.Printf("Synced to already expired key set, retrying in %v: %v", next, err)
} else {
log.Errorf("Failed syncing key set, retrying in %v: %v", next, err)
log.Printf("Failed syncing key set, retrying in %v: %v", next, err)
}
} else {
failing = false
next = exp / 2
log.Infof("Synced key set, checking again in %v", next)
}
select {

View file

@ -332,16 +332,16 @@ func parseTokenResponse(resp *http.Response) (result TokenResponse, err error) {
result.Scope = vals.Get("scope")
} else {
var r struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
Scope string `json:"scope"`
State string `json:"state"`
ExpiresIn int `json:"expires_in"`
Expires int `json:"expires"`
Error string `json:"error"`
Desc string `json:"error_description"`
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
Scope string `json:"scope"`
State string `json:"state"`
ExpiresIn json.Number `json:"expires_in"` // Azure AD returns string
Expires int `json:"expires"`
Error string `json:"error"`
Desc string `json:"error_description"`
}
if err = json.Unmarshal(body, &r); err != nil {
return
@ -355,10 +355,10 @@ func parseTokenResponse(resp *http.Response) (result TokenResponse, err error) {
result.IDToken = r.IDToken
result.RefreshToken = r.RefreshToken
result.Scope = r.Scope
if r.ExpiresIn == 0 {
if expiresIn, err := r.ExpiresIn.Int64(); err != nil {
result.Expires = r.Expires
} else {
result.Expires = r.ExpiresIn
result.Expires = int(expiresIn)
}
}
return

View file

@ -4,13 +4,13 @@ import (
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/coreos/pkg/capnslog"
"github.com/coreos/pkg/timeutil"
"github.com/jonboulle/clockwork"
@ -18,10 +18,6 @@ import (
"github.com/coreos/go-oidc/oauth2"
)
var (
log = capnslog.NewPackageLogger("github.com/coreos/go-oidc", "http")
)
const (
// Subject Identifier types defined by the OIDC spec. Specifies if the provider
// should provide the same sub claim value to all clients (public) or a unique
@ -69,6 +65,8 @@ type ProviderConfig struct {
UserInfoEndpoint *url.URL
KeysEndpoint *url.URL // Required
RegistrationEndpoint *url.URL
EndSessionEndpoint *url.URL
CheckSessionIFrame *url.URL
// Servers MAY choose not to advertise some supported scope values even when this
// parameter is used, although those defined in OpenID Core SHOULD be listed, if supported.
@ -170,6 +168,8 @@ type encodableProviderConfig struct {
UserInfoEndpoint string `json:"userinfo_endpoint,omitempty"`
KeysEndpoint string `json:"jwks_uri"`
RegistrationEndpoint string `json:"registration_endpoint,omitempty"`
EndSessionEndpoint string `json:"end_session_endpoint,omitempty"`
CheckSessionIFrame string `json:"check_session_iframe,omitempty"`
// Use 'omitempty' for all slices as per OIDC spec:
// "Claims that return multiple values are represented as JSON arrays.
@ -219,6 +219,8 @@ func (cfg ProviderConfig) toEncodableStruct() encodableProviderConfig {
UserInfoEndpoint: uriToString(cfg.UserInfoEndpoint),
KeysEndpoint: uriToString(cfg.KeysEndpoint),
RegistrationEndpoint: uriToString(cfg.RegistrationEndpoint),
EndSessionEndpoint: uriToString(cfg.EndSessionEndpoint),
CheckSessionIFrame: uriToString(cfg.CheckSessionIFrame),
ScopesSupported: cfg.ScopesSupported,
ResponseTypesSupported: cfg.ResponseTypesSupported,
ResponseModesSupported: cfg.ResponseModesSupported,
@ -260,6 +262,8 @@ func (e encodableProviderConfig) toStruct() (ProviderConfig, error) {
UserInfoEndpoint: p.parseURI(e.UserInfoEndpoint, "userinfo_endpoint"),
KeysEndpoint: p.parseURI(e.KeysEndpoint, "jwks_uri"),
RegistrationEndpoint: p.parseURI(e.RegistrationEndpoint, "registration_endpoint"),
EndSessionEndpoint: p.parseURI(e.EndSessionEndpoint, "end_session_endpoint"),
CheckSessionIFrame: p.parseURI(e.CheckSessionIFrame, "check_session_iframe"),
ScopesSupported: e.ScopesSupported,
ResponseTypesSupported: e.ResponseTypesSupported,
ResponseModesSupported: e.ResponseModesSupported,
@ -364,6 +368,8 @@ func (p ProviderConfig) Valid() error {
{p.UserInfoEndpoint, "userinfo_endpoint", false},
{p.KeysEndpoint, "jwks_uri", true},
{p.RegistrationEndpoint, "registration_endpoint", false},
{p.EndSessionEndpoint, "end_session_endpoint", false},
{p.CheckSessionIFrame, "check_session_iframe", false},
{p.ServiceDocs, "service_documentation", false},
{p.Policy, "op_policy_uri", false},
{p.TermsOfService, "op_tos_uri", false},
@ -537,8 +543,6 @@ func (s *ProviderConfigSyncer) sync() (time.Duration, error) {
s.initialSyncDone = true
}
log.Infof("Updating provider config: config=%#v", cfg)
return nextSyncAfter(cfg.ExpiresAt, s.clock), nil
}
@ -561,10 +565,9 @@ func (n *pcsStepNext) step(fn pcsStepFunc) (next pcsStepper) {
ttl, err := fn()
if err == nil {
next = &pcsStepNext{aft: ttl}
log.Debugf("Synced provider config, next attempt in %v", next.after())
} else {
next = &pcsStepRetry{aft: time.Second}
log.Errorf("Provider config sync failed, retrying in %v: %v", next.after(), err)
log.Printf("go-oidc: provider config sync falied, retyring in %v: %v", next.after(), err)
}
return
}
@ -581,10 +584,9 @@ func (r *pcsStepRetry) step(fn pcsStepFunc) (next pcsStepper) {
ttl, err := fn()
if err == nil {
next = &pcsStepNext{aft: ttl}
log.Infof("Provider config sync no longer failing")
} else {
next = &pcsStepRetry{aft: timeutil.ExpBackoff(r.aft, time.Minute)}
log.Errorf("Provider config sync still failing, retrying in %v: %v", next.after(), err)
log.Printf("go-oidc: provider config sync falied, retyring in %v: %v", next.after(), err)
}
return
}

View file

@ -161,11 +161,18 @@ func NewJWTVerifier(issuer, clientID string, syncFunc func() error, keysFunc fun
}
func (v *JWTVerifier) Verify(jwt jose.JWT) error {
// Verify claims before verifying the signature. This is an optimization to throw out
// tokens we know are invalid without undergoing an expensive signature check and
// possibly a re-sync event.
if err := VerifyClaims(jwt, v.issuer, v.clientID); err != nil {
return fmt.Errorf("oidc: JWT claims invalid: %v", err)
}
ok, err := VerifySignature(jwt, v.keysFunc())
if ok {
goto SignatureVerified
} else if err != nil {
if err != nil {
return fmt.Errorf("oidc: JWT signature verification failed: %v", err)
} else if ok {
return nil
}
if err = v.syncFunc(); err != nil {
@ -179,10 +186,5 @@ func (v *JWTVerifier) Verify(jwt jose.JWT) error {
return errors.New("oidc: unable to verify JWT signature: no matching keys")
}
SignatureVerified:
if err := VerifyClaims(jwt, v.issuer, v.clientID); err != nil {
return fmt.Errorf("oidc: JWT claims invalid: %v", err)
}
return nil
}

View file

@ -1,8 +1,7 @@
language: go
go:
- 1.3
- 1.4
- tip
install:
- export GOPATH="$HOME/gopath"

View file

@ -1,6 +1,7 @@
# OAuth2 for Go
[![Build Status](https://travis-ci.org/golang/oauth2.svg?branch=master)](https://travis-ci.org/golang/oauth2)
[![GoDoc](https://godoc.org/golang.org/x/oauth2?status.svg)](https://godoc.org/golang.org/x/oauth2)
oauth2 package contains a client implementation for OAuth 2.0 spec.

View file

@ -1,8 +1,8 @@
// Copyright 2014 The oauth2 Authors. All rights reserved.
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build appengine appenginevm
// +build appengine
// App Engine hooks.

View file

@ -1,4 +1,4 @@
// Copyright 2014 The oauth2 Authors. All rights reserved.
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
@ -14,6 +14,9 @@ import (
"golang.org/x/oauth2"
)
// Set at init time by appenginevm_hook.go. If true, we are on App Engine Managed VMs.
var appengineVM bool
// Set at init time by appengine_hook.go. If nil, we're not on App Engine.
var appengineTokenFunc func(c context.Context, scopes ...string) (token string, expiry time.Time, err error)

View file

@ -1,8 +1,8 @@
// Copyright 2015 The oauth2 Authors. All rights reserved.
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build appengine appenginevm
// +build appengine
package google

14
vendor/golang.org/x/oauth2/google/appenginevm_hook.go generated vendored Normal file
View file

@ -0,0 +1,14 @@
// Copyright 2015 The oauth2 Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build appenginevm
package google
import "google.golang.org/appengine"
func init() {
appengineVM = true
appengineTokenFunc = appengine.AccessToken
}

View file

@ -1,4 +1,4 @@
// Copyright 2015 The oauth2 Authors. All rights reserved.
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
@ -14,10 +14,10 @@ import (
"path/filepath"
"runtime"
"cloud.google.com/go/compute/metadata"
"golang.org/x/net/context"
"golang.org/x/oauth2"
"golang.org/x/oauth2/jwt"
"google.golang.org/cloud/compute/metadata"
)
// DefaultClient returns an HTTP Client that uses the
@ -50,7 +50,8 @@ func DefaultClient(ctx context.Context, scope ...string) (*http.Client, error) {
// On Windows, this is %APPDATA%/gcloud/application_default_credentials.json.
// On other systems, $HOME/.config/gcloud/application_default_credentials.json.
// 3. On Google App Engine it uses the appengine.AccessToken function.
// 4. On Google Compute Engine, it fetches credentials from the metadata server.
// 4. On Google Compute Engine and Google App Engine Managed VMs, it fetches
// credentials from the metadata server.
// (In this final case any provided scopes are ignored.)
//
// For more details, see:
@ -84,7 +85,7 @@ func DefaultTokenSource(ctx context.Context, scope ...string) (oauth2.TokenSourc
}
// Third, if we're on Google App Engine use those credentials.
if appengineTokenFunc != nil {
if appengineTokenFunc != nil && !appengineVM {
return AppEngineTokenSource(ctx, scope...), nil
}

View file

@ -1,4 +1,4 @@
// Copyright 2014 The oauth2 Authors. All rights reserved.
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
@ -21,9 +21,9 @@ import (
"strings"
"time"
"cloud.google.com/go/compute/metadata"
"golang.org/x/oauth2"
"golang.org/x/oauth2/jwt"
"google.golang.org/cloud/compute/metadata"
)
// Endpoint is Google's OAuth 2.0 endpoint.
@ -37,9 +37,10 @@ const JWTTokenURL = "https://accounts.google.com/o/oauth2/token"
// ConfigFromJSON uses a Google Developers Console client_credentials.json
// file to construct a config.
// client_credentials.json can be downloadable from https://console.developers.google.com,
// under "APIs & Auth" > "Credentials". Download the Web application credentials in the
// JSON format and provide the contents of the file as jsonKey.
// client_credentials.json can be downloaded from
// https://console.developers.google.com, under "Credentials". Download the Web
// application credentials in the JSON format and provide the contents of the
// file as jsonKey.
func ConfigFromJSON(jsonKey []byte, scope ...string) (*oauth2.Config, error) {
type cred struct {
ClientID string `json:"client_id"`
@ -81,22 +82,29 @@ func ConfigFromJSON(jsonKey []byte, scope ...string) (*oauth2.Config, error) {
// JWTConfigFromJSON uses a Google Developers service account JSON key file to read
// the credentials that authorize and authenticate the requests.
// Create a service account on "Credentials" page under "APIs & Auth" for your
// project at https://console.developers.google.com to download a JSON key file.
// Create a service account on "Credentials" for your project at
// https://console.developers.google.com to download a JSON key file.
func JWTConfigFromJSON(jsonKey []byte, scope ...string) (*jwt.Config, error) {
var key struct {
Email string `json:"client_email"`
PrivateKey string `json:"private_key"`
Email string `json:"client_email"`
PrivateKey string `json:"private_key"`
PrivateKeyID string `json:"private_key_id"`
TokenURL string `json:"token_uri"`
}
if err := json.Unmarshal(jsonKey, &key); err != nil {
return nil, err
}
return &jwt.Config{
Email: key.Email,
PrivateKey: []byte(key.PrivateKey),
Scopes: scope,
TokenURL: JWTTokenURL,
}, nil
config := &jwt.Config{
Email: key.Email,
PrivateKey: []byte(key.PrivateKey),
PrivateKeyID: key.PrivateKeyID,
Scopes: scope,
TokenURL: key.TokenURL,
}
if config.TokenURL == "" {
config.TokenURL = JWTTokenURL
}
return config, nil
}
// ComputeTokenSource returns a token source that fetches access tokens

74
vendor/golang.org/x/oauth2/google/jwt.go generated vendored Normal file
View file

@ -0,0 +1,74 @@
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package google
import (
"crypto/rsa"
"fmt"
"time"
"golang.org/x/oauth2"
"golang.org/x/oauth2/internal"
"golang.org/x/oauth2/jws"
)
// JWTAccessTokenSourceFromJSON uses a Google Developers service account JSON
// key file to read the credentials that authorize and authenticate the
// requests, and returns a TokenSource that does not use any OAuth2 flow but
// instead creates a JWT and sends that as the access token.
// The audience is typically a URL that specifies the scope of the credentials.
//
// Note that this is not a standard OAuth flow, but rather an
// optimization supported by a few Google services.
// Unless you know otherwise, you should use JWTConfigFromJSON instead.
func JWTAccessTokenSourceFromJSON(jsonKey []byte, audience string) (oauth2.TokenSource, error) {
cfg, err := JWTConfigFromJSON(jsonKey)
if err != nil {
return nil, fmt.Errorf("google: could not parse JSON key: %v", err)
}
pk, err := internal.ParseKey(cfg.PrivateKey)
if err != nil {
return nil, fmt.Errorf("google: could not parse key: %v", err)
}
ts := &jwtAccessTokenSource{
email: cfg.Email,
audience: audience,
pk: pk,
pkID: cfg.PrivateKeyID,
}
tok, err := ts.Token()
if err != nil {
return nil, err
}
return oauth2.ReuseTokenSource(tok, ts), nil
}
type jwtAccessTokenSource struct {
email, audience string
pk *rsa.PrivateKey
pkID string
}
func (ts *jwtAccessTokenSource) Token() (*oauth2.Token, error) {
iat := time.Now()
exp := iat.Add(time.Hour)
cs := &jws.ClaimSet{
Iss: ts.email,
Sub: ts.email,
Aud: ts.audience,
Iat: iat.Unix(),
Exp: exp.Unix(),
}
hdr := &jws.Header{
Algorithm: "RS256",
Typ: "JWT",
KeyID: string(ts.pkID),
}
msg, err := jws.Encode(hdr, cs, ts.pk)
if err != nil {
return nil, fmt.Errorf("google: could not encode JWT: %v", err)
}
return &oauth2.Token{AccessToken: msg, TokenType: "Bearer", Expiry: exp}, nil
}

View file

@ -1,4 +1,4 @@
// Copyright 2015 The oauth2 Authors. All rights reserved.
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

View file

@ -1,4 +1,4 @@
// Copyright 2014 The oauth2 Authors. All rights reserved.
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

View file

@ -1,4 +1,4 @@
// Copyright 2014 The oauth2 Authors. All rights reserved.
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
@ -91,24 +91,36 @@ func (e *expirationTime) UnmarshalJSON(b []byte) error {
var brokenAuthHeaderProviders = []string{
"https://accounts.google.com/",
"https://www.googleapis.com/",
"https://github.com/",
"https://api.instagram.com/",
"https://www.douban.com/",
"https://api.dropbox.com/",
"https://api.soundcloud.com/",
"https://www.linkedin.com/",
"https://api.twitch.tv/",
"https://oauth.vk.com/",
"https://api.dropboxapi.com/",
"https://api.instagram.com/",
"https://api.netatmo.net/",
"https://api.odnoklassniki.ru/",
"https://connect.stripe.com/",
"https://api.pushbullet.com/",
"https://api.soundcloud.com/",
"https://api.twitch.tv/",
"https://app.box.com/",
"https://connect.stripe.com/",
"https://login.microsoftonline.com/",
"https://login.salesforce.com/",
"https://oauth.sandbox.trainingpeaks.com/",
"https://oauth.trainingpeaks.com/",
"https://www.strava.com/oauth/",
"https://app.box.com/",
"https://oauth.vk.com/",
"https://openapi.baidu.com/",
"https://slack.com/",
"https://test-sandbox.auth.corp.google.com",
"https://test.salesforce.com/",
"https://user.gini.net/",
"https://www.douban.com/",
"https://www.googleapis.com/",
"https://www.linkedin.com/",
"https://www.strava.com/oauth/",
"https://www.wunderlist.com/oauth/",
"https://api.patreon.com/",
}
func RegisterBrokenAuthHeaderProvider(tokenURL string) {
brokenAuthHeaderProviders = append(brokenAuthHeaderProviders, tokenURL)
}
// providerAuthHeaderWorks reports whether the OAuth2 server identified by the tokenURL
@ -134,23 +146,23 @@ func providerAuthHeaderWorks(tokenURL string) bool {
return true
}
func RetrieveToken(ctx context.Context, ClientID, ClientSecret, TokenURL string, v url.Values) (*Token, error) {
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values) (*Token, error) {
hc, err := ContextClient(ctx)
if err != nil {
return nil, err
}
v.Set("client_id", ClientID)
bustedAuth := !providerAuthHeaderWorks(TokenURL)
if bustedAuth && ClientSecret != "" {
v.Set("client_secret", ClientSecret)
v.Set("client_id", clientID)
bustedAuth := !providerAuthHeaderWorks(tokenURL)
if bustedAuth && clientSecret != "" {
v.Set("client_secret", clientSecret)
}
req, err := http.NewRequest("POST", TokenURL, strings.NewReader(v.Encode()))
req, err := http.NewRequest("POST", tokenURL, strings.NewReader(v.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
if !bustedAuth {
req.SetBasicAuth(ClientID, ClientSecret)
req.SetBasicAuth(clientID, clientSecret)
}
r, err := hc.Do(req)
if err != nil {

View file

@ -1,4 +1,4 @@
// Copyright 2014 The oauth2 Authors. All rights reserved.
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
@ -33,6 +33,11 @@ func RegisterContextClientFunc(fn ContextClientFunc) {
}
func ContextClient(ctx context.Context) (*http.Client, error) {
if ctx != nil {
if hc, ok := ctx.Value(HTTPClient).(*http.Client); ok {
return hc, nil
}
}
for _, fn := range contextClientFuncs {
c, err := fn(ctx)
if err != nil {
@ -42,9 +47,6 @@ func ContextClient(ctx context.Context) (*http.Client, error) {
return c, nil
}
}
if hc, ok := ctx.Value(HTTPClient).(*http.Client); ok {
return hc, nil
}
return http.DefaultClient, nil
}

110
vendor/golang.org/x/oauth2/jws/jws.go generated vendored
View file

@ -1,9 +1,17 @@
// Copyright 2014 The oauth2 Authors. All rights reserved.
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package jws provides encoding and decoding utilities for
// signed JWS messages.
// Package jws provides a partial implementation
// of JSON Web Signature encoding and decoding.
// It exists to support the golang.org/x/oauth2 package.
//
// See RFC 7515.
//
// Deprecated: this package is not intended for public use and might be
// removed in the future. It exists for internal use only.
// Please switch to another JWS package or copy this package into your own
// source tree.
package jws
import (
@ -27,8 +35,8 @@ type ClaimSet struct {
Iss string `json:"iss"` // email address of the client_id of the application making the access token request
Scope string `json:"scope,omitempty"` // space-delimited list of the permissions the application requests
Aud string `json:"aud"` // descriptor of the intended target of the assertion (Optional).
Exp int64 `json:"exp"` // the expiration time of the assertion
Iat int64 `json:"iat"` // the time the assertion was issued.
Exp int64 `json:"exp"` // the expiration time of the assertion (seconds since Unix epoch)
Iat int64 `json:"iat"` // the time the assertion was issued (seconds since Unix epoch)
Typ string `json:"typ,omitempty"` // token type (Optional).
// Email for which the application is requesting delegated access (Optional).
@ -41,23 +49,22 @@ type ClaimSet struct {
// See http://tools.ietf.org/html/draft-jones-json-web-token-10#section-4.3
// This array is marshalled using custom code (see (c *ClaimSet) encode()).
PrivateClaims map[string]interface{} `json:"-"`
exp time.Time
iat time.Time
}
func (c *ClaimSet) encode() (string, error) {
if c.exp.IsZero() || c.iat.IsZero() {
// Reverting time back for machines whose time is not perfectly in sync.
// If client machine's time is in the future according
// to Google servers, an access token will not be issued.
now := time.Now().Add(-10 * time.Second)
c.iat = now
c.exp = now.Add(time.Hour)
// Reverting time back for machines whose time is not perfectly in sync.
// If client machine's time is in the future according
// to Google servers, an access token will not be issued.
now := time.Now().Add(-10 * time.Second)
if c.Iat == 0 {
c.Iat = now.Unix()
}
if c.Exp == 0 {
c.Exp = now.Add(time.Hour).Unix()
}
if c.Exp < c.Iat {
return "", fmt.Errorf("jws: invalid Exp = %v; must be later than Iat = %v", c.Exp, c.Iat)
}
c.Exp = c.exp.Unix()
c.Iat = c.iat.Unix()
b, err := json.Marshal(c)
if err != nil {
@ -65,7 +72,7 @@ func (c *ClaimSet) encode() (string, error) {
}
if len(c.PrivateClaims) == 0 {
return base64Encode(b), nil
return base64.RawURLEncoding.EncodeToString(b), nil
}
// Marshal private claim set and then append it to b.
@ -83,7 +90,7 @@ func (c *ClaimSet) encode() (string, error) {
}
b[len(b)-1] = ',' // Replace closing curly brace with a comma.
b = append(b, prv[1:]...) // Append private claims.
return base64Encode(b), nil
return base64.RawURLEncoding.EncodeToString(b), nil
}
// Header represents the header for the signed JWS payloads.
@ -93,6 +100,9 @@ type Header struct {
// Represents the token type.
Typ string `json:"typ"`
// The optional hint of which key is being used.
KeyID string `json:"kid,omitempty"`
}
func (h *Header) encode() (string, error) {
@ -100,7 +110,7 @@ func (h *Header) encode() (string, error) {
if err != nil {
return "", err
}
return base64Encode(b), nil
return base64.RawURLEncoding.EncodeToString(b), nil
}
// Decode decodes a claim set from a JWS payload.
@ -111,7 +121,7 @@ func Decode(payload string) (*ClaimSet, error) {
// TODO(jbd): Provide more context about the error.
return nil, errors.New("jws: invalid token received")
}
decoded, err := base64Decode(s[1])
decoded, err := base64.RawURLEncoding.DecodeString(s[1])
if err != nil {
return nil, err
}
@ -120,8 +130,11 @@ func Decode(payload string) (*ClaimSet, error) {
return c, err
}
// Encode encodes a signed JWS with provided header and claim set.
func Encode(header *Header, c *ClaimSet, signature *rsa.PrivateKey) (string, error) {
// Signer returns a signature for the given data.
type Signer func(data []byte) (sig []byte, err error)
// EncodeWithSigner encodes a header and claim set with the provided signer.
func EncodeWithSigner(header *Header, c *ClaimSet, sg Signer) (string, error) {
head, err := header.encode()
if err != nil {
return "", err
@ -131,30 +144,39 @@ func Encode(header *Header, c *ClaimSet, signature *rsa.PrivateKey) (string, err
return "", err
}
ss := fmt.Sprintf("%s.%s", head, cs)
h := sha256.New()
h.Write([]byte(ss))
b, err := rsa.SignPKCS1v15(rand.Reader, signature, crypto.SHA256, h.Sum(nil))
sig, err := sg([]byte(ss))
if err != nil {
return "", err
}
sig := base64Encode(b)
return fmt.Sprintf("%s.%s", ss, sig), nil
return fmt.Sprintf("%s.%s", ss, base64.RawURLEncoding.EncodeToString(sig)), nil
}
// base64Encode returns and Base64url encoded version of the input string with any
// trailing "=" stripped.
func base64Encode(b []byte) string {
return strings.TrimRight(base64.URLEncoding.EncodeToString(b), "=")
}
// base64Decode decodes the Base64url encoded string
func base64Decode(s string) ([]byte, error) {
// add back missing padding
switch len(s) % 4 {
case 2:
s += "=="
case 3:
s += "="
// Encode encodes a signed JWS with provided header and claim set.
// This invokes EncodeWithSigner using crypto/rsa.SignPKCS1v15 with the given RSA private key.
func Encode(header *Header, c *ClaimSet, key *rsa.PrivateKey) (string, error) {
sg := func(data []byte) (sig []byte, err error) {
h := sha256.New()
h.Write(data)
return rsa.SignPKCS1v15(rand.Reader, key, crypto.SHA256, h.Sum(nil))
}
return base64.URLEncoding.DecodeString(s)
return EncodeWithSigner(header, c, sg)
}
// Verify tests whether the provided JWT token's signature was produced by the private key
// associated with the supplied public key.
func Verify(token string, key *rsa.PublicKey) error {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return errors.New("jws: invalid token received, token must have 3 parts")
}
signedContent := parts[0] + "." + parts[1]
signatureString, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return err
}
h := sha256.New()
h.Write([]byte(signedContent))
return rsa.VerifyPKCS1v15(key, crypto.SHA256, h.Sum(nil), []byte(signatureString))
}

View file

@ -1,4 +1,4 @@
// Copyright 2014 The oauth2 Authors. All rights reserved.
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
@ -46,6 +46,10 @@ type Config struct {
//
PrivateKey []byte
// PrivateKeyID contains an optional hint indicating which key is being
// used.
PrivateKeyID string
// Subject is the optional user to impersonate.
Subject string
@ -54,6 +58,9 @@ type Config struct {
// TokenURL is the endpoint required to complete the 2-legged JWT flow.
TokenURL string
// Expires optionally specifies how long the token is valid for.
Expires time.Duration
}
// TokenSource returns a JWT TokenSource using the configuration
@ -95,6 +102,9 @@ func (js jwtSource) Token() (*oauth2.Token, error) {
// to be compatible with legacy OAuth 2.0 providers.
claimSet.Prn = subject
}
if t := js.conf.Expires; t > 0 {
claimSet.Exp = time.Now().Add(t).Unix()
}
payload, err := jws.Encode(defaultHeader, claimSet, pk)
if err != nil {
return nil, err

20
vendor/golang.org/x/oauth2/oauth2.go generated vendored
View file

@ -1,4 +1,4 @@
// Copyright 2014 The oauth2 Authors. All rights reserved.
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
@ -21,10 +21,26 @@ import (
// NoContext is the default context you should supply if not using
// your own context.Context (see https://golang.org/x/net/context).
//
// Deprecated: Use context.Background() or context.TODO() instead.
var NoContext = context.TODO()
// RegisterBrokenAuthHeaderProvider registers an OAuth2 server
// identified by the tokenURL prefix as an OAuth2 implementation
// which doesn't support the HTTP Basic authentication
// scheme to authenticate with the authorization server.
// Once a server is registered, credentials (client_id and client_secret)
// will be passed as query parameters rather than being present
// in the Authorization header.
// See https://code.google.com/p/goauth2/issues/detail?id=31 for background.
func RegisterBrokenAuthHeaderProvider(tokenURL string) {
internal.RegisterBrokenAuthHeaderProvider(tokenURL)
}
// Config describes a typical 3-legged OAuth2 flow, with both the
// client application information and the server's endpoint URLs.
// For the client credentials 2-legged OAuth2 flow, see the clientcredentials
// package (https://golang.org/x/oauth2/clientcredentials).
type Config struct {
// ClientID is the application's ID.
ClientID string
@ -283,7 +299,7 @@ func NewClient(ctx context.Context, src TokenSource) *http.Client {
if src == nil {
c, err := internal.ContextClient(ctx)
if err != nil {
return &http.Client{Transport: internal.ErrorTransport{err}}
return &http.Client{Transport: internal.ErrorTransport{Err: err}}
}
return c
}

27
vendor/golang.org/x/oauth2/token.go generated vendored
View file

@ -1,4 +1,4 @@
// Copyright 2014 The oauth2 Authors. All rights reserved.
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
@ -7,6 +7,7 @@ package oauth2
import (
"net/http"
"net/url"
"strconv"
"strings"
"time"
@ -92,14 +93,28 @@ func (t *Token) WithExtra(extra interface{}) *Token {
// Extra fields are key-value pairs returned by the server as a
// part of the token retrieval response.
func (t *Token) Extra(key string) interface{} {
if vals, ok := t.raw.(url.Values); ok {
// TODO(jbd): Cast numeric values to int64 or float64.
return vals.Get(key)
}
if raw, ok := t.raw.(map[string]interface{}); ok {
return raw[key]
}
return nil
vals, ok := t.raw.(url.Values)
if !ok {
return nil
}
v := vals.Get(key)
switch s := strings.TrimSpace(v); strings.Count(s, ".") {
case 0: // Contains no "."; try to parse as int
if i, err := strconv.ParseInt(s, 10, 64); err == nil {
return i
}
case 1: // Contains a single "."; try to parse as float
if f, err := strconv.ParseFloat(s, 64); err == nil {
return f
}
}
return v
}
// expired reports whether the token is expired.

View file

@ -1,4 +1,4 @@
// Copyright 2014 The oauth2 Authors. All rights reserved.
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -878,23 +878,22 @@ func (c *ProjectsZonesGetServerconfigCall) Context(ctx context.Context) *Project
}
func (c *ProjectsZonesGetServerconfigCall) doRequest(alt string) (*http.Response, error) {
reqHeaders := make(http.Header)
reqHeaders.Set("User-Agent", c.s.userAgent())
if c.ifNoneMatch_ != "" {
reqHeaders.Set("If-None-Match", c.ifNoneMatch_)
}
var body io.Reader = nil
c.urlParams_.Set("alt", alt)
urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/serverconfig")
urls += "?" + c.urlParams_.Encode()
req, _ := http.NewRequest("GET", urls, body)
req.Header = reqHeaders
googleapi.Expand(req.URL, map[string]string{
"projectId": c.projectId,
"zone": c.zone,
})
req.Header.Set("User-Agent", c.s.userAgent())
if c.ifNoneMatch_ != "" {
req.Header.Set("If-None-Match", c.ifNoneMatch_)
}
if c.ctx_ != nil {
return ctxhttp.Do(c.ctx_, c.s.client, req)
}
return c.s.client.Do(req)
return gensupport.SendRequest(c.ctx_, c.s.client, req)
}
// Do executes the "container.projects.zones.getServerconfig" call.
@ -929,7 +928,8 @@ func (c *ProjectsZonesGetServerconfigCall) Do(opts ...googleapi.CallOption) (*Se
HTTPStatusCode: res.StatusCode,
},
}
if err := json.NewDecoder(res.Body).Decode(&ret); err != nil {
target := &ret
if err := json.NewDecoder(res.Body).Decode(target); err != nil {
return nil, err
}
return ret, nil
@ -1011,26 +1011,24 @@ func (c *ProjectsZonesClustersCreateCall) Context(ctx context.Context) *Projects
}
func (c *ProjectsZonesClustersCreateCall) doRequest(alt string) (*http.Response, error) {
reqHeaders := make(http.Header)
reqHeaders.Set("User-Agent", c.s.userAgent())
var body io.Reader = nil
body, err := googleapi.WithoutDataWrapper.JSONReader(c.createclusterrequest)
if err != nil {
return nil, err
}
ctype := "application/json"
reqHeaders.Set("Content-Type", "application/json")
c.urlParams_.Set("alt", alt)
urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/clusters")
urls += "?" + c.urlParams_.Encode()
req, _ := http.NewRequest("POST", urls, body)
req.Header = reqHeaders
googleapi.Expand(req.URL, map[string]string{
"projectId": c.projectId,
"zone": c.zone,
})
req.Header.Set("Content-Type", ctype)
req.Header.Set("User-Agent", c.s.userAgent())
if c.ctx_ != nil {
return ctxhttp.Do(c.ctx_, c.s.client, req)
}
return c.s.client.Do(req)
return gensupport.SendRequest(c.ctx_, c.s.client, req)
}
// Do executes the "container.projects.zones.clusters.create" call.
@ -1065,7 +1063,8 @@ func (c *ProjectsZonesClustersCreateCall) Do(opts ...googleapi.CallOption) (*Ope
HTTPStatusCode: res.StatusCode,
},
}
if err := json.NewDecoder(res.Body).Decode(&ret); err != nil {
target := &ret
if err := json.NewDecoder(res.Body).Decode(target); err != nil {
return nil, err
}
return ret, nil
@ -1147,21 +1146,20 @@ func (c *ProjectsZonesClustersDeleteCall) Context(ctx context.Context) *Projects
}
func (c *ProjectsZonesClustersDeleteCall) doRequest(alt string) (*http.Response, error) {
reqHeaders := make(http.Header)
reqHeaders.Set("User-Agent", c.s.userAgent())
var body io.Reader = nil
c.urlParams_.Set("alt", alt)
urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/clusters/{clusterId}")
urls += "?" + c.urlParams_.Encode()
req, _ := http.NewRequest("DELETE", urls, body)
req.Header = reqHeaders
googleapi.Expand(req.URL, map[string]string{
"projectId": c.projectId,
"zone": c.zone,
"clusterId": c.clusterId,
})
req.Header.Set("User-Agent", c.s.userAgent())
if c.ctx_ != nil {
return ctxhttp.Do(c.ctx_, c.s.client, req)
}
return c.s.client.Do(req)
return gensupport.SendRequest(c.ctx_, c.s.client, req)
}
// Do executes the "container.projects.zones.clusters.delete" call.
@ -1196,7 +1194,8 @@ func (c *ProjectsZonesClustersDeleteCall) Do(opts ...googleapi.CallOption) (*Ope
HTTPStatusCode: res.StatusCode,
},
}
if err := json.NewDecoder(res.Body).Decode(&ret); err != nil {
target := &ret
if err := json.NewDecoder(res.Body).Decode(target); err != nil {
return nil, err
}
return ret, nil
@ -1288,24 +1287,23 @@ func (c *ProjectsZonesClustersGetCall) Context(ctx context.Context) *ProjectsZon
}
func (c *ProjectsZonesClustersGetCall) doRequest(alt string) (*http.Response, error) {
reqHeaders := make(http.Header)
reqHeaders.Set("User-Agent", c.s.userAgent())
if c.ifNoneMatch_ != "" {
reqHeaders.Set("If-None-Match", c.ifNoneMatch_)
}
var body io.Reader = nil
c.urlParams_.Set("alt", alt)
urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/clusters/{clusterId}")
urls += "?" + c.urlParams_.Encode()
req, _ := http.NewRequest("GET", urls, body)
req.Header = reqHeaders
googleapi.Expand(req.URL, map[string]string{
"projectId": c.projectId,
"zone": c.zone,
"clusterId": c.clusterId,
})
req.Header.Set("User-Agent", c.s.userAgent())
if c.ifNoneMatch_ != "" {
req.Header.Set("If-None-Match", c.ifNoneMatch_)
}
if c.ctx_ != nil {
return ctxhttp.Do(c.ctx_, c.s.client, req)
}
return c.s.client.Do(req)
return gensupport.SendRequest(c.ctx_, c.s.client, req)
}
// Do executes the "container.projects.zones.clusters.get" call.
@ -1340,7 +1338,8 @@ func (c *ProjectsZonesClustersGetCall) Do(opts ...googleapi.CallOption) (*Cluste
HTTPStatusCode: res.StatusCode,
},
}
if err := json.NewDecoder(res.Body).Decode(&ret); err != nil {
target := &ret
if err := json.NewDecoder(res.Body).Decode(target); err != nil {
return nil, err
}
return ret, nil
@ -1431,23 +1430,22 @@ func (c *ProjectsZonesClustersListCall) Context(ctx context.Context) *ProjectsZo
}
func (c *ProjectsZonesClustersListCall) doRequest(alt string) (*http.Response, error) {
reqHeaders := make(http.Header)
reqHeaders.Set("User-Agent", c.s.userAgent())
if c.ifNoneMatch_ != "" {
reqHeaders.Set("If-None-Match", c.ifNoneMatch_)
}
var body io.Reader = nil
c.urlParams_.Set("alt", alt)
urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/clusters")
urls += "?" + c.urlParams_.Encode()
req, _ := http.NewRequest("GET", urls, body)
req.Header = reqHeaders
googleapi.Expand(req.URL, map[string]string{
"projectId": c.projectId,
"zone": c.zone,
})
req.Header.Set("User-Agent", c.s.userAgent())
if c.ifNoneMatch_ != "" {
req.Header.Set("If-None-Match", c.ifNoneMatch_)
}
if c.ctx_ != nil {
return ctxhttp.Do(c.ctx_, c.s.client, req)
}
return c.s.client.Do(req)
return gensupport.SendRequest(c.ctx_, c.s.client, req)
}
// Do executes the "container.projects.zones.clusters.list" call.
@ -1482,7 +1480,8 @@ func (c *ProjectsZonesClustersListCall) Do(opts ...googleapi.CallOption) (*ListC
HTTPStatusCode: res.StatusCode,
},
}
if err := json.NewDecoder(res.Body).Decode(&ret); err != nil {
target := &ret
if err := json.NewDecoder(res.Body).Decode(target); err != nil {
return nil, err
}
return ret, nil
@ -1558,27 +1557,25 @@ func (c *ProjectsZonesClustersUpdateCall) Context(ctx context.Context) *Projects
}
func (c *ProjectsZonesClustersUpdateCall) doRequest(alt string) (*http.Response, error) {
reqHeaders := make(http.Header)
reqHeaders.Set("User-Agent", c.s.userAgent())
var body io.Reader = nil
body, err := googleapi.WithoutDataWrapper.JSONReader(c.updateclusterrequest)
if err != nil {
return nil, err
}
ctype := "application/json"
reqHeaders.Set("Content-Type", "application/json")
c.urlParams_.Set("alt", alt)
urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/clusters/{clusterId}")
urls += "?" + c.urlParams_.Encode()
req, _ := http.NewRequest("PUT", urls, body)
req.Header = reqHeaders
googleapi.Expand(req.URL, map[string]string{
"projectId": c.projectId,
"zone": c.zone,
"clusterId": c.clusterId,
})
req.Header.Set("Content-Type", ctype)
req.Header.Set("User-Agent", c.s.userAgent())
if c.ctx_ != nil {
return ctxhttp.Do(c.ctx_, c.s.client, req)
}
return c.s.client.Do(req)
return gensupport.SendRequest(c.ctx_, c.s.client, req)
}
// Do executes the "container.projects.zones.clusters.update" call.
@ -1613,7 +1610,8 @@ func (c *ProjectsZonesClustersUpdateCall) Do(opts ...googleapi.CallOption) (*Ope
HTTPStatusCode: res.StatusCode,
},
}
if err := json.NewDecoder(res.Body).Decode(&ret); err != nil {
target := &ret
if err := json.NewDecoder(res.Body).Decode(target); err != nil {
return nil, err
}
return ret, nil
@ -1699,27 +1697,25 @@ func (c *ProjectsZonesClustersNodePoolsCreateCall) Context(ctx context.Context)
}
func (c *ProjectsZonesClustersNodePoolsCreateCall) doRequest(alt string) (*http.Response, error) {
reqHeaders := make(http.Header)
reqHeaders.Set("User-Agent", c.s.userAgent())
var body io.Reader = nil
body, err := googleapi.WithoutDataWrapper.JSONReader(c.createnodepoolrequest)
if err != nil {
return nil, err
}
ctype := "application/json"
reqHeaders.Set("Content-Type", "application/json")
c.urlParams_.Set("alt", alt)
urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/clusters/{clusterId}/nodePools")
urls += "?" + c.urlParams_.Encode()
req, _ := http.NewRequest("POST", urls, body)
req.Header = reqHeaders
googleapi.Expand(req.URL, map[string]string{
"projectId": c.projectId,
"zone": c.zone,
"clusterId": c.clusterId,
})
req.Header.Set("Content-Type", ctype)
req.Header.Set("User-Agent", c.s.userAgent())
if c.ctx_ != nil {
return ctxhttp.Do(c.ctx_, c.s.client, req)
}
return c.s.client.Do(req)
return gensupport.SendRequest(c.ctx_, c.s.client, req)
}
// Do executes the "container.projects.zones.clusters.nodePools.create" call.
@ -1754,7 +1750,8 @@ func (c *ProjectsZonesClustersNodePoolsCreateCall) Do(opts ...googleapi.CallOpti
HTTPStatusCode: res.StatusCode,
},
}
if err := json.NewDecoder(res.Body).Decode(&ret); err != nil {
target := &ret
if err := json.NewDecoder(res.Body).Decode(target); err != nil {
return nil, err
}
return ret, nil
@ -1840,22 +1837,21 @@ func (c *ProjectsZonesClustersNodePoolsDeleteCall) Context(ctx context.Context)
}
func (c *ProjectsZonesClustersNodePoolsDeleteCall) doRequest(alt string) (*http.Response, error) {
reqHeaders := make(http.Header)
reqHeaders.Set("User-Agent", c.s.userAgent())
var body io.Reader = nil
c.urlParams_.Set("alt", alt)
urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/clusters/{clusterId}/nodePools/{nodePoolId}")
urls += "?" + c.urlParams_.Encode()
req, _ := http.NewRequest("DELETE", urls, body)
req.Header = reqHeaders
googleapi.Expand(req.URL, map[string]string{
"projectId": c.projectId,
"zone": c.zone,
"clusterId": c.clusterId,
"nodePoolId": c.nodePoolId,
})
req.Header.Set("User-Agent", c.s.userAgent())
if c.ctx_ != nil {
return ctxhttp.Do(c.ctx_, c.s.client, req)
}
return c.s.client.Do(req)
return gensupport.SendRequest(c.ctx_, c.s.client, req)
}
// Do executes the "container.projects.zones.clusters.nodePools.delete" call.
@ -1890,7 +1886,8 @@ func (c *ProjectsZonesClustersNodePoolsDeleteCall) Do(opts ...googleapi.CallOpti
HTTPStatusCode: res.StatusCode,
},
}
if err := json.NewDecoder(res.Body).Decode(&ret); err != nil {
target := &ret
if err := json.NewDecoder(res.Body).Decode(target); err != nil {
return nil, err
}
return ret, nil
@ -1991,25 +1988,24 @@ func (c *ProjectsZonesClustersNodePoolsGetCall) Context(ctx context.Context) *Pr
}
func (c *ProjectsZonesClustersNodePoolsGetCall) doRequest(alt string) (*http.Response, error) {
reqHeaders := make(http.Header)
reqHeaders.Set("User-Agent", c.s.userAgent())
if c.ifNoneMatch_ != "" {
reqHeaders.Set("If-None-Match", c.ifNoneMatch_)
}
var body io.Reader = nil
c.urlParams_.Set("alt", alt)
urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/clusters/{clusterId}/nodePools/{nodePoolId}")
urls += "?" + c.urlParams_.Encode()
req, _ := http.NewRequest("GET", urls, body)
req.Header = reqHeaders
googleapi.Expand(req.URL, map[string]string{
"projectId": c.projectId,
"zone": c.zone,
"clusterId": c.clusterId,
"nodePoolId": c.nodePoolId,
})
req.Header.Set("User-Agent", c.s.userAgent())
if c.ifNoneMatch_ != "" {
req.Header.Set("If-None-Match", c.ifNoneMatch_)
}
if c.ctx_ != nil {
return ctxhttp.Do(c.ctx_, c.s.client, req)
}
return c.s.client.Do(req)
return gensupport.SendRequest(c.ctx_, c.s.client, req)
}
// Do executes the "container.projects.zones.clusters.nodePools.get" call.
@ -2044,7 +2040,8 @@ func (c *ProjectsZonesClustersNodePoolsGetCall) Do(opts ...googleapi.CallOption)
HTTPStatusCode: res.StatusCode,
},
}
if err := json.NewDecoder(res.Body).Decode(&ret); err != nil {
target := &ret
if err := json.NewDecoder(res.Body).Decode(target); err != nil {
return nil, err
}
return ret, nil
@ -2143,24 +2140,23 @@ func (c *ProjectsZonesClustersNodePoolsListCall) Context(ctx context.Context) *P
}
func (c *ProjectsZonesClustersNodePoolsListCall) doRequest(alt string) (*http.Response, error) {
reqHeaders := make(http.Header)
reqHeaders.Set("User-Agent", c.s.userAgent())
if c.ifNoneMatch_ != "" {
reqHeaders.Set("If-None-Match", c.ifNoneMatch_)
}
var body io.Reader = nil
c.urlParams_.Set("alt", alt)
urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/clusters/{clusterId}/nodePools")
urls += "?" + c.urlParams_.Encode()
req, _ := http.NewRequest("GET", urls, body)
req.Header = reqHeaders
googleapi.Expand(req.URL, map[string]string{
"projectId": c.projectId,
"zone": c.zone,
"clusterId": c.clusterId,
})
req.Header.Set("User-Agent", c.s.userAgent())
if c.ifNoneMatch_ != "" {
req.Header.Set("If-None-Match", c.ifNoneMatch_)
}
if c.ctx_ != nil {
return ctxhttp.Do(c.ctx_, c.s.client, req)
}
return c.s.client.Do(req)
return gensupport.SendRequest(c.ctx_, c.s.client, req)
}
// Do executes the "container.projects.zones.clusters.nodePools.list" call.
@ -2195,7 +2191,8 @@ func (c *ProjectsZonesClustersNodePoolsListCall) Do(opts ...googleapi.CallOption
HTTPStatusCode: res.StatusCode,
},
}
if err := json.NewDecoder(res.Body).Decode(&ret); err != nil {
target := &ret
if err := json.NewDecoder(res.Body).Decode(target); err != nil {
return nil, err
}
return ret, nil
@ -2287,24 +2284,23 @@ func (c *ProjectsZonesOperationsGetCall) Context(ctx context.Context) *ProjectsZ
}
func (c *ProjectsZonesOperationsGetCall) doRequest(alt string) (*http.Response, error) {
reqHeaders := make(http.Header)
reqHeaders.Set("User-Agent", c.s.userAgent())
if c.ifNoneMatch_ != "" {
reqHeaders.Set("If-None-Match", c.ifNoneMatch_)
}
var body io.Reader = nil
c.urlParams_.Set("alt", alt)
urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/operations/{operationId}")
urls += "?" + c.urlParams_.Encode()
req, _ := http.NewRequest("GET", urls, body)
req.Header = reqHeaders
googleapi.Expand(req.URL, map[string]string{
"projectId": c.projectId,
"zone": c.zone,
"operationId": c.operationId,
})
req.Header.Set("User-Agent", c.s.userAgent())
if c.ifNoneMatch_ != "" {
req.Header.Set("If-None-Match", c.ifNoneMatch_)
}
if c.ctx_ != nil {
return ctxhttp.Do(c.ctx_, c.s.client, req)
}
return c.s.client.Do(req)
return gensupport.SendRequest(c.ctx_, c.s.client, req)
}
// Do executes the "container.projects.zones.operations.get" call.
@ -2339,7 +2335,8 @@ func (c *ProjectsZonesOperationsGetCall) Do(opts ...googleapi.CallOption) (*Oper
HTTPStatusCode: res.StatusCode,
},
}
if err := json.NewDecoder(res.Body).Decode(&ret); err != nil {
target := &ret
if err := json.NewDecoder(res.Body).Decode(target); err != nil {
return nil, err
}
return ret, nil
@ -2430,23 +2427,22 @@ func (c *ProjectsZonesOperationsListCall) Context(ctx context.Context) *Projects
}
func (c *ProjectsZonesOperationsListCall) doRequest(alt string) (*http.Response, error) {
reqHeaders := make(http.Header)
reqHeaders.Set("User-Agent", c.s.userAgent())
if c.ifNoneMatch_ != "" {
reqHeaders.Set("If-None-Match", c.ifNoneMatch_)
}
var body io.Reader = nil
c.urlParams_.Set("alt", alt)
urls := googleapi.ResolveRelative(c.s.BasePath, "v1/projects/{projectId}/zones/{zone}/operations")
urls += "?" + c.urlParams_.Encode()
req, _ := http.NewRequest("GET", urls, body)
req.Header = reqHeaders
googleapi.Expand(req.URL, map[string]string{
"projectId": c.projectId,
"zone": c.zone,
})
req.Header.Set("User-Agent", c.s.userAgent())
if c.ifNoneMatch_ != "" {
req.Header.Set("If-None-Match", c.ifNoneMatch_)
}
if c.ctx_ != nil {
return ctxhttp.Do(c.ctx_, c.s.client, req)
}
return c.s.client.Do(req)
return gensupport.SendRequest(c.ctx_, c.s.client, req)
}
// Do executes the "container.projects.zones.operations.list" call.
@ -2481,7 +2477,8 @@ func (c *ProjectsZonesOperationsListCall) Do(opts ...googleapi.CallOption) (*Lis
HTTPStatusCode: res.StatusCode,
},
}
if err := json.NewDecoder(res.Body).Decode(&ret); err != nil {
target := &ret
if err := json.NewDecoder(res.Body).Decode(target); err != nil {
return nil, err
}
return ret, nil

View file

@ -12,7 +12,6 @@ import (
"time"
"golang.org/x/net/context"
"golang.org/x/net/context/ctxhttp"
)
const (
@ -80,7 +79,7 @@ func (rx *ResumableUpload) doUploadRequest(ctx context.Context, data io.Reader,
req.Header.Set("Content-Range", contentRange)
req.Header.Set("Content-Type", rx.MediaType)
req.Header.Set("User-Agent", rx.UserAgent)
return ctxhttp.Do(ctx, rx.Client, req)
return SendRequest(ctx, rx.Client, req)
}
@ -135,6 +134,8 @@ func contextDone(ctx context.Context) bool {
// It retries using the provided back off strategy until cancelled or the
// strategy indicates to stop retrying.
// It is called from the auto-generated API code and is not visible to the user.
// Before sending an HTTP request, Upload calls any registered hook functions,
// and calls the returned functions after the request returns (see send.go).
// rx is private to the auto-generated API code.
// Exactly one of resp or err will be nil. If resp is non-nil, the caller must call resp.Body.Close.
func (rx *ResumableUpload) Upload(ctx context.Context) (resp *http.Response, err error) {

55
vendor/google.golang.org/api/gensupport/send.go generated vendored Normal file
View file

@ -0,0 +1,55 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gensupport
import (
"net/http"
"golang.org/x/net/context"
"golang.org/x/net/context/ctxhttp"
)
// Hook is the type of a function that is called once before each HTTP request
// that is sent by a generated API. It returns a function that is called after
// the request returns.
// Hooks are not called if the context is nil.
type Hook func(ctx context.Context, req *http.Request) func(resp *http.Response)
var hooks []Hook
// RegisterHook registers a Hook to be called before each HTTP request by a
// generated API. Hooks are called in the order they are registered. Each
// hook can return a function; if it is non-nil, it is called after the HTTP
// request returns. These functions are called in the reverse order.
// RegisterHook should not be called concurrently with itself or SendRequest.
func RegisterHook(h Hook) {
hooks = append(hooks, h)
}
// SendRequest sends a single HTTP request using the given client.
// If ctx is non-nil, it calls all hooks, then sends the request with
// ctxhttp.Do, then calls any functions returned by the hooks in reverse order.
func SendRequest(ctx context.Context, client *http.Client, req *http.Request) (*http.Response, error) {
if ctx == nil {
return client.Do(req)
}
// Call hooks in order of registration, store returned funcs.
post := make([]func(resp *http.Response), len(hooks))
for i, h := range hooks {
fn := h(ctx, req)
post[i] = fn
}
// Send request.
resp, err := ctxhttp.Do(ctx, client, req)
// Call returned funcs in reverse order.
for i := len(post) - 1; i >= 0; i-- {
if fn := post[i]; fn != nil {
fn(resp)
}
}
return resp, err
}

View file

@ -149,12 +149,12 @@ func IsNotModified(err error) bool {
// CheckMediaResponse returns an error (of type *Error) if the response
// status code is not 2xx. Unlike CheckResponse it does not assume the
// body is a JSON error document.
// It is the caller's responsibility to close res.Body.
func CheckMediaResponse(res *http.Response) error {
if res.StatusCode >= 200 && res.StatusCode <= 299 {
return nil
}
slurp, _ := ioutil.ReadAll(io.LimitReader(res.Body, 1<<20))
res.Body.Close()
return &Error{
Code: res.StatusCode,
Body: string(slurp),
@ -278,41 +278,15 @@ func ResolveRelative(basestr, relstr string) string {
return us
}
// has4860Fix is whether this Go environment contains the fix for
// http://golang.org/issue/4860
var has4860Fix bool
// init initializes has4860Fix by checking the behavior of the net/http package.
func init() {
r := http.Request{
URL: &url.URL{
Scheme: "http",
Opaque: "//opaque",
},
}
b := &bytes.Buffer{}
r.Write(b)
has4860Fix = bytes.HasPrefix(b.Bytes(), []byte("GET http"))
}
// SetOpaque sets u.Opaque from u.Path such that HTTP requests to it
// don't alter any hex-escaped characters in u.Path.
func SetOpaque(u *url.URL) {
u.Opaque = "//" + u.Host + u.Path
if !has4860Fix {
u.Opaque = u.Scheme + ":" + u.Opaque
}
}
// Expand subsitutes any {encoded} strings in the URL passed in using
// the map supplied.
//
// This calls SetOpaque to avoid encoding of the parameters in the URL path.
func Expand(u *url.URL, expansions map[string]string) {
expanded, err := uritemplates.Expand(u.Path, expansions)
escaped, unescaped, err := uritemplates.Expand(u.Path, expansions)
if err == nil {
u.Path = expanded
SetOpaque(u)
u.Path = unescaped
u.RawPath = escaped
}
}

View file

@ -34,11 +34,37 @@ func pctEncode(src []byte) []byte {
return dst
}
func escape(s string, allowReserved bool) string {
// pairWriter is a convenience struct which allows escaped and unescaped
// versions of the template to be written in parallel.
type pairWriter struct {
escaped, unescaped bytes.Buffer
}
// Write writes the provided string directly without any escaping.
func (w *pairWriter) Write(s string) {
w.escaped.WriteString(s)
w.unescaped.WriteString(s)
}
// Escape writes the provided string, escaping the string for the
// escaped output.
func (w *pairWriter) Escape(s string, allowReserved bool) {
w.unescaped.WriteString(s)
if allowReserved {
return string(reserved.ReplaceAllFunc([]byte(s), pctEncode))
w.escaped.Write(reserved.ReplaceAllFunc([]byte(s), pctEncode))
} else {
w.escaped.Write(unreserved.ReplaceAllFunc([]byte(s), pctEncode))
}
return string(unreserved.ReplaceAllFunc([]byte(s), pctEncode))
}
// Escaped returns the escaped string.
func (w *pairWriter) Escaped() string {
return w.escaped.String()
}
// Unescaped returns the unescaped string.
func (w *pairWriter) Unescaped() string {
return w.unescaped.String()
}
// A uriTemplate is a parsed representation of a URI template.
@ -170,18 +196,20 @@ func parseTerm(term string) (result templateTerm, err error) {
return result, err
}
// Expand expands a URI template with a set of values to produce a string.
func (t *uriTemplate) Expand(values map[string]string) string {
var buf bytes.Buffer
// Expand expands a URI template with a set of values to produce the
// resultant URI. Two forms of the result are returned: one with all the
// elements escaped, and one with the elements unescaped.
func (t *uriTemplate) Expand(values map[string]string) (escaped, unescaped string) {
var w pairWriter
for _, p := range t.parts {
p.expand(&buf, values)
p.expand(&w, values)
}
return buf.String()
return w.Escaped(), w.Unescaped()
}
func (tp *templatePart) expand(buf *bytes.Buffer, values map[string]string) {
func (tp *templatePart) expand(w *pairWriter, values map[string]string) {
if len(tp.raw) > 0 {
buf.WriteString(tp.raw)
w.Write(tp.raw)
return
}
var first = true
@ -191,30 +219,30 @@ func (tp *templatePart) expand(buf *bytes.Buffer, values map[string]string) {
continue
}
if first {
buf.WriteString(tp.first)
w.Write(tp.first)
first = false
} else {
buf.WriteString(tp.sep)
w.Write(tp.sep)
}
tp.expandString(buf, term, value)
tp.expandString(w, term, value)
}
}
func (tp *templatePart) expandName(buf *bytes.Buffer, name string, empty bool) {
func (tp *templatePart) expandName(w *pairWriter, name string, empty bool) {
if tp.named {
buf.WriteString(name)
w.Write(name)
if empty {
buf.WriteString(tp.ifemp)
w.Write(tp.ifemp)
} else {
buf.WriteString("=")
w.Write("=")
}
}
}
func (tp *templatePart) expandString(buf *bytes.Buffer, t templateTerm, s string) {
func (tp *templatePart) expandString(w *pairWriter, t templateTerm, s string) {
if len(s) > t.truncate && t.truncate > 0 {
s = s[:t.truncate]
}
tp.expandName(buf, t.name, len(s) == 0)
buf.WriteString(escape(s, tp.allowReserved))
tp.expandName(w, t.name, len(s) == 0)
w.Escape(s, tp.allowReserved)
}

View file

@ -4,10 +4,14 @@
package uritemplates
func Expand(path string, values map[string]string) (string, error) {
// Expand parses then expands a URI template with a set of values to produce
// the resultant URI. Two forms of the result are returned: one with all the
// elements escaped, and one with the elements unescaped.
func Expand(path string, values map[string]string) (escaped, unescaped string, err error) {
template, err := parse(path)
if err != nil {
return "", err
return "", "", err
}
return template.Expand(values), nil
escaped, unescaped = template.Expand(values)
return escaped, unescaped, nil
}

View file

@ -1,128 +0,0 @@
// Copyright 2014 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package internal provides support for the cloud packages.
//
// Users should not import this package directly.
package internal
import (
"fmt"
"net/http"
"sync"
"golang.org/x/net/context"
)
type contextKey struct{}
func WithContext(parent context.Context, projID string, c *http.Client) context.Context {
if c == nil {
panic("nil *http.Client passed to WithContext")
}
if projID == "" {
panic("empty project ID passed to WithContext")
}
return context.WithValue(parent, contextKey{}, &cloudContext{
ProjectID: projID,
HTTPClient: c,
})
}
const userAgent = "gcloud-golang/0.1"
type cloudContext struct {
ProjectID string
HTTPClient *http.Client
mu sync.Mutex // guards svc
svc map[string]interface{} // e.g. "storage" => *rawStorage.Service
}
// Service returns the result of the fill function if it's never been
// called before for the given name (which is assumed to be an API
// service name, like "datastore"). If it has already been cached, the fill
// func is not run.
// It's safe for concurrent use by multiple goroutines.
func Service(ctx context.Context, name string, fill func(*http.Client) interface{}) interface{} {
return cc(ctx).service(name, fill)
}
func (c *cloudContext) service(name string, fill func(*http.Client) interface{}) interface{} {
c.mu.Lock()
defer c.mu.Unlock()
if c.svc == nil {
c.svc = make(map[string]interface{})
} else if v, ok := c.svc[name]; ok {
return v
}
v := fill(c.HTTPClient)
c.svc[name] = v
return v
}
// Transport is an http.RoundTripper that appends
// Google Cloud client's user-agent to the original
// request's user-agent header.
type Transport struct {
// Base is the actual http.RoundTripper
// requests will use. It must not be nil.
Base http.RoundTripper
}
// RoundTrip appends a user-agent to the existing user-agent
// header and delegates the request to the base http.RoundTripper.
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
req = cloneRequest(req)
ua := req.Header.Get("User-Agent")
if ua == "" {
ua = userAgent
} else {
ua = fmt.Sprintf("%s %s", ua, userAgent)
}
req.Header.Set("User-Agent", ua)
return t.Base.RoundTrip(req)
}
// cloneRequest returns a clone of the provided *http.Request.
// The clone is a shallow copy of the struct and its Header map.
func cloneRequest(r *http.Request) *http.Request {
// shallow copy of the struct
r2 := new(http.Request)
*r2 = *r
// deep copy of the Header
r2.Header = make(http.Header)
for k, s := range r.Header {
r2.Header[k] = s
}
return r2
}
func ProjID(ctx context.Context) string {
return cc(ctx).ProjectID
}
func HTTPClient(ctx context.Context) *http.Client {
return cc(ctx).HTTPClient
}
// cc returns the internal *cloudContext (cc) state for a context.Context.
// It panics if the user did it wrong.
func cc(ctx context.Context) *cloudContext {
if c, ok := ctx.Value(contextKey{}).(*cloudContext); ok {
return c
}
panic("invalid context.Context type; it should be created with cloud.NewContext")
}

View file

@ -1,17 +1,21 @@
language: go
go:
- 1.5.3
- 1.6
- 1.5.4
- 1.6.3
go_import_path: google.golang.org/grpc
before_install:
- go get golang.org/x/tools/cmd/goimports
- go get github.com/golang/lint/golint
- go get github.com/axw/gocov/gocov
- go get github.com/mattn/goveralls
- go get golang.org/x/tools/cmd/cover
install:
- mkdir -p "$GOPATH/src/google.golang.org"
- mv "$TRAVIS_BUILD_DIR" "$GOPATH/src/google.golang.org/grpc"
script:
- '! gofmt -s -d -l . 2>&1 | read'
- '! goimports -l . | read'
- '! golint ./... | grep -vE "(_string|\.pb)\.go:"'
- '! go tool vet -all . 2>&1 | grep -vE "constant [0-9]+ not a string in call to Errorf"'
- make test testrace

View file

@ -36,6 +36,7 @@ package grpc
import (
"bytes"
"io"
"math"
"time"
"golang.org/x/net/context"
@ -51,13 +52,20 @@ import (
func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error {
// Try to acquire header metadata from the server if there is any.
var err error
defer func() {
if err != nil {
if _, ok := err.(transport.ConnectionError); !ok {
t.CloseStream(stream, err)
}
}
}()
c.headerMD, err = stream.Header()
if err != nil {
return err
}
p := &parser{r: stream}
for {
if err = recv(p, dopts.codec, stream, dopts.dc, reply); err != nil {
if err = recv(p, dopts.codec, stream, dopts.dc, reply, math.MaxInt32); err != nil {
if err == io.EOF {
break
}
@ -76,6 +84,7 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd
}
defer func() {
if err != nil {
// If err is connection error, t will be closed, no need to close stream here.
if _, ok := err.(transport.ConnectionError); !ok {
t.CloseStream(stream, err)
}
@ -90,7 +99,10 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd
return nil, transport.StreamErrorf(codes.Internal, "grpc: %v", err)
}
err = t.Write(stream, outBuf, opts)
if err != nil {
// t.NewStream(...) could lead to an early rejection of the RPC (e.g., the service/method
// does not exist.) so that t.Write could get io.EOF from wait(...). Leave the following
// recvResponse to get the final status.
if err != nil && err != io.EOF {
return nil, err
}
// Sent successfully.
@ -158,9 +170,9 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
if _, ok := err.(*rpcError); ok {
return err
}
if err == errConnClosing {
if err == errConnClosing || err == errConnUnavailable {
if c.failFast {
return Errorf(codes.Unavailable, "%v", errConnClosing)
return Errorf(codes.Unavailable, "%v", err)
}
continue
}
@ -176,7 +188,10 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
put()
put = nil
}
if _, ok := err.(transport.ConnectionError); ok {
// Retry a non-failfast RPC when
// i) there is a connection error; or
// ii) the server started to drain before this RPC was initiated.
if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
if c.failFast {
return toRPCErr(err)
}
@ -184,20 +199,18 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
}
return toRPCErr(err)
}
// Receive the response
err = recvResponse(cc.dopts, t, &c, stream, reply)
if err != nil {
if put != nil {
put()
put = nil
}
if _, ok := err.(transport.ConnectionError); ok {
if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
if c.failFast {
return toRPCErr(err)
}
continue
}
t.CloseStream(stream, err)
return toRPCErr(err)
}
if c.traceInfo.tr != nil {

View file

@ -43,7 +43,6 @@ import (
"golang.org/x/net/context"
"golang.org/x/net/trace"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/transport"
@ -68,13 +67,15 @@ var (
// errCredentialsConflict indicates that grpc.WithTransportCredentials()
// and grpc.WithInsecure() are both called for a connection.
errCredentialsConflict = errors.New("grpc: transport credentials are set for an insecure connection (grpc.WithTransportCredentials() and grpc.WithInsecure() are both called)")
// errNetworkIP indicates that the connection is down due to some network I/O error.
// errNetworkIO indicates that the connection is down due to some network I/O error.
errNetworkIO = errors.New("grpc: failed with network I/O error")
// errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs.
errConnDrain = errors.New("grpc: the connection is drained")
// errConnClosing indicates that the connection is closing.
errConnClosing = errors.New("grpc: the connection is closing")
errNoAddr = errors.New("grpc: there is no address available to dial")
// errConnUnavailable indicates that the connection is unavailable.
errConnUnavailable = errors.New("grpc: the connection is unavailable")
errNoAddr = errors.New("grpc: there is no address available to dial")
// minimum time to give a connection to complete
minConnectTimeout = 20 * time.Second
)
@ -196,9 +197,14 @@ func WithTimeout(d time.Duration) DialOption {
}
// WithDialer returns a DialOption that specifies a function to use for dialing network addresses.
func WithDialer(f func(addr string, timeout time.Duration) (net.Conn, error)) DialOption {
func WithDialer(f func(string, time.Duration) (net.Conn, error)) DialOption {
return func(o *dialOptions) {
o.copts.Dialer = f
o.copts.Dialer = func(ctx context.Context, addr string) (net.Conn, error) {
if deadline, ok := ctx.Deadline(); ok {
return f(addr, deadline.Sub(time.Now()))
}
return f(addr, 0)
}
}
}
@ -209,12 +215,19 @@ func WithUserAgent(s string) DialOption {
}
}
// Dial creates a client connection the given target.
// Dial creates a client connection to the given target.
func Dial(target string, opts ...DialOption) (*ClientConn, error) {
return DialContext(context.Background(), target, opts...)
}
// DialContext creates a client connection to the given target
// using the supplied context.
func DialContext(ctx context.Context, target string, opts ...DialOption) (*ClientConn, error) {
cc := &ClientConn{
target: target,
conns: make(map[Address]*addrConn),
}
cc.ctx, cc.cancel = context.WithCancel(ctx)
for _, opt := range opts {
opt(&cc.dopts)
}
@ -226,31 +239,33 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
if cc.dopts.bs == nil {
cc.dopts.bs = DefaultBackoffConfig
}
if cc.dopts.balancer == nil {
cc.dopts.balancer = RoundRobin(nil)
}
if err := cc.dopts.balancer.Start(target); err != nil {
return nil, err
}
var (
ok bool
addrs []Address
)
ch := cc.dopts.balancer.Notify()
if ch == nil {
// There is no name resolver installed.
if cc.dopts.balancer == nil {
// Connect to target directly if balancer is nil.
addrs = append(addrs, Address{Addr: target})
} else {
addrs, ok = <-ch
if !ok || len(addrs) == 0 {
return nil, errNoAddr
if err := cc.dopts.balancer.Start(target); err != nil {
return nil, err
}
ch := cc.dopts.balancer.Notify()
if ch == nil {
// There is no name resolver installed.
addrs = append(addrs, Address{Addr: target})
} else {
addrs, ok = <-ch
if !ok || len(addrs) == 0 {
return nil, errNoAddr
}
}
}
waitC := make(chan error, 1)
go func() {
for _, a := range addrs {
if err := cc.newAddrConn(a, false); err != nil {
if err := cc.resetAddrConn(a, false, nil); err != nil {
waitC <- err
return
}
@ -267,10 +282,15 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
cc.Close()
return nil, err
}
case <-cc.ctx.Done():
cc.Close()
return nil, cc.ctx.Err()
case <-timeoutCh:
cc.Close()
return nil, ErrClientConnTimeout
}
// If balancer is nil or balancer.Notify() is nil, ok will be false here.
// The lbWatcher goroutine will not be created.
if ok {
go cc.lbWatcher()
}
@ -317,6 +337,9 @@ func (s ConnectivityState) String() string {
// ClientConn represents a client connection to an RPC server.
type ClientConn struct {
ctx context.Context
cancel context.CancelFunc
target string
authority string
dopts dialOptions
@ -347,11 +370,12 @@ func (cc *ClientConn) lbWatcher() {
}
if !keep {
del = append(del, c)
delete(cc.conns, c.addr)
}
}
cc.mu.Unlock()
for _, a := range add {
cc.newAddrConn(a, true)
cc.resetAddrConn(a, true, nil)
}
for _, c := range del {
c.tearDown(errConnDrain)
@ -359,13 +383,17 @@ func (cc *ClientConn) lbWatcher() {
}
}
func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
// resetAddrConn creates an addrConn for addr and adds it to cc.conns.
// If there is an old addrConn for addr, it will be torn down, using tearDownErr as the reason.
// If tearDownErr is nil, errConnDrain will be used instead.
func (cc *ClientConn) resetAddrConn(addr Address, skipWait bool, tearDownErr error) error {
ac := &addrConn{
cc: cc,
addr: addr,
dopts: cc.dopts,
shutdownChan: make(chan struct{}),
cc: cc,
addr: addr,
dopts: cc.dopts,
}
ac.ctx, ac.cancel = context.WithCancel(cc.ctx)
ac.stateCV = sync.NewCond(&ac.mu)
if EnableTracing {
ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr)
}
@ -383,26 +411,44 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
}
}
}
// Insert ac into ac.cc.conns. This needs to be done before any getTransport(...) is called.
ac.cc.mu.Lock()
if ac.cc.conns == nil {
ac.cc.mu.Unlock()
// Track ac in cc. This needs to be done before any getTransport(...) is called.
cc.mu.Lock()
if cc.conns == nil {
cc.mu.Unlock()
return ErrClientConnClosing
}
stale := ac.cc.conns[ac.addr]
ac.cc.conns[ac.addr] = ac
ac.cc.mu.Unlock()
stale := cc.conns[ac.addr]
cc.conns[ac.addr] = ac
cc.mu.Unlock()
if stale != nil {
// There is an addrConn alive on ac.addr already. This could be due to
// i) stale's Close is undergoing;
// ii) a buggy Balancer notifies duplicated Addresses.
stale.tearDown(errConnDrain)
// 1) a buggy Balancer notifies duplicated Addresses;
// 2) goaway was received, a new ac will replace the old ac.
// The old ac should be deleted from cc.conns, but the
// underlying transport should drain rather than close.
if tearDownErr == nil {
// tearDownErr is nil if resetAddrConn is called by
// 1) Dial
// 2) lbWatcher
// In both cases, the stale ac should drain, not close.
stale.tearDown(errConnDrain)
} else {
stale.tearDown(tearDownErr)
}
}
ac.stateCV = sync.NewCond(&ac.mu)
// skipWait may overwrite the decision in ac.dopts.block.
if ac.dopts.block && !skipWait {
if err := ac.resetTransport(false); err != nil {
ac.tearDown(err)
if err != errConnClosing {
// Tear down ac and delete it from cc.conns.
cc.mu.Lock()
delete(cc.conns, ac.addr)
cc.mu.Unlock()
ac.tearDown(err)
}
if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() {
return e.Origin()
}
return err
}
// Start to monitor the error status of transport.
@ -412,7 +458,10 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
go func() {
if err := ac.resetTransport(false); err != nil {
grpclog.Printf("Failed to dial %s: %v; please retry.", ac.addr.Addr, err)
ac.tearDown(err)
if err != errConnClosing {
// Keep this ac in cc.conns, to get the reason it's torn down.
ac.tearDown(err)
}
return
}
ac.transportMonitor()
@ -422,24 +471,48 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
}
func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) (transport.ClientTransport, func(), error) {
addr, put, err := cc.dopts.balancer.Get(ctx, opts)
if err != nil {
return nil, nil, toRPCErr(err)
}
cc.mu.RLock()
if cc.conns == nil {
var (
ac *addrConn
ok bool
put func()
)
if cc.dopts.balancer == nil {
// If balancer is nil, there should be only one addrConn available.
cc.mu.RLock()
if cc.conns == nil {
cc.mu.RUnlock()
return nil, nil, toRPCErr(ErrClientConnClosing)
}
for _, ac = range cc.conns {
// Break after the first iteration to get the first addrConn.
ok = true
break
}
cc.mu.RUnlock()
} else {
var (
addr Address
err error
)
addr, put, err = cc.dopts.balancer.Get(ctx, opts)
if err != nil {
return nil, nil, toRPCErr(err)
}
cc.mu.RLock()
if cc.conns == nil {
cc.mu.RUnlock()
return nil, nil, toRPCErr(ErrClientConnClosing)
}
ac, ok = cc.conns[addr]
cc.mu.RUnlock()
return nil, nil, toRPCErr(ErrClientConnClosing)
}
ac, ok := cc.conns[addr]
cc.mu.RUnlock()
if !ok {
if put != nil {
put()
}
return nil, nil, Errorf(codes.Internal, "grpc: failed to find the transport to send the rpc")
return nil, nil, errConnClosing
}
t, err := ac.wait(ctx, !opts.BlockingWait)
t, err := ac.wait(ctx, cc.dopts.balancer != nil, !opts.BlockingWait)
if err != nil {
if put != nil {
put()
@ -451,6 +524,8 @@ func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions)
// Close tears down the ClientConn and all underlying connections.
func (cc *ClientConn) Close() error {
cc.cancel()
cc.mu.Lock()
if cc.conns == nil {
cc.mu.Unlock()
@ -459,7 +534,9 @@ func (cc *ClientConn) Close() error {
conns := cc.conns
cc.conns = nil
cc.mu.Unlock()
cc.dopts.balancer.Close()
if cc.dopts.balancer != nil {
cc.dopts.balancer.Close()
}
for _, ac := range conns {
ac.tearDown(ErrClientConnClosing)
}
@ -468,11 +545,13 @@ func (cc *ClientConn) Close() error {
// addrConn is a network connection to a given address.
type addrConn struct {
cc *ClientConn
addr Address
dopts dialOptions
shutdownChan chan struct{}
events trace.EventLog
ctx context.Context
cancel context.CancelFunc
cc *ClientConn
addr Address
dopts dialOptions
events trace.EventLog
mu sync.Mutex
state ConnectivityState
@ -482,6 +561,9 @@ type addrConn struct {
// due to timeout.
ready chan struct{}
transport transport.ClientTransport
// The reason this addrConn is torn down.
tearDownErr error
}
// printf records an event in ac's event log, unless ac has been closed.
@ -537,8 +619,7 @@ func (ac *addrConn) waitForStateChange(ctx context.Context, sourceState Connecti
}
func (ac *addrConn) resetTransport(closeTransport bool) error {
var retries int
for {
for retries := 0; ; retries++ {
ac.mu.Lock()
ac.printf("connecting")
if ac.state == Shutdown {
@ -558,13 +639,20 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
t.Close()
}
sleepTime := ac.dopts.bs.backoff(retries)
ac.dopts.copts.Timeout = sleepTime
if sleepTime < minConnectTimeout {
ac.dopts.copts.Timeout = minConnectTimeout
timeout := minConnectTimeout
if timeout < sleepTime {
timeout = sleepTime
}
ctx, cancel := context.WithTimeout(ac.ctx, timeout)
connectTime := time.Now()
newTransport, err := transport.NewClientTransport(ac.addr.Addr, &ac.dopts.copts)
newTransport, err := transport.NewClientTransport(ctx, ac.addr.Addr, ac.dopts.copts)
if err != nil {
cancel()
if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() {
return err
}
grpclog.Printf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, ac.addr)
ac.mu.Lock()
if ac.state == Shutdown {
// ac.tearDown(...) has been invoked.
@ -579,17 +667,12 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
ac.ready = nil
}
ac.mu.Unlock()
sleepTime -= time.Since(connectTime)
if sleepTime < 0 {
sleepTime = 0
}
closeTransport = false
select {
case <-time.After(sleepTime):
case <-ac.shutdownChan:
case <-time.After(sleepTime - time.Since(connectTime)):
case <-ac.ctx.Done():
return ac.ctx.Err()
}
retries++
grpclog.Printf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, ac.addr)
continue
}
ac.mu.Lock()
@ -607,7 +690,9 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
close(ac.ready)
ac.ready = nil
}
ac.down = ac.cc.dopts.balancer.Up(ac.addr)
if ac.cc.dopts.balancer != nil {
ac.down = ac.cc.dopts.balancer.Up(ac.addr)
}
ac.mu.Unlock()
return nil
}
@ -621,14 +706,42 @@ func (ac *addrConn) transportMonitor() {
t := ac.transport
ac.mu.Unlock()
select {
// shutdownChan is needed to detect the teardown when
// This is needed to detect the teardown when
// the addrConn is idle (i.e., no RPC in flight).
case <-ac.shutdownChan:
case <-ac.ctx.Done():
select {
case <-t.Error():
t.Close()
default:
}
return
case <-t.GoAway():
// If GoAway happens without any network I/O error, ac is closed without shutting down the
// underlying transport (the transport will be closed when all the pending RPCs finished or
// failed.).
// If GoAway and some network I/O error happen concurrently, ac and its underlying transport
// are closed.
// In both cases, a new ac is created.
select {
case <-t.Error():
ac.cc.resetAddrConn(ac.addr, true, errNetworkIO)
default:
ac.cc.resetAddrConn(ac.addr, true, errConnDrain)
}
return
case <-t.Error():
select {
case <-ac.ctx.Done():
t.Close()
return
case <-t.GoAway():
ac.cc.resetAddrConn(ac.addr, true, errNetworkIO)
return
default:
}
ac.mu.Lock()
if ac.state == Shutdown {
// ac.tearDown(...) has been invoked.
// ac has been shutdown.
ac.mu.Unlock()
return
}
@ -640,6 +753,10 @@ func (ac *addrConn) transportMonitor() {
ac.printf("transport exiting: %v", err)
ac.mu.Unlock()
grpclog.Printf("grpc: addrConn.transportMonitor exits due to: %v", err)
if err != errConnClosing {
// Keep this ac in cc.conns, to get the reason it's torn down.
ac.tearDown(err)
}
return
}
}
@ -647,35 +764,42 @@ func (ac *addrConn) transportMonitor() {
}
// wait blocks until i) the new transport is up or ii) ctx is done or iii) ac is closed or
// iv) transport is in TransientFailure and the RPC is fail-fast.
func (ac *addrConn) wait(ctx context.Context, failFast bool) (transport.ClientTransport, error) {
// iv) transport is in TransientFailure and there's no balancer/failfast is true.
func (ac *addrConn) wait(ctx context.Context, hasBalancer, failfast bool) (transport.ClientTransport, error) {
for {
ac.mu.Lock()
switch {
case ac.state == Shutdown:
if failfast || !hasBalancer {
// RPC is failfast or balancer is nil. This RPC should fail with ac.tearDownErr.
err := ac.tearDownErr
ac.mu.Unlock()
return nil, err
}
ac.mu.Unlock()
return nil, errConnClosing
case ac.state == Ready:
ct := ac.transport
ac.mu.Unlock()
return ct, nil
case ac.state == TransientFailure && failFast:
ac.mu.Unlock()
return nil, Errorf(codes.Unavailable, "grpc: RPC failed fast due to transport failure")
default:
ready := ac.ready
if ready == nil {
ready = make(chan struct{})
ac.ready = ready
}
ac.mu.Unlock()
select {
case <-ctx.Done():
return nil, toRPCErr(ctx.Err())
// Wait until the new transport is ready or failed.
case <-ready:
case ac.state == TransientFailure:
if failfast || hasBalancer {
ac.mu.Unlock()
return nil, errConnUnavailable
}
}
ready := ac.ready
if ready == nil {
ready = make(chan struct{})
ac.ready = ready
}
ac.mu.Unlock()
select {
case <-ctx.Done():
return nil, toRPCErr(ctx.Err())
// Wait until the new transport is ready or failed.
case <-ready:
}
}
}
@ -683,24 +807,28 @@ func (ac *addrConn) wait(ctx context.Context, failFast bool) (transport.ClientTr
// TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in
// some edge cases (e.g., the caller opens and closes many addrConn's in a
// tight loop.
// tearDown doesn't remove ac from ac.cc.conns.
func (ac *addrConn) tearDown(err error) {
ac.cancel()
ac.mu.Lock()
defer func() {
ac.mu.Unlock()
ac.cc.mu.Lock()
if ac.cc.conns != nil {
delete(ac.cc.conns, ac.addr)
}
ac.cc.mu.Unlock()
}()
if ac.state == Shutdown {
return
}
ac.state = Shutdown
defer ac.mu.Unlock()
if ac.down != nil {
ac.down(downErrorf(false, false, "%v", err))
ac.down = nil
}
if err == errConnDrain && ac.transport != nil {
// GracefulClose(...) may be executed multiple times when
// i) receiving multiple GoAway frames from the server; or
// ii) there are concurrent name resolver/Balancer triggered
// address removal and GoAway.
ac.transport.GracefulClose()
}
if ac.state == Shutdown {
return
}
ac.state = Shutdown
ac.tearDownErr = err
ac.stateCV.Broadcast()
if ac.events != nil {
ac.events.Finish()
@ -710,15 +838,8 @@ func (ac *addrConn) tearDown(err error) {
close(ac.ready)
ac.ready = nil
}
if ac.transport != nil {
if err == errConnDrain {
ac.transport.GracefulClose()
} else {
ac.transport.Close()
}
}
if ac.shutdownChan != nil {
close(ac.shutdownChan)
if ac.transport != nil && err != errConnDrain {
ac.transport.Close()
}
return
}

View file

@ -44,7 +44,6 @@ import (
"io/ioutil"
"net"
"strings"
"time"
"golang.org/x/net/context"
)
@ -93,11 +92,12 @@ type TransportCredentials interface {
// ClientHandshake does the authentication handshake specified by the corresponding
// authentication protocol on rawConn for clients. It returns the authenticated
// connection and the corresponding auth information about the connection.
ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, AuthInfo, error)
// Implementations must use the provided context to implement timely cancellation.
ClientHandshake(context.Context, string, net.Conn) (net.Conn, AuthInfo, error)
// ServerHandshake does the authentication handshake for servers. It returns
// the authenticated connection and the corresponding auth information about
// the connection.
ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
ServerHandshake(net.Conn) (net.Conn, AuthInfo, error)
// Info provides the ProtocolInfo of this TransportCredentials.
Info() ProtocolInfo
}
@ -136,42 +136,28 @@ func (c *tlsCreds) RequireTransportSecurity() bool {
return true
}
type timeoutError struct{}
func (timeoutError) Error() string { return "credentials: Dial timed out" }
func (timeoutError) Timeout() bool { return true }
func (timeoutError) Temporary() bool { return true }
func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, _ AuthInfo, err error) {
// borrow some code from tls.DialWithDialer
var errChannel chan error
if timeout != 0 {
errChannel = make(chan error, 2)
time.AfterFunc(timeout, func() {
errChannel <- timeoutError{}
})
}
func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ AuthInfo, err error) {
// use local cfg to avoid clobbering ServerName if using multiple endpoints
cfg := *c.config
if c.config.ServerName == "" {
cfg := cloneTLSConfig(c.config)
if cfg.ServerName == "" {
colonPos := strings.LastIndex(addr, ":")
if colonPos == -1 {
colonPos = len(addr)
}
cfg.ServerName = addr[:colonPos]
}
conn := tls.Client(rawConn, &cfg)
if timeout == 0 {
err = conn.Handshake()
} else {
go func() {
errChannel <- conn.Handshake()
}()
err = <-errChannel
}
if err != nil {
rawConn.Close()
return nil, nil, err
conn := tls.Client(rawConn, cfg)
errChannel := make(chan error, 1)
go func() {
errChannel <- conn.Handshake()
}()
select {
case err := <-errChannel:
if err != nil {
return nil, nil, err
}
case <-ctx.Done():
return nil, nil, ctx.Err()
}
// TODO(zhaoq): Omit the auth info for client now. It is more for
// information than anything else.
@ -181,7 +167,6 @@ func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.D
func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) {
conn := tls.Server(rawConn, c.config)
if err := conn.Handshake(); err != nil {
rawConn.Close()
return nil, nil, err
}
return conn, TLSInfo{conn.ConnectionState()}, nil
@ -189,7 +174,7 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
// NewTLS uses c to construct a TransportCredentials based on TLS.
func NewTLS(c *tls.Config) TransportCredentials {
tc := &tlsCreds{c}
tc := &tlsCreds{cloneTLSConfig(c)}
tc.config.NextProtos = alpnProtoStr
return tc
}

View file

@ -0,0 +1,76 @@
// +build go1.7
/*
*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package credentials
import (
"crypto/tls"
)
// cloneTLSConfig returns a shallow clone of the exported
// fields of cfg, ignoring the unexported sync.Once, which
// contains a mutex and must not be copied.
//
// If cfg is nil, a new zero tls.Config is returned.
//
// TODO replace this function with official clone function.
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return &tls.Config{
Rand: cfg.Rand,
Time: cfg.Time,
Certificates: cfg.Certificates,
NameToCertificate: cfg.NameToCertificate,
GetCertificate: cfg.GetCertificate,
RootCAs: cfg.RootCAs,
NextProtos: cfg.NextProtos,
ServerName: cfg.ServerName,
ClientAuth: cfg.ClientAuth,
ClientCAs: cfg.ClientCAs,
InsecureSkipVerify: cfg.InsecureSkipVerify,
CipherSuites: cfg.CipherSuites,
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
SessionTicketsDisabled: cfg.SessionTicketsDisabled,
SessionTicketKey: cfg.SessionTicketKey,
ClientSessionCache: cfg.ClientSessionCache,
MinVersion: cfg.MinVersion,
MaxVersion: cfg.MaxVersion,
CurvePreferences: cfg.CurvePreferences,
DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled,
Renegotiation: cfg.Renegotiation,
}
}

View file

@ -0,0 +1,74 @@
// +build !go1.7
/*
*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package credentials
import (
"crypto/tls"
)
// cloneTLSConfig returns a shallow clone of the exported
// fields of cfg, ignoring the unexported sync.Once, which
// contains a mutex and must not be copied.
//
// If cfg is nil, a new zero tls.Config is returned.
//
// TODO replace this function with official clone function.
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return &tls.Config{
Rand: cfg.Rand,
Time: cfg.Time,
Certificates: cfg.Certificates,
NameToCertificate: cfg.NameToCertificate,
GetCertificate: cfg.GetCertificate,
RootCAs: cfg.RootCAs,
NextProtos: cfg.NextProtos,
ServerName: cfg.ServerName,
ClientAuth: cfg.ClientAuth,
ClientCAs: cfg.ClientCAs,
InsecureSkipVerify: cfg.InsecureSkipVerify,
CipherSuites: cfg.CipherSuites,
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
SessionTicketsDisabled: cfg.SessionTicketsDisabled,
SessionTicketKey: cfg.SessionTicketKey,
ClientSessionCache: cfg.ClientSessionCache,
MinVersion: cfg.MinVersion,
MaxVersion: cfg.MaxVersion,
CurvePreferences: cfg.CurvePreferences,
}
}

View file

@ -60,15 +60,21 @@ func encodeKeyValue(k, v string) (string, string) {
// DecodeKeyValue returns the original key and value corresponding to the
// encoded data in k, v.
// If k is a binary header and v contains comma, v is split on comma before decoded,
// and the decoded v will be joined with comma before returned.
func DecodeKeyValue(k, v string) (string, string, error) {
if !strings.HasSuffix(k, binHdrSuffix) {
return k, v, nil
}
val, err := base64.StdEncoding.DecodeString(v)
if err != nil {
return "", "", err
vvs := strings.Split(v, ",")
for i, vv := range vvs {
val, err := base64.StdEncoding.DecodeString(vv)
if err != nil {
return "", "", err
}
vvs[i] = string(val)
}
return k, string(val), nil
return k, strings.Join(vvs, ","), nil
}
// MD is a mapping from metadata keys to values. Users should use the following

View file

@ -227,7 +227,7 @@ type parser struct {
// No other error values or types must be returned, which also means
// that the underlying io.Reader must not return an incompatible
// error.
func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) {
func (p *parser) recvMsg(maxMsgSize int) (pf payloadFormat, msg []byte, err error) {
if _, err := io.ReadFull(p.r, p.header[:]); err != nil {
return 0, nil, err
}
@ -238,6 +238,9 @@ func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) {
if length == 0 {
return pf, nil, nil
}
if length > uint32(maxMsgSize) {
return 0, nil, Errorf(codes.Internal, "grpc: received message length %d exceeding the max size %d", length, maxMsgSize)
}
// TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead
// of making it for each message:
msg = make([]byte, int(length))
@ -308,8 +311,8 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er
return nil
}
func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}) error {
pf, d, err := p.recvMsg()
func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxMsgSize int) error {
pf, d, err := p.recvMsg(maxMsgSize)
if err != nil {
return err
}
@ -319,11 +322,16 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{
if pf == compressionMade {
d, err = dc.Do(bytes.NewReader(d))
if err != nil {
return transport.StreamErrorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
}
}
if len(d) > maxMsgSize {
// TODO: Revisit the error code. Currently keep it consistent with java
// implementation.
return Errorf(codes.Internal, "grpc: received a message of %d bytes exceeding %d limit", len(d), maxMsgSize)
}
if err := c.Unmarshal(d, m); err != nil {
return transport.StreamErrorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err)
return Errorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err)
}
return nil
}

View file

@ -89,9 +89,13 @@ type service struct {
type Server struct {
opts options
mu sync.Mutex // guards following
lis map[net.Listener]bool
conns map[io.Closer]bool
mu sync.Mutex // guards following
lis map[net.Listener]bool
conns map[io.Closer]bool
drain bool
// A CondVar to let GracefulStop() blocks until all the pending RPCs are finished
// and all the transport goes away.
cv *sync.Cond
m map[string]*service // service name -> service info
events trace.EventLog
}
@ -101,12 +105,15 @@ type options struct {
codec Codec
cp Compressor
dc Decompressor
maxMsgSize int
unaryInt UnaryServerInterceptor
streamInt StreamServerInterceptor
maxConcurrentStreams uint32
useHandlerImpl bool // use http.Handler-based server
}
var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size limit
// A ServerOption sets options.
type ServerOption func(*options)
@ -117,20 +124,28 @@ func CustomCodec(codec Codec) ServerOption {
}
}
// RPCCompressor returns a ServerOption that sets a compressor for outbound message.
// RPCCompressor returns a ServerOption that sets a compressor for outbound messages.
func RPCCompressor(cp Compressor) ServerOption {
return func(o *options) {
o.cp = cp
}
}
// RPCDecompressor returns a ServerOption that sets a decompressor for inbound message.
// RPCDecompressor returns a ServerOption that sets a decompressor for inbound messages.
func RPCDecompressor(dc Decompressor) ServerOption {
return func(o *options) {
o.dc = dc
}
}
// MaxMsgSize returns a ServerOption to set the max message size in bytes for inbound mesages.
// If this is not set, gRPC uses the default 4MB.
func MaxMsgSize(m int) ServerOption {
return func(o *options) {
o.maxMsgSize = m
}
}
// MaxConcurrentStreams returns a ServerOption that will apply a limit on the number
// of concurrent streams to each ServerTransport.
func MaxConcurrentStreams(n uint32) ServerOption {
@ -173,6 +188,7 @@ func StreamInterceptor(i StreamServerInterceptor) ServerOption {
// started to accept requests yet.
func NewServer(opt ...ServerOption) *Server {
var opts options
opts.maxMsgSize = defaultMaxMsgSize
for _, o := range opt {
o(&opts)
}
@ -186,6 +202,7 @@ func NewServer(opt ...ServerOption) *Server {
conns: make(map[io.Closer]bool),
m: make(map[string]*service),
}
s.cv = sync.NewCond(&s.mu)
if EnableTracing {
_, file, line, _ := runtime.Caller(1)
s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line))
@ -264,8 +281,8 @@ type ServiceInfo struct {
// GetServiceInfo returns a map from service names to ServiceInfo.
// Service names include the package names, in the form of <package>.<service>.
func (s *Server) GetServiceInfo() map[string]*ServiceInfo {
ret := make(map[string]*ServiceInfo)
func (s *Server) GetServiceInfo() map[string]ServiceInfo {
ret := make(map[string]ServiceInfo)
for n, srv := range s.m {
methods := make([]MethodInfo, 0, len(srv.md)+len(srv.sd))
for m := range srv.md {
@ -283,7 +300,7 @@ func (s *Server) GetServiceInfo() map[string]*ServiceInfo {
})
}
ret[n] = &ServiceInfo{
ret[n] = ServiceInfo{
Methods: methods,
Metadata: srv.mdata,
}
@ -468,7 +485,7 @@ func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Strea
func (s *Server) addConn(c io.Closer) bool {
s.mu.Lock()
defer s.mu.Unlock()
if s.conns == nil {
if s.conns == nil || s.drain {
return false
}
s.conns[c] = true
@ -480,6 +497,7 @@ func (s *Server) removeConn(c io.Closer) {
defer s.mu.Unlock()
if s.conns != nil {
delete(s.conns, c)
s.cv.Signal()
}
}
@ -520,7 +538,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
}
p := &parser{r: stream}
for {
pf, req, err := p.recvMsg()
pf, req, err := p.recvMsg(s.opts.maxMsgSize)
if err == io.EOF {
// The entire stream is done (for unary RPC only).
return err
@ -530,6 +548,10 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
}
if err != nil {
switch err := err.(type) {
case *rpcError:
if err := t.WriteStatus(stream, err.code, err.desc); err != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
}
case transport.ConnectionError:
// Nothing to do here.
case transport.StreamError:
@ -569,6 +591,12 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
return err
}
}
if len(req) > s.opts.maxMsgSize {
// TODO: Revisit the error code. Currently keep it consistent with
// java implementation.
statusCode = codes.Internal
statusDesc = fmt.Sprintf("grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxMsgSize)
}
if err := s.opts.codec.Unmarshal(req, v); err != nil {
return err
}
@ -628,13 +656,14 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
stream.SetSendCompress(s.opts.cp.Type())
}
ss := &serverStream{
t: t,
s: stream,
p: &parser{r: stream},
codec: s.opts.codec,
cp: s.opts.cp,
dc: s.opts.dc,
trInfo: trInfo,
t: t,
s: stream,
p: &parser{r: stream},
codec: s.opts.codec,
cp: s.opts.cp,
dc: s.opts.dc,
maxMsgSize: s.opts.maxMsgSize,
trInfo: trInfo,
}
if ss.cp != nil {
ss.cbuf = new(bytes.Buffer)
@ -766,14 +795,16 @@ func (s *Server) Stop() {
s.mu.Lock()
listeners := s.lis
s.lis = nil
cs := s.conns
st := s.conns
s.conns = nil
// interrupt GracefulStop if Stop and GracefulStop are called concurrently.
s.cv.Signal()
s.mu.Unlock()
for lis := range listeners {
lis.Close()
}
for c := range cs {
for c := range st {
c.Close()
}
@ -785,6 +816,32 @@ func (s *Server) Stop() {
s.mu.Unlock()
}
// GracefulStop stops the gRPC server gracefully. It stops the server to accept new
// connections and RPCs and blocks until all the pending RPCs are finished.
func (s *Server) GracefulStop() {
s.mu.Lock()
defer s.mu.Unlock()
if s.drain == true || s.conns == nil {
return
}
s.drain = true
for lis := range s.lis {
lis.Close()
}
s.lis = nil
for c := range s.conns {
c.(transport.ServerTransport).Drain()
}
for len(s.conns) != 0 {
s.cv.Wait()
}
s.conns = nil
if s.events != nil {
s.events.Finish()
s.events = nil
}
}
func init() {
internal.TestingCloseConns = func(arg interface{}) {
arg.(*Server).testingCloseConns()

View file

@ -37,6 +37,7 @@ import (
"bytes"
"errors"
"io"
"math"
"sync"
"time"
@ -84,12 +85,9 @@ type ClientStream interface {
// Header returns the header metadata received from the server if there
// is any. It blocks if the metadata is not ready to read.
Header() (metadata.MD, error)
// Trailer returns the trailer metadata from the server. It must be called
// after stream.Recv() returns non-nil error (including io.EOF) for
// bi-directional streaming and server streaming or stream.CloseAndRecv()
// returns for client streaming in order to receive trailer metadata if
// present. Otherwise, it could returns an empty MD even though trailer
// is present.
// Trailer returns the trailer metadata from the server, if there is any.
// It must only be called after stream.CloseAndRecv has returned, or
// stream.Recv has returned a non-nil error (including io.EOF).
Trailer() metadata.MD
// CloseSend closes the send direction of the stream. It closes the stream
// when non-nil error is met.
@ -99,11 +97,10 @@ type ClientStream interface {
// NewClientStream creates a new Stream for the client side. This is called
// by generated code.
func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) {
func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
var (
t transport.ClientTransport
s *transport.Stream
err error
put func()
)
c := defaultCallInfo
@ -120,27 +117,24 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
if cc.dopts.cp != nil {
callHdr.SendCompress = cc.dopts.cp.Type()
}
cs := &clientStream{
opts: opts,
c: c,
desc: desc,
codec: cc.dopts.codec,
cp: cc.dopts.cp,
dc: cc.dopts.dc,
tracing: EnableTracing,
}
if cc.dopts.cp != nil {
callHdr.SendCompress = cc.dopts.cp.Type()
cs.cbuf = new(bytes.Buffer)
}
if cs.tracing {
cs.trInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method)
cs.trInfo.firstLine.client = true
var trInfo traceInfo
if EnableTracing {
trInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method)
trInfo.firstLine.client = true
if deadline, ok := ctx.Deadline(); ok {
cs.trInfo.firstLine.deadline = deadline.Sub(time.Now())
trInfo.firstLine.deadline = deadline.Sub(time.Now())
}
cs.trInfo.tr.LazyLog(&cs.trInfo.firstLine, false)
ctx = trace.NewContext(ctx, cs.trInfo.tr)
trInfo.tr.LazyLog(&trInfo.firstLine, false)
ctx = trace.NewContext(ctx, trInfo.tr)
defer func() {
if err != nil {
// Need to call tr.finish() if error is returned.
// Because tr will not be returned to caller.
trInfo.tr.LazyPrintf("RPC: [%v]", err)
trInfo.tr.SetError()
trInfo.tr.Finish()
}
}()
}
gopts := BalancerGetOptions{
BlockingWait: !c.failFast,
@ -152,9 +146,9 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
if _, ok := err.(*rpcError); ok {
return nil, err
}
if err == errConnClosing {
if err == errConnClosing || err == errConnUnavailable {
if c.failFast {
return nil, Errorf(codes.Unavailable, "%v", errConnClosing)
return nil, Errorf(codes.Unavailable, "%v", err)
}
continue
}
@ -168,9 +162,8 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
put()
put = nil
}
if _, ok := err.(transport.ConnectionError); ok {
if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
if c.failFast {
cs.finish(err)
return nil, toRPCErr(err)
}
continue
@ -179,16 +172,43 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
}
break
}
cs.put = put
cs.t = t
cs.s = s
cs.p = &parser{r: s}
// Listen on ctx.Done() to detect cancellation when there is no pending
// I/O operations on this stream.
cs := &clientStream{
opts: opts,
c: c,
desc: desc,
codec: cc.dopts.codec,
cp: cc.dopts.cp,
dc: cc.dopts.dc,
put: put,
t: t,
s: s,
p: &parser{r: s},
tracing: EnableTracing,
trInfo: trInfo,
}
if cc.dopts.cp != nil {
cs.cbuf = new(bytes.Buffer)
}
// Listen on ctx.Done() to detect cancellation and s.Done() to detect normal termination
// when there is no pending I/O operations on this stream.
go func() {
select {
case <-t.Error():
// Incur transport error, simply exit.
case <-s.Done():
// TODO: The trace of the RPC is terminated here when there is no pending
// I/O, which is probably not the optimal solution.
if s.StatusCode() == codes.OK {
cs.finish(nil)
} else {
cs.finish(Errorf(s.StatusCode(), "%s", s.StatusDesc()))
}
cs.closeTransportStream(nil)
case <-s.GoAway():
cs.finish(errConnDrain)
cs.closeTransportStream(errConnDrain)
case <-s.Context().Done():
err := s.Context().Err()
cs.finish(err)
@ -251,7 +271,17 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
if err != nil {
cs.finish(err)
}
if err == nil || err == io.EOF {
if err == nil {
return
}
if err == io.EOF {
// Specialize the process for server streaming. SendMesg is only called
// once when creating the stream object. io.EOF needs to be skipped when
// the rpc is early finished (before the stream object is created.).
// TODO: It is probably better to move this into the generated code.
if !cs.desc.ClientStreams && cs.desc.ServerStreams {
err = nil
}
return
}
if _, ok := err.(transport.ConnectionError); !ok {
@ -272,7 +302,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
}
func (cs *clientStream) RecvMsg(m interface{}) (err error) {
err = recv(cs.p, cs.codec, cs.s, cs.dc, m)
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32)
defer func() {
// err != nil indicates the termination of the stream.
if err != nil {
@ -291,7 +321,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
return
}
// Special handling for client streaming rpc.
err = recv(cs.p, cs.codec, cs.s, cs.dc, m)
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32)
cs.closeTransportStream(err)
if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
@ -326,7 +356,7 @@ func (cs *clientStream) CloseSend() (err error) {
}
}()
if err == nil || err == io.EOF {
return
return nil
}
if _, ok := err.(transport.ConnectionError); !ok {
cs.closeTransportStream(err)
@ -392,6 +422,7 @@ type serverStream struct {
cp Compressor
dc Decompressor
cbuf *bytes.Buffer
maxMsgSize int
statusCode codes.Code
statusDesc string
trInfo *traceInfo
@ -458,5 +489,5 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
ss.mu.Unlock()
}
}()
return recv(ss.p, ss.codec, ss.s, ss.dc, m)
return recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize)
}

View file

@ -72,6 +72,11 @@ type resetStream struct {
func (*resetStream) item() {}
type goAway struct {
}
func (*goAway) item() {}
type flushIO struct {
}

46
vendor/google.golang.org/grpc/transport/go16.go generated vendored Normal file
View file

@ -0,0 +1,46 @@
// +build go1.6,!go1.7
/*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package transport
import (
"net"
"golang.org/x/net/context"
)
// dialContext connects to the address on the named network.
func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
return (&net.Dialer{Cancel: ctx.Done()}).Dial(network, address)
}

46
vendor/google.golang.org/grpc/transport/go17.go generated vendored Normal file
View file

@ -0,0 +1,46 @@
// +build go1.7
/*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package transport
import (
"net"
"golang.org/x/net/context"
)
// dialContext connects to the address on the named network.
func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
return (&net.Dialer{}).DialContext(ctx, network, address)
}

View file

@ -83,7 +83,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr
}
if v := r.Header.Get("grpc-timeout"); v != "" {
to, err := timeoutDecode(v)
to, err := decodeTimeout(v)
if err != nil {
return nil, StreamErrorf(codes.Internal, "malformed time-out: %v", err)
}
@ -194,7 +194,7 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, statusCode codes.Code,
h := ht.rw.Header()
h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode))
if statusDesc != "" {
h.Set("Grpc-Message", statusDesc)
h.Set("Grpc-Message", encodeGrpcMessage(statusDesc))
}
if md := s.Trailer(); len(md) > 0 {
for k, vv := range md {
@ -370,6 +370,10 @@ func (ht *serverHandlerTransport) runStream() {
}
}
func (ht *serverHandlerTransport) Drain() {
panic("Drain() is not implemented")
}
// mapRecvMsgError returns the non-nil err into the appropriate
// error value as expected by callers of *grpc.parser.recvMsg.
// In particular, in can only be:

View file

@ -35,6 +35,7 @@ package transport
import (
"bytes"
"fmt"
"io"
"math"
"net"
@ -71,6 +72,9 @@ type http2Client struct {
shutdownChan chan struct{}
// errorChan is closed to notify the I/O error to the caller.
errorChan chan struct{}
// goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor)
// that the server sent GoAway on this transport.
goAway chan struct{}
framer *framer
hBuf *bytes.Buffer // the buffer for HPACK encoding
@ -97,41 +101,44 @@ type http2Client struct {
maxStreams int
// the per-stream outbound flow control window size set by the peer.
streamSendQuota uint32
// goAwayID records the Last-Stream-ID in the GoAway frame from the server.
goAwayID uint32
// prevGoAway ID records the Last-Stream-ID in the previous GOAway frame.
prevGoAwayID uint32
}
func dial(fn func(context.Context, string) (net.Conn, error), ctx context.Context, addr string) (net.Conn, error) {
if fn != nil {
return fn(ctx, addr)
}
return dialContext(ctx, "tcp", addr)
}
// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
// and starts to receive messages on it. Non-nil error returns if construction
// fails.
func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err error) {
if opts.Dialer == nil {
// Set the default Dialer.
opts.Dialer = func(addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout("tcp", addr, timeout)
}
}
func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ ClientTransport, err error) {
scheme := "http"
startT := time.Now()
timeout := opts.Timeout
conn, connErr := opts.Dialer(addr, timeout)
conn, connErr := dial(opts.Dialer, ctx, addr)
if connErr != nil {
return nil, ConnectionErrorf("transport: %v", connErr)
return nil, ConnectionErrorf(true, connErr, "transport: %v", connErr)
}
var authInfo credentials.AuthInfo
if opts.TransportCredentials != nil {
scheme = "https"
if timeout > 0 {
timeout -= time.Since(startT)
}
conn, authInfo, connErr = opts.TransportCredentials.ClientHandshake(addr, conn, timeout)
}
if connErr != nil {
return nil, ConnectionErrorf("transport: %v", connErr)
}
defer func() {
// Any further errors will close the underlying connection
defer func(conn net.Conn) {
if err != nil {
conn.Close()
}
}()
}(conn)
var authInfo credentials.AuthInfo
if creds := opts.TransportCredentials; creds != nil {
scheme = "https"
conn, authInfo, connErr = creds.ClientHandshake(ctx, addr, conn)
}
if connErr != nil {
// Credentials handshake error is not a temporary error (unless the error
// was the connection closing).
return nil, ConnectionErrorf(connErr == io.EOF, connErr, "transport: %v", connErr)
}
ua := primaryUA
if opts.UserAgent != "" {
ua = opts.UserAgent + " " + ua
@ -147,6 +154,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
writableChan: make(chan int, 1),
shutdownChan: make(chan struct{}),
errorChan: make(chan struct{}),
goAway: make(chan struct{}),
framer: newFramer(conn),
hBuf: &buf,
hEnc: hpack.NewEncoder(&buf),
@ -168,11 +176,11 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
n, err := t.conn.Write(clientPreface)
if err != nil {
t.Close()
return nil, ConnectionErrorf("transport: %v", err)
return nil, ConnectionErrorf(true, err, "transport: %v", err)
}
if n != len(clientPreface) {
t.Close()
return nil, ConnectionErrorf("transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface))
return nil, ConnectionErrorf(true, err, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface))
}
if initialWindowSize != defaultWindowSize {
err = t.framer.writeSettings(true, http2.Setting{
@ -184,13 +192,13 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
}
if err != nil {
t.Close()
return nil, ConnectionErrorf("transport: %v", err)
return nil, ConnectionErrorf(true, err, "transport: %v", err)
}
// Adjust the connection flow control window if needed.
if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 {
if err := t.framer.writeWindowUpdate(true, 0, delta); err != nil {
t.Close()
return nil, ConnectionErrorf("transport: %v", err)
return nil, ConnectionErrorf(true, err, "transport: %v", err)
}
}
go t.controller()
@ -202,6 +210,8 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
// TODO(zhaoq): Handle uint32 overflow of Stream.id.
s := &Stream{
id: t.nextID,
done: make(chan struct{}),
goAway: make(chan struct{}),
method: callHdr.Method,
sendCompress: callHdr.SendCompress,
buf: newRecvBuffer(),
@ -216,8 +226,9 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
// Make a stream be able to cancel the pending operations by itself.
s.ctx, s.cancel = context.WithCancel(ctx)
s.dec = &recvBufferReader{
ctx: s.ctx,
recv: s.buf,
ctx: s.ctx,
goAway: s.goAway,
recv: s.buf,
}
return s
}
@ -271,6 +282,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
t.mu.Unlock()
return nil, ErrConnClosing
}
if t.state == draining {
t.mu.Unlock()
return nil, ErrStreamDrain
}
if t.state != reachable {
t.mu.Unlock()
return nil, ErrConnClosing
@ -278,7 +293,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
checkStreamsQuota := t.streamsQuota != nil
t.mu.Unlock()
if checkStreamsQuota {
sq, err := wait(ctx, t.shutdownChan, t.streamsQuota.acquire())
sq, err := wait(ctx, nil, nil, t.shutdownChan, t.streamsQuota.acquire())
if err != nil {
return nil, err
}
@ -287,7 +302,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
t.streamsQuota.add(sq - 1)
}
}
if _, err := wait(ctx, t.shutdownChan, t.writableChan); err != nil {
if _, err := wait(ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
// Return the quota back now because there is no stream returned to the caller.
if _, ok := err.(StreamError); ok && checkStreamsQuota {
t.streamsQuota.add(1)
@ -295,6 +310,15 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
return nil, err
}
t.mu.Lock()
if t.state == draining {
t.mu.Unlock()
if checkStreamsQuota {
t.streamsQuota.add(1)
}
// Need to make t writable again so that the rpc in flight can still proceed.
t.writableChan <- 0
return nil, ErrStreamDrain
}
if t.state != reachable {
t.mu.Unlock()
return nil, ErrConnClosing
@ -329,7 +353,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress})
}
if timeout > 0 {
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: timeoutEncode(timeout)})
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)})
}
for k, v := range authData {
// Capital header names are illegal in HTTP/2.
@ -384,7 +408,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
}
if err != nil {
t.notifyError(err)
return nil, ConnectionErrorf("transport: %v", err)
return nil, ConnectionErrorf(true, err, "transport: %v", err)
}
}
t.writableChan <- 0
@ -403,22 +427,17 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
if t.streamsQuota != nil {
updateStreams = true
}
if t.state == draining && len(t.activeStreams) == 1 {
delete(t.activeStreams, s.id)
if t.state == draining && len(t.activeStreams) == 0 {
// The transport is draining and s is the last live stream on t.
t.mu.Unlock()
t.Close()
return
}
delete(t.activeStreams, s.id)
t.mu.Unlock()
if updateStreams {
t.streamsQuota.add(1)
}
// In case stream sending and receiving are invoked in separate
// goroutines (e.g., bi-directional streaming), the caller needs
// to call cancel on the stream to interrupt the blocking on
// other goroutines.
s.cancel()
s.mu.Lock()
if q := s.fc.resetPendingData(); q > 0 {
if n := t.fc.onRead(q); n > 0 {
@ -445,13 +464,13 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
// accessed any more.
func (t *http2Client) Close() (err error) {
t.mu.Lock()
if t.state == reachable {
close(t.errorChan)
}
if t.state == closing {
t.mu.Unlock()
return
}
if t.state == reachable || t.state == draining {
close(t.errorChan)
}
t.state = closing
t.mu.Unlock()
close(t.shutdownChan)
@ -475,10 +494,35 @@ func (t *http2Client) Close() (err error) {
func (t *http2Client) GracefulClose() error {
t.mu.Lock()
if t.state == closing {
switch t.state {
case unreachable:
// The server may close the connection concurrently. t is not available for
// any streams. Close it now.
t.mu.Unlock()
t.Close()
return nil
case closing:
t.mu.Unlock()
return nil
}
// Notify the streams which were initiated after the server sent GOAWAY.
select {
case <-t.goAway:
n := t.prevGoAwayID
if n == 0 && t.nextID > 1 {
n = t.nextID - 2
}
m := t.goAwayID + 2
if m == 2 {
m = 1
}
for i := m; i <= n; i += 2 {
if s, ok := t.activeStreams[i]; ok {
close(s.goAway)
}
}
default:
}
if t.state == draining {
t.mu.Unlock()
return nil
@ -504,15 +548,15 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
size := http2MaxFrameLen
s.sendQuotaPool.add(0)
// Wait until the stream has some quota to send the data.
sq, err := wait(s.ctx, t.shutdownChan, s.sendQuotaPool.acquire())
sq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, s.sendQuotaPool.acquire())
if err != nil {
return err
}
t.sendQuotaPool.add(0)
// Wait until the transport has some quota to send the data.
tq, err := wait(s.ctx, t.shutdownChan, t.sendQuotaPool.acquire())
tq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.sendQuotaPool.acquire())
if err != nil {
if _, ok := err.(StreamError); ok {
if _, ok := err.(StreamError); ok || err == io.EOF {
t.sendQuotaPool.cancel()
}
return err
@ -544,8 +588,8 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
// Indicate there is a writer who is about to write a data frame.
t.framer.adjustNumWriters(1)
// Got some quota. Try to acquire writing privilege on the transport.
if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil {
if _, ok := err.(StreamError); ok {
if _, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.writableChan); err != nil {
if _, ok := err.(StreamError); ok || err == io.EOF {
// Return the connection quota back.
t.sendQuotaPool.add(len(p))
}
@ -578,7 +622,7 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
// invoked.
if err := t.framer.writeData(forceFlush, s.id, endStream, p); err != nil {
t.notifyError(err)
return ConnectionErrorf("transport: %v", err)
return ConnectionErrorf(true, err, "transport: %v", err)
}
if t.framer.adjustNumWriters(-1) == 0 {
t.framer.flushWrite()
@ -593,11 +637,7 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
}
s.mu.Lock()
if s.state != streamDone {
if s.state == streamReadDone {
s.state = streamDone
} else {
s.state = streamWriteDone
}
s.state = streamWriteDone
}
s.mu.Unlock()
return nil
@ -630,7 +670,7 @@ func (t *http2Client) updateWindow(s *Stream, n uint32) {
func (t *http2Client) handleData(f *http2.DataFrame) {
size := len(f.Data())
if err := t.fc.onData(uint32(size)); err != nil {
t.notifyError(ConnectionErrorf("%v", err))
t.notifyError(ConnectionErrorf(true, err, "%v", err))
return
}
// Select the right stream to dispatch.
@ -655,6 +695,7 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
s.state = streamDone
s.statusCode = codes.Internal
s.statusDesc = err.Error()
close(s.done)
s.mu.Unlock()
s.write(recvMsg{err: io.EOF})
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl})
@ -672,13 +713,14 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
// the read direction is closed, and set the status appropriately.
if f.FrameHeader.Flags.Has(http2.FlagDataEndStream) {
s.mu.Lock()
if s.state == streamWriteDone {
s.state = streamDone
} else {
s.state = streamReadDone
if s.state == streamDone {
s.mu.Unlock()
return
}
s.state = streamDone
s.statusCode = codes.Internal
s.statusDesc = "server closed the stream without sending trailers"
close(s.done)
s.mu.Unlock()
s.write(recvMsg{err: io.EOF})
}
@ -704,6 +746,8 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) {
grpclog.Println("transport: http2Client.handleRSTStream found no mapped gRPC status for the received http2 error ", f.ErrCode)
s.statusCode = codes.Unknown
}
s.statusDesc = fmt.Sprintf("stream terminated by RST_STREAM with error code: %d", f.ErrCode)
close(s.done)
s.mu.Unlock()
s.write(recvMsg{err: io.EOF})
}
@ -728,7 +772,32 @@ func (t *http2Client) handlePing(f *http2.PingFrame) {
}
func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
// TODO(zhaoq): GoAwayFrame handler to be implemented
t.mu.Lock()
if t.state == reachable || t.state == draining {
if f.LastStreamID > 0 && f.LastStreamID%2 != 1 {
t.mu.Unlock()
t.notifyError(ConnectionErrorf(true, nil, "received illegal http2 GOAWAY frame: stream ID %d is even", f.LastStreamID))
return
}
select {
case <-t.goAway:
id := t.goAwayID
// t.goAway has been closed (i.e.,multiple GoAways).
if id < f.LastStreamID {
t.mu.Unlock()
t.notifyError(ConnectionErrorf(true, nil, "received illegal http2 GOAWAY frame: previously recv GOAWAY frame with LastStramID %d, currently recv %d", id, f.LastStreamID))
return
}
t.prevGoAwayID = id
t.goAwayID = f.LastStreamID
t.mu.Unlock()
return
default:
}
t.goAwayID = f.LastStreamID
close(t.goAway)
}
t.mu.Unlock()
}
func (t *http2Client) handleWindowUpdate(f *http2.WindowUpdateFrame) {
@ -780,11 +849,11 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
if len(state.mdata) > 0 {
s.trailer = state.mdata
}
s.state = streamDone
s.statusCode = state.statusCode
s.statusDesc = state.statusDesc
close(s.done)
s.state = streamDone
s.mu.Unlock()
s.write(recvMsg{err: io.EOF})
}
@ -937,13 +1006,22 @@ func (t *http2Client) Error() <-chan struct{} {
return t.errorChan
}
func (t *http2Client) GoAway() <-chan struct{} {
return t.goAway
}
func (t *http2Client) notifyError(err error) {
t.mu.Lock()
defer t.mu.Unlock()
// make sure t.errorChan is closed only once.
if t.state == draining {
t.mu.Unlock()
t.Close()
return
}
if t.state == reachable {
t.state = unreachable
close(t.errorChan)
grpclog.Printf("transport: http2Client.notifyError got notified that the client transport was broken %v.", err)
}
t.mu.Unlock()
}

View file

@ -111,12 +111,12 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthI
Val: uint32(initialWindowSize)})
}
if err := framer.writeSettings(true, settings...); err != nil {
return nil, ConnectionErrorf("transport: %v", err)
return nil, ConnectionErrorf(true, err, "transport: %v", err)
}
// Adjust the connection flow control window if needed.
if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 {
if err := framer.writeWindowUpdate(true, 0, delta); err != nil {
return nil, ConnectionErrorf("transport: %v", err)
return nil, ConnectionErrorf(true, err, "transport: %v", err)
}
}
var buf bytes.Buffer
@ -142,7 +142,7 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthI
}
// operateHeader takes action on the decoded headers.
func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) {
func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) (close bool) {
buf := newRecvBuffer()
s := &Stream{
id: frame.Header().StreamID,
@ -205,6 +205,13 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream})
return
}
if s.id%2 != 1 || s.id <= t.maxStreamID {
t.mu.Unlock()
// illegal gRPC stream id.
grpclog.Println("transport: http2Server.HandleStreams received an illegal stream id: ", s.id)
return true
}
t.maxStreamID = s.id
s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota))
t.activeStreams[s.id] = s
t.mu.Unlock()
@ -212,6 +219,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
t.updateWindow(s, uint32(n))
}
handle(s)
return
}
// HandleStreams receives incoming streams using the given handler. This is
@ -231,6 +239,10 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
}
frame, err := t.framer.readFrame()
if err == io.EOF || err == io.ErrUnexpectedEOF {
t.Close()
return
}
if err != nil {
grpclog.Printf("transport: http2Server.HandleStreams failed to read frame: %v", err)
t.Close()
@ -257,20 +269,20 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
t.controlBuf.put(&resetStream{se.StreamID, se.Code})
continue
}
if err == io.EOF || err == io.ErrUnexpectedEOF {
t.Close()
return
}
grpclog.Printf("transport: http2Server.HandleStreams failed to read frame: %v", err)
t.Close()
return
}
switch frame := frame.(type) {
case *http2.MetaHeadersFrame:
id := frame.Header().StreamID
if id%2 != 1 || id <= t.maxStreamID {
// illegal gRPC stream id.
grpclog.Println("transport: http2Server.HandleStreams received an illegal stream id: ", id)
if t.operateHeaders(frame, handle) {
t.Close()
break
}
t.maxStreamID = id
t.operateHeaders(frame, handle)
case *http2.DataFrame:
t.handleData(frame)
case *http2.RSTStreamFrame:
@ -282,7 +294,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
case *http2.WindowUpdateFrame:
t.handleWindowUpdate(frame)
case *http2.GoAwayFrame:
break
// TODO: Handle GoAway from the client appropriately.
default:
grpclog.Printf("transport: http2Server.HandleStreams found unhandled frame type %v.", frame)
}
@ -364,11 +376,7 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
// Received the end of stream from the client.
s.mu.Lock()
if s.state != streamDone {
if s.state == streamWriteDone {
s.state = streamDone
} else {
s.state = streamReadDone
}
s.state = streamReadDone
}
s.mu.Unlock()
s.write(recvMsg{err: io.EOF})
@ -440,7 +448,7 @@ func (t *http2Server) writeHeaders(s *Stream, b *bytes.Buffer, endStream bool) e
}
if err != nil {
t.Close()
return ConnectionErrorf("transport: %v", err)
return ConnectionErrorf(true, err, "transport: %v", err)
}
}
return nil
@ -455,7 +463,7 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
}
s.headerOk = true
s.mu.Unlock()
if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil {
if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
return err
}
t.hBuf.Reset()
@ -495,7 +503,7 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s
headersSent = true
}
s.mu.Unlock()
if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil {
if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
return err
}
t.hBuf.Reset()
@ -508,7 +516,7 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s
Name: "grpc-status",
Value: strconv.Itoa(int(statusCode)),
})
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: statusDesc})
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(statusDesc)})
// Attach the trailer metadata.
for k, v := range s.trailer {
// Clients don't tolerate reading restricted headers after some non restricted ones were sent.
@ -544,7 +552,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
}
s.mu.Unlock()
if writeHeaderFrame {
if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil {
if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
return err
}
t.hBuf.Reset()
@ -560,7 +568,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
}
if err := t.framer.writeHeaders(false, p); err != nil {
t.Close()
return ConnectionErrorf("transport: %v", err)
return ConnectionErrorf(true, err, "transport: %v", err)
}
t.writableChan <- 0
}
@ -572,13 +580,13 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
size := http2MaxFrameLen
s.sendQuotaPool.add(0)
// Wait until the stream has some quota to send the data.
sq, err := wait(s.ctx, t.shutdownChan, s.sendQuotaPool.acquire())
sq, err := wait(s.ctx, nil, nil, t.shutdownChan, s.sendQuotaPool.acquire())
if err != nil {
return err
}
t.sendQuotaPool.add(0)
// Wait until the transport has some quota to send the data.
tq, err := wait(s.ctx, t.shutdownChan, t.sendQuotaPool.acquire())
tq, err := wait(s.ctx, nil, nil, t.shutdownChan, t.sendQuotaPool.acquire())
if err != nil {
if _, ok := err.(StreamError); ok {
t.sendQuotaPool.cancel()
@ -604,7 +612,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
t.framer.adjustNumWriters(1)
// Got some quota. Try to acquire writing privilege on the
// transport.
if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil {
if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
if _, ok := err.(StreamError); ok {
// Return the connection quota back.
t.sendQuotaPool.add(ps)
@ -634,7 +642,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
}
if err := t.framer.writeData(forceFlush, s.id, false, p); err != nil {
t.Close()
return ConnectionErrorf("transport: %v", err)
return ConnectionErrorf(true, err, "transport: %v", err)
}
if t.framer.adjustNumWriters(-1) == 0 {
t.framer.flushWrite()
@ -679,6 +687,17 @@ func (t *http2Server) controller() {
}
case *resetStream:
t.framer.writeRSTStream(true, i.streamID, i.code)
case *goAway:
t.mu.Lock()
if t.state == closing {
t.mu.Unlock()
// The transport is closing.
return
}
sid := t.maxStreamID
t.state = draining
t.mu.Unlock()
t.framer.writeGoAway(true, sid, http2.ErrCodeNo, nil)
case *flushIO:
t.framer.flushWrite()
case *ping:
@ -724,6 +743,9 @@ func (t *http2Server) Close() (err error) {
func (t *http2Server) closeStream(s *Stream) {
t.mu.Lock()
delete(t.activeStreams, s.id)
if t.state == draining && len(t.activeStreams) == 0 {
defer t.Close()
}
t.mu.Unlock()
// In case stream sending and receiving are invoked in separate
// goroutines (e.g., bi-directional streaming), cancel needs to be
@ -746,3 +768,7 @@ func (t *http2Server) closeStream(s *Stream) {
func (t *http2Server) RemoteAddr() net.Addr {
return t.conn.RemoteAddr()
}
func (t *http2Server) Drain() {
t.controlBuf.put(&goAway{})
}

View file

@ -35,6 +35,7 @@ package transport
import (
"bufio"
"bytes"
"fmt"
"io"
"net"
@ -174,11 +175,11 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) {
}
d.statusCode = codes.Code(code)
case "grpc-message":
d.statusDesc = f.Value
d.statusDesc = decodeGrpcMessage(f.Value)
case "grpc-timeout":
d.timeoutSet = true
var err error
d.timeout, err = timeoutDecode(f.Value)
d.timeout, err = decodeTimeout(f.Value)
if err != nil {
d.setErr(StreamErrorf(codes.Internal, "transport: malformed time-out: %v", err))
return
@ -251,7 +252,7 @@ func div(d, r time.Duration) int64 {
}
// TODO(zhaoq): It is the simplistic and not bandwidth efficient. Improve it.
func timeoutEncode(t time.Duration) string {
func encodeTimeout(t time.Duration) string {
if d := div(t, time.Nanosecond); d <= maxTimeoutValue {
return strconv.FormatInt(d, 10) + "n"
}
@ -271,7 +272,7 @@ func timeoutEncode(t time.Duration) string {
return strconv.FormatInt(div(t, time.Hour), 10) + "H"
}
func timeoutDecode(s string) (time.Duration, error) {
func decodeTimeout(s string) (time.Duration, error) {
size := len(s)
if size < 2 {
return 0, fmt.Errorf("transport: timeout string is too short: %q", s)
@ -288,6 +289,80 @@ func timeoutDecode(s string) (time.Duration, error) {
return d * time.Duration(t), nil
}
const (
spaceByte = ' '
tildaByte = '~'
percentByte = '%'
)
// encodeGrpcMessage is used to encode status code in header field
// "grpc-message".
// It checks to see if each individual byte in msg is an
// allowable byte, and then either percent encoding or passing it through.
// When percent encoding, the byte is converted into hexadecimal notation
// with a '%' prepended.
func encodeGrpcMessage(msg string) string {
if msg == "" {
return ""
}
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
c := msg[i]
if !(c >= spaceByte && c < tildaByte && c != percentByte) {
return encodeGrpcMessageUnchecked(msg)
}
}
return msg
}
func encodeGrpcMessageUnchecked(msg string) string {
var buf bytes.Buffer
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
c := msg[i]
if c >= spaceByte && c < tildaByte && c != percentByte {
buf.WriteByte(c)
} else {
buf.WriteString(fmt.Sprintf("%%%02X", c))
}
}
return buf.String()
}
// decodeGrpcMessage decodes the msg encoded by encodeGrpcMessage.
func decodeGrpcMessage(msg string) string {
if msg == "" {
return ""
}
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
if msg[i] == percentByte && i+2 < lenMsg {
return decodeGrpcMessageUnchecked(msg)
}
}
return msg
}
func decodeGrpcMessageUnchecked(msg string) string {
var buf bytes.Buffer
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
c := msg[i]
if c == percentByte && i+2 < lenMsg {
parsed, err := strconv.ParseInt(msg[i+1:i+3], 16, 8)
if err != nil {
buf.WriteByte(c)
} else {
buf.WriteByte(byte(parsed))
i += 2
}
} else {
buf.WriteByte(c)
}
}
return buf.String()
}
type framer struct {
numWriters int32
reader io.Reader

51
vendor/google.golang.org/grpc/transport/pre_go16.go generated vendored Normal file
View file

@ -0,0 +1,51 @@
// +build !go1.6
/*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package transport
import (
"net"
"time"
"golang.org/x/net/context"
)
// dialContext connects to the address on the named network.
func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
var dialer net.Dialer
if deadline, ok := ctx.Deadline(); ok {
dialer.Timeout = deadline.Sub(time.Now())
}
return dialer.Dial(network, address)
}

View file

@ -44,7 +44,6 @@ import (
"io"
"net"
"sync"
"time"
"golang.org/x/net/context"
"golang.org/x/net/trace"
@ -120,10 +119,11 @@ func (b *recvBuffer) get() <-chan item {
// recvBufferReader implements io.Reader interface to read the data from
// recvBuffer.
type recvBufferReader struct {
ctx context.Context
recv *recvBuffer
last *bytes.Reader // Stores the remaining data in the previous calls.
err error
ctx context.Context
goAway chan struct{}
recv *recvBuffer
last *bytes.Reader // Stores the remaining data in the previous calls.
err error
}
// Read reads the next len(p) bytes from last. If last is drained, it tries to
@ -141,6 +141,8 @@ func (r *recvBufferReader) Read(p []byte) (n int, err error) {
select {
case <-r.ctx.Done():
return 0, ContextErr(r.ctx.Err())
case <-r.goAway:
return 0, ErrStreamDrain
case i := <-r.recv.get():
r.recv.load()
m := i.(*recvMsg)
@ -158,7 +160,7 @@ const (
streamActive streamState = iota
streamWriteDone // EndStream sent
streamReadDone // EndStream received
streamDone // sendDone and recvDone or RSTStreamFrame is sent or received.
streamDone // the entire stream is finished.
)
// Stream represents an RPC in the transport layer.
@ -169,6 +171,10 @@ type Stream struct {
// ctx is the associated context of the stream.
ctx context.Context
cancel context.CancelFunc
// done is closed when the final status arrives.
done chan struct{}
// goAway is closed when the server sent GoAways signal before this stream was initiated.
goAway chan struct{}
// method records the associated RPC method of the stream.
method string
recvCompress string
@ -214,6 +220,18 @@ func (s *Stream) SetSendCompress(str string) {
s.sendCompress = str
}
// Done returns a chanel which is closed when it receives the final status
// from the server.
func (s *Stream) Done() <-chan struct{} {
return s.done
}
// GoAway returns a channel which is closed when the server sent GoAways signal
// before this stream was initiated.
func (s *Stream) GoAway() <-chan struct{} {
return s.goAway
}
// Header acquires the key-value pairs of header metadata once it
// is available. It blocks until i) the metadata is ready or ii) there is no
// header metadata or iii) the stream is cancelled/expired.
@ -221,6 +239,8 @@ func (s *Stream) Header() (metadata.MD, error) {
select {
case <-s.ctx.Done():
return nil, ContextErr(s.ctx.Err())
case <-s.goAway:
return nil, ErrStreamDrain
case <-s.headerChan:
return s.header.Copy(), nil
}
@ -335,19 +355,17 @@ type ConnectOptions struct {
// UserAgent is the application user agent.
UserAgent string
// Dialer specifies how to dial a network address.
Dialer func(string, time.Duration) (net.Conn, error)
Dialer func(context.Context, string) (net.Conn, error)
// PerRPCCredentials stores the PerRPCCredentials required to issue RPCs.
PerRPCCredentials []credentials.PerRPCCredentials
// TransportCredentials stores the Authenticator required to setup a client connection.
TransportCredentials credentials.TransportCredentials
// Timeout specifies the timeout for dialing a ClientTransport.
Timeout time.Duration
}
// NewClientTransport establishes the transport with the required ConnectOptions
// and returns it to the caller.
func NewClientTransport(target string, opts *ConnectOptions) (ClientTransport, error) {
return newHTTP2Client(target, opts)
func NewClientTransport(ctx context.Context, target string, opts ConnectOptions) (ClientTransport, error) {
return newHTTP2Client(ctx, target, opts)
}
// Options provides additional hints and information for message
@ -417,6 +435,11 @@ type ClientTransport interface {
// and create a new one) in error case. It should not return nil
// once the transport is initiated.
Error() <-chan struct{}
// GoAway returns a channel that is closed when ClientTranspor
// receives the draining signal from the server (e.g., GOAWAY frame in
// HTTP/2).
GoAway() <-chan struct{}
}
// ServerTransport is the common interface for all gRPC server-side transport
@ -448,6 +471,9 @@ type ServerTransport interface {
// RemoteAddr returns the remote network address.
RemoteAddr() net.Addr
// Drain notifies the client this ServerTransport stops accepting new RPCs.
Drain()
}
// StreamErrorf creates an StreamError with the specified error code and description.
@ -459,9 +485,11 @@ func StreamErrorf(c codes.Code, format string, a ...interface{}) StreamError {
}
// ConnectionErrorf creates an ConnectionError with the specified error description.
func ConnectionErrorf(format string, a ...interface{}) ConnectionError {
func ConnectionErrorf(temp bool, e error, format string, a ...interface{}) ConnectionError {
return ConnectionError{
Desc: fmt.Sprintf(format, a...),
temp: temp,
err: e,
}
}
@ -469,14 +497,36 @@ func ConnectionErrorf(format string, a ...interface{}) ConnectionError {
// entire connection and the retry of all the active streams.
type ConnectionError struct {
Desc string
temp bool
err error
}
func (e ConnectionError) Error() string {
return fmt.Sprintf("connection error: desc = %q", e.Desc)
}
// ErrConnClosing indicates that the transport is closing.
var ErrConnClosing = ConnectionError{Desc: "transport is closing"}
// Temporary indicates if this connection error is temporary or fatal.
func (e ConnectionError) Temporary() bool {
return e.temp
}
// Origin returns the original error of this connection error.
func (e ConnectionError) Origin() error {
// Never return nil error here.
// If the original error is nil, return itself.
if e.err == nil {
return e
}
return e.err
}
var (
// ErrConnClosing indicates that the transport is closing.
ErrConnClosing = ConnectionError{Desc: "transport is closing", temp: true}
// ErrStreamDrain indicates that the stream is rejected by the server because
// the server stops accepting new RPCs.
ErrStreamDrain = StreamErrorf(codes.Unavailable, "the server stops accepting new RPCs")
)
// StreamError is an error that only affects one stream within a connection.
type StreamError struct {
@ -501,12 +551,25 @@ func ContextErr(err error) StreamError {
// wait blocks until it can receive from ctx.Done, closing, or proceed.
// If it receives from ctx.Done, it returns 0, the StreamError for ctx.Err.
// If it receives from done, it returns 0, io.EOF if ctx is not done; otherwise
// it return the StreamError for ctx.Err.
// If it receives from goAway, it returns 0, ErrStreamDrain.
// If it receives from closing, it returns 0, ErrConnClosing.
// If it receives from proceed, it returns the received integer, nil.
func wait(ctx context.Context, closing <-chan struct{}, proceed <-chan int) (int, error) {
func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <-chan int) (int, error) {
select {
case <-ctx.Done():
return 0, ContextErr(ctx.Err())
case <-done:
// User cancellation has precedence.
select {
case <-ctx.Done():
return 0, ContextErr(ctx.Err())
default:
}
return 0, io.EOF
case <-goAway:
return 0, ErrStreamDrain
case <-closing:
return 0, ErrConnClosing
case i := <-proceed:

1
vendor/k8s.io/client-go generated vendored
View file

@ -1 +0,0 @@
kubernetes/staging/src/k8s.io/client-go

View file

@ -325,13 +325,13 @@ func NewGenericServerResponse(code int, verb string, qualifiedResource unversion
default:
if code >= 500 {
reason = unversioned.StatusReasonInternalError
message = "an error on the server has prevented the request from succeeding"
message = fmt.Sprintf("an error on the server (%q) has prevented the request from succeeding", serverMessage)
}
}
switch {
case !qualifiedResource.IsEmpty() && len(name) > 0:
case !qualifiedResource.Empty() && len(name) > 0:
message = fmt.Sprintf("%s (%s %s %s)", message, strings.ToLower(verb), qualifiedResource.String(), name)
case !qualifiedResource.IsEmpty():
case !qualifiedResource.Empty():
message = fmt.Sprintf("%s (%s %s)", message, strings.ToLower(verb), qualifiedResource.String())
}
var causes []unversioned.StatusCause

View file

@ -30,6 +30,7 @@ import (
"k8s.io/client-go/1.4/pkg/fields"
"k8s.io/client-go/1.4/pkg/labels"
"k8s.io/client-go/1.4/pkg/runtime"
"k8s.io/client-go/1.4/pkg/selection"
"k8s.io/client-go/1.4/pkg/types"
"k8s.io/client-go/1.4/pkg/util/sets"
@ -222,6 +223,10 @@ func IsServiceIPSet(service *Service) bool {
// this function aims to check if the service's cluster IP is requested or not
func IsServiceIPRequested(service *Service) bool {
// ExternalName services are CNAME aliases to external ones. Ignore the IP.
if service.Spec.Type == ServiceTypeExternalName {
return false
}
return service.Spec.ClusterIP == ""
}
@ -379,20 +384,20 @@ func NodeSelectorRequirementsAsSelector(nsm []NodeSelectorRequirement) (labels.S
}
selector := labels.NewSelector()
for _, expr := range nsm {
var op labels.Operator
var op selection.Operator
switch expr.Operator {
case NodeSelectorOpIn:
op = labels.InOperator
op = selection.In
case NodeSelectorOpNotIn:
op = labels.NotInOperator
op = selection.NotIn
case NodeSelectorOpExists:
op = labels.ExistsOperator
op = selection.Exists
case NodeSelectorOpDoesNotExist:
op = labels.DoesNotExistOperator
op = selection.DoesNotExist
case NodeSelectorOpGt:
op = labels.GreaterThanOperator
op = selection.GreaterThan
case NodeSelectorOpLt:
op = labels.LessThanOperator
op = selection.LessThan
default:
return nil, fmt.Errorf("%q is not a valid node selector operator", expr.Operator)
}
@ -433,6 +438,20 @@ const (
// PreferAvoidPodsAnnotationKey represents the key of preferAvoidPods data (json serialized)
// in the Annotations of a Node.
PreferAvoidPodsAnnotationKey string = "scheduler.alpha.kubernetes.io/preferAvoidPods"
// SysctlsPodAnnotationKey represents the key of sysctls which are set for the infrastructure
// container of a pod. The annotation value is a comma separated list of sysctl_name=value
// key-value pairs. Only a limited set of whitelisted and isolated sysctls is supported by
// the kubelet. Pods with other sysctls will fail to launch.
SysctlsPodAnnotationKey string = "security.alpha.kubernetes.io/sysctls"
// UnsafeSysctlsPodAnnotationKey represents the key of sysctls which are set for the infrastructure
// container of a pod. The annotation value is a comma separated list of sysctl_name=value
// key-value pairs. Unsafe sysctls must be explicitly enabled for a kubelet. They are properly
// namespaced to a pod or a container, but their isolation is usually unclear or weak. Their use
// is at-your-own-risk. Pods that attempt to set an unsafe sysctl that is not enabled for a kubelet
// will fail to launch.
UnsafeSysctlsPodAnnotationKey string = "security.alpha.kubernetes.io/unsafe-sysctls"
)
// GetAffinityFromPod gets the json serialized affinity data from Pod.Annotations
@ -517,3 +536,51 @@ func GetAvoidPodsFromNodeAnnotations(annotations map[string]string) (AvoidPods,
}
return avoidPods, nil
}
// SysctlsFromPodAnnotations parses the sysctl annotations into a slice of safe Sysctls
// and a slice of unsafe Sysctls. This is only a convenience wrapper around
// SysctlsFromPodAnnotation.
func SysctlsFromPodAnnotations(a map[string]string) ([]Sysctl, []Sysctl, error) {
safe, err := SysctlsFromPodAnnotation(a[SysctlsPodAnnotationKey])
if err != nil {
return nil, nil, err
}
unsafe, err := SysctlsFromPodAnnotation(a[UnsafeSysctlsPodAnnotationKey])
if err != nil {
return nil, nil, err
}
return safe, unsafe, nil
}
// SysctlsFromPodAnnotation parses an annotation value into a slice of Sysctls.
func SysctlsFromPodAnnotation(annotation string) ([]Sysctl, error) {
if len(annotation) == 0 {
return nil, nil
}
kvs := strings.Split(annotation, ",")
sysctls := make([]Sysctl, len(kvs))
for i, kv := range kvs {
cs := strings.Split(kv, "=")
if len(cs) != 2 {
return nil, fmt.Errorf("sysctl %q not of the format sysctl_name=value", kv)
}
sysctls[i].Name = cs[0]
sysctls[i].Value = cs[1]
}
return sysctls, nil
}
// PodAnnotationsFromSysctls creates an annotation value for a slice of Sysctls.
func PodAnnotationsFromSysctls(sysctls []Sysctl) string {
if len(sysctls) == 0 {
return ""
}
kvs := make([]string, len(sysctls))
for i := range sysctls {
kvs[i] = fmt.Sprintf("%s=%s", sysctls[i].Name, sysctls[i].Value)
}
return strings.Join(kvs, ",")
}

View file

@ -21,6 +21,7 @@ import (
"k8s.io/client-go/1.4/pkg/api/meta"
"k8s.io/client-go/1.4/pkg/api/unversioned"
"k8s.io/client-go/1.4/pkg/runtime"
"k8s.io/client-go/1.4/pkg/util/sets"
)
@ -34,14 +35,21 @@ func RegisterRESTMapper(m meta.RESTMapper) {
RESTMapper = append(RESTMapper.(meta.MultiRESTMapper), m)
}
// Instantiates a DefaultRESTMapper based on types registered in api.Scheme
func NewDefaultRESTMapper(defaultGroupVersions []unversioned.GroupVersion, interfacesFunc meta.VersionInterfacesFunc,
importPathPrefix string, ignoredKinds, rootScoped sets.String) *meta.DefaultRESTMapper {
return NewDefaultRESTMapperFromScheme(defaultGroupVersions, interfacesFunc, importPathPrefix, ignoredKinds, rootScoped, Scheme)
}
// Instantiates a DefaultRESTMapper based on types registered in the given scheme.
func NewDefaultRESTMapperFromScheme(defaultGroupVersions []unversioned.GroupVersion, interfacesFunc meta.VersionInterfacesFunc,
importPathPrefix string, ignoredKinds, rootScoped sets.String, scheme *runtime.Scheme) *meta.DefaultRESTMapper {
mapper := meta.NewDefaultRESTMapper(defaultGroupVersions, interfacesFunc)
// enumerate all supported versions, get the kinds, and register with the mapper how to address
// our resources.
for _, gv := range defaultGroupVersions {
for kind, oType := range Scheme.KnownTypes(gv) {
for kind, oType := range scheme.KnownTypes(gv) {
gvk := gv.WithKind(kind)
// TODO: Remove import path check.
// We check the import path because we currently stuff both "api" and "extensions" objects

View file

@ -131,3 +131,10 @@ func (meta *ObjectMeta) SetOwnerReferences(references []metatypes.OwnerReference
}
meta.OwnerReferences = newReferences
}
func (meta *ObjectMeta) GetClusterName() string {
return meta.ClusterName
}
func (meta *ObjectMeta) SetClusterName(clusterName string) {
meta.ClusterName = clusterName
}

View file

@ -62,6 +62,8 @@ type Object interface {
SetFinalizers(finalizers []string)
GetOwnerReferences() []metatypes.OwnerReference
SetOwnerReferences([]metatypes.OwnerReference)
GetClusterName() string
SetClusterName(clusterName string)
}
var _ Object = &runtime.Unstructured{}
@ -161,16 +163,16 @@ type RESTMapping struct {
// TODO(caesarxuchao): Add proper multi-group support so that kinds & resources are
// scoped to groups. See http://issues.k8s.io/12413 and http://issues.k8s.io/10009.
type RESTMapper interface {
// KindFor takes a partial resource and returns back the single match. Returns an error if there are multiple matches
// KindFor takes a partial resource and returns the single match. Returns an error if there are multiple matches
KindFor(resource unversioned.GroupVersionResource) (unversioned.GroupVersionKind, error)
// KindsFor takes a partial resource and returns back the list of potential kinds in priority order
// KindsFor takes a partial resource and returns the list of potential kinds in priority order
KindsFor(resource unversioned.GroupVersionResource) ([]unversioned.GroupVersionKind, error)
// ResourceFor takes a partial resource and returns back the single match. Returns an error if there are multiple matches
// ResourceFor takes a partial resource and returns the single match. Returns an error if there are multiple matches
ResourceFor(input unversioned.GroupVersionResource) (unversioned.GroupVersionResource, error)
// ResourcesFor takes a partial resource and returns back the list of potential resource in priority order
// ResourcesFor takes a partial resource and returns the list of potential resource in priority order
ResourcesFor(input unversioned.GroupVersionResource) ([]unversioned.GroupVersionResource, error)
// RESTMapping identifies a preferred resource mapping for the provided group kind.

View file

@ -183,17 +183,17 @@ func (m *DefaultRESTMapper) ResourceSingularizer(resourceType string) (string, e
if !ok {
continue
}
if singular.IsEmpty() {
if singular.Empty() {
singular = currSingular
continue
}
if currSingular.Resource != singular.Resource {
return resourceType, fmt.Errorf("multiple possibile singular resources (%v) found for %v", resources, resourceType)
return resourceType, fmt.Errorf("multiple possible singular resources (%v) found for %v", resources, resourceType)
}
}
if singular.IsEmpty() {
if singular.Empty() {
return resourceType, fmt.Errorf("no singular of resource %v has been defined", resourceType)
}

View file

@ -0,0 +1,31 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package meta
import (
"k8s.io/client-go/1.4/pkg/api/unversioned"
"k8s.io/client-go/1.4/pkg/runtime"
)
// InterfacesForUnstructured returns VersionInterfaces suitable for
// dealing with runtime.Unstructured objects.
func InterfacesForUnstructured(unversioned.GroupVersion) (*VersionInterfaces, error) {
return &VersionInterfaces{
ObjectConvertor: &runtime.UnstructuredObjectConverter{},
MetadataAccessor: NewAccessor(),
}, nil
}

View file

@ -30,7 +30,7 @@ const (
// TODO: to be de!eted after v1.3 is released. PodSpec has a dedicated Subdomain field.
// The annotation value is a string specifying the subdomain e.g. "my-web-service"
// If specified, on the the pod itself, "<hostname>.my-web-service.<namespace>.svc.<cluster domain>" would resolve to
// If specified, on the pod itself, "<hostname>.my-web-service.<namespace>.svc.<cluster domain>" would resolve to
// the pod's IP.
// If there is a headless service named "my-web-service" in the same namespace as the pod, then,
// <hostname>.my-web-service.<namespace>.svc.<cluster domain>" would be resolved by the cluster DNS Server.

View file

@ -45,12 +45,12 @@ var Unversioned = unversioned.GroupVersion{Group: "", Version: "v1"}
// ParameterCodec handles versioning of objects that are converted to query parameters.
var ParameterCodec = runtime.NewParameterCodec(Scheme)
// Kind takes an unqualified kind and returns back a Group qualified GroupKind
// Kind takes an unqualified kind and returns a Group qualified GroupKind
func Kind(kind string) unversioned.GroupKind {
return SchemeGroupVersion.WithKind(kind).GroupKind()
}
// Resource takes an unqualified resource and returns back a Group qualified GroupResource
// Resource takes an unqualified resource and returns a Group qualified GroupResource
func Resource(resource string) unversioned.GroupResource {
return SchemeGroupVersion.WithResource(resource).GroupResource()
}
@ -62,7 +62,7 @@ var (
func init() {
// TODO(lavalamp): move this call to scheme builder above. Can't
// remove it from here because lots of people inapropriately rely on it
// remove it from here because lots of people inappropriately rely on it
// (specifically the unversioned time conversion). Can't have it in
// both places because then it gets double registered. Consequence of
// current state is that it only ever gets registered in the main

View file

@ -40,7 +40,8 @@
"secret": null,
"nfs": null,
"iscsi": null,
"glusterfs": null
"glusterfs": null,
"quobyte": null
}
],
"containers": [

Some files were not shown because too many files have changed in this diff Show more