mirror of
https://github.com/MuXiu1997/traefik-github-oauth-plugin
synced 2025-12-17 18:31:27 +00:00
refactor(server): extract struct AuthRequestManager
This commit is contained in:
@@ -25,10 +25,10 @@ type App struct {
|
||||
Server *http.Server
|
||||
Engine *gin.Engine
|
||||
GitHubOAuthConfig *oauth2.Config
|
||||
AuthRequestManager *cache.Cache
|
||||
AuthRequestManager *AuthRequestManager
|
||||
}
|
||||
|
||||
func NewApp(config *Config, server *http.Server, engine *gin.Engine, authRequestManager *cache.Cache) *App {
|
||||
func NewApp(config *Config, server *http.Server, engine *gin.Engine, authRequestManager *AuthRequestManager) *App {
|
||||
server.Addr = config.ServerAddress
|
||||
server.Handler = engine
|
||||
|
||||
@@ -59,7 +59,7 @@ func NewDefaultApp() *App {
|
||||
ReadHeaderTimeout: 5 * time.Second,
|
||||
},
|
||||
gin.Default(),
|
||||
cache.New(10*time.Minute, 30*time.Minute),
|
||||
NewAuthRequestManager(cache.New(10*time.Minute, 30*time.Minute)),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
package traefik_github_oauth_server
|
||||
|
||||
import (
|
||||
"github.com/muxiu1997/traefik-github-oauth-plugin/internal/app/traefik-github-oauth-server/model"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"github.com/rs/xid"
|
||||
)
|
||||
|
||||
type AuthRequestManager struct {
|
||||
cache *cache.Cache
|
||||
}
|
||||
|
||||
func NewAuthRequestManager(cache *cache.Cache) *AuthRequestManager {
|
||||
return &AuthRequestManager{
|
||||
cache: cache,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *AuthRequestManager) Insert(aq *model.AuthRequest) string {
|
||||
rid := xid.New().String()
|
||||
m.cache.SetDefault(rid, aq)
|
||||
return rid
|
||||
}
|
||||
|
||||
func (m *AuthRequestManager) Get(rid string) (*model.AuthRequest, bool) {
|
||||
authRequest, found := m.cache.Get(rid)
|
||||
if !found {
|
||||
return nil, false
|
||||
}
|
||||
return authRequest.(*model.AuthRequest), true
|
||||
}
|
||||
|
||||
func (m *AuthRequestManager) Pop(rid string) (*model.AuthRequest, bool) {
|
||||
aq, found := m.Get(rid)
|
||||
if found {
|
||||
m.cache.Delete(rid)
|
||||
}
|
||||
return aq, found
|
||||
}
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
server "github.com/muxiu1997/traefik-github-oauth-plugin/internal/app/traefik-github-oauth-server"
|
||||
"github.com/muxiu1997/traefik-github-oauth-plugin/internal/app/traefik-github-oauth-server/model"
|
||||
"github.com/muxiu1997/traefik-github-oauth-plugin/internal/pkg/constant"
|
||||
"github.com/rs/xid"
|
||||
"github.com/spf13/cast"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
@@ -24,8 +23,7 @@ func generateOAuthPageURL(app *server.App) gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
rid := xid.New().String()
|
||||
app.AuthRequestManager.SetDefault(rid, &model.AuthRequest{
|
||||
rid := app.AuthRequestManager.Insert(&model.AuthRequest{
|
||||
RedirectURI: body.RedirectURI,
|
||||
AuthURL: body.AuthURL,
|
||||
})
|
||||
@@ -57,22 +55,27 @@ func redirect(app *server.App) gin.HandlerFunc {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
authRequestCache, found := app.AuthRequestManager.Get(query.RID)
|
||||
|
||||
authRequest, found := app.AuthRequestManager.Get(query.RID)
|
||||
if !found {
|
||||
c.String(http.StatusBadRequest, "invalid rid")
|
||||
return
|
||||
}
|
||||
authRequest := authRequestCache.(*model.AuthRequest)
|
||||
|
||||
user, err := oAuthCodeToUser(c.Request.Context(), app.GitHubOAuthConfig, query.Code)
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
authRequest.GitHubUserID = cast.ToString(user.GetID())
|
||||
authRequest.GitHubUserLogin = user.GetLogin()
|
||||
|
||||
authURL, _ := url.Parse(authRequest.AuthURL)
|
||||
authURL, err := url.Parse(authRequest.AuthURL)
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, "invalid auth url: %s", authRequest.AuthURL)
|
||||
return
|
||||
}
|
||||
authURLQuery := authURL.Query()
|
||||
authURLQuery.Set(constant.QUERY_KEY_REQUEST_ID, query.RID)
|
||||
authURL.RawQuery = authURLQuery.Encode()
|
||||
@@ -88,13 +91,12 @@ func getAuthResult(app *server.App) gin.HandlerFunc {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
authRequestCache, found := app.AuthRequestManager.Get(query.RID)
|
||||
|
||||
authRequest, found := app.AuthRequestManager.Pop(query.RID)
|
||||
if !found {
|
||||
c.String(http.StatusBadRequest, "invalid rid")
|
||||
return
|
||||
}
|
||||
defer app.AuthRequestManager.Delete(query.RID)
|
||||
authRequest := authRequestCache.(*model.AuthRequest)
|
||||
|
||||
c.JSON(
|
||||
http.StatusOK,
|
||||
|
||||
Reference in New Issue
Block a user