diff --git a/app/app-context.go b/app/app-context.go index cc1f03e..71f8e8a 100644 --- a/app/app-context.go +++ b/app/app-context.go @@ -3,6 +3,7 @@ package app import ( "context" "fmt" + "net/http" "git.farahty.com/nimer/go-mongo/models" ) @@ -16,8 +17,16 @@ var ( StatusKey = &contextKey{"status"} ExpiryKey = &contextKey{"expiry"} LoadersKey = &contextKey{"dataloaders"} + WriterKye = &contextKey{"writer"} ) +func WriterFor(ctx context.Context) (*http.ResponseWriter, error) { + if writer, ok := ctx.Value(WriterKye).(*http.ResponseWriter); ok { + return writer, nil + } + return nil, fmt.Errorf("no writer found in context") +} + // Retrieves the current user from the context func CurrentUser(ctx context.Context) (*models.UserJWT, error) { user, _ := ctx.Value(UserKey).(*models.UserJWT) diff --git a/app/helpers.go b/app/helpers.go index 2ab7cf7..b21e5ed 100644 --- a/app/helpers.go +++ b/app/helpers.go @@ -30,7 +30,7 @@ func getTokenFromHeader(r *http.Request) (string, error) { func getTokenFromCookie(r *http.Request) (string, error) { - cookie, err := r.Cookie("app-access-token") + cookie, err := r.Cookie("access_token") if err != nil { return "", fmt.Errorf("there is no authorization cookie provided") diff --git a/app/middlewares.go b/app/middlewares.go index c74d4de..84cdd00 100644 --- a/app/middlewares.go +++ b/app/middlewares.go @@ -3,6 +3,7 @@ package app import ( "context" "errors" + "net/http" "github.com/99designs/gqlgen/graphql" @@ -17,18 +18,33 @@ func ExpiryMiddleware(ctx context.Context, next graphql.ResponseHandler) *graphq return next(ctx) } +// add response writer to context for GraphQL resolvers +func WriterMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + ctx := context.WithValue(r.Context(), WriterKye, &rw) + next.ServeHTTP(rw, r.WithContext(ctx)) + }) +} + // AuthMiddleware parses JWT token and injects user context for HTTP requests func AuthMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - tokenStr, err := getTokenFromHeader(r) - if err != nil { - ctx := SetStatus(r.Context(), err.Error()) + headerToken, headerErr := getTokenFromHeader(r) + cookieToken, cookieErr := getTokenFromCookie(r) + + if headerErr != nil && cookieErr != nil { + ctx := SetStatus(r.Context(), headerErr.Error()) next.ServeHTTP(rw, r.WithContext(ctx)) return } - user, err := getUserFromToken(tokenStr) + token := headerToken + if token == "" { + token = cookieToken + } + + user, err := getUserFromToken(token) ctx := r.Context() if err != nil { diff --git a/main.go b/main.go index 42e08e5..e86ad8c 100644 --- a/main.go +++ b/main.go @@ -2,98 +2,119 @@ package main import ( "context" + "log" + "net/http" "os" "os/signal" "time" - "log" - "net/http" - "git.farahty.com/nimer/go-mongo/app" + "github.com/fatih/color" "github.com/joho/godotenv" "github.com/redis/go-redis/v9" ) func main() { + // Setup cancelable root context + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - color.Yellow("Starting server ...\n") - - if _, exists := os.LookupEnv("MONGO_URI"); !exists { - err := godotenv.Load() - if err != nil { - log.Fatal("Error loading .env file\n") - } - - color.Green("✅ .env loaded\n") - } - - port := os.Getenv("PORT") - - if cancel, err := app.Connect(); err != nil { - cancel() - log.Fatal(err) - } else { - defer func() { - color.Red("❌ Database Connection Closed\n") - cancel() - }() - } - + // Panic recovery defer func() { - if err := app.Mongo.Disconnect(context.Background()); err != nil { - log.Fatal("MogoDB Errors" + err.Error()) + if r := recover(); r != nil { + color.Red("🔴 Panic occurred: %v\n", r) + os.Exit(1) } }() - color.Green("✅ Connected to Database successfully\n") - if err := app.LoadAuthorizer(context.Background()); err != nil { - log.Fatal("Authorizer Errors : " + err.Error()) + color.Yellow("🚀 Starting server ...\n") + + // Load .env if needed + if _, exists := os.LookupEnv("MONGO_URI"); !exists { + if err := godotenv.Load(); err != nil { + log.Fatal("🔴 Failed to load .env file: ", err) + } + color.Green("✅ .env file loaded\n") } + // Validate environment variables + requiredEnvs := []string{"PORT", "MONGO_URI", "REDIS_HOST", "REDIS_PORT"} + for _, key := range requiredEnvs { + if os.Getenv(key) == "" { + log.Fatalf("🔴 Required environment variable %s is missing", key) + } + } + port := os.Getenv("PORT") + + // Connect to Mongo + dbCancel, err := app.Connect() + if err != nil { + log.Fatalf("🔴 MongoDB connection error: %v", err) + } + defer func() { + color.Red("❌ Closing MongoDB connection...\n") + dbCancel() + if err := app.Mongo.Disconnect(ctx); err != nil { + log.Fatal("🔴 MongoDB disconnection error: ", err) + } + }() + color.Green("✅ Connected to MongoDB successfully\n") + + // Load authorization policies using root context + if err := app.LoadAuthorizer(ctx); err != nil { + log.Fatal("🔴 Authorizer error: ", err) + } color.Green("✅ Authorization policies loaded successfully\n") + // Redis redisClient := redis.NewClient(&redis.Options{ Addr: os.Getenv("REDIS_HOST") + ":" + os.Getenv("REDIS_PORT"), - Password: os.Getenv("REDIS_PASSWORD"), // no password set + Password: os.Getenv("REDIS_PASSWORD"), + DB: 0, }) - if _, err := redisClient.Ping(context.Background()).Result(); err != nil { - log.Fatal("Redis Error : " + err.Error()) + if _, err := redisClient.Ping(ctx).Result(); err != nil { + log.Fatal("🔴 Redis connection error: ", err) } - - defer redisClient.Close() - + defer func() { + color.Red("❌ Closing Redis connection...\n") + _ = redisClient.Close() + }() color.Green("✅ Connected to Redis cache successfully\n") + // Create GraphQL server graphqlServer := createGraphqlServer(redisClient) - color.Green("🚀 Server Started at http://localhost:" + port + "\n") - - //http.ListenAndServe(":"+port, createRouter(graphqlServer)) - + // Start HTTP server server := &http.Server{ Addr: ":" + port, - WriteTimeout: time.Second * 30, - ReadTimeout: time.Second * 30, - IdleTimeout: time.Second * 30, Handler: createRouter(graphqlServer), + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 30 * time.Second, } - go server.ListenAndServe() + go func() { + color.Green("🌐 Server listening at http://localhost:%s\n", port) + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Fatalf("🔴 Server failed: %v", err) + } + }() - // Wait for interrupt signal to gracefully shut down the server with - // a timeout of 15 seconds. + // Graceful shutdown quit := make(chan os.Signal, 1) signal.Notify(quit, os.Interrupt) - <-quit - color.Yellow(" 🎬 Start Shutdown Signal ... ") - ctx, cancelShutdown := context.WithTimeout(context.Background(), 15*time.Second) - defer cancelShutdown() - if err := server.Shutdown(ctx); err != nil { - log.Fatal("Server Shutdown:", err) + color.Yellow("🟡 Shutdown signal received, initiating cleanup...") + + // Cancel root context and wait for graceful shutdown + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 15*time.Second) + defer shutdownCancel() + + if err := server.Shutdown(shutdownCtx); err != nil { + log.Fatalf("🔴 Server forced to shutdown: %v", err) } - color.Red("❌ Server Exiting") + color.Green("✅ Server shutdown completed gracefully") } diff --git a/resolvers/auth.resolvers.go b/resolvers/auth.resolvers.go index 751f5c5..8cfd797 100644 --- a/resolvers/auth.resolvers.go +++ b/resolvers/auth.resolvers.go @@ -14,6 +14,7 @@ import ( // Login is the resolver for the login field. func (r *mutationResolver) Login(ctx context.Context, input models.LoginInput) (*models.LoginResponse, error) { + return authService.Login(ctx, &input) } diff --git a/router.go b/router.go index 2c16324..6f4f81c 100644 --- a/router.go +++ b/router.go @@ -41,6 +41,8 @@ func createRouter(graphqlServer http.Handler) chi.Router { // Custom middleware for Auth router.Use(app.AuthMiddleware) + router.Use(app.WriterMiddleware) + // REST routes router.Mount("/users", controllers.UserRouter()) diff --git a/services/auth/sucess_login.go b/services/auth/sucess_login.go index 37fc73e..524d6a6 100644 --- a/services/auth/sucess_login.go +++ b/services/auth/sucess_login.go @@ -3,6 +3,8 @@ package authService import ( "context" "encoding/hex" + "net/http" + "os" "git.farahty.com/nimer/go-mongo/app" @@ -61,6 +63,21 @@ func successLogin(ctx context.Context, user *models.User) (*models.LoginResponse return nil, err } + w, err := app.WriterFor(ctx) + + if err != nil { + return nil, err + } + + http.SetCookie(*w, &http.Cookie{ + Name: "access_token", + Value: *accessToken, + Path: "/", + HttpOnly: true, + Secure: true, + SameSite: http.SameSiteLaxMode, + }) + return &models.LoginResponse{ AccessToken: *accessToken, RefreshToken: *refreshToken,