From fe762797067ebd78c9ca80f91d5703e1d86361e7 Mon Sep 17 00:00:00 2001 From: Nimer Farahty Date: Sun, 8 Jun 2025 16:21:29 +0300 Subject: [PATCH] structure enchantment --- app/app-context.go | 26 +++- app/auth.go | 15 +- app/config.go | 90 +++++++++++ app/database.go | 5 +- app/loaders.go | 17 -- app/middlewares.go | 32 +++- app/redis-client.go | 24 +++ directives/auth.go | 2 +- directives/has-role.go | 20 +++ generated/generated.go | 256 +++++++++++++------------------ gql/auth.gql | 7 + gql/base.gql | 16 ++ gql/user.gql | 6 +- gqlgen.yml | 2 +- main.go | 36 ++--- models/models_gen.go | 59 ++++++- resolvers/auth.resolvers.go | 4 +- resolvers/resolver.go | 11 +- resolvers/todo.resolvers.go | 10 +- router.go | 18 +-- services/auth/auth.go | 38 ++++- services/auth/create_token.go | 29 ---- services/auth/sucess_login.go | 44 +++--- {helpers => utils}/cache.go | 2 +- {helpers => utils}/encrypt.go | 2 +- {app => utils}/scalers.go | 4 +- app/helpers.go => utils/token.go | 39 +++-- 27 files changed, 511 insertions(+), 303 deletions(-) create mode 100644 app/config.go create mode 100644 app/redis-client.go create mode 100644 directives/has-role.go delete mode 100644 services/auth/create_token.go rename {helpers => utils}/cache.go (97%) rename {helpers => utils}/encrypt.go (98%) rename {app => utils}/scalers.go (84%) rename app/helpers.go => utils/token.go (56%) diff --git a/app/app-context.go b/app/app-context.go index 71f8e8a..1001607 100644 --- a/app/app-context.go +++ b/app/app-context.go @@ -16,15 +16,31 @@ var ( UserKey = &contextKey{"user"} StatusKey = &contextKey{"status"} ExpiryKey = &contextKey{"expiry"} - LoadersKey = &contextKey{"dataloaders"} + LoadersKey = &contextKey{"dataLoaders"} WriterKye = &contextKey{"writer"} + RequestKey = &contextKey{"request"} ) -func WriterFor(ctx context.Context) (*http.ResponseWriter, error) { - if writer, ok := ctx.Value(WriterKye).(*http.ResponseWriter); ok { - return writer, nil +// LoaderFor retrieves the dataLoaders from context +func LoaderFor(ctx context.Context) *Loaders { + if loaders, ok := ctx.Value(LoadersKey).(*Loaders); ok { + return loaders } - return nil, fmt.Errorf("no writer found in context") + panic("dataloader not found in context") +} + +func WriterFor(ctx context.Context) *http.ResponseWriter { + if writer, ok := ctx.Value(WriterKye).(*http.ResponseWriter); ok { + return writer + } + panic("no writer found in context") +} + +func RequestFor(ctx context.Context) *http.Request { + if req, ok := ctx.Value(RequestKey).(*http.Request); ok { + return req + } + panic("no request found in context") } // Retrieves the current user from the context diff --git a/app/auth.go b/app/auth.go index 1746d64..ae3a3ed 100644 --- a/app/auth.go +++ b/app/auth.go @@ -5,9 +5,9 @@ import ( "errors" "fmt" "log" - "os" "git.farahty.com/nimer/go-mongo/models" + "git.farahty.com/nimer/go-mongo/utils" "github.com/99designs/gqlgen/graphql" "github.com/99designs/gqlgen/graphql/handler/transport" "github.com/casbin/casbin/v2" @@ -22,8 +22,7 @@ var ( ) func LoadAuthorizer(ctx context.Context) error { - a, err := mongodbadapter.NewAdapterWithClientOption(options.Client().ApplyURI(os.Getenv("MONGO_URI")), os.Getenv("MONGO_DB")) - + a, err := mongodbadapter.NewAdapterWithClientOption(options.Client().ApplyURI(Config.MongoURI), Config.MongoDB) if err != nil { return err } @@ -32,18 +31,15 @@ func LoadAuthorizer(ctx context.Context) error { return err } else { Authorizer = enforcer - } a.AddPolicy("mongodb", "p", []string{"*", "login", "mutation"}) a.AddPolicy("mongodb", "p", []string{"todos-admin", "createTodo", "mutation"}) a.AddPolicy("mongodb", "p", []string{"todos-admin", "todos", "query"}) a.AddPolicy("mongodb", "p", []string{"todos-admin", "todo", "query"}) - a.AddPolicy("mongodb", "p", []string{"category-admin", "createCategory", "mutation"}) a.AddPolicy("mongodb", "p", []string{"category-admin", "categories", "query"}) a.AddPolicy("mongodb", "p", []string{"category-admin", "category", "query"}) - a.AddPolicy("mongodb", "p", []string{"users-admin", "users", "query"}) a.AddPolicy("mongodb", "p", []string{"users-admin", "createUser", "mutation"}) @@ -51,8 +47,8 @@ func LoadAuthorizer(ctx context.Context) error { Authorizer.AddGroupingPolicy("admin", "todos-admin") Authorizer.AddGroupingPolicy("admin", "category-admin") - email := os.Getenv("ADMIN_EMAIL") - password, _ := models.MakeHash(os.Getenv("ADMIN_PASSWORD")) + email := Config.AdminEmail + password, _ := models.MakeHash(Config.AdminPass) if _, err := FindOne[models.User](ctx, "users", bson.M{"email": email}); err != nil { log.Println("Creating admin user") @@ -67,7 +63,6 @@ func LoadAuthorizer(ctx context.Context) error { } Authorizer.AddRoleForUser(admin.ID.Hex(), "admin") - } return nil @@ -82,7 +77,7 @@ func AuthorizeWebSocket(ctx context.Context, initPayload transport.InitPayload) return ctx, &initPayload, nil } - user, err := getUserFromToken(token) + user, err := utils.GetUserFromToken[models.UserJWT](token, Config.AccessSecret) if err != nil { ctx = SetStatus(ctx, err.Error()) diff --git a/app/config.go b/app/config.go new file mode 100644 index 0000000..fb885ae --- /dev/null +++ b/app/config.go @@ -0,0 +1,90 @@ +package app + +import ( + "fmt" + "log" + "os" + + "github.com/joho/godotenv" +) + +var ( + Config *ConfigStruct +) + +type ConfigStruct struct { + Port string + MongoURI string + MongoDB string + RedisHost string + RedisPort string + RedisPass string + AdminEmail string + AdminPass string + RefreshSecret string + RefreshExpiry string + AccessSecret string + AccessExpiry string + Env string +} + +// LoadConfig loads environment variables into Config struct and validates required fields. +func LoadConfig() error { + if _, exists := os.LookupEnv("MONGO_URI"); !exists { + if err := godotenv.Load(); err != nil { + log.Fatal("🔴 Failed to load .env file: ", err) + } + } + + Config = &ConfigStruct{ + Port: getEnv("PORT", "8080"), + MongoURI: os.Getenv("MONGO_URI"), + MongoDB: getEnv("MONGO_DB", "test"), + RedisHost: os.Getenv("REDIS_HOST"), + RedisPort: os.Getenv("REDIS_PORT"), + RedisPass: os.Getenv("REDIS_PASSWORD"), + AdminEmail: os.Getenv("ADMIN_EMAIL"), + AdminPass: os.Getenv("ADMIN_PASSWORD"), + RefreshSecret: os.Getenv("REFRESH_SECRET"), + RefreshExpiry: os.Getenv("REFRESH_EXPIRY"), + AccessSecret: os.Getenv("ACCESS_SECRET"), + AccessExpiry: os.Getenv("ACCESS_EXPIRY"), + Env: os.Getenv("ENV"), + } + + missing := "" + if Config.Port == "" { + missing += "PORT " + } + if Config.MongoURI == "" { + missing += "MONGO_URI " + } + if Config.RedisHost == "" { + missing += "REDIS_HOST " + } + if Config.RedisPort == "" { + missing += "REDIS_PORT " + } + if Config.AccessSecret == "" { + missing += "ACCESS_SECRET " + } + if Config.RefreshSecret == "" { + missing += "REFRESH_SECRET " + } + if missing != "" { + return fmt.Errorf("missing required environment variables: %s", missing) + } + return nil +} + +// GetConfig returns the global config +func GetConfig() *ConfigStruct { + return Config +} + +func getEnv(key, def string) string { + if v := os.Getenv(key); v != "" { + return v + } + return def +} diff --git a/app/database.go b/app/database.go index 8ad92e8..c4fd70f 100644 --- a/app/database.go +++ b/app/database.go @@ -3,7 +3,6 @@ package app import ( "context" "fmt" - "os" "time" "github.com/fatih/color" @@ -22,7 +21,7 @@ func Connect() (context.CancelFunc, error) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) var err error - Mongo, err = mongo.Connect(ctx, options.Client().ApplyURI(os.Getenv("MONGO_URI"))) + Mongo, err = mongo.Connect(ctx, options.Client().ApplyURI(Config.MongoURI)) if err != nil { color.Red("❌ Database connection failed\n" + err.Error()) return cancel, err @@ -45,7 +44,7 @@ func Disconnect(ctx context.Context) error { // Collection returns a MongoDB collection func Collection(name string) *mongo.Collection { - return Mongo.Database(os.Getenv("MONGO_DB")).Collection(name) + return Mongo.Database(Config.MongoDB).Collection(name) } // Find returns all matching documents from a collection diff --git a/app/loaders.go b/app/loaders.go index 226f39e..bf9b54f 100644 --- a/app/loaders.go +++ b/app/loaders.go @@ -3,7 +3,6 @@ package app import ( "context" "fmt" - "net/http" "git.farahty.com/nimer/go-mongo/models" "github.com/graph-gophers/dataloader" @@ -27,22 +26,6 @@ func NewLoaders() *Loaders { } } -// Middleware injects dataloaders into context for each HTTP request -func LoaderMiddleware(loaders *Loaders, next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctxWithLoaders := context.WithValue(r.Context(), LoadersKey, loaders) - next.ServeHTTP(w, r.WithContext(ctxWithLoaders)) - }) -} - -// LoaderFor retrieves the dataloaders from context -func LoaderFor(ctx context.Context) *Loaders { - if loaders, ok := ctx.Value(LoadersKey).(*Loaders); ok { - return loaders - } - panic("dataloader not found in context") -} - // CreateBatch creates a generic batched loader for a MongoDB collection func CreateBatch[T models.Identifiable](coll string) dataloader.BatchFunc { return func(ctx context.Context, keys dataloader.Keys) []*dataloader.Result { diff --git a/app/middlewares.go b/app/middlewares.go index 84cdd00..7e8b91e 100644 --- a/app/middlewares.go +++ b/app/middlewares.go @@ -6,12 +6,30 @@ import ( "net/http" + "git.farahty.com/nimer/go-mongo/models" + "git.farahty.com/nimer/go-mongo/utils" "github.com/99designs/gqlgen/graphql" "github.com/golang-jwt/jwt/v4" ) +// Middleware injects dataloaders into context for each HTTP request +func LoaderMiddleware(loaders *Loaders, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctxWithLoaders := context.WithValue(r.Context(), LoadersKey, loaders) + next.ServeHTTP(w, r.WithContext(ctxWithLoaders)) + }) +} + // ExpiryMiddleware checks for expired tokens in GraphQL resolvers func ExpiryMiddleware(ctx context.Context, next graphql.ResponseHandler) *graphql.Response { + + object := graphql.GetRootFieldContext(ctx) + if object != nil { + println("ExpiryMiddleware: checking token expiry in GraphQL resolver") + } else { + println("ExpiryMiddleware: checking token expiry in HTTP request") + } + if IsTokenExpired(ctx) { return graphql.ErrorResponse(ctx, "token expired") } @@ -26,12 +44,20 @@ func WriterMiddleware(next http.Handler) http.Handler { }) } +// add request to context for GraphQL resolvers +func RequestMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + ctx := context.WithValue(r.Context(), RequestKey, r) + next.ServeHTTP(rw, r.WithContext(ctx)) + }) +} + // AuthMiddleware parses JWT token and injects user context for HTTP requests func AuthMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - headerToken, headerErr := getTokenFromHeader(r) - cookieToken, cookieErr := getTokenFromCookie(r) + headerToken, headerErr := utils.GetTokenFromHeader(r) + cookieToken, cookieErr := utils.GetTokenFromCookie(r) if headerErr != nil && cookieErr != nil { ctx := SetStatus(r.Context(), headerErr.Error()) @@ -44,7 +70,7 @@ func AuthMiddleware(next http.Handler) http.Handler { token = cookieToken } - user, err := getUserFromToken(token) + user, err := utils.GetUserFromToken[models.UserJWT](token, Config.AccessSecret) ctx := r.Context() if err != nil { diff --git a/app/redis-client.go b/app/redis-client.go new file mode 100644 index 0000000..9a31e3c --- /dev/null +++ b/app/redis-client.go @@ -0,0 +1,24 @@ +package app + +import ( + "context" + + "github.com/redis/go-redis/v9" +) + +var ( + RedisClient redis.UniversalClient +) + +// InitRedis initializes the Redis client and assigns it to RedisClient +func InitRedis(ctx context.Context) error { + RedisClient = redis.NewClient(&redis.Options{ + Addr: Config.RedisHost + ":" + Config.RedisPort, + Password: Config.RedisPass, + DB: 0, + }) + if _, err := RedisClient.Ping(ctx).Result(); err != nil { + return err + } + return nil +} diff --git a/directives/auth.go b/directives/auth.go index 5e89ab9..506ca8c 100644 --- a/directives/auth.go +++ b/directives/auth.go @@ -8,7 +8,7 @@ import ( "github.com/99designs/gqlgen/graphql" ) -func Auth(ctx context.Context, obj interface{}, next graphql.Resolver) (res interface{}, err error) { +func Auth(ctx context.Context, obj any, next graphql.Resolver) (res any, err error) { if _, err := app.CurrentUser(ctx); err != nil { return nil, fmt.Errorf("access denied, %s", err.Error()) diff --git a/directives/has-role.go b/directives/has-role.go new file mode 100644 index 0000000..ad574e5 --- /dev/null +++ b/directives/has-role.go @@ -0,0 +1,20 @@ +package directives + +import ( + "context" + "fmt" + + "git.farahty.com/nimer/go-mongo/app" + "git.farahty.com/nimer/go-mongo/models" + "github.com/99designs/gqlgen/graphql" +) + +func HasRole(ctx context.Context, obj any, next graphql.Resolver, role models.Role) (res any, err error) { + + if _, err := app.CurrentUser(ctx); err != nil { + return nil, fmt.Errorf("access denied, %s", err.Error()) + } + + return next(ctx) + +} diff --git a/generated/generated.go b/generated/generated.go index b68c014..766e743 100644 --- a/generated/generated.go +++ b/generated/generated.go @@ -13,8 +13,8 @@ import ( "sync/atomic" "time" - "git.farahty.com/nimer/go-mongo/app" "git.farahty.com/nimer/go-mongo/models" + "git.farahty.com/nimer/go-mongo/utils" "github.com/99designs/gqlgen/graphql" "github.com/99designs/gqlgen/graphql/introspection" gqlparser "github.com/vektah/gqlparser/v2" @@ -51,7 +51,8 @@ type ResolverRoot interface { } type DirectiveRoot struct { - Auth func(ctx context.Context, obj any, next graphql.Resolver) (res any, err error) + Auth func(ctx context.Context, obj any, next graphql.Resolver) (res any, err error) + HasRole func(ctx context.Context, obj any, next graphql.Resolver, role models.Role) (res any, err error) } type ComplexityRoot struct { @@ -119,10 +120,8 @@ type ComplexityRoot struct { User struct { Email func(childComplexity int) int ID func(childComplexity int) int - Password func(childComplexity int) int Phone func(childComplexity int) int Status func(childComplexity int) int - Token func(childComplexity int) int Type func(childComplexity int) int Verified func(childComplexity int) int } @@ -507,13 +506,6 @@ func (e *executableSchema) Complexity(ctx context.Context, typeName, field strin return e.complexity.User.ID(childComplexity), true - case "User.password": - if e.complexity.User.Password == nil { - break - } - - return e.complexity.User.Password(childComplexity), true - case "User.phone": if e.complexity.User.Phone == nil { break @@ -528,13 +520,6 @@ func (e *executableSchema) Complexity(ctx context.Context, typeName, field strin return e.complexity.User.Status(childComplexity), true - case "User.token": - if e.complexity.User.Token == nil { - break - } - - return e.complexity.User.Token(childComplexity), true - case "User.type": if e.complexity.User.Type == nil { break @@ -678,6 +663,13 @@ func (ec *executionContext) introspectType(name string) (*introspection.Type, er var sources = []*ast.Source{ {Name: "../gql/auth.gql", Input: `directive @auth on FIELD_DEFINITION +directive @hasRole(role: Role!) on FIELD_DEFINITION + +enum Role { + ADMIN + USER +} + type LoginResponse { user: User! accessToken: String! @@ -693,15 +685,31 @@ extend type Mutation { login(input: LoginInput!): LoginResponse! } `, BuiltIn: false}, - {Name: "../gql/base.gql", Input: `directive @goField( + {Name: "../gql/base.gql", Input: `directive @goModel( + model: String + models: [String!] + forceGenerate: Boolean +) on OBJECT | INPUT_OBJECT | SCALAR | ENUM | INTERFACE | UNION + +directive @goField( forceResolver: Boolean name: String + omittable: Boolean + type: String ) on INPUT_FIELD_DEFINITION | FIELD_DEFINITION + directive @goTag( key: String! value: String ) on INPUT_FIELD_DEFINITION | FIELD_DEFINITION +directive @goExtraField( + name: String + type: String! + overrideTags: String + description: String +) repeatable on OBJECT | INPUT_OBJECT + scalar Time interface Base { @@ -813,15 +821,15 @@ extend type Subscription { onTodo: Todo } `, BuiltIn: false}, - {Name: "../gql/user.gql", Input: `type User { + {Name: "../gql/user.gql", Input: `type User + @goExtraField(name: "Password", type: "*string") + @goExtraField(name: "Token", type: "*string") { id: ID! phone: String email: String type: String status: String verified: Boolean @goField(forceResolver: true) - password: String - token: String } input CreateUserInput { @@ -846,6 +854,34 @@ var parsedSchema = gqlparser.MustLoadSchema(sources...) // region ***************************** args.gotpl ***************************** +func (ec *executionContext) dir_hasRole_args(ctx context.Context, rawArgs map[string]any) (map[string]any, error) { + var err error + args := map[string]any{} + arg0, err := ec.dir_hasRole_argsRole(ctx, rawArgs) + if err != nil { + return nil, err + } + args["role"] = arg0 + return args, nil +} +func (ec *executionContext) dir_hasRole_argsRole( + ctx context.Context, + rawArgs map[string]any, +) (models.Role, error) { + if _, ok := rawArgs["role"]; !ok { + var zeroVal models.Role + return zeroVal, nil + } + + ctx = graphql.WithPathContext(ctx, graphql.NewPathWithField("role")) + if tmp, ok := rawArgs["role"]; ok { + return ec.unmarshalNRole2gitᚗfarahtyᚗcomᚋnimerᚋgoᚑmongoᚋmodelsᚐRole(ctx, tmp) + } + + var zeroVal models.Role + return zeroVal, nil +} + func (ec *executionContext) field_Mutation_createCategory_args(ctx context.Context, rawArgs map[string]any) (map[string]any, error) { var err error args := map[string]any{} @@ -1556,10 +1592,6 @@ func (ec *executionContext) fieldContext_Category_createdBy(_ context.Context, f return ec.fieldContext_User_status(ctx, field) case "verified": return ec.fieldContext_User_verified(ctx, field) - case "password": - return ec.fieldContext_User_password(ctx, field) - case "token": - return ec.fieldContext_User_token(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type User", field.Name) }, @@ -1662,10 +1694,6 @@ func (ec *executionContext) fieldContext_Category_updatedBy(_ context.Context, f return ec.fieldContext_User_status(ctx, field) case "verified": return ec.fieldContext_User_verified(ctx, field) - case "password": - return ec.fieldContext_User_password(ctx, field) - case "token": - return ec.fieldContext_User_token(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type User", field.Name) }, @@ -1765,10 +1793,6 @@ func (ec *executionContext) fieldContext_Category_owner(_ context.Context, field return ec.fieldContext_User_status(ctx, field) case "verified": return ec.fieldContext_User_verified(ctx, field) - case "password": - return ec.fieldContext_User_password(ctx, field) - case "token": - return ec.fieldContext_User_token(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type User", field.Name) }, @@ -1871,10 +1895,6 @@ func (ec *executionContext) fieldContext_LoginResponse_user(_ context.Context, f return ec.fieldContext_User_status(ctx, field) case "verified": return ec.fieldContext_User_verified(ctx, field) - case "password": - return ec.fieldContext_User_password(ctx, field) - case "token": - return ec.fieldContext_User_token(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type User", field.Name) }, @@ -2240,10 +2260,6 @@ func (ec *executionContext) fieldContext_Mutation_createUser(ctx context.Context return ec.fieldContext_User_status(ctx, field) case "verified": return ec.fieldContext_User_verified(ctx, field) - case "password": - return ec.fieldContext_User_password(ctx, field) - case "token": - return ec.fieldContext_User_token(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type User", field.Name) }, @@ -2606,10 +2622,6 @@ func (ec *executionContext) fieldContext_Query_users(_ context.Context, field gr return ec.fieldContext_User_status(ctx, field) case "verified": return ec.fieldContext_User_verified(ctx, field) - case "password": - return ec.fieldContext_User_password(ctx, field) - case "token": - return ec.fieldContext_User_token(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type User", field.Name) }, @@ -3092,10 +3104,6 @@ func (ec *executionContext) fieldContext_Todo_createdBy(_ context.Context, field return ec.fieldContext_User_status(ctx, field) case "verified": return ec.fieldContext_User_verified(ctx, field) - case "password": - return ec.fieldContext_User_password(ctx, field) - case "token": - return ec.fieldContext_User_token(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type User", field.Name) }, @@ -3198,10 +3206,6 @@ func (ec *executionContext) fieldContext_Todo_updatedBy(_ context.Context, field return ec.fieldContext_User_status(ctx, field) case "verified": return ec.fieldContext_User_verified(ctx, field) - case "password": - return ec.fieldContext_User_password(ctx, field) - case "token": - return ec.fieldContext_User_token(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type User", field.Name) }, @@ -3301,10 +3305,6 @@ func (ec *executionContext) fieldContext_Todo_owner(_ context.Context, field gra return ec.fieldContext_User_status(ctx, field) case "verified": return ec.fieldContext_User_verified(ctx, field) - case "password": - return ec.fieldContext_User_password(ctx, field) - case "token": - return ec.fieldContext_User_token(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type User", field.Name) }, @@ -3737,88 +3737,6 @@ func (ec *executionContext) fieldContext_User_verified(_ context.Context, field return fc, nil } -func (ec *executionContext) _User_password(ctx context.Context, field graphql.CollectedField, obj *models.User) (ret graphql.Marshaler) { - fc, err := ec.fieldContext_User_password(ctx, field) - if err != nil { - return graphql.Null - } - ctx = graphql.WithFieldContext(ctx, fc) - defer func() { - if r := recover(); r != nil { - ec.Error(ctx, ec.Recover(ctx, r)) - ret = graphql.Null - } - }() - resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (any, error) { - ctx = rctx // use context from middleware stack in children - return obj.Password, nil - }) - if err != nil { - ec.Error(ctx, err) - return graphql.Null - } - if resTmp == nil { - return graphql.Null - } - res := resTmp.(*string) - fc.Result = res - return ec.marshalOString2ᚖstring(ctx, field.Selections, res) -} - -func (ec *executionContext) fieldContext_User_password(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { - fc = &graphql.FieldContext{ - Object: "User", - Field: field, - IsMethod: false, - IsResolver: false, - Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { - return nil, errors.New("field of type String does not have child fields") - }, - } - return fc, nil -} - -func (ec *executionContext) _User_token(ctx context.Context, field graphql.CollectedField, obj *models.User) (ret graphql.Marshaler) { - fc, err := ec.fieldContext_User_token(ctx, field) - if err != nil { - return graphql.Null - } - ctx = graphql.WithFieldContext(ctx, fc) - defer func() { - if r := recover(); r != nil { - ec.Error(ctx, ec.Recover(ctx, r)) - ret = graphql.Null - } - }() - resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (any, error) { - ctx = rctx // use context from middleware stack in children - return obj.Token, nil - }) - if err != nil { - ec.Error(ctx, err) - return graphql.Null - } - if resTmp == nil { - return graphql.Null - } - res := resTmp.(*string) - fc.Result = res - return ec.marshalOString2ᚖstring(ctx, field.Selections, res) -} - -func (ec *executionContext) fieldContext_User_token(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { - fc = &graphql.FieldContext{ - Object: "User", - Field: field, - IsMethod: false, - IsResolver: false, - Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { - return nil, errors.New("field of type String does not have child fields") - }, - } - return fc, nil -} - func (ec *executionContext) ___Directive_name(ctx context.Context, field graphql.CollectedField, obj *introspection.Directive) (ret graphql.Marshaler) { fc, err := ec.fieldContext___Directive_name(ctx, field) if err != nil { @@ -6777,10 +6695,6 @@ func (ec *executionContext) _User(ctx context.Context, sel ast.SelectionSet, obj } out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) }) - case "password": - out.Values[i] = ec._User_password(ctx, field, obj) - case "token": - out.Values[i] = ec._User_token(ctx, field, obj) default: panic("unknown field " + strconv.Quote(field.Name)) } @@ -7223,13 +7137,13 @@ func (ec *executionContext) unmarshalNCreateUserInput2gitᚗfarahtyᚗcomᚋnime } func (ec *executionContext) unmarshalNID2goᚗmongodbᚗorgᚋmongoᚑdriverᚋbsonᚋprimitiveᚐObjectID(ctx context.Context, v any) (primitive.ObjectID, error) { - res, err := app.UnmarshalObjectID(v) + res, err := utils.UnmarshalObjectID(v) return res, graphql.ErrorOnPath(ctx, err) } func (ec *executionContext) marshalNID2goᚗmongodbᚗorgᚋmongoᚑdriverᚋbsonᚋprimitiveᚐObjectID(ctx context.Context, sel ast.SelectionSet, v primitive.ObjectID) graphql.Marshaler { _ = sel - res := app.MarshalObjectID(v) + res := utils.MarshalObjectID(v) if res == graphql.Null { if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { ec.Errorf(ctx, "the requested element is null which the schema does not allow") @@ -7257,6 +7171,16 @@ func (ec *executionContext) marshalNLoginResponse2ᚖgitᚗfarahtyᚗcomᚋnimer return ec._LoginResponse(ctx, sel, v) } +func (ec *executionContext) unmarshalNRole2gitᚗfarahtyᚗcomᚋnimerᚋgoᚑmongoᚋmodelsᚐRole(ctx context.Context, v any) (models.Role, error) { + var res models.Role + err := res.UnmarshalGQL(v) + return res, graphql.ErrorOnPath(ctx, err) +} + +func (ec *executionContext) marshalNRole2gitᚗfarahtyᚗcomᚋnimerᚋgoᚑmongoᚋmodelsᚐRole(ctx context.Context, sel ast.SelectionSet, v models.Role) graphql.Marshaler { + return v +} + func (ec *executionContext) unmarshalNString2string(ctx context.Context, v any) (string, error) { res, err := graphql.UnmarshalString(v) return res, graphql.ErrorOnPath(ctx, err) @@ -7671,7 +7595,7 @@ func (ec *executionContext) unmarshalOID2ᚖgoᚗmongodbᚗorgᚋmongoᚑdriver if v == nil { return nil, nil } - res, err := app.UnmarshalObjectID(v) + res, err := utils.UnmarshalObjectID(v) return &res, graphql.ErrorOnPath(ctx, err) } @@ -7681,7 +7605,7 @@ func (ec *executionContext) marshalOID2ᚖgoᚗmongodbᚗorgᚋmongoᚑdriverᚋ } _ = sel _ = ctx - res := app.MarshalObjectID(*v) + res := utils.MarshalObjectID(*v) return res } @@ -7701,6 +7625,42 @@ func (ec *executionContext) marshalOStatus2ᚖgitᚗfarahtyᚗcomᚋnimerᚋgo return v } +func (ec *executionContext) unmarshalOString2ᚕstringᚄ(ctx context.Context, v any) ([]string, error) { + if v == nil { + return nil, nil + } + var vSlice []any + vSlice = graphql.CoerceList(v) + var err error + res := make([]string, len(vSlice)) + for i := range vSlice { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithIndex(i)) + res[i], err = ec.unmarshalNString2string(ctx, vSlice[i]) + if err != nil { + return nil, err + } + } + return res, nil +} + +func (ec *executionContext) marshalOString2ᚕstringᚄ(ctx context.Context, sel ast.SelectionSet, v []string) graphql.Marshaler { + if v == nil { + return graphql.Null + } + ret := make(graphql.Array, len(v)) + for i := range v { + ret[i] = ec.marshalNString2string(ctx, sel, v[i]) + } + + for _, e := range ret { + if e == graphql.Null { + return graphql.Null + } + } + + return ret +} + func (ec *executionContext) unmarshalOString2ᚖstring(ctx context.Context, v any) (*string, error) { if v == nil { return nil, nil diff --git a/gql/auth.gql b/gql/auth.gql index 30ae5c5..19dc16a 100644 --- a/gql/auth.gql +++ b/gql/auth.gql @@ -1,5 +1,12 @@ directive @auth on FIELD_DEFINITION +directive @hasRole(role: Role!) on FIELD_DEFINITION + +enum Role { + ADMIN + USER +} + type LoginResponse { user: User! accessToken: String! diff --git a/gql/base.gql b/gql/base.gql index 76a9af1..8f823c4 100644 --- a/gql/base.gql +++ b/gql/base.gql @@ -1,12 +1,28 @@ +directive @goModel( + model: String + models: [String!] + forceGenerate: Boolean +) on OBJECT | INPUT_OBJECT | SCALAR | ENUM | INTERFACE | UNION + directive @goField( forceResolver: Boolean name: String + omittable: Boolean + type: String ) on INPUT_FIELD_DEFINITION | FIELD_DEFINITION + directive @goTag( key: String! value: String ) on INPUT_FIELD_DEFINITION | FIELD_DEFINITION +directive @goExtraField( + name: String + type: String! + overrideTags: String + description: String +) repeatable on OBJECT | INPUT_OBJECT + scalar Time interface Base { diff --git a/gql/user.gql b/gql/user.gql index 7bae2d4..ee83677 100644 --- a/gql/user.gql +++ b/gql/user.gql @@ -1,12 +1,12 @@ -type User { +type User + @goExtraField(name: "Password", type: "*string") + @goExtraField(name: "Token", type: "*string") { id: ID! phone: String email: String type: String status: String verified: Boolean @goField(forceResolver: true) - password: String - token: String } input CreateUserInput { diff --git a/gqlgen.yml b/gqlgen.yml index efcaeb2..6e2850a 100644 --- a/gqlgen.yml +++ b/gqlgen.yml @@ -27,7 +27,7 @@ autobind: models: ID: model: - - git.farahty.com/nimer/go-mongo/app.ObjectID + - git.farahty.com/nimer/go-mongo/utils.ObjectID - github.com/99designs/gqlgen/graphql.ID Int: model: diff --git a/main.go b/main.go index e86ad8c..15943e1 100644 --- a/main.go +++ b/main.go @@ -11,8 +11,6 @@ import ( "git.farahty.com/nimer/go-mongo/app" "github.com/fatih/color" - "github.com/joho/godotenv" - "github.com/redis/go-redis/v9" ) func main() { @@ -30,23 +28,12 @@ func main() { color.Yellow("🚀 Starting server ...\n") - // Load .env if needed - if _, exists := os.LookupEnv("MONGO_URI"); !exists { - if err := godotenv.Load(); err != nil { - log.Fatal("🔴 Failed to load .env file: ", err) - } - color.Green("✅ .env file loaded\n") + // Load and validate config + err := app.LoadConfig() + if err != nil { + log.Fatal("🔴 Config error: ", err) } - // Validate environment variables - requiredEnvs := []string{"PORT", "MONGO_URI", "REDIS_HOST", "REDIS_PORT"} - for _, key := range requiredEnvs { - if os.Getenv(key) == "" { - log.Fatalf("🔴 Required environment variable %s is missing", key) - } - } - port := os.Getenv("PORT") - // Connect to Mongo dbCancel, err := app.Connect() if err != nil { @@ -68,26 +55,21 @@ func main() { color.Green("✅ Authorization policies loaded successfully\n") // Redis - redisClient := redis.NewClient(&redis.Options{ - Addr: os.Getenv("REDIS_HOST") + ":" + os.Getenv("REDIS_PORT"), - Password: os.Getenv("REDIS_PASSWORD"), - DB: 0, - }) - if _, err := redisClient.Ping(ctx).Result(); err != nil { + if err := app.InitRedis(ctx); err != nil { log.Fatal("🔴 Redis connection error: ", err) } defer func() { color.Red("❌ Closing Redis connection...\n") - _ = redisClient.Close() + _ = app.RedisClient.Close() }() color.Green("✅ Connected to Redis cache successfully\n") // Create GraphQL server - graphqlServer := createGraphqlServer(redisClient) + graphqlServer := createGraphqlServer() // Start HTTP server server := &http.Server{ - Addr: ":" + port, + Addr: ":" + app.Config.Port, Handler: createRouter(graphqlServer), ReadTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second, @@ -95,7 +77,7 @@ func main() { } go func() { - color.Green("🌐 Server listening at http://localhost:%s\n", port) + color.Green("🌐 Server listening at http://localhost:%s\n", app.Config.Port) if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Fatalf("🔴 Server failed: %v", err) } diff --git a/models/models_gen.go b/models/models_gen.go index 9605cd3..5d0fb02 100644 --- a/models/models_gen.go +++ b/models/models_gen.go @@ -132,8 +132,63 @@ type User struct { Type *string `json:"type,omitempty" bson:"type,omitempty"` Status *string `json:"status,omitempty" bson:"status,omitempty"` Verified *bool `json:"verified,omitempty" bson:"verified,omitempty"` - Password *string `json:"password,omitempty" bson:"password,omitempty"` - Token *string `json:"token,omitempty" bson:"token,omitempty"` + Password *string `json:"-" bson:"Password,omitempty"` + Token *string `json:"-" bson:"Token,omitempty"` +} + +type Role string + +const ( + RoleAdmin Role = "ADMIN" + RoleUser Role = "USER" +) + +var AllRole = []Role{ + RoleAdmin, + RoleUser, +} + +func (e Role) IsValid() bool { + switch e { + case RoleAdmin, RoleUser: + return true + } + return false +} + +func (e Role) String() string { + return string(e) +} + +func (e *Role) UnmarshalGQL(v any) error { + str, ok := v.(string) + if !ok { + return fmt.Errorf("enums must be strings") + } + + *e = Role(str) + if !e.IsValid() { + return fmt.Errorf("%s is not a valid Role", str) + } + return nil +} + +func (e Role) MarshalGQL(w io.Writer) { + fmt.Fprint(w, strconv.Quote(e.String())) +} + +func (e *Role) UnmarshalJSON(b []byte) error { + s, err := strconv.Unquote(string(b)) + if err != nil { + return err + } + return e.UnmarshalGQL(s) +} + +func (e Role) MarshalJSON() ([]byte, error) { + var buf bytes.Buffer + e.MarshalGQL(&buf) + return buf.Bytes(), nil } type Status string diff --git a/resolvers/auth.resolvers.go b/resolvers/auth.resolvers.go index 8cfd797..7cfa2fb 100644 --- a/resolvers/auth.resolvers.go +++ b/resolvers/auth.resolvers.go @@ -7,6 +7,7 @@ package resolvers import ( "context" + "git.farahty.com/nimer/go-mongo/app" "git.farahty.com/nimer/go-mongo/generated" "git.farahty.com/nimer/go-mongo/models" authService "git.farahty.com/nimer/go-mongo/services/auth" @@ -14,8 +15,7 @@ import ( // Login is the resolver for the login field. func (r *mutationResolver) Login(ctx context.Context, input models.LoginInput) (*models.LoginResponse, error) { - - return authService.Login(ctx, &input) + return authService.Login(ctx, &input, app.RedisClient) } // Mutation returns generated.MutationResolver implementation. diff --git a/resolvers/resolver.go b/resolvers/resolver.go index 9957529..d2fdee4 100644 --- a/resolvers/resolver.go +++ b/resolvers/resolver.go @@ -15,15 +15,19 @@ import ( // It serves as dependency injection for your app, add any dependencies you require here. type Resolver struct { - Redis *redis.Client + Redis redis.UniversalClient } -func Subscribe[T any](ctx context.Context, redis *redis.Client, event string) (<-chan *T, error) { +func Subscribe[T any](ctx context.Context, redisClient redis.UniversalClient, event string) (<-chan *T, error) { + client, ok := redisClient.(*redis.Client) + if !ok { + return nil, nil // or handle error + } clientChannel := make(chan *T, 1) go func() { - sub := redis.Subscribe(ctx, event) + sub := client.Subscribe(ctx, event) if _, err := sub.Receive(ctx); err != nil { return @@ -52,5 +56,4 @@ func Subscribe[T any](ctx context.Context, redis *redis.Client, event string) (< }() return clientChannel, nil - } diff --git a/resolvers/todo.resolvers.go b/resolvers/todo.resolvers.go index 50d4940..b5766f0 100644 --- a/resolvers/todo.resolvers.go +++ b/resolvers/todo.resolvers.go @@ -12,6 +12,7 @@ import ( "git.farahty.com/nimer/go-mongo/models" todoService "git.farahty.com/nimer/go-mongo/services/todo" userService "git.farahty.com/nimer/go-mongo/services/user" + redis "github.com/redis/go-redis/v9" "go.mongodb.org/mongo-driver/bson/primitive" ) @@ -24,7 +25,9 @@ func (r *mutationResolver) CreateTodo(ctx context.Context, input models.CreateTo } if objJson, err := json.Marshal(obj); err == nil { - r.Redis.Publish(ctx, "NEW_TODO_EVENT", objJson) + if client, ok := r.Redis.(*redis.Client); ok { + client.Publish(ctx, "NEW_TODO_EVENT", objJson) + } } return obj, nil @@ -42,7 +45,10 @@ func (r *queryResolver) Todo(ctx context.Context, id primitive.ObjectID) (*model // OnTodo is the resolver for the onTodo field. func (r *subscriptionResolver) OnTodo(ctx context.Context) (<-chan *models.Todo, error) { - return Subscribe[models.Todo](ctx, r.Redis, "NEW_TODO_EVENT") + if client, ok := r.Redis.(*redis.Client); ok { + return Subscribe[models.Todo](ctx, client, "NEW_TODO_EVENT") + } + return nil, nil } // CreatedBy is the resolver for the createdBy field. diff --git a/router.go b/router.go index 6f4f81c..e7cfb1f 100644 --- a/router.go +++ b/router.go @@ -3,14 +3,13 @@ package main import ( "log" "net/http" - "os" "time" "git.farahty.com/nimer/go-mongo/app" "git.farahty.com/nimer/go-mongo/controllers" "git.farahty.com/nimer/go-mongo/directives" "git.farahty.com/nimer/go-mongo/generated" - "git.farahty.com/nimer/go-mongo/helpers" + "git.farahty.com/nimer/go-mongo/utils" "github.com/99designs/gqlgen/graphql/handler" "github.com/99designs/gqlgen/graphql/handler/extension" "github.com/99designs/gqlgen/graphql/handler/transport" @@ -22,14 +21,13 @@ import ( "github.com/go-chi/httprate" "github.com/gorilla/websocket" - "github.com/redis/go-redis/v9" ) func createRouter(graphqlServer http.Handler) chi.Router { router := chi.NewRouter() // Apply middleware - if os.Getenv("ENV") == "production" { + if app.Config.Env == "production" { router.Use(httprate.LimitByIP(100, 1*time.Minute)) // 100 requests/minute/IP } @@ -42,6 +40,7 @@ func createRouter(graphqlServer http.Handler) chi.Router { router.Use(app.AuthMiddleware) router.Use(app.WriterMiddleware) + router.Use(app.RequestMiddleware) // REST routes router.Mount("/users", controllers.UserRouter()) @@ -53,10 +52,8 @@ func createRouter(graphqlServer http.Handler) chi.Router { return router } -func createGraphqlServer(redisClient *redis.Client) http.Handler { - - cache, err := helpers.NewCache(redisClient, 24*time.Hour, "apq:") - +func createGraphqlServer() http.Handler { + cache, err := utils.NewCache(app.RedisClient, 24*time.Hour, "apq:") if err != nil { log.Fatalf("cannot create APQ redis cache: %v", err) } @@ -64,7 +61,7 @@ func createGraphqlServer(redisClient *redis.Client) http.Handler { // Setup gqlgen with resolvers and Redis client schema := generated.Config{ Resolvers: &resolvers.Resolver{ - Redis: redisClient, + Redis: app.RedisClient, }, } @@ -91,6 +88,7 @@ func createGraphqlServer(redisClient *redis.Client) http.Handler { srv.Use(extension.AutomaticPersistedQuery{Cache: cache}) srv.Use(extension.Introspection{}) + // Apply global middleware srv.AroundRootFields(app.RootFieldsAuthorizer) // Check for @auth at root fields srv.AroundResponses(app.ExpiryMiddleware) // Token expiry validation @@ -101,4 +99,6 @@ func createGraphqlServer(redisClient *redis.Client) http.Handler { func mapDirectives(config *generated.Config) { config.Directives.Auth = directives.Auth + config.Directives.HasRole = directives.HasRole + } diff --git a/services/auth/auth.go b/services/auth/auth.go index fedf850..80b56a2 100644 --- a/services/auth/auth.go +++ b/services/auth/auth.go @@ -3,15 +3,23 @@ package authService import ( "context" "fmt" + "time" "git.farahty.com/nimer/go-mongo/app" "git.farahty.com/nimer/go-mongo/models" + "github.com/redis/go-redis/v9" "go.mongodb.org/mongo-driver/bson" ) -func Login(ctx context.Context, loginInput *models.LoginInput) (*models.LoginResponse, error) { +const maxAttempts = 5 +const lockoutDuration = 15 * time.Minute + +func Login(ctx context.Context, loginInput *models.LoginInput, redisClient redis.UniversalClient) (*models.LoginResponse, error) { + + if locked, _ := isLockedOut(loginInput.Identity, redisClient); locked { + return nil, fmt.Errorf("too many failed attempts, please try again later") + } - // todo : fix the security threats here filter := bson.D{ { Key: "$or", @@ -29,8 +37,32 @@ func Login(ctx context.Context, loginInput *models.LoginInput) (*models.LoginRes } if !user.CheckPassword(loginInput.Password) { - return nil, fmt.Errorf("incorrect password") + logFailedLoginAttempt(ctx, loginInput.Identity, redisClient) // optional + return nil, fmt.Errorf("invalid identity or password") + } + clearFailedAttempts(ctx, loginInput.Identity, redisClient) + return successLogin(ctx, user) } + +func isLockedOut(identity string, redisClient redis.UniversalClient) (bool, error) { + key := fmt.Sprintf("login:fail:%s", identity) + attempts, err := redisClient.Get(context.Background(), key).Int() + if err != nil && err != redis.Nil { + return false, err + } + return attempts >= maxAttempts, nil +} + +func logFailedLoginAttempt(ctx context.Context, identity string, redisClient redis.UniversalClient) { + key := fmt.Sprintf("login:fail:%s", identity) + redisClient.Incr(ctx, key) + redisClient.Expire(ctx, key, lockoutDuration) +} + +func clearFailedAttempts(ctx context.Context, identity string, redisClient redis.UniversalClient) { + key := fmt.Sprintf("login:fail:%s", identity) + redisClient.Del(ctx, key) +} diff --git a/services/auth/create_token.go b/services/auth/create_token.go deleted file mode 100644 index b6a1ff8..0000000 --- a/services/auth/create_token.go +++ /dev/null @@ -1,29 +0,0 @@ -package authService - -import ( - "time" - - "github.com/golang-jwt/jwt/v4" -) - -func createToken(sub, secret, expiry string, payload interface{}) (*string, error) { - - duration, err := time.ParseDuration(expiry) - - if err != nil { - return nil, err - } - - token := jwt.New(jwt.SigningMethodHS256) - claims := token.Claims.(jwt.MapClaims) - claims["sub"] = sub - claims["exp"] = time.Now().Add(duration).Unix() - claims["data"] = payload - - signedToken, err := token.SignedString([]byte(secret)) - if err != nil { - return nil, err - } - - return &signedToken, nil -} diff --git a/services/auth/sucess_login.go b/services/auth/sucess_login.go index 524d6a6..57e1da2 100644 --- a/services/auth/sucess_login.go +++ b/services/auth/sucess_login.go @@ -3,37 +3,41 @@ package authService import ( "context" "encoding/hex" + "fmt" "net/http" - "os" - "git.farahty.com/nimer/go-mongo/app" "git.farahty.com/nimer/go-mongo/models" + "git.farahty.com/nimer/go-mongo/utils" "github.com/google/uuid" "go.mongodb.org/mongo-driver/bson" ) func successLogin(ctx context.Context, user *models.User) (*models.LoginResponse, error) { - refresh_secret := os.Getenv("REFRESH_SECRET") - refresh_expiry := os.Getenv("REFRESH_EXPIRY") + refresh_secret := app.Config.RefreshSecret + refresh_expiry := app.Config.RefreshExpiry - access_secret := os.Getenv("ACCESS_SECRET") - access_expiry := os.Getenv("ACCESS_EXPIRY") + access_secret := app.Config.AccessSecret + access_expiry := app.Config.AccessExpiry - identity := user.Email - if identity == nil { - identity = user.Phone + var identity string + if user.Email != nil { + identity = *user.Email + } else if user.Phone != nil { + identity = *user.Phone + } else { + return nil, fmt.Errorf("user identity not found") } refreshHandle := hex.EncodeToString([]byte(uuid.NewString())) - refreshToken, err := createToken( + refreshToken, err := utils.CreateToken( refreshHandle, refresh_secret, refresh_expiry, models.UserJWT{ ID: user.ID.Hex(), - Identity: *identity, + Identity: identity, }, ) @@ -41,13 +45,13 @@ func successLogin(ctx context.Context, user *models.User) (*models.LoginResponse return nil, err } - accessToken, err := createToken( + accessToken, err := utils.CreateToken( user.ID.Hex(), access_secret, access_expiry, models.UserJWT{ ID: user.ID.Hex(), - Identity: *identity, + Identity: identity, }, ) @@ -55,19 +59,17 @@ func successLogin(ctx context.Context, user *models.User) (*models.LoginResponse return nil, err } - user.Token = &refreshHandle - - _, err = app.Collection("users").UpdateByID(ctx, user.ID, bson.D{{Key: "$set", Value: user}}) + _, err = app.Collection("users").UpdateByID(ctx, user.ID, bson.D{ + {Key: "$set", Value: bson.D{ + {Key: "token", Value: refreshHandle}, + }}, + }) if err != nil { return nil, err } - w, err := app.WriterFor(ctx) - - if err != nil { - return nil, err - } + w := app.WriterFor(ctx) http.SetCookie(*w, &http.Cookie{ Name: "access_token", diff --git a/helpers/cache.go b/utils/cache.go similarity index 97% rename from helpers/cache.go rename to utils/cache.go index e2875f8..1f0039c 100644 --- a/helpers/cache.go +++ b/utils/cache.go @@ -1,4 +1,4 @@ -package helpers +package utils import ( "context" diff --git a/helpers/encrypt.go b/utils/encrypt.go similarity index 98% rename from helpers/encrypt.go rename to utils/encrypt.go index c153b60..1f0d251 100644 --- a/helpers/encrypt.go +++ b/utils/encrypt.go @@ -1,4 +1,4 @@ -package helpers +package utils import ( "crypto/aes" diff --git a/app/scalers.go b/utils/scalers.go similarity index 84% rename from app/scalers.go rename to utils/scalers.go index 3d42043..b28cfe0 100644 --- a/app/scalers.go +++ b/utils/scalers.go @@ -1,4 +1,4 @@ -package app +package utils import ( "errors" @@ -18,7 +18,7 @@ func MarshalObjectID(value primitive.ObjectID) graphql.Marshaler { } -func UnmarshalObjectID(value interface{}) (primitive.ObjectID, error) { +func UnmarshalObjectID(value any) (primitive.ObjectID, error) { if str, ok := value.(string); ok { diff --git a/app/helpers.go b/utils/token.go similarity index 56% rename from app/helpers.go rename to utils/token.go index b21e5ed..19fb8ed 100644 --- a/app/helpers.go +++ b/utils/token.go @@ -1,17 +1,16 @@ -package app +package utils import ( "fmt" "net/http" - "os" "strings" + "time" - "git.farahty.com/nimer/go-mongo/models" "github.com/golang-jwt/jwt/v4" "github.com/mitchellh/mapstructure" ) -func getTokenFromHeader(r *http.Request) (string, error) { +func GetTokenFromHeader(r *http.Request) (string, error) { authHeader := r.Header.Get("Authorization") @@ -28,7 +27,7 @@ func getTokenFromHeader(r *http.Request) (string, error) { return strings.TrimSpace(authSlice[1]), nil } -func getTokenFromCookie(r *http.Request) (string, error) { +func GetTokenFromCookie(r *http.Request) (string, error) { cookie, err := r.Cookie("access_token") @@ -39,13 +38,13 @@ func getTokenFromCookie(r *http.Request) (string, error) { return strings.TrimSpace(cookie.Value), nil } -func getUserFromToken(tokenString string) (*models.UserJWT, error) { - token, err := jwt.Parse(tokenString, func(t *jwt.Token) (interface{}, error) { +func GetUserFromToken[T any](tokenString string, secret string) (*T, error) { + token, err := jwt.Parse(tokenString, func(t *jwt.Token) (any, error) { if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("wrong token format ") } - return []byte(os.Getenv("ACCESS_SECRET")), nil + return []byte(secret), nil }) if err != nil { @@ -56,7 +55,7 @@ func getUserFromToken(tokenString string) (*models.UserJWT, error) { return nil, fmt.Errorf("token is not valid") } - var user *models.UserJWT + var user *T claims := token.Claims.(jwt.MapClaims) if err := mapstructure.Decode(claims["data"], &user); err != nil { @@ -65,3 +64,25 @@ func getUserFromToken(tokenString string) (*models.UserJWT, error) { return user, nil } + +func CreateToken(sub, secret, expiry string, payload any) (*string, error) { + + duration, err := time.ParseDuration(expiry) + + if err != nil { + return nil, err + } + + token := jwt.New(jwt.SigningMethodHS256) + claims := token.Claims.(jwt.MapClaims) + claims["sub"] = sub + claims["exp"] = time.Now().Add(duration).Unix() + claims["data"] = payload + + signedToken, err := token.SignedString([]byte(secret)) + if err != nil { + return nil, err + } + + return &signedToken, nil +}