refactor(server): extract struct AuthRequestManager

This commit is contained in:
MuXiu1997
2023-01-26 05:00:14 +08:00
parent 28e126f2e4
commit 6ba37f6fe0
3 changed files with 53 additions and 12 deletions

View File

@@ -25,10 +25,10 @@ type App struct {
Server *http.Server
Engine *gin.Engine
GitHubOAuthConfig *oauth2.Config
AuthRequestManager *cache.Cache
AuthRequestManager *AuthRequestManager
}
func NewApp(config *Config, server *http.Server, engine *gin.Engine, authRequestManager *cache.Cache) *App {
func NewApp(config *Config, server *http.Server, engine *gin.Engine, authRequestManager *AuthRequestManager) *App {
server.Addr = config.ServerAddress
server.Handler = engine
@@ -59,7 +59,7 @@ func NewDefaultApp() *App {
ReadHeaderTimeout: 5 * time.Second,
},
gin.Default(),
cache.New(10*time.Minute, 30*time.Minute),
NewAuthRequestManager(cache.New(10*time.Minute, 30*time.Minute)),
)
}

View File

@@ -0,0 +1,39 @@
package traefik_github_oauth_server
import (
"github.com/muxiu1997/traefik-github-oauth-plugin/internal/app/traefik-github-oauth-server/model"
"github.com/patrickmn/go-cache"
"github.com/rs/xid"
)
type AuthRequestManager struct {
cache *cache.Cache
}
func NewAuthRequestManager(cache *cache.Cache) *AuthRequestManager {
return &AuthRequestManager{
cache: cache,
}
}
func (m *AuthRequestManager) Insert(aq *model.AuthRequest) string {
rid := xid.New().String()
m.cache.SetDefault(rid, aq)
return rid
}
func (m *AuthRequestManager) Get(rid string) (*model.AuthRequest, bool) {
authRequest, found := m.cache.Get(rid)
if !found {
return nil, false
}
return authRequest.(*model.AuthRequest), true
}
func (m *AuthRequestManager) Pop(rid string) (*model.AuthRequest, bool) {
aq, found := m.Get(rid)
if found {
m.cache.Delete(rid)
}
return aq, found
}

View File

@@ -11,7 +11,6 @@ import (
server "github.com/muxiu1997/traefik-github-oauth-plugin/internal/app/traefik-github-oauth-server"
"github.com/muxiu1997/traefik-github-oauth-plugin/internal/app/traefik-github-oauth-server/model"
"github.com/muxiu1997/traefik-github-oauth-plugin/internal/pkg/constant"
"github.com/rs/xid"
"github.com/spf13/cast"
"golang.org/x/oauth2"
)
@@ -24,8 +23,7 @@ func generateOAuthPageURL(app *server.App) gin.HandlerFunc {
return
}
rid := xid.New().String()
app.AuthRequestManager.SetDefault(rid, &model.AuthRequest{
rid := app.AuthRequestManager.Insert(&model.AuthRequest{
RedirectURI: body.RedirectURI,
AuthURL: body.AuthURL,
})
@@ -57,22 +55,27 @@ func redirect(app *server.App) gin.HandlerFunc {
if err != nil {
return
}
authRequestCache, found := app.AuthRequestManager.Get(query.RID)
authRequest, found := app.AuthRequestManager.Get(query.RID)
if !found {
c.String(http.StatusBadRequest, "invalid rid")
return
}
authRequest := authRequestCache.(*model.AuthRequest)
user, err := oAuthCodeToUser(c.Request.Context(), app.GitHubOAuthConfig, query.Code)
if err != nil {
c.String(http.StatusInternalServerError, err.Error())
return
}
authRequest.GitHubUserID = cast.ToString(user.GetID())
authRequest.GitHubUserLogin = user.GetLogin()
authURL, _ := url.Parse(authRequest.AuthURL)
authURL, err := url.Parse(authRequest.AuthURL)
if err != nil {
c.String(http.StatusInternalServerError, "invalid auth url: %s", authRequest.AuthURL)
return
}
authURLQuery := authURL.Query()
authURLQuery.Set(constant.QUERY_KEY_REQUEST_ID, query.RID)
authURL.RawQuery = authURLQuery.Encode()
@@ -88,13 +91,12 @@ func getAuthResult(app *server.App) gin.HandlerFunc {
if err != nil {
return
}
authRequestCache, found := app.AuthRequestManager.Get(query.RID)
authRequest, found := app.AuthRequestManager.Pop(query.RID)
if !found {
c.String(http.StatusBadRequest, "invalid rid")
return
}
defer app.AuthRequestManager.Delete(query.RID)
authRequest := authRequestCache.(*model.AuthRequest)
c.JSON(
http.StatusOK,