diff --git a/app/helpers.go b/app/helpers.go index 10331f3..2ab7cf7 100644 --- a/app/helpers.go +++ b/app/helpers.go @@ -2,6 +2,7 @@ package app import ( "fmt" + "net/http" "os" "strings" @@ -10,7 +11,9 @@ import ( "github.com/mitchellh/mapstructure" ) -func getTokenFromHeader(authHeader string) (string, error) { +func getTokenFromHeader(r *http.Request) (string, error) { + + authHeader := r.Header.Get("Authorization") if authHeader == "" { return "", fmt.Errorf("there is no authorization header provided") @@ -25,6 +28,17 @@ func getTokenFromHeader(authHeader string) (string, error) { return strings.TrimSpace(authSlice[1]), nil } +func getTokenFromCookie(r *http.Request) (string, error) { + + cookie, err := r.Cookie("app-access-token") + + if err != nil { + return "", fmt.Errorf("there is no authorization cookie provided") + } + + return strings.TrimSpace(cookie.Value), nil +} + func getUserFromToken(tokenString string) (*models.UserJWT, error) { token, err := jwt.Parse(tokenString, func(t *jwt.Token) (interface{}, error) { diff --git a/app/middlewares.go b/app/middlewares.go index 6bacc9f..c74d4de 100644 --- a/app/middlewares.go +++ b/app/middlewares.go @@ -20,8 +20,8 @@ func ExpiryMiddleware(ctx context.Context, next graphql.ResponseHandler) *graphq // AuthMiddleware parses JWT token and injects user context for HTTP requests func AuthMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - authHeader := r.Header.Get("Authorization") - tokenStr, err := getTokenFromHeader(authHeader) + + tokenStr, err := getTokenFromHeader(r) if err != nil { ctx := SetStatus(r.Context(), err.Error()) next.ServeHTTP(rw, r.WithContext(ctx)) diff --git a/helpers/cache.go b/helpers/cache.go new file mode 100644 index 0000000..e2875f8 --- /dev/null +++ b/helpers/cache.go @@ -0,0 +1,33 @@ +package helpers + +import ( + "context" + "log" + "time" + + "github.com/redis/go-redis/v9" +) + +type Cache struct { + client redis.UniversalClient + ttl time.Duration + prefix string +} + +func NewCache(client redis.UniversalClient, ttl time.Duration, prefix string) (*Cache, error) { + return &Cache{client: client, ttl: ttl, prefix: prefix}, nil +} + +func (c *Cache) Add(ctx context.Context, key string, value string) { + log.Printf("Adding key %s to cache with value %s", key, value) + c.client.Set(ctx, c.prefix+key, value, c.ttl) +} + +func (c *Cache) Get(ctx context.Context, key string) (string, bool) { + s, err := c.client.Get(ctx, c.prefix+key).Result() + if err != nil { + return "", false + } + log.Printf("Retrieved key %s from cache with value %s", key, s) + return s, true +} diff --git a/models/todo.go b/models/todo.go index ada9eca..18a769f 100644 --- a/models/todo.go +++ b/models/todo.go @@ -1,7 +1,5 @@ package models -import "go.mongodb.org/mongo-driver/bson/primitive" - func (m Todo) Validate() []error { errors := []error{} @@ -13,10 +11,6 @@ func (m Todo) Validate() []error { return errors } -func (m Todo) getId() primitive.ObjectID { - return m.ID -} - func (t *Todo) validate() []error { return nil diff --git a/router.go b/router.go index 03d495f..6e6e831 100644 --- a/router.go +++ b/router.go @@ -1,6 +1,7 @@ package main import ( + "log" "net/http" "os" "time" @@ -9,6 +10,7 @@ import ( "git.farahty.com/nimer/go-mongo/controllers" "git.farahty.com/nimer/go-mongo/directives" "git.farahty.com/nimer/go-mongo/generated" + "git.farahty.com/nimer/go-mongo/helpers" "github.com/99designs/gqlgen/graphql/handler" "github.com/99designs/gqlgen/graphql/handler/extension" "github.com/99designs/gqlgen/graphql/handler/transport" @@ -50,6 +52,13 @@ func createRouter(graphqlServer http.Handler) chi.Router { } func createGraphqlServer(redisClient *redis.Client) http.Handler { + + cache, err := helpers.NewCache(redisClient, 24*time.Hour, "apq:") + + if err != nil { + log.Fatalf("cannot create APQ redis cache: %v", err) + } + // Setup gqlgen with resolvers and Redis client schema := generated.Config{ Resolvers: &resolvers.Resolver{ @@ -78,14 +87,7 @@ func createGraphqlServer(redisClient *redis.Client) http.Handler { srv.AddTransport(transport.POST{}) srv.AddTransport(transport.MultipartForm{}) - // Optional: Enable persisted queries or caching - // srv.Use(extension.AutomaticPersistedQuery{ - // Cache: lru.New(100), - // }) - // srv.SetQueryCache(lru.New(1000)) - - // Enable introspection for Playground - srv.Use(extension.Introspection{}) + srv.Use(extension.AutomaticPersistedQuery{Cache: cache}) // Apply global middleware srv.AroundRootFields(app.RootFieldsAuthorizer) // Check for @auth at root fields