use the middlewares

This commit is contained in:
Hamidreza Ghavami 2023-05-31 01:24:18 +04:30
parent 8170b65db4
commit ea7fe09c27
No known key found for this signature in database
GPG Key ID: 402C6797325182D9
2 changed files with 16 additions and 37 deletions

View File

@ -7,10 +7,10 @@ import (
"net" "net"
"net/http" "net/http"
"strconv" "strconv"
"strings"
"x-ui/config" "x-ui/config"
"x-ui/logger" "x-ui/logger"
"x-ui/util/common" "x-ui/util/common"
"x-ui/web/middleware"
"x-ui/web/network" "x-ui/web/network"
"x-ui/web/service" "x-ui/web/service"
@ -58,18 +58,7 @@ func (s *Server) initRouter() (*gin.Engine, error) {
} }
if subDomain != "" { if subDomain != "" {
validateDomain := func(c *gin.Context) { engine.Use(middleware.DomainValidatorMiddleware(subDomain))
host := strings.Split(c.Request.Host, ":")[0]
if host != subDomain {
c.AbortWithStatus(http.StatusForbidden)
return
}
c.Next()
}
engine.Use(validateDomain)
} }
g := engine.Group(subPath) g := engine.Group(subPath)
@ -116,11 +105,13 @@ func (s *Server) Start() (err error) {
if err != nil { if err != nil {
return err return err
} }
listenAddr := net.JoinHostPort(listen, strconv.Itoa(port)) listenAddr := net.JoinHostPort(listen, strconv.Itoa(port))
listener, err := net.Listen("tcp", listenAddr) listener, err := net.Listen("tcp", listenAddr)
if err != nil { if err != nil {
return err return err
} }
if certFile != "" || keyFile != "" { if certFile != "" || keyFile != "" {
cert, err := tls.LoadX509KeyPair(certFile, keyFile) cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil { if err != nil {
@ -168,4 +159,4 @@ func (s *Server) Stop() error {
func (s *Server) GetCtx() context.Context { func (s *Server) GetCtx() context.Context {
return s.ctx return s.ctx
} }

View File

@ -19,6 +19,7 @@ import (
"x-ui/web/controller" "x-ui/web/controller"
"x-ui/web/job" "x-ui/web/job"
"x-ui/web/locale" "x-ui/web/locale"
"x-ui/web/middleware"
"x-ui/web/network" "x-ui/web/network"
"x-ui/web/service" "x-ui/web/service"
@ -144,28 +145,6 @@ func (s *Server) getHtmlTemplate(funcMap template.FuncMap) (*template.Template,
return t, nil return t, nil
} }
func redirectMiddleware(basePath string) gin.HandlerFunc {
return func(c *gin.Context) {
// Redirect from old '/xui' path to '/panel'
path := c.Request.URL.Path
redirects := map[string]string{
"panel/API": "panel/api",
"xui/API": "panel/api",
"xui": "panel",
}
for from, to := range redirects {
from, to = basePath+from, basePath+to
if strings.HasPrefix(path, from) {
newPath := to + path[len(from):]
c.Redirect(http.StatusMovedPermanently, newPath)
c.Abort()
return
}
}
c.Next()
}
}
func (s *Server) initRouter() (*gin.Engine, error) { func (s *Server) initRouter() (*gin.Engine, error) {
if config.IsDebug() { if config.IsDebug() {
gin.SetMode(gin.DebugMode) gin.SetMode(gin.DebugMode)
@ -177,6 +156,15 @@ func (s *Server) initRouter() (*gin.Engine, error) {
engine := gin.Default() engine := gin.Default()
webDomain, err := s.settingService.GetWebDomain()
if err != nil {
return nil, err
}
if webDomain != "" {
engine.Use(middleware.DomainValidatorMiddleware(webDomain))
}
secret, err := s.settingService.GetSecret() secret, err := s.settingService.GetSecret()
if err != nil { if err != nil {
return nil, err return nil, err
@ -233,7 +221,7 @@ func (s *Server) initRouter() (*gin.Engine, error) {
} }
// Apply the redirect middleware (`/xui` to `/panel`) // Apply the redirect middleware (`/xui` to `/panel`)
engine.Use(redirectMiddleware(basePath)) engine.Use(middleware.RedirectMiddleware(basePath))
g := engine.Group(basePath) g := engine.Group(basePath)