From 94c2b61903c7f00e5c7f563ae287ab42f0456222 Mon Sep 17 00:00:00 2001 From: Glenn Lewis Date: Thu, 2 Oct 2014 22:44:50 -0700 Subject: [PATCH] Locally cache oauth tokens. This change is for both App Engine and Managed VMs so that these apps can scale without running into the app_identity_service quota limit due to calling appengine.AccessToken too frequently. An added benefit of caching is that calls to Google APIs will be significantly sped up due to removing the round-trip calls to the api_identity_service. --- google/appengine.go | 76 +++++++++++- google/appengine_test.go | 239 +++++++++++++++++++++++++++++++++++++ google/appenginevm.go | 76 +++++++++++- google/appenginevm_test.go | 238 ++++++++++++++++++++++++++++++++++++ google/example_test.go | 2 + 5 files changed, 625 insertions(+), 6 deletions(-) create mode 100644 google/appengine_test.go create mode 100644 google/appenginevm_test.go diff --git a/google/appengine.go b/google/appengine.go index 64afbd3..ead2cdc 100644 --- a/google/appengine.go +++ b/google/appengine.go @@ -8,13 +8,53 @@ package google import ( "net/http" + "strings" + "sync" + "time" "github.com/golang/oauth2" "appengine" + "appengine/memcache" "appengine/urlfetch" ) +// aeMemcache wraps the needed Memcache functionality to make it easy to mock +type aeMemcache struct{} + +func (m *aeMemcache) Get(c appengine.Context, key string, tok *oauth2.Token) (*memcache.Item, error) { + return memcache.Gob.Get(c, key, tok) +} + +func (m *aeMemcache) Set(c appengine.Context, item *memcache.Item) error { + return memcache.Gob.Set(c, item) +} + +type memcacher interface { + Get(c appengine.Context, key string, tok *oauth2.Token) (*memcache.Item, error) + Set(c appengine.Context, item *memcache.Item) error +} + +// memcacheGob enables mocking of the memcache.Gob calls for unit testing. +var memcacheGob memcacher = &aeMemcache{} + +// accessTokenFunc enables mocking of the appengine.AccessToken call for unit testing. +var accessTokenFunc = appengine.AccessToken + +// safetyMargin is used to avoid clock-skew problems. +// 5 minutes is conservative because tokens are valid for 60 minutes. +const safetyMargin = 5 * time.Minute + +// mu protects multiple threads from attempting to fetch a token at the same time. +var mu sync.Mutex + +// tokens implements a local cache of tokens to prevent hitting quota limits for appengine.AccessToken calls. +var tokens map[string]*oauth2.Token + +func init() { + tokens = make(map[string]*oauth2.Token) +} + // AppEngineConfig represents a configuration for an // App Engine application's Google service account. type AppEngineConfig struct { @@ -43,15 +83,45 @@ func (c *AppEngineConfig) NewTransport() *oauth2.Transport { } // FetchToken fetches a new access token for the provided scopes. +// Tokens are cached locally and also with Memcache so that the app can scale +// without hitting quota limits by calling appengine.AccessToken too frequently. func (c *AppEngineConfig) FetchToken(existing *oauth2.Token) (*oauth2.Token, error) { - token, expiry, err := appengine.AccessToken(c.context, c.scopes...) + mu.Lock() + defer mu.Unlock() + key := ":" + strings.Join(c.scopes, "_") + now := time.Now().Add(safetyMargin) + if t, ok := tokens[key]; ok && !t.Expiry.Before(now) { + return t, nil + } + delete(tokens, key) + + // Attempt to get token from Memcache + tok := new(oauth2.Token) + _, err := memcacheGob.Get(c.context, key, tok) + if err == nil && !tok.Expiry.Before(now) { + tokens[key] = tok // Save token locally + return tok, nil + } + + token, expiry, err := accessTokenFunc(c.context, c.scopes...) if err != nil { return nil, err } - return &oauth2.Token{ + t := &oauth2.Token{ AccessToken: token, Expiry: expiry, - }, nil + } + tokens[key] = t + // Also back up token in Memcache + if err = memcacheGob.Set(c.context, &memcache.Item{ + Key: key, + Value: []byte{}, + Object: *t, + Expiration: expiry.Sub(now), + }); err != nil { + c.context.Errorf("unexpected memcache.Set error: %v", err) + } + return t, nil } func (c *AppEngineConfig) transport() http.RoundTripper { diff --git a/google/appengine_test.go b/google/appengine_test.go new file mode 100644 index 0000000..9637493 --- /dev/null +++ b/google/appengine_test.go @@ -0,0 +1,239 @@ +// Copyright 2014 The oauth2 Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build appengine,!appenginevm + +package google + +import ( + "fmt" + "log" + "sync" + "testing" + "time" + + "github.com/golang/oauth2" + + "appengine" + "appengine/memcache" +) + +type tokMap map[string]*oauth2.Token + +type mockMemcache struct { + mu sync.RWMutex + vals tokMap + getCount, setCount int +} + +func (m *mockMemcache) Get(c appengine.Context, key string, tok *oauth2.Token) (*memcache.Item, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.getCount++ + v, ok := m.vals[key] + if !ok { + return nil, fmt.Errorf("unexpected test error: key %q not found", key) + } + *tok = *v + return nil, nil // memcache.Item is ignored anyway - return nil +} + +func (m *mockMemcache) Set(c appengine.Context, item *memcache.Item) error { + m.mu.Lock() + defer m.mu.Unlock() + m.setCount++ + tok, ok := item.Object.(oauth2.Token) + if !ok { + log.Fatalf("unexpected test error: item.Object is not an oauth2.Token: %#v", item) + } + m.vals[item.Key] = &tok + return nil +} + +var accessTokenCount = 0 + +func mockAccessToken(c appengine.Context, scopes ...string) (token string, expiry time.Time, err error) { + accessTokenCount++ + return "mytoken", time.Now(), nil +} + +const ( + testScope = "myscope" + testScopeKey = ":" + testScope +) + +func init() { + accessTokenFunc = mockAccessToken +} + +func TestFetchTokenLocalCacheMiss(t *testing.T) { + m := &mockMemcache{vals: make(tokMap)} + memcacheGob = m + accessTokenCount = 0 + delete(tokens, testScopeKey) // clear local cache + config := NewAppEngineConfig(nil, testScope) + _, err := config.FetchToken(nil) + if err != nil { + t.Errorf("unable to FetchToken: %v", err) + } + if w := 1; m.getCount != w { + t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w) + } + if w := 1; accessTokenCount != w { + t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w) + } + if w := 1; m.setCount != w { + t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w) + } + // Make sure local cache has been populated + _, ok := tokens[testScopeKey] + if !ok { + t.Errorf("local cache not populated!") + } +} + +func TestFetchTokenLocalCacheHit(t *testing.T) { + m := &mockMemcache{vals: make(tokMap)} + memcacheGob = m + accessTokenCount = 0 + // Pre-populate the local cache + tokens[testScopeKey] = &oauth2.Token{ + AccessToken: "mytoken", + Expiry: time.Now().Add(1 * time.Hour), + } + config := NewAppEngineConfig(nil, testScope) + _, err := config.FetchToken(nil) + if err != nil { + t.Errorf("unable to FetchToken: %v", err) + } + if w := 0; m.getCount != w { + t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w) + } + if w := 0; accessTokenCount != w { + t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w) + } + if w := 0; m.setCount != w { + t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w) + } + // Make sure local cache remains populated + _, ok := tokens[testScopeKey] + if !ok { + t.Errorf("local cache not populated!") + } +} + +func TestFetchTokenMemcacheHit(t *testing.T) { + m := &mockMemcache{vals: make(tokMap)} + memcacheGob = m + accessTokenCount = 0 + delete(tokens, testScopeKey) // clear local cache + // Pre-populate the memcache + tok := &oauth2.Token{ + AccessToken: "mytoken", + Expiry: time.Now().Add(1 * time.Hour), + } + m.Set(nil, &memcache.Item{ + Key: testScopeKey, + Object: *tok, + Expiration: 1 * time.Hour, + }) + m.setCount = 0 + config := NewAppEngineConfig(nil, testScope) + _, err := config.FetchToken(nil) + if err != nil { + t.Errorf("unable to FetchToken: %v", err) + } + if w := 1; m.getCount != w { + t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w) + } + if w := 0; accessTokenCount != w { + t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w) + } + if w := 0; m.setCount != w { + t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w) + } + // Make sure local cache has been populated + _, ok := tokens[testScopeKey] + if !ok { + t.Errorf("local cache not populated!") + } +} + +func TestFetchTokenLocalCacheExpired(t *testing.T) { + m := &mockMemcache{vals: make(tokMap)} + memcacheGob = m + accessTokenCount = 0 + // Pre-populate the local cache + tokens[testScopeKey] = &oauth2.Token{ + AccessToken: "mytoken", + Expiry: time.Now().Add(-1 * time.Hour), + } + // Pre-populate the memcache + tok := &oauth2.Token{ + AccessToken: "mytoken", + Expiry: time.Now().Add(1 * time.Hour), + } + m.Set(nil, &memcache.Item{ + Key: testScopeKey, + Object: *tok, + Expiration: 1 * time.Hour, + }) + m.setCount = 0 + config := NewAppEngineConfig(nil, testScope) + _, err := config.FetchToken(nil) + if err != nil { + t.Errorf("unable to FetchToken: %v", err) + } + if w := 1; m.getCount != w { + t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w) + } + if w := 0; accessTokenCount != w { + t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w) + } + if w := 0; m.setCount != w { + t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w) + } + // Make sure local cache remains populated + _, ok := tokens[testScopeKey] + if !ok { + t.Errorf("local cache not populated!") + } +} + +func TestFetchTokenMemcacheExpired(t *testing.T) { + m := &mockMemcache{vals: make(tokMap)} + memcacheGob = m + accessTokenCount = 0 + delete(tokens, testScopeKey) // clear local cache + // Pre-populate the memcache + tok := &oauth2.Token{ + AccessToken: "mytoken", + Expiry: time.Now().Add(-1 * time.Hour), + } + m.Set(nil, &memcache.Item{ + Key: testScopeKey, + Object: *tok, + Expiration: -1 * time.Hour, + }) + m.setCount = 0 + config := NewAppEngineConfig(nil, testScope) + _, err := config.FetchToken(nil) + if err != nil { + t.Errorf("unable to FetchToken: %v", err) + } + if w := 1; m.getCount != w { + t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w) + } + if w := 1; accessTokenCount != w { + t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w) + } + if w := 1; m.setCount != w { + t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w) + } + // Make sure local cache has been populated + _, ok := tokens[testScopeKey] + if !ok { + t.Errorf("local cache not populated!") + } +} diff --git a/google/appenginevm.go b/google/appenginevm.go index 291791a..ef88850 100644 --- a/google/appenginevm.go +++ b/google/appenginevm.go @@ -8,11 +8,51 @@ package google import ( "net/http" + "strings" + "sync" + "time" "github.com/golang/oauth2" "google.golang.org/appengine" + "google.golang.org/appengine/memcache" ) +// aeMemcache wraps the needed Memcache functionality to make it easy to mock +type aeMemcache struct{} + +func (m *aeMemcache) Get(c appengine.Context, key string, tok *oauth2.Token) (*memcache.Item, error) { + return memcache.Gob.Get(c, key, tok) +} + +func (m *aeMemcache) Set(c appengine.Context, item *memcache.Item) error { + return memcache.Gob.Set(c, item) +} + +type memcacher interface { + Get(c appengine.Context, key string, tok *oauth2.Token) (*memcache.Item, error) + Set(c appengine.Context, item *memcache.Item) error +} + +// memcacheGob enables mocking of the memcache.Gob calls for unit testing. +var memcacheGob memcacher = &aeMemcache{} + +// accessTokenFunc enables mocking of the appengine.AccessToken call for unit testing. +var accessTokenFunc = appengine.AccessToken + +// safetyMargin is used to avoid clock-skew problems. +// 5 minutes is conservative because tokens are valid for 60 minutes. +const safetyMargin = 5 * time.Minute + +// mu protects multiple threads from attempting to fetch a token at the same time. +var mu sync.Mutex + +// tokens implements a local cache of tokens to prevent hitting quota limits for appengine.AccessToken calls. +var tokens map[string]*oauth2.Token + +func init() { + tokens = make(map[string]*oauth2.Token) +} + // AppEngineConfig represents a configuration for an // App Engine application's Google service account. type AppEngineConfig struct { @@ -41,15 +81,45 @@ func (c *AppEngineConfig) NewTransport() *oauth2.Transport { } // FetchToken fetches a new access token for the provided scopes. +// Tokens are cached locally and also with Memcache so that the app can scale +// without hitting quota limits by calling appengine.AccessToken too frequently. func (c *AppEngineConfig) FetchToken(existing *oauth2.Token) (*oauth2.Token, error) { - token, expiry, err := appengine.AccessToken(c.context, c.scopes...) + mu.Lock() + defer mu.Unlock() + key := ":" + strings.Join(c.scopes, "_") + now := time.Now().Add(safetyMargin) + if t, ok := tokens[key]; ok && !t.Expiry.Before(now) { + return t, nil + } + delete(tokens, key) + + // Attempt to get token from Memcache + tok := new(oauth2.Token) + _, err := memcacheGob.Get(c.context, key, tok) + if err == nil && !tok.Expiry.Before(now) { + tokens[key] = tok // Save token locally + return tok, nil + } + + token, expiry, err := accessTokenFunc(c.context, c.scopes...) if err != nil { return nil, err } - return &oauth2.Token{ + t := &oauth2.Token{ AccessToken: token, Expiry: expiry, - }, nil + } + tokens[key] = t + // Also back up token in Memcache + if err = memcacheGob.Set(c.context, &memcache.Item{ + Key: key, + Value: []byte{}, + Object: *t, + Expiration: expiry.Sub(now), + }); err != nil { + c.context.Errorf("unexpected memcache.Set error: %v", err) + } + return t, nil } func (c *AppEngineConfig) transport() http.RoundTripper { diff --git a/google/appenginevm_test.go b/google/appenginevm_test.go new file mode 100644 index 0000000..dc3ecf5 --- /dev/null +++ b/google/appenginevm_test.go @@ -0,0 +1,238 @@ +// Copyright 2014 The oauth2 Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build appenginevm !appengine + +package google + +import ( + "fmt" + "log" + "sync" + "testing" + "time" + + "github.com/golang/oauth2" + "google.golang.org/appengine" + "google.golang.org/appengine/memcache" +) + +type tokMap map[string]*oauth2.Token + +type mockMemcache struct { + mu sync.RWMutex + vals tokMap + getCount, setCount int +} + +func (m *mockMemcache) Get(c appengine.Context, key string, tok *oauth2.Token) (*memcache.Item, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.getCount++ + v, ok := m.vals[key] + if !ok { + return nil, fmt.Errorf("unexpected test error: key %q not found", key) + } + *tok = *v + return nil, nil // memcache.Item is ignored anyway - return nil +} + +func (m *mockMemcache) Set(c appengine.Context, item *memcache.Item) error { + m.mu.Lock() + defer m.mu.Unlock() + m.setCount++ + tok, ok := item.Object.(oauth2.Token) + if !ok { + log.Fatalf("unexpected test error: item.Object is not an oauth2.Token: %#v", item) + } + m.vals[item.Key] = &tok + return nil +} + +var accessTokenCount = 0 + +func mockAccessToken(c appengine.Context, scopes ...string) (token string, expiry time.Time, err error) { + accessTokenCount++ + return "mytoken", time.Now(), nil +} + +const ( + testScope = "myscope" + testScopeKey = ":" + testScope +) + +func init() { + accessTokenFunc = mockAccessToken +} + +func TestFetchTokenLocalCacheMiss(t *testing.T) { + m := &mockMemcache{vals: make(tokMap)} + memcacheGob = m + accessTokenCount = 0 + delete(tokens, testScopeKey) // clear local cache + config := NewAppEngineConfig(nil, testScope) + _, err := config.FetchToken(nil) + if err != nil { + t.Errorf("unable to FetchToken: %v", err) + } + if w := 1; m.getCount != w { + t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w) + } + if w := 1; accessTokenCount != w { + t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w) + } + if w := 1; m.setCount != w { + t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w) + } + // Make sure local cache has been populated + _, ok := tokens[testScopeKey] + if !ok { + t.Errorf("local cache not populated!") + } +} + +func TestFetchTokenLocalCacheHit(t *testing.T) { + m := &mockMemcache{vals: make(tokMap)} + memcacheGob = m + accessTokenCount = 0 + // Pre-populate the local cache + tokens[testScopeKey] = &oauth2.Token{ + AccessToken: "mytoken", + Expiry: time.Now().Add(1 * time.Hour), + } + config := NewAppEngineConfig(nil, testScope) + _, err := config.FetchToken(nil) + if err != nil { + t.Errorf("unable to FetchToken: %v", err) + } + if w := 0; m.getCount != w { + t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w) + } + if w := 0; accessTokenCount != w { + t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w) + } + if w := 0; m.setCount != w { + t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w) + } + // Make sure local cache remains populated + _, ok := tokens[testScopeKey] + if !ok { + t.Errorf("local cache not populated!") + } +} + +func TestFetchTokenMemcacheHit(t *testing.T) { + m := &mockMemcache{vals: make(tokMap)} + memcacheGob = m + accessTokenCount = 0 + delete(tokens, testScopeKey) // clear local cache + // Pre-populate the memcache + tok := &oauth2.Token{ + AccessToken: "mytoken", + Expiry: time.Now().Add(1 * time.Hour), + } + m.Set(nil, &memcache.Item{ + Key: testScopeKey, + Object: *tok, + Expiration: 1 * time.Hour, + }) + m.setCount = 0 + config := NewAppEngineConfig(nil, testScope) + _, err := config.FetchToken(nil) + if err != nil { + t.Errorf("unable to FetchToken: %v", err) + } + if w := 1; m.getCount != w { + t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w) + } + if w := 0; accessTokenCount != w { + t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w) + } + if w := 0; m.setCount != w { + t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w) + } + // Make sure local cache has been populated + _, ok := tokens[testScopeKey] + if !ok { + t.Errorf("local cache not populated!") + } +} + +func TestFetchTokenLocalCacheExpired(t *testing.T) { + m := &mockMemcache{vals: make(tokMap)} + memcacheGob = m + accessTokenCount = 0 + // Pre-populate the local cache + tokens[testScopeKey] = &oauth2.Token{ + AccessToken: "mytoken", + Expiry: time.Now().Add(-1 * time.Hour), + } + // Pre-populate the memcache + tok := &oauth2.Token{ + AccessToken: "mytoken", + Expiry: time.Now().Add(1 * time.Hour), + } + m.Set(nil, &memcache.Item{ + Key: testScopeKey, + Object: *tok, + Expiration: 1 * time.Hour, + }) + m.setCount = 0 + config := NewAppEngineConfig(nil, testScope) + _, err := config.FetchToken(nil) + if err != nil { + t.Errorf("unable to FetchToken: %v", err) + } + if w := 1; m.getCount != w { + t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w) + } + if w := 0; accessTokenCount != w { + t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w) + } + if w := 0; m.setCount != w { + t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w) + } + // Make sure local cache remains populated + _, ok := tokens[testScopeKey] + if !ok { + t.Errorf("local cache not populated!") + } +} + +func TestFetchTokenMemcacheExpired(t *testing.T) { + m := &mockMemcache{vals: make(tokMap)} + memcacheGob = m + accessTokenCount = 0 + delete(tokens, testScopeKey) // clear local cache + // Pre-populate the memcache + tok := &oauth2.Token{ + AccessToken: "mytoken", + Expiry: time.Now().Add(-1 * time.Hour), + } + m.Set(nil, &memcache.Item{ + Key: testScopeKey, + Object: *tok, + Expiration: -1 * time.Hour, + }) + m.setCount = 0 + config := NewAppEngineConfig(nil, testScope) + _, err := config.FetchToken(nil) + if err != nil { + t.Errorf("unable to FetchToken: %v", err) + } + if w := 1; m.getCount != w { + t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w) + } + if w := 1; accessTokenCount != w { + t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w) + } + if w := 1; m.setCount != w { + t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w) + } + // Make sure local cache has been populated + _, ok := tokens[testScopeKey] + if !ok { + t.Errorf("local cache not populated!") + } +} diff --git a/google/example_test.go b/google/example_test.go index 979efc5..5471d93 100644 --- a/google/example_test.go +++ b/google/example_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +// +build appenginevm !appengine + package google_test import (