structure enchantment

This commit is contained in:
Nimer Farahty 2025-06-08 16:21:29 +03:00
parent 8296b813c9
commit fe76279706
27 changed files with 511 additions and 303 deletions

View File

@ -16,15 +16,31 @@ var (
UserKey = &contextKey{"user"} UserKey = &contextKey{"user"}
StatusKey = &contextKey{"status"} StatusKey = &contextKey{"status"}
ExpiryKey = &contextKey{"expiry"} ExpiryKey = &contextKey{"expiry"}
LoadersKey = &contextKey{"dataloaders"} LoadersKey = &contextKey{"dataLoaders"}
WriterKye = &contextKey{"writer"} WriterKye = &contextKey{"writer"}
RequestKey = &contextKey{"request"}
) )
func WriterFor(ctx context.Context) (*http.ResponseWriter, error) { // LoaderFor retrieves the dataLoaders from context
if writer, ok := ctx.Value(WriterKye).(*http.ResponseWriter); ok { func LoaderFor(ctx context.Context) *Loaders {
return writer, nil if loaders, ok := ctx.Value(LoadersKey).(*Loaders); ok {
return loaders
} }
return nil, fmt.Errorf("no writer found in context") panic("dataloader not found in context")
}
func WriterFor(ctx context.Context) *http.ResponseWriter {
if writer, ok := ctx.Value(WriterKye).(*http.ResponseWriter); ok {
return writer
}
panic("no writer found in context")
}
func RequestFor(ctx context.Context) *http.Request {
if req, ok := ctx.Value(RequestKey).(*http.Request); ok {
return req
}
panic("no request found in context")
} }
// Retrieves the current user from the context // Retrieves the current user from the context

View File

@ -5,9 +5,9 @@ import (
"errors" "errors"
"fmt" "fmt"
"log" "log"
"os"
"git.farahty.com/nimer/go-mongo/models" "git.farahty.com/nimer/go-mongo/models"
"git.farahty.com/nimer/go-mongo/utils"
"github.com/99designs/gqlgen/graphql" "github.com/99designs/gqlgen/graphql"
"github.com/99designs/gqlgen/graphql/handler/transport" "github.com/99designs/gqlgen/graphql/handler/transport"
"github.com/casbin/casbin/v2" "github.com/casbin/casbin/v2"
@ -22,8 +22,7 @@ var (
) )
func LoadAuthorizer(ctx context.Context) error { func LoadAuthorizer(ctx context.Context) error {
a, err := mongodbadapter.NewAdapterWithClientOption(options.Client().ApplyURI(os.Getenv("MONGO_URI")), os.Getenv("MONGO_DB")) a, err := mongodbadapter.NewAdapterWithClientOption(options.Client().ApplyURI(Config.MongoURI), Config.MongoDB)
if err != nil { if err != nil {
return err return err
} }
@ -32,18 +31,15 @@ func LoadAuthorizer(ctx context.Context) error {
return err return err
} else { } else {
Authorizer = enforcer Authorizer = enforcer
} }
a.AddPolicy("mongodb", "p", []string{"*", "login", "mutation"}) a.AddPolicy("mongodb", "p", []string{"*", "login", "mutation"})
a.AddPolicy("mongodb", "p", []string{"todos-admin", "createTodo", "mutation"}) a.AddPolicy("mongodb", "p", []string{"todos-admin", "createTodo", "mutation"})
a.AddPolicy("mongodb", "p", []string{"todos-admin", "todos", "query"}) a.AddPolicy("mongodb", "p", []string{"todos-admin", "todos", "query"})
a.AddPolicy("mongodb", "p", []string{"todos-admin", "todo", "query"}) a.AddPolicy("mongodb", "p", []string{"todos-admin", "todo", "query"})
a.AddPolicy("mongodb", "p", []string{"category-admin", "createCategory", "mutation"}) a.AddPolicy("mongodb", "p", []string{"category-admin", "createCategory", "mutation"})
a.AddPolicy("mongodb", "p", []string{"category-admin", "categories", "query"}) a.AddPolicy("mongodb", "p", []string{"category-admin", "categories", "query"})
a.AddPolicy("mongodb", "p", []string{"category-admin", "category", "query"}) a.AddPolicy("mongodb", "p", []string{"category-admin", "category", "query"})
a.AddPolicy("mongodb", "p", []string{"users-admin", "users", "query"}) a.AddPolicy("mongodb", "p", []string{"users-admin", "users", "query"})
a.AddPolicy("mongodb", "p", []string{"users-admin", "createUser", "mutation"}) a.AddPolicy("mongodb", "p", []string{"users-admin", "createUser", "mutation"})
@ -51,8 +47,8 @@ func LoadAuthorizer(ctx context.Context) error {
Authorizer.AddGroupingPolicy("admin", "todos-admin") Authorizer.AddGroupingPolicy("admin", "todos-admin")
Authorizer.AddGroupingPolicy("admin", "category-admin") Authorizer.AddGroupingPolicy("admin", "category-admin")
email := os.Getenv("ADMIN_EMAIL") email := Config.AdminEmail
password, _ := models.MakeHash(os.Getenv("ADMIN_PASSWORD")) password, _ := models.MakeHash(Config.AdminPass)
if _, err := FindOne[models.User](ctx, "users", bson.M{"email": email}); err != nil { if _, err := FindOne[models.User](ctx, "users", bson.M{"email": email}); err != nil {
log.Println("Creating admin user") log.Println("Creating admin user")
@ -67,7 +63,6 @@ func LoadAuthorizer(ctx context.Context) error {
} }
Authorizer.AddRoleForUser(admin.ID.Hex(), "admin") Authorizer.AddRoleForUser(admin.ID.Hex(), "admin")
} }
return nil return nil
@ -82,7 +77,7 @@ func AuthorizeWebSocket(ctx context.Context, initPayload transport.InitPayload)
return ctx, &initPayload, nil return ctx, &initPayload, nil
} }
user, err := getUserFromToken(token) user, err := utils.GetUserFromToken[models.UserJWT](token, Config.AccessSecret)
if err != nil { if err != nil {
ctx = SetStatus(ctx, err.Error()) ctx = SetStatus(ctx, err.Error())

90
app/config.go Normal file
View File

@ -0,0 +1,90 @@
package app
import (
"fmt"
"log"
"os"
"github.com/joho/godotenv"
)
var (
Config *ConfigStruct
)
type ConfigStruct struct {
Port string
MongoURI string
MongoDB string
RedisHost string
RedisPort string
RedisPass string
AdminEmail string
AdminPass string
RefreshSecret string
RefreshExpiry string
AccessSecret string
AccessExpiry string
Env string
}
// LoadConfig loads environment variables into Config struct and validates required fields.
func LoadConfig() error {
if _, exists := os.LookupEnv("MONGO_URI"); !exists {
if err := godotenv.Load(); err != nil {
log.Fatal("🔴 Failed to load .env file: ", err)
}
}
Config = &ConfigStruct{
Port: getEnv("PORT", "8080"),
MongoURI: os.Getenv("MONGO_URI"),
MongoDB: getEnv("MONGO_DB", "test"),
RedisHost: os.Getenv("REDIS_HOST"),
RedisPort: os.Getenv("REDIS_PORT"),
RedisPass: os.Getenv("REDIS_PASSWORD"),
AdminEmail: os.Getenv("ADMIN_EMAIL"),
AdminPass: os.Getenv("ADMIN_PASSWORD"),
RefreshSecret: os.Getenv("REFRESH_SECRET"),
RefreshExpiry: os.Getenv("REFRESH_EXPIRY"),
AccessSecret: os.Getenv("ACCESS_SECRET"),
AccessExpiry: os.Getenv("ACCESS_EXPIRY"),
Env: os.Getenv("ENV"),
}
missing := ""
if Config.Port == "" {
missing += "PORT "
}
if Config.MongoURI == "" {
missing += "MONGO_URI "
}
if Config.RedisHost == "" {
missing += "REDIS_HOST "
}
if Config.RedisPort == "" {
missing += "REDIS_PORT "
}
if Config.AccessSecret == "" {
missing += "ACCESS_SECRET "
}
if Config.RefreshSecret == "" {
missing += "REFRESH_SECRET "
}
if missing != "" {
return fmt.Errorf("missing required environment variables: %s", missing)
}
return nil
}
// GetConfig returns the global config
func GetConfig() *ConfigStruct {
return Config
}
func getEnv(key, def string) string {
if v := os.Getenv(key); v != "" {
return v
}
return def
}

View File

@ -3,7 +3,6 @@ package app
import ( import (
"context" "context"
"fmt" "fmt"
"os"
"time" "time"
"github.com/fatih/color" "github.com/fatih/color"
@ -22,7 +21,7 @@ func Connect() (context.CancelFunc, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
var err error var err error
Mongo, err = mongo.Connect(ctx, options.Client().ApplyURI(os.Getenv("MONGO_URI"))) Mongo, err = mongo.Connect(ctx, options.Client().ApplyURI(Config.MongoURI))
if err != nil { if err != nil {
color.Red("❌ Database connection failed\n" + err.Error()) color.Red("❌ Database connection failed\n" + err.Error())
return cancel, err return cancel, err
@ -45,7 +44,7 @@ func Disconnect(ctx context.Context) error {
// Collection returns a MongoDB collection // Collection returns a MongoDB collection
func Collection(name string) *mongo.Collection { func Collection(name string) *mongo.Collection {
return Mongo.Database(os.Getenv("MONGO_DB")).Collection(name) return Mongo.Database(Config.MongoDB).Collection(name)
} }
// Find returns all matching documents from a collection // Find returns all matching documents from a collection

View File

@ -3,7 +3,6 @@ package app
import ( import (
"context" "context"
"fmt" "fmt"
"net/http"
"git.farahty.com/nimer/go-mongo/models" "git.farahty.com/nimer/go-mongo/models"
"github.com/graph-gophers/dataloader" "github.com/graph-gophers/dataloader"
@ -27,22 +26,6 @@ func NewLoaders() *Loaders {
} }
} }
// Middleware injects dataloaders into context for each HTTP request
func LoaderMiddleware(loaders *Loaders, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctxWithLoaders := context.WithValue(r.Context(), LoadersKey, loaders)
next.ServeHTTP(w, r.WithContext(ctxWithLoaders))
})
}
// LoaderFor retrieves the dataloaders from context
func LoaderFor(ctx context.Context) *Loaders {
if loaders, ok := ctx.Value(LoadersKey).(*Loaders); ok {
return loaders
}
panic("dataloader not found in context")
}
// CreateBatch creates a generic batched loader for a MongoDB collection // CreateBatch creates a generic batched loader for a MongoDB collection
func CreateBatch[T models.Identifiable](coll string) dataloader.BatchFunc { func CreateBatch[T models.Identifiable](coll string) dataloader.BatchFunc {
return func(ctx context.Context, keys dataloader.Keys) []*dataloader.Result { return func(ctx context.Context, keys dataloader.Keys) []*dataloader.Result {

View File

@ -6,12 +6,30 @@ import (
"net/http" "net/http"
"git.farahty.com/nimer/go-mongo/models"
"git.farahty.com/nimer/go-mongo/utils"
"github.com/99designs/gqlgen/graphql" "github.com/99designs/gqlgen/graphql"
"github.com/golang-jwt/jwt/v4" "github.com/golang-jwt/jwt/v4"
) )
// Middleware injects dataloaders into context for each HTTP request
func LoaderMiddleware(loaders *Loaders, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctxWithLoaders := context.WithValue(r.Context(), LoadersKey, loaders)
next.ServeHTTP(w, r.WithContext(ctxWithLoaders))
})
}
// ExpiryMiddleware checks for expired tokens in GraphQL resolvers // ExpiryMiddleware checks for expired tokens in GraphQL resolvers
func ExpiryMiddleware(ctx context.Context, next graphql.ResponseHandler) *graphql.Response { func ExpiryMiddleware(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
object := graphql.GetRootFieldContext(ctx)
if object != nil {
println("ExpiryMiddleware: checking token expiry in GraphQL resolver")
} else {
println("ExpiryMiddleware: checking token expiry in HTTP request")
}
if IsTokenExpired(ctx) { if IsTokenExpired(ctx) {
return graphql.ErrorResponse(ctx, "token expired") return graphql.ErrorResponse(ctx, "token expired")
} }
@ -26,12 +44,20 @@ func WriterMiddleware(next http.Handler) http.Handler {
}) })
} }
// add request to context for GraphQL resolvers
func RequestMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), RequestKey, r)
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
// AuthMiddleware parses JWT token and injects user context for HTTP requests // AuthMiddleware parses JWT token and injects user context for HTTP requests
func AuthMiddleware(next http.Handler) http.Handler { func AuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
headerToken, headerErr := getTokenFromHeader(r) headerToken, headerErr := utils.GetTokenFromHeader(r)
cookieToken, cookieErr := getTokenFromCookie(r) cookieToken, cookieErr := utils.GetTokenFromCookie(r)
if headerErr != nil && cookieErr != nil { if headerErr != nil && cookieErr != nil {
ctx := SetStatus(r.Context(), headerErr.Error()) ctx := SetStatus(r.Context(), headerErr.Error())
@ -44,7 +70,7 @@ func AuthMiddleware(next http.Handler) http.Handler {
token = cookieToken token = cookieToken
} }
user, err := getUserFromToken(token) user, err := utils.GetUserFromToken[models.UserJWT](token, Config.AccessSecret)
ctx := r.Context() ctx := r.Context()
if err != nil { if err != nil {

24
app/redis-client.go Normal file
View File

@ -0,0 +1,24 @@
package app
import (
"context"
"github.com/redis/go-redis/v9"
)
var (
RedisClient redis.UniversalClient
)
// InitRedis initializes the Redis client and assigns it to RedisClient
func InitRedis(ctx context.Context) error {
RedisClient = redis.NewClient(&redis.Options{
Addr: Config.RedisHost + ":" + Config.RedisPort,
Password: Config.RedisPass,
DB: 0,
})
if _, err := RedisClient.Ping(ctx).Result(); err != nil {
return err
}
return nil
}

View File

@ -8,7 +8,7 @@ import (
"github.com/99designs/gqlgen/graphql" "github.com/99designs/gqlgen/graphql"
) )
func Auth(ctx context.Context, obj interface{}, next graphql.Resolver) (res interface{}, err error) { func Auth(ctx context.Context, obj any, next graphql.Resolver) (res any, err error) {
if _, err := app.CurrentUser(ctx); err != nil { if _, err := app.CurrentUser(ctx); err != nil {
return nil, fmt.Errorf("access denied, %s", err.Error()) return nil, fmt.Errorf("access denied, %s", err.Error())

20
directives/has-role.go Normal file
View File

@ -0,0 +1,20 @@
package directives
import (
"context"
"fmt"
"git.farahty.com/nimer/go-mongo/app"
"git.farahty.com/nimer/go-mongo/models"
"github.com/99designs/gqlgen/graphql"
)
func HasRole(ctx context.Context, obj any, next graphql.Resolver, role models.Role) (res any, err error) {
if _, err := app.CurrentUser(ctx); err != nil {
return nil, fmt.Errorf("access denied, %s", err.Error())
}
return next(ctx)
}

View File

@ -13,8 +13,8 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"git.farahty.com/nimer/go-mongo/app"
"git.farahty.com/nimer/go-mongo/models" "git.farahty.com/nimer/go-mongo/models"
"git.farahty.com/nimer/go-mongo/utils"
"github.com/99designs/gqlgen/graphql" "github.com/99designs/gqlgen/graphql"
"github.com/99designs/gqlgen/graphql/introspection" "github.com/99designs/gqlgen/graphql/introspection"
gqlparser "github.com/vektah/gqlparser/v2" gqlparser "github.com/vektah/gqlparser/v2"
@ -51,7 +51,8 @@ type ResolverRoot interface {
} }
type DirectiveRoot struct { type DirectiveRoot struct {
Auth func(ctx context.Context, obj any, next graphql.Resolver) (res any, err error) Auth func(ctx context.Context, obj any, next graphql.Resolver) (res any, err error)
HasRole func(ctx context.Context, obj any, next graphql.Resolver, role models.Role) (res any, err error)
} }
type ComplexityRoot struct { type ComplexityRoot struct {
@ -119,10 +120,8 @@ type ComplexityRoot struct {
User struct { User struct {
Email func(childComplexity int) int Email func(childComplexity int) int
ID func(childComplexity int) int ID func(childComplexity int) int
Password func(childComplexity int) int
Phone func(childComplexity int) int Phone func(childComplexity int) int
Status func(childComplexity int) int Status func(childComplexity int) int
Token func(childComplexity int) int
Type func(childComplexity int) int Type func(childComplexity int) int
Verified func(childComplexity int) int Verified func(childComplexity int) int
} }
@ -507,13 +506,6 @@ func (e *executableSchema) Complexity(ctx context.Context, typeName, field strin
return e.complexity.User.ID(childComplexity), true return e.complexity.User.ID(childComplexity), true
case "User.password":
if e.complexity.User.Password == nil {
break
}
return e.complexity.User.Password(childComplexity), true
case "User.phone": case "User.phone":
if e.complexity.User.Phone == nil { if e.complexity.User.Phone == nil {
break break
@ -528,13 +520,6 @@ func (e *executableSchema) Complexity(ctx context.Context, typeName, field strin
return e.complexity.User.Status(childComplexity), true return e.complexity.User.Status(childComplexity), true
case "User.token":
if e.complexity.User.Token == nil {
break
}
return e.complexity.User.Token(childComplexity), true
case "User.type": case "User.type":
if e.complexity.User.Type == nil { if e.complexity.User.Type == nil {
break break
@ -678,6 +663,13 @@ func (ec *executionContext) introspectType(name string) (*introspection.Type, er
var sources = []*ast.Source{ var sources = []*ast.Source{
{Name: "../gql/auth.gql", Input: `directive @auth on FIELD_DEFINITION {Name: "../gql/auth.gql", Input: `directive @auth on FIELD_DEFINITION
directive @hasRole(role: Role!) on FIELD_DEFINITION
enum Role {
ADMIN
USER
}
type LoginResponse { type LoginResponse {
user: User! user: User!
accessToken: String! accessToken: String!
@ -693,15 +685,31 @@ extend type Mutation {
login(input: LoginInput!): LoginResponse! login(input: LoginInput!): LoginResponse!
} }
`, BuiltIn: false}, `, BuiltIn: false},
{Name: "../gql/base.gql", Input: `directive @goField( {Name: "../gql/base.gql", Input: `directive @goModel(
model: String
models: [String!]
forceGenerate: Boolean
) on OBJECT | INPUT_OBJECT | SCALAR | ENUM | INTERFACE | UNION
directive @goField(
forceResolver: Boolean forceResolver: Boolean
name: String name: String
omittable: Boolean
type: String
) on INPUT_FIELD_DEFINITION | FIELD_DEFINITION ) on INPUT_FIELD_DEFINITION | FIELD_DEFINITION
directive @goTag( directive @goTag(
key: String! key: String!
value: String value: String
) on INPUT_FIELD_DEFINITION | FIELD_DEFINITION ) on INPUT_FIELD_DEFINITION | FIELD_DEFINITION
directive @goExtraField(
name: String
type: String!
overrideTags: String
description: String
) repeatable on OBJECT | INPUT_OBJECT
scalar Time scalar Time
interface Base { interface Base {
@ -813,15 +821,15 @@ extend type Subscription {
onTodo: Todo onTodo: Todo
} }
`, BuiltIn: false}, `, BuiltIn: false},
{Name: "../gql/user.gql", Input: `type User { {Name: "../gql/user.gql", Input: `type User
@goExtraField(name: "Password", type: "*string")
@goExtraField(name: "Token", type: "*string") {
id: ID! id: ID!
phone: String phone: String
email: String email: String
type: String type: String
status: String status: String
verified: Boolean @goField(forceResolver: true) verified: Boolean @goField(forceResolver: true)
password: String
token: String
} }
input CreateUserInput { input CreateUserInput {
@ -846,6 +854,34 @@ var parsedSchema = gqlparser.MustLoadSchema(sources...)
// region ***************************** args.gotpl ***************************** // region ***************************** args.gotpl *****************************
func (ec *executionContext) dir_hasRole_args(ctx context.Context, rawArgs map[string]any) (map[string]any, error) {
var err error
args := map[string]any{}
arg0, err := ec.dir_hasRole_argsRole(ctx, rawArgs)
if err != nil {
return nil, err
}
args["role"] = arg0
return args, nil
}
func (ec *executionContext) dir_hasRole_argsRole(
ctx context.Context,
rawArgs map[string]any,
) (models.Role, error) {
if _, ok := rawArgs["role"]; !ok {
var zeroVal models.Role
return zeroVal, nil
}
ctx = graphql.WithPathContext(ctx, graphql.NewPathWithField("role"))
if tmp, ok := rawArgs["role"]; ok {
return ec.unmarshalNRole2gitᚗfarahtyᚗcomᚋnimerᚋgoᚑmongoᚋmodelsᚐRole(ctx, tmp)
}
var zeroVal models.Role
return zeroVal, nil
}
func (ec *executionContext) field_Mutation_createCategory_args(ctx context.Context, rawArgs map[string]any) (map[string]any, error) { func (ec *executionContext) field_Mutation_createCategory_args(ctx context.Context, rawArgs map[string]any) (map[string]any, error) {
var err error var err error
args := map[string]any{} args := map[string]any{}
@ -1556,10 +1592,6 @@ func (ec *executionContext) fieldContext_Category_createdBy(_ context.Context, f
return ec.fieldContext_User_status(ctx, field) return ec.fieldContext_User_status(ctx, field)
case "verified": case "verified":
return ec.fieldContext_User_verified(ctx, field) return ec.fieldContext_User_verified(ctx, field)
case "password":
return ec.fieldContext_User_password(ctx, field)
case "token":
return ec.fieldContext_User_token(ctx, field)
} }
return nil, fmt.Errorf("no field named %q was found under type User", field.Name) return nil, fmt.Errorf("no field named %q was found under type User", field.Name)
}, },
@ -1662,10 +1694,6 @@ func (ec *executionContext) fieldContext_Category_updatedBy(_ context.Context, f
return ec.fieldContext_User_status(ctx, field) return ec.fieldContext_User_status(ctx, field)
case "verified": case "verified":
return ec.fieldContext_User_verified(ctx, field) return ec.fieldContext_User_verified(ctx, field)
case "password":
return ec.fieldContext_User_password(ctx, field)
case "token":
return ec.fieldContext_User_token(ctx, field)
} }
return nil, fmt.Errorf("no field named %q was found under type User", field.Name) return nil, fmt.Errorf("no field named %q was found under type User", field.Name)
}, },
@ -1765,10 +1793,6 @@ func (ec *executionContext) fieldContext_Category_owner(_ context.Context, field
return ec.fieldContext_User_status(ctx, field) return ec.fieldContext_User_status(ctx, field)
case "verified": case "verified":
return ec.fieldContext_User_verified(ctx, field) return ec.fieldContext_User_verified(ctx, field)
case "password":
return ec.fieldContext_User_password(ctx, field)
case "token":
return ec.fieldContext_User_token(ctx, field)
} }
return nil, fmt.Errorf("no field named %q was found under type User", field.Name) return nil, fmt.Errorf("no field named %q was found under type User", field.Name)
}, },
@ -1871,10 +1895,6 @@ func (ec *executionContext) fieldContext_LoginResponse_user(_ context.Context, f
return ec.fieldContext_User_status(ctx, field) return ec.fieldContext_User_status(ctx, field)
case "verified": case "verified":
return ec.fieldContext_User_verified(ctx, field) return ec.fieldContext_User_verified(ctx, field)
case "password":
return ec.fieldContext_User_password(ctx, field)
case "token":
return ec.fieldContext_User_token(ctx, field)
} }
return nil, fmt.Errorf("no field named %q was found under type User", field.Name) return nil, fmt.Errorf("no field named %q was found under type User", field.Name)
}, },
@ -2240,10 +2260,6 @@ func (ec *executionContext) fieldContext_Mutation_createUser(ctx context.Context
return ec.fieldContext_User_status(ctx, field) return ec.fieldContext_User_status(ctx, field)
case "verified": case "verified":
return ec.fieldContext_User_verified(ctx, field) return ec.fieldContext_User_verified(ctx, field)
case "password":
return ec.fieldContext_User_password(ctx, field)
case "token":
return ec.fieldContext_User_token(ctx, field)
} }
return nil, fmt.Errorf("no field named %q was found under type User", field.Name) return nil, fmt.Errorf("no field named %q was found under type User", field.Name)
}, },
@ -2606,10 +2622,6 @@ func (ec *executionContext) fieldContext_Query_users(_ context.Context, field gr
return ec.fieldContext_User_status(ctx, field) return ec.fieldContext_User_status(ctx, field)
case "verified": case "verified":
return ec.fieldContext_User_verified(ctx, field) return ec.fieldContext_User_verified(ctx, field)
case "password":
return ec.fieldContext_User_password(ctx, field)
case "token":
return ec.fieldContext_User_token(ctx, field)
} }
return nil, fmt.Errorf("no field named %q was found under type User", field.Name) return nil, fmt.Errorf("no field named %q was found under type User", field.Name)
}, },
@ -3092,10 +3104,6 @@ func (ec *executionContext) fieldContext_Todo_createdBy(_ context.Context, field
return ec.fieldContext_User_status(ctx, field) return ec.fieldContext_User_status(ctx, field)
case "verified": case "verified":
return ec.fieldContext_User_verified(ctx, field) return ec.fieldContext_User_verified(ctx, field)
case "password":
return ec.fieldContext_User_password(ctx, field)
case "token":
return ec.fieldContext_User_token(ctx, field)
} }
return nil, fmt.Errorf("no field named %q was found under type User", field.Name) return nil, fmt.Errorf("no field named %q was found under type User", field.Name)
}, },
@ -3198,10 +3206,6 @@ func (ec *executionContext) fieldContext_Todo_updatedBy(_ context.Context, field
return ec.fieldContext_User_status(ctx, field) return ec.fieldContext_User_status(ctx, field)
case "verified": case "verified":
return ec.fieldContext_User_verified(ctx, field) return ec.fieldContext_User_verified(ctx, field)
case "password":
return ec.fieldContext_User_password(ctx, field)
case "token":
return ec.fieldContext_User_token(ctx, field)
} }
return nil, fmt.Errorf("no field named %q was found under type User", field.Name) return nil, fmt.Errorf("no field named %q was found under type User", field.Name)
}, },
@ -3301,10 +3305,6 @@ func (ec *executionContext) fieldContext_Todo_owner(_ context.Context, field gra
return ec.fieldContext_User_status(ctx, field) return ec.fieldContext_User_status(ctx, field)
case "verified": case "verified":
return ec.fieldContext_User_verified(ctx, field) return ec.fieldContext_User_verified(ctx, field)
case "password":
return ec.fieldContext_User_password(ctx, field)
case "token":
return ec.fieldContext_User_token(ctx, field)
} }
return nil, fmt.Errorf("no field named %q was found under type User", field.Name) return nil, fmt.Errorf("no field named %q was found under type User", field.Name)
}, },
@ -3737,88 +3737,6 @@ func (ec *executionContext) fieldContext_User_verified(_ context.Context, field
return fc, nil return fc, nil
} }
func (ec *executionContext) _User_password(ctx context.Context, field graphql.CollectedField, obj *models.User) (ret graphql.Marshaler) {
fc, err := ec.fieldContext_User_password(ctx, field)
if err != nil {
return graphql.Null
}
ctx = graphql.WithFieldContext(ctx, fc)
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = graphql.Null
}
}()
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (any, error) {
ctx = rctx // use context from middleware stack in children
return obj.Password, nil
})
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
if resTmp == nil {
return graphql.Null
}
res := resTmp.(*string)
fc.Result = res
return ec.marshalOString2ᚖstring(ctx, field.Selections, res)
}
func (ec *executionContext) fieldContext_User_password(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) {
fc = &graphql.FieldContext{
Object: "User",
Field: field,
IsMethod: false,
IsResolver: false,
Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) {
return nil, errors.New("field of type String does not have child fields")
},
}
return fc, nil
}
func (ec *executionContext) _User_token(ctx context.Context, field graphql.CollectedField, obj *models.User) (ret graphql.Marshaler) {
fc, err := ec.fieldContext_User_token(ctx, field)
if err != nil {
return graphql.Null
}
ctx = graphql.WithFieldContext(ctx, fc)
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = graphql.Null
}
}()
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (any, error) {
ctx = rctx // use context from middleware stack in children
return obj.Token, nil
})
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
if resTmp == nil {
return graphql.Null
}
res := resTmp.(*string)
fc.Result = res
return ec.marshalOString2ᚖstring(ctx, field.Selections, res)
}
func (ec *executionContext) fieldContext_User_token(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) {
fc = &graphql.FieldContext{
Object: "User",
Field: field,
IsMethod: false,
IsResolver: false,
Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) {
return nil, errors.New("field of type String does not have child fields")
},
}
return fc, nil
}
func (ec *executionContext) ___Directive_name(ctx context.Context, field graphql.CollectedField, obj *introspection.Directive) (ret graphql.Marshaler) { func (ec *executionContext) ___Directive_name(ctx context.Context, field graphql.CollectedField, obj *introspection.Directive) (ret graphql.Marshaler) {
fc, err := ec.fieldContext___Directive_name(ctx, field) fc, err := ec.fieldContext___Directive_name(ctx, field)
if err != nil { if err != nil {
@ -6777,10 +6695,6 @@ func (ec *executionContext) _User(ctx context.Context, sel ast.SelectionSet, obj
} }
out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) }) out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) })
case "password":
out.Values[i] = ec._User_password(ctx, field, obj)
case "token":
out.Values[i] = ec._User_token(ctx, field, obj)
default: default:
panic("unknown field " + strconv.Quote(field.Name)) panic("unknown field " + strconv.Quote(field.Name))
} }
@ -7223,13 +7137,13 @@ func (ec *executionContext) unmarshalNCreateUserInput2gitᚗfarahtyᚗcomᚋnime
} }
func (ec *executionContext) unmarshalNID2goᚗmongodbᚗorgᚋmongoᚑdriverᚋbsonᚋprimitiveᚐObjectID(ctx context.Context, v any) (primitive.ObjectID, error) { func (ec *executionContext) unmarshalNID2goᚗmongodbᚗorgᚋmongoᚑdriverᚋbsonᚋprimitiveᚐObjectID(ctx context.Context, v any) (primitive.ObjectID, error) {
res, err := app.UnmarshalObjectID(v) res, err := utils.UnmarshalObjectID(v)
return res, graphql.ErrorOnPath(ctx, err) return res, graphql.ErrorOnPath(ctx, err)
} }
func (ec *executionContext) marshalNID2goᚗmongodbᚗorgᚋmongoᚑdriverᚋbsonᚋprimitiveᚐObjectID(ctx context.Context, sel ast.SelectionSet, v primitive.ObjectID) graphql.Marshaler { func (ec *executionContext) marshalNID2goᚗmongodbᚗorgᚋmongoᚑdriverᚋbsonᚋprimitiveᚐObjectID(ctx context.Context, sel ast.SelectionSet, v primitive.ObjectID) graphql.Marshaler {
_ = sel _ = sel
res := app.MarshalObjectID(v) res := utils.MarshalObjectID(v)
if res == graphql.Null { if res == graphql.Null {
if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) {
ec.Errorf(ctx, "the requested element is null which the schema does not allow") ec.Errorf(ctx, "the requested element is null which the schema does not allow")
@ -7257,6 +7171,16 @@ func (ec *executionContext) marshalNLoginResponse2ᚖgitᚗfarahtyᚗcomᚋnimer
return ec._LoginResponse(ctx, sel, v) return ec._LoginResponse(ctx, sel, v)
} }
func (ec *executionContext) unmarshalNRole2gitᚗfarahtyᚗcomᚋnimerᚋgoᚑmongoᚋmodelsᚐRole(ctx context.Context, v any) (models.Role, error) {
var res models.Role
err := res.UnmarshalGQL(v)
return res, graphql.ErrorOnPath(ctx, err)
}
func (ec *executionContext) marshalNRole2gitᚗfarahtyᚗcomᚋnimerᚋgoᚑmongoᚋmodelsᚐRole(ctx context.Context, sel ast.SelectionSet, v models.Role) graphql.Marshaler {
return v
}
func (ec *executionContext) unmarshalNString2string(ctx context.Context, v any) (string, error) { func (ec *executionContext) unmarshalNString2string(ctx context.Context, v any) (string, error) {
res, err := graphql.UnmarshalString(v) res, err := graphql.UnmarshalString(v)
return res, graphql.ErrorOnPath(ctx, err) return res, graphql.ErrorOnPath(ctx, err)
@ -7671,7 +7595,7 @@ func (ec *executionContext) unmarshalOID2ᚖgoᚗmongodbᚗorgᚋmongoᚑdriver
if v == nil { if v == nil {
return nil, nil return nil, nil
} }
res, err := app.UnmarshalObjectID(v) res, err := utils.UnmarshalObjectID(v)
return &res, graphql.ErrorOnPath(ctx, err) return &res, graphql.ErrorOnPath(ctx, err)
} }
@ -7681,7 +7605,7 @@ func (ec *executionContext) marshalOID2ᚖgoᚗmongodbᚗorgᚋmongoᚑdriverᚋ
} }
_ = sel _ = sel
_ = ctx _ = ctx
res := app.MarshalObjectID(*v) res := utils.MarshalObjectID(*v)
return res return res
} }
@ -7701,6 +7625,42 @@ func (ec *executionContext) marshalOStatus2ᚖgitᚗfarahtyᚗcomᚋnimerᚋgo
return v return v
} }
func (ec *executionContext) unmarshalOString2ᚕstringᚄ(ctx context.Context, v any) ([]string, error) {
if v == nil {
return nil, nil
}
var vSlice []any
vSlice = graphql.CoerceList(v)
var err error
res := make([]string, len(vSlice))
for i := range vSlice {
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithIndex(i))
res[i], err = ec.unmarshalNString2string(ctx, vSlice[i])
if err != nil {
return nil, err
}
}
return res, nil
}
func (ec *executionContext) marshalOString2ᚕstringᚄ(ctx context.Context, sel ast.SelectionSet, v []string) graphql.Marshaler {
if v == nil {
return graphql.Null
}
ret := make(graphql.Array, len(v))
for i := range v {
ret[i] = ec.marshalNString2string(ctx, sel, v[i])
}
for _, e := range ret {
if e == graphql.Null {
return graphql.Null
}
}
return ret
}
func (ec *executionContext) unmarshalOString2ᚖstring(ctx context.Context, v any) (*string, error) { func (ec *executionContext) unmarshalOString2ᚖstring(ctx context.Context, v any) (*string, error) {
if v == nil { if v == nil {
return nil, nil return nil, nil

View File

@ -1,5 +1,12 @@
directive @auth on FIELD_DEFINITION directive @auth on FIELD_DEFINITION
directive @hasRole(role: Role!) on FIELD_DEFINITION
enum Role {
ADMIN
USER
}
type LoginResponse { type LoginResponse {
user: User! user: User!
accessToken: String! accessToken: String!

View File

@ -1,12 +1,28 @@
directive @goModel(
model: String
models: [String!]
forceGenerate: Boolean
) on OBJECT | INPUT_OBJECT | SCALAR | ENUM | INTERFACE | UNION
directive @goField( directive @goField(
forceResolver: Boolean forceResolver: Boolean
name: String name: String
omittable: Boolean
type: String
) on INPUT_FIELD_DEFINITION | FIELD_DEFINITION ) on INPUT_FIELD_DEFINITION | FIELD_DEFINITION
directive @goTag( directive @goTag(
key: String! key: String!
value: String value: String
) on INPUT_FIELD_DEFINITION | FIELD_DEFINITION ) on INPUT_FIELD_DEFINITION | FIELD_DEFINITION
directive @goExtraField(
name: String
type: String!
overrideTags: String
description: String
) repeatable on OBJECT | INPUT_OBJECT
scalar Time scalar Time
interface Base { interface Base {

View File

@ -1,12 +1,12 @@
type User { type User
@goExtraField(name: "Password", type: "*string")
@goExtraField(name: "Token", type: "*string") {
id: ID! id: ID!
phone: String phone: String
email: String email: String
type: String type: String
status: String status: String
verified: Boolean @goField(forceResolver: true) verified: Boolean @goField(forceResolver: true)
password: String
token: String
} }
input CreateUserInput { input CreateUserInput {

View File

@ -27,7 +27,7 @@ autobind:
models: models:
ID: ID:
model: model:
- git.farahty.com/nimer/go-mongo/app.ObjectID - git.farahty.com/nimer/go-mongo/utils.ObjectID
- github.com/99designs/gqlgen/graphql.ID - github.com/99designs/gqlgen/graphql.ID
Int: Int:
model: model:

36
main.go
View File

@ -11,8 +11,6 @@ import (
"git.farahty.com/nimer/go-mongo/app" "git.farahty.com/nimer/go-mongo/app"
"github.com/fatih/color" "github.com/fatih/color"
"github.com/joho/godotenv"
"github.com/redis/go-redis/v9"
) )
func main() { func main() {
@ -30,23 +28,12 @@ func main() {
color.Yellow("🚀 Starting server ...\n") color.Yellow("🚀 Starting server ...\n")
// Load .env if needed // Load and validate config
if _, exists := os.LookupEnv("MONGO_URI"); !exists { err := app.LoadConfig()
if err := godotenv.Load(); err != nil { if err != nil {
log.Fatal("🔴 Failed to load .env file: ", err) log.Fatal("🔴 Config error: ", err)
}
color.Green("✅ .env file loaded\n")
} }
// Validate environment variables
requiredEnvs := []string{"PORT", "MONGO_URI", "REDIS_HOST", "REDIS_PORT"}
for _, key := range requiredEnvs {
if os.Getenv(key) == "" {
log.Fatalf("🔴 Required environment variable %s is missing", key)
}
}
port := os.Getenv("PORT")
// Connect to Mongo // Connect to Mongo
dbCancel, err := app.Connect() dbCancel, err := app.Connect()
if err != nil { if err != nil {
@ -68,26 +55,21 @@ func main() {
color.Green("✅ Authorization policies loaded successfully\n") color.Green("✅ Authorization policies loaded successfully\n")
// Redis // Redis
redisClient := redis.NewClient(&redis.Options{ if err := app.InitRedis(ctx); err != nil {
Addr: os.Getenv("REDIS_HOST") + ":" + os.Getenv("REDIS_PORT"),
Password: os.Getenv("REDIS_PASSWORD"),
DB: 0,
})
if _, err := redisClient.Ping(ctx).Result(); err != nil {
log.Fatal("🔴 Redis connection error: ", err) log.Fatal("🔴 Redis connection error: ", err)
} }
defer func() { defer func() {
color.Red("❌ Closing Redis connection...\n") color.Red("❌ Closing Redis connection...\n")
_ = redisClient.Close() _ = app.RedisClient.Close()
}() }()
color.Green("✅ Connected to Redis cache successfully\n") color.Green("✅ Connected to Redis cache successfully\n")
// Create GraphQL server // Create GraphQL server
graphqlServer := createGraphqlServer(redisClient) graphqlServer := createGraphqlServer()
// Start HTTP server // Start HTTP server
server := &http.Server{ server := &http.Server{
Addr: ":" + port, Addr: ":" + app.Config.Port,
Handler: createRouter(graphqlServer), Handler: createRouter(graphqlServer),
ReadTimeout: 30 * time.Second, ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second,
@ -95,7 +77,7 @@ func main() {
} }
go func() { go func() {
color.Green("🌐 Server listening at http://localhost:%s\n", port) color.Green("🌐 Server listening at http://localhost:%s\n", app.Config.Port)
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("🔴 Server failed: %v", err) log.Fatalf("🔴 Server failed: %v", err)
} }

View File

@ -132,8 +132,63 @@ type User struct {
Type *string `json:"type,omitempty" bson:"type,omitempty"` Type *string `json:"type,omitempty" bson:"type,omitempty"`
Status *string `json:"status,omitempty" bson:"status,omitempty"` Status *string `json:"status,omitempty" bson:"status,omitempty"`
Verified *bool `json:"verified,omitempty" bson:"verified,omitempty"` Verified *bool `json:"verified,omitempty" bson:"verified,omitempty"`
Password *string `json:"password,omitempty" bson:"password,omitempty"` Password *string `json:"-" bson:"Password,omitempty"`
Token *string `json:"token,omitempty" bson:"token,omitempty"` Token *string `json:"-" bson:"Token,omitempty"`
}
type Role string
const (
RoleAdmin Role = "ADMIN"
RoleUser Role = "USER"
)
var AllRole = []Role{
RoleAdmin,
RoleUser,
}
func (e Role) IsValid() bool {
switch e {
case RoleAdmin, RoleUser:
return true
}
return false
}
func (e Role) String() string {
return string(e)
}
func (e *Role) UnmarshalGQL(v any) error {
str, ok := v.(string)
if !ok {
return fmt.Errorf("enums must be strings")
}
*e = Role(str)
if !e.IsValid() {
return fmt.Errorf("%s is not a valid Role", str)
}
return nil
}
func (e Role) MarshalGQL(w io.Writer) {
fmt.Fprint(w, strconv.Quote(e.String()))
}
func (e *Role) UnmarshalJSON(b []byte) error {
s, err := strconv.Unquote(string(b))
if err != nil {
return err
}
return e.UnmarshalGQL(s)
}
func (e Role) MarshalJSON() ([]byte, error) {
var buf bytes.Buffer
e.MarshalGQL(&buf)
return buf.Bytes(), nil
} }
type Status string type Status string

View File

@ -7,6 +7,7 @@ package resolvers
import ( import (
"context" "context"
"git.farahty.com/nimer/go-mongo/app"
"git.farahty.com/nimer/go-mongo/generated" "git.farahty.com/nimer/go-mongo/generated"
"git.farahty.com/nimer/go-mongo/models" "git.farahty.com/nimer/go-mongo/models"
authService "git.farahty.com/nimer/go-mongo/services/auth" authService "git.farahty.com/nimer/go-mongo/services/auth"
@ -14,8 +15,7 @@ import (
// Login is the resolver for the login field. // Login is the resolver for the login field.
func (r *mutationResolver) Login(ctx context.Context, input models.LoginInput) (*models.LoginResponse, error) { func (r *mutationResolver) Login(ctx context.Context, input models.LoginInput) (*models.LoginResponse, error) {
return authService.Login(ctx, &input, app.RedisClient)
return authService.Login(ctx, &input)
} }
// Mutation returns generated.MutationResolver implementation. // Mutation returns generated.MutationResolver implementation.

View File

@ -15,15 +15,19 @@ import (
// It serves as dependency injection for your app, add any dependencies you require here. // It serves as dependency injection for your app, add any dependencies you require here.
type Resolver struct { type Resolver struct {
Redis *redis.Client Redis redis.UniversalClient
} }
func Subscribe[T any](ctx context.Context, redis *redis.Client, event string) (<-chan *T, error) { func Subscribe[T any](ctx context.Context, redisClient redis.UniversalClient, event string) (<-chan *T, error) {
client, ok := redisClient.(*redis.Client)
if !ok {
return nil, nil // or handle error
}
clientChannel := make(chan *T, 1) clientChannel := make(chan *T, 1)
go func() { go func() {
sub := redis.Subscribe(ctx, event) sub := client.Subscribe(ctx, event)
if _, err := sub.Receive(ctx); err != nil { if _, err := sub.Receive(ctx); err != nil {
return return
@ -52,5 +56,4 @@ func Subscribe[T any](ctx context.Context, redis *redis.Client, event string) (<
}() }()
return clientChannel, nil return clientChannel, nil
} }

View File

@ -12,6 +12,7 @@ import (
"git.farahty.com/nimer/go-mongo/models" "git.farahty.com/nimer/go-mongo/models"
todoService "git.farahty.com/nimer/go-mongo/services/todo" todoService "git.farahty.com/nimer/go-mongo/services/todo"
userService "git.farahty.com/nimer/go-mongo/services/user" userService "git.farahty.com/nimer/go-mongo/services/user"
redis "github.com/redis/go-redis/v9"
"go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/bson/primitive"
) )
@ -24,7 +25,9 @@ func (r *mutationResolver) CreateTodo(ctx context.Context, input models.CreateTo
} }
if objJson, err := json.Marshal(obj); err == nil { if objJson, err := json.Marshal(obj); err == nil {
r.Redis.Publish(ctx, "NEW_TODO_EVENT", objJson) if client, ok := r.Redis.(*redis.Client); ok {
client.Publish(ctx, "NEW_TODO_EVENT", objJson)
}
} }
return obj, nil return obj, nil
@ -42,7 +45,10 @@ func (r *queryResolver) Todo(ctx context.Context, id primitive.ObjectID) (*model
// OnTodo is the resolver for the onTodo field. // OnTodo is the resolver for the onTodo field.
func (r *subscriptionResolver) OnTodo(ctx context.Context) (<-chan *models.Todo, error) { func (r *subscriptionResolver) OnTodo(ctx context.Context) (<-chan *models.Todo, error) {
return Subscribe[models.Todo](ctx, r.Redis, "NEW_TODO_EVENT") if client, ok := r.Redis.(*redis.Client); ok {
return Subscribe[models.Todo](ctx, client, "NEW_TODO_EVENT")
}
return nil, nil
} }
// CreatedBy is the resolver for the createdBy field. // CreatedBy is the resolver for the createdBy field.

View File

@ -3,14 +3,13 @@ package main
import ( import (
"log" "log"
"net/http" "net/http"
"os"
"time" "time"
"git.farahty.com/nimer/go-mongo/app" "git.farahty.com/nimer/go-mongo/app"
"git.farahty.com/nimer/go-mongo/controllers" "git.farahty.com/nimer/go-mongo/controllers"
"git.farahty.com/nimer/go-mongo/directives" "git.farahty.com/nimer/go-mongo/directives"
"git.farahty.com/nimer/go-mongo/generated" "git.farahty.com/nimer/go-mongo/generated"
"git.farahty.com/nimer/go-mongo/helpers" "git.farahty.com/nimer/go-mongo/utils"
"github.com/99designs/gqlgen/graphql/handler" "github.com/99designs/gqlgen/graphql/handler"
"github.com/99designs/gqlgen/graphql/handler/extension" "github.com/99designs/gqlgen/graphql/handler/extension"
"github.com/99designs/gqlgen/graphql/handler/transport" "github.com/99designs/gqlgen/graphql/handler/transport"
@ -22,14 +21,13 @@ import (
"github.com/go-chi/httprate" "github.com/go-chi/httprate"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/redis/go-redis/v9"
) )
func createRouter(graphqlServer http.Handler) chi.Router { func createRouter(graphqlServer http.Handler) chi.Router {
router := chi.NewRouter() router := chi.NewRouter()
// Apply middleware // Apply middleware
if os.Getenv("ENV") == "production" { if app.Config.Env == "production" {
router.Use(httprate.LimitByIP(100, 1*time.Minute)) // 100 requests/minute/IP router.Use(httprate.LimitByIP(100, 1*time.Minute)) // 100 requests/minute/IP
} }
@ -42,6 +40,7 @@ func createRouter(graphqlServer http.Handler) chi.Router {
router.Use(app.AuthMiddleware) router.Use(app.AuthMiddleware)
router.Use(app.WriterMiddleware) router.Use(app.WriterMiddleware)
router.Use(app.RequestMiddleware)
// REST routes // REST routes
router.Mount("/users", controllers.UserRouter()) router.Mount("/users", controllers.UserRouter())
@ -53,10 +52,8 @@ func createRouter(graphqlServer http.Handler) chi.Router {
return router return router
} }
func createGraphqlServer(redisClient *redis.Client) http.Handler { func createGraphqlServer() http.Handler {
cache, err := utils.NewCache(app.RedisClient, 24*time.Hour, "apq:")
cache, err := helpers.NewCache(redisClient, 24*time.Hour, "apq:")
if err != nil { if err != nil {
log.Fatalf("cannot create APQ redis cache: %v", err) log.Fatalf("cannot create APQ redis cache: %v", err)
} }
@ -64,7 +61,7 @@ func createGraphqlServer(redisClient *redis.Client) http.Handler {
// Setup gqlgen with resolvers and Redis client // Setup gqlgen with resolvers and Redis client
schema := generated.Config{ schema := generated.Config{
Resolvers: &resolvers.Resolver{ Resolvers: &resolvers.Resolver{
Redis: redisClient, Redis: app.RedisClient,
}, },
} }
@ -91,6 +88,7 @@ func createGraphqlServer(redisClient *redis.Client) http.Handler {
srv.Use(extension.AutomaticPersistedQuery{Cache: cache}) srv.Use(extension.AutomaticPersistedQuery{Cache: cache})
srv.Use(extension.Introspection{}) srv.Use(extension.Introspection{})
// Apply global middleware // Apply global middleware
srv.AroundRootFields(app.RootFieldsAuthorizer) // Check for @auth at root fields srv.AroundRootFields(app.RootFieldsAuthorizer) // Check for @auth at root fields
srv.AroundResponses(app.ExpiryMiddleware) // Token expiry validation srv.AroundResponses(app.ExpiryMiddleware) // Token expiry validation
@ -101,4 +99,6 @@ func createGraphqlServer(redisClient *redis.Client) http.Handler {
func mapDirectives(config *generated.Config) { func mapDirectives(config *generated.Config) {
config.Directives.Auth = directives.Auth config.Directives.Auth = directives.Auth
config.Directives.HasRole = directives.HasRole
} }

View File

@ -3,15 +3,23 @@ package authService
import ( import (
"context" "context"
"fmt" "fmt"
"time"
"git.farahty.com/nimer/go-mongo/app" "git.farahty.com/nimer/go-mongo/app"
"git.farahty.com/nimer/go-mongo/models" "git.farahty.com/nimer/go-mongo/models"
"github.com/redis/go-redis/v9"
"go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson"
) )
func Login(ctx context.Context, loginInput *models.LoginInput) (*models.LoginResponse, error) { const maxAttempts = 5
const lockoutDuration = 15 * time.Minute
func Login(ctx context.Context, loginInput *models.LoginInput, redisClient redis.UniversalClient) (*models.LoginResponse, error) {
if locked, _ := isLockedOut(loginInput.Identity, redisClient); locked {
return nil, fmt.Errorf("too many failed attempts, please try again later")
}
// todo : fix the security threats here
filter := bson.D{ filter := bson.D{
{ {
Key: "$or", Key: "$or",
@ -29,8 +37,32 @@ func Login(ctx context.Context, loginInput *models.LoginInput) (*models.LoginRes
} }
if !user.CheckPassword(loginInput.Password) { if !user.CheckPassword(loginInput.Password) {
return nil, fmt.Errorf("incorrect password") logFailedLoginAttempt(ctx, loginInput.Identity, redisClient) // optional
return nil, fmt.Errorf("invalid identity or password")
} }
clearFailedAttempts(ctx, loginInput.Identity, redisClient)
return successLogin(ctx, user) return successLogin(ctx, user)
} }
func isLockedOut(identity string, redisClient redis.UniversalClient) (bool, error) {
key := fmt.Sprintf("login:fail:%s", identity)
attempts, err := redisClient.Get(context.Background(), key).Int()
if err != nil && err != redis.Nil {
return false, err
}
return attempts >= maxAttempts, nil
}
func logFailedLoginAttempt(ctx context.Context, identity string, redisClient redis.UniversalClient) {
key := fmt.Sprintf("login:fail:%s", identity)
redisClient.Incr(ctx, key)
redisClient.Expire(ctx, key, lockoutDuration)
}
func clearFailedAttempts(ctx context.Context, identity string, redisClient redis.UniversalClient) {
key := fmt.Sprintf("login:fail:%s", identity)
redisClient.Del(ctx, key)
}

View File

@ -1,29 +0,0 @@
package authService
import (
"time"
"github.com/golang-jwt/jwt/v4"
)
func createToken(sub, secret, expiry string, payload interface{}) (*string, error) {
duration, err := time.ParseDuration(expiry)
if err != nil {
return nil, err
}
token := jwt.New(jwt.SigningMethodHS256)
claims := token.Claims.(jwt.MapClaims)
claims["sub"] = sub
claims["exp"] = time.Now().Add(duration).Unix()
claims["data"] = payload
signedToken, err := token.SignedString([]byte(secret))
if err != nil {
return nil, err
}
return &signedToken, nil
}

View File

@ -3,37 +3,41 @@ package authService
import ( import (
"context" "context"
"encoding/hex" "encoding/hex"
"fmt"
"net/http" "net/http"
"os"
"git.farahty.com/nimer/go-mongo/app" "git.farahty.com/nimer/go-mongo/app"
"git.farahty.com/nimer/go-mongo/models" "git.farahty.com/nimer/go-mongo/models"
"git.farahty.com/nimer/go-mongo/utils"
"github.com/google/uuid" "github.com/google/uuid"
"go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson"
) )
func successLogin(ctx context.Context, user *models.User) (*models.LoginResponse, error) { func successLogin(ctx context.Context, user *models.User) (*models.LoginResponse, error) {
refresh_secret := os.Getenv("REFRESH_SECRET") refresh_secret := app.Config.RefreshSecret
refresh_expiry := os.Getenv("REFRESH_EXPIRY") refresh_expiry := app.Config.RefreshExpiry
access_secret := os.Getenv("ACCESS_SECRET") access_secret := app.Config.AccessSecret
access_expiry := os.Getenv("ACCESS_EXPIRY") access_expiry := app.Config.AccessExpiry
identity := user.Email var identity string
if identity == nil { if user.Email != nil {
identity = user.Phone identity = *user.Email
} else if user.Phone != nil {
identity = *user.Phone
} else {
return nil, fmt.Errorf("user identity not found")
} }
refreshHandle := hex.EncodeToString([]byte(uuid.NewString())) refreshHandle := hex.EncodeToString([]byte(uuid.NewString()))
refreshToken, err := createToken( refreshToken, err := utils.CreateToken(
refreshHandle, refreshHandle,
refresh_secret, refresh_secret,
refresh_expiry, refresh_expiry,
models.UserJWT{ models.UserJWT{
ID: user.ID.Hex(), ID: user.ID.Hex(),
Identity: *identity, Identity: identity,
}, },
) )
@ -41,13 +45,13 @@ func successLogin(ctx context.Context, user *models.User) (*models.LoginResponse
return nil, err return nil, err
} }
accessToken, err := createToken( accessToken, err := utils.CreateToken(
user.ID.Hex(), user.ID.Hex(),
access_secret, access_secret,
access_expiry, access_expiry,
models.UserJWT{ models.UserJWT{
ID: user.ID.Hex(), ID: user.ID.Hex(),
Identity: *identity, Identity: identity,
}, },
) )
@ -55,19 +59,17 @@ func successLogin(ctx context.Context, user *models.User) (*models.LoginResponse
return nil, err return nil, err
} }
user.Token = &refreshHandle _, err = app.Collection("users").UpdateByID(ctx, user.ID, bson.D{
{Key: "$set", Value: bson.D{
_, err = app.Collection("users").UpdateByID(ctx, user.ID, bson.D{{Key: "$set", Value: user}}) {Key: "token", Value: refreshHandle},
}},
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
w, err := app.WriterFor(ctx) w := app.WriterFor(ctx)
if err != nil {
return nil, err
}
http.SetCookie(*w, &http.Cookie{ http.SetCookie(*w, &http.Cookie{
Name: "access_token", Name: "access_token",

View File

@ -1,4 +1,4 @@
package helpers package utils
import ( import (
"context" "context"

View File

@ -1,4 +1,4 @@
package helpers package utils
import ( import (
"crypto/aes" "crypto/aes"

View File

@ -1,4 +1,4 @@
package app package utils
import ( import (
"errors" "errors"
@ -18,7 +18,7 @@ func MarshalObjectID(value primitive.ObjectID) graphql.Marshaler {
} }
func UnmarshalObjectID(value interface{}) (primitive.ObjectID, error) { func UnmarshalObjectID(value any) (primitive.ObjectID, error) {
if str, ok := value.(string); ok { if str, ok := value.(string); ok {

View File

@ -1,17 +1,16 @@
package app package utils
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"os"
"strings" "strings"
"time"
"git.farahty.com/nimer/go-mongo/models"
"github.com/golang-jwt/jwt/v4" "github.com/golang-jwt/jwt/v4"
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
) )
func getTokenFromHeader(r *http.Request) (string, error) { func GetTokenFromHeader(r *http.Request) (string, error) {
authHeader := r.Header.Get("Authorization") authHeader := r.Header.Get("Authorization")
@ -28,7 +27,7 @@ func getTokenFromHeader(r *http.Request) (string, error) {
return strings.TrimSpace(authSlice[1]), nil return strings.TrimSpace(authSlice[1]), nil
} }
func getTokenFromCookie(r *http.Request) (string, error) { func GetTokenFromCookie(r *http.Request) (string, error) {
cookie, err := r.Cookie("access_token") cookie, err := r.Cookie("access_token")
@ -39,13 +38,13 @@ func getTokenFromCookie(r *http.Request) (string, error) {
return strings.TrimSpace(cookie.Value), nil return strings.TrimSpace(cookie.Value), nil
} }
func getUserFromToken(tokenString string) (*models.UserJWT, error) { func GetUserFromToken[T any](tokenString string, secret string) (*T, error) {
token, err := jwt.Parse(tokenString, func(t *jwt.Token) (interface{}, error) { token, err := jwt.Parse(tokenString, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("wrong token format ") return nil, fmt.Errorf("wrong token format ")
} }
return []byte(os.Getenv("ACCESS_SECRET")), nil return []byte(secret), nil
}) })
if err != nil { if err != nil {
@ -56,7 +55,7 @@ func getUserFromToken(tokenString string) (*models.UserJWT, error) {
return nil, fmt.Errorf("token is not valid") return nil, fmt.Errorf("token is not valid")
} }
var user *models.UserJWT var user *T
claims := token.Claims.(jwt.MapClaims) claims := token.Claims.(jwt.MapClaims)
if err := mapstructure.Decode(claims["data"], &user); err != nil { if err := mapstructure.Decode(claims["data"], &user); err != nil {
@ -65,3 +64,25 @@ func getUserFromToken(tokenString string) (*models.UserJWT, error) {
return user, nil return user, nil
} }
func CreateToken(sub, secret, expiry string, payload any) (*string, error) {
duration, err := time.ParseDuration(expiry)
if err != nil {
return nil, err
}
token := jwt.New(jwt.SigningMethodHS256)
claims := token.Claims.(jwt.MapClaims)
claims["sub"] = sub
claims["exp"] = time.Now().Add(duration).Unix()
claims["data"] = payload
signedToken, err := token.SignedString([]byte(secret))
if err != nil {
return nil, err
}
return &signedToken, nil
}