Compare commits

..

6 Commits

33 changed files with 831 additions and 338 deletions

View File

@ -3,6 +3,7 @@ package app
import (
"context"
"fmt"
"net/http"
"git.farahty.com/nimer/go-mongo/models"
)
@ -15,9 +16,33 @@ var (
UserKey = &contextKey{"user"}
StatusKey = &contextKey{"status"}
ExpiryKey = &contextKey{"expiry"}
LoadersKey = &contextKey{"dataloaders"}
LoadersKey = &contextKey{"dataLoaders"}
WriterKye = &contextKey{"writer"}
RequestKey = &contextKey{"request"}
)
// LoaderFor retrieves the dataLoaders from context
func LoaderFor(ctx context.Context) *Loaders {
if loaders, ok := ctx.Value(LoadersKey).(*Loaders); ok {
return loaders
}
panic("dataloader not found in context")
}
func WriterFor(ctx context.Context) *http.ResponseWriter {
if writer, ok := ctx.Value(WriterKye).(*http.ResponseWriter); ok {
return writer
}
panic("no writer found in context")
}
func RequestFor(ctx context.Context) *http.Request {
if req, ok := ctx.Value(RequestKey).(*http.Request); ok {
return req
}
panic("no request found in context")
}
// Retrieves the current user from the context
func CurrentUser(ctx context.Context) (*models.UserJWT, error) {
user, _ := ctx.Value(UserKey).(*models.UserJWT)
@ -32,6 +57,7 @@ func CurrentUser(ctx context.Context) (*models.UserJWT, error) {
// Check if the token was marked as expired
func IsTokenExpired(ctx context.Context) bool {
if expired, ok := ctx.Value(ExpiryKey).(bool); ok {
return expired
}

View File

@ -5,9 +5,9 @@ import (
"errors"
"fmt"
"log"
"os"
"git.farahty.com/nimer/go-mongo/models"
"git.farahty.com/nimer/go-mongo/utils"
"github.com/99designs/gqlgen/graphql"
"github.com/99designs/gqlgen/graphql/handler/transport"
"github.com/casbin/casbin/v2"
@ -22,8 +22,7 @@ var (
)
func LoadAuthorizer(ctx context.Context) error {
a, err := mongodbadapter.NewAdapterWithClientOption(options.Client().ApplyURI(os.Getenv("MONGO_URI")), os.Getenv("MONGO_DB"))
a, err := mongodbadapter.NewAdapterWithClientOption(options.Client().ApplyURI(Config.MongoURI), Config.MongoDB)
if err != nil {
return err
}
@ -32,18 +31,15 @@ func LoadAuthorizer(ctx context.Context) error {
return err
} else {
Authorizer = enforcer
}
a.AddPolicy("mongodb", "p", []string{"*", "login", "mutation"})
a.AddPolicy("mongodb", "p", []string{"todos-admin", "createTodo", "mutation"})
a.AddPolicy("mongodb", "p", []string{"todos-admin", "todos", "query"})
a.AddPolicy("mongodb", "p", []string{"todos-admin", "todo", "query"})
a.AddPolicy("mongodb", "p", []string{"category-admin", "createCategory", "mutation"})
a.AddPolicy("mongodb", "p", []string{"category-admin", "categories", "query"})
a.AddPolicy("mongodb", "p", []string{"category-admin", "category", "query"})
a.AddPolicy("mongodb", "p", []string{"users-admin", "users", "query"})
a.AddPolicy("mongodb", "p", []string{"users-admin", "createUser", "mutation"})
@ -51,8 +47,8 @@ func LoadAuthorizer(ctx context.Context) error {
Authorizer.AddGroupingPolicy("admin", "todos-admin")
Authorizer.AddGroupingPolicy("admin", "category-admin")
email := os.Getenv("ADMIN_EMAIL")
password, _ := models.MakeHash(os.Getenv("ADMIN_PASSWORD"))
email := Config.AdminEmail
password, _ := models.MakeHash(Config.AdminPass)
if _, err := FindOne[models.User](ctx, "users", bson.M{"email": email}); err != nil {
log.Println("Creating admin user")
@ -67,7 +63,6 @@ func LoadAuthorizer(ctx context.Context) error {
}
Authorizer.AddRoleForUser(admin.ID.Hex(), "admin")
}
return nil
@ -82,7 +77,7 @@ func AuthorizeWebSocket(ctx context.Context, initPayload transport.InitPayload)
return ctx, &initPayload, nil
}
user, err := getUserFromToken(token)
user, err := utils.GetUserFromToken[models.UserJWT](token, Config.AccessSecret)
if err != nil {
ctx = SetStatus(ctx, err.Error())
@ -119,6 +114,10 @@ func AuthorizeOperation(ctx context.Context) error {
return nil
}
if IsTokenExpired(ctx) && object != "login" {
return fmt.Errorf("token expired")
}
if obj, err := CurrentUser(ctx); err == nil {
user = string(obj.ID)
}

90
app/config.go Normal file
View File

@ -0,0 +1,90 @@
package app
import (
"fmt"
"log"
"os"
"github.com/joho/godotenv"
)
var (
Config *ConfigStruct
)
type ConfigStruct struct {
Port string
MongoURI string
MongoDB string
RedisHost string
RedisPort string
RedisPass string
AdminEmail string
AdminPass string
RefreshSecret string
RefreshExpiry string
AccessSecret string
AccessExpiry string
Env string
}
// LoadConfig loads environment variables into Config struct and validates required fields.
func LoadConfig() error {
if _, exists := os.LookupEnv("MONGO_URI"); !exists {
if err := godotenv.Load(); err != nil {
log.Fatal("🔴 Failed to load .env file: ", err)
}
}
Config = &ConfigStruct{
Port: getEnv("PORT", "8080"),
MongoURI: os.Getenv("MONGO_URI"),
MongoDB: getEnv("MONGO_DB", "test"),
RedisHost: os.Getenv("REDIS_HOST"),
RedisPort: os.Getenv("REDIS_PORT"),
RedisPass: os.Getenv("REDIS_PASSWORD"),
AdminEmail: os.Getenv("ADMIN_EMAIL"),
AdminPass: os.Getenv("ADMIN_PASSWORD"),
RefreshSecret: os.Getenv("REFRESH_SECRET"),
RefreshExpiry: os.Getenv("REFRESH_EXPIRY"),
AccessSecret: os.Getenv("ACCESS_SECRET"),
AccessExpiry: os.Getenv("ACCESS_EXPIRY"),
Env: os.Getenv("ENV"),
}
missing := ""
if Config.Port == "" {
missing += "PORT "
}
if Config.MongoURI == "" {
missing += "MONGO_URI "
}
if Config.RedisHost == "" {
missing += "REDIS_HOST "
}
if Config.RedisPort == "" {
missing += "REDIS_PORT "
}
if Config.AccessSecret == "" {
missing += "ACCESS_SECRET "
}
if Config.RefreshSecret == "" {
missing += "REFRESH_SECRET "
}
if missing != "" {
return fmt.Errorf("missing required environment variables: %s", missing)
}
return nil
}
// GetConfig returns the global config
func GetConfig() *ConfigStruct {
return Config
}
func getEnv(key, def string) string {
if v := os.Getenv(key); v != "" {
return v
}
return def
}

View File

@ -3,7 +3,6 @@ package app
import (
"context"
"fmt"
"os"
"time"
"github.com/fatih/color"
@ -22,7 +21,7 @@ func Connect() (context.CancelFunc, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
var err error
Mongo, err = mongo.Connect(ctx, options.Client().ApplyURI(os.Getenv("MONGO_URI")))
Mongo, err = mongo.Connect(ctx, options.Client().ApplyURI(Config.MongoURI))
if err != nil {
color.Red("❌ Database connection failed\n" + err.Error())
return cancel, err
@ -45,7 +44,7 @@ func Disconnect(ctx context.Context) error {
// Collection returns a MongoDB collection
func Collection(name string) *mongo.Collection {
return Mongo.Database(os.Getenv("MONGO_DB")).Collection(name)
return Mongo.Database(Config.MongoDB).Collection(name)
}
// Find returns all matching documents from a collection

View File

@ -3,7 +3,6 @@ package app
import (
"context"
"fmt"
"net/http"
"git.farahty.com/nimer/go-mongo/models"
"github.com/graph-gophers/dataloader"
@ -27,22 +26,6 @@ func NewLoaders() *Loaders {
}
}
// Middleware injects dataloaders into context for each HTTP request
func LoaderMiddleware(loaders *Loaders, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctxWithLoaders := context.WithValue(r.Context(), LoadersKey, loaders)
next.ServeHTTP(w, r.WithContext(ctxWithLoaders))
})
}
// LoaderFor retrieves the dataloaders from context
func LoaderFor(ctx context.Context) *Loaders {
if loaders, ok := ctx.Value(LoadersKey).(*Loaders); ok {
return loaders
}
panic("dataloader not found in context")
}
// CreateBatch creates a generic batched loader for a MongoDB collection
func CreateBatch[T models.Identifiable](coll string) dataloader.BatchFunc {
return func(ctx context.Context, keys dataloader.Keys) []*dataloader.Result {

View File

@ -3,32 +3,67 @@ package app
import (
"context"
"errors"
"net/http"
"git.farahty.com/nimer/go-mongo/models"
"git.farahty.com/nimer/go-mongo/utils"
"github.com/99designs/gqlgen/graphql"
"github.com/golang-jwt/jwt/v4"
)
// Middleware injects dataLoaders into context for each HTTP request
func LoaderMiddleware(loaders *Loaders, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctxWithLoaders := context.WithValue(r.Context(), LoadersKey, loaders)
next.ServeHTTP(w, r.WithContext(ctxWithLoaders))
})
}
// ExpiryMiddleware checks for expired tokens in GraphQL resolvers
func ExpiryMiddleware(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
if IsTokenExpired(ctx) {
return graphql.ErrorResponse(ctx, "token expired")
}
return next(ctx)
}
// add response writer to context for GraphQL resolvers
func WriterMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), WriterKye, &rw)
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
// add request to context for GraphQL resolvers
func RequestMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), RequestKey, r)
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
// AuthMiddleware parses JWT token and injects user context for HTTP requests
func AuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
tokenStr, err := getTokenFromHeader(r)
if err != nil {
ctx := SetStatus(r.Context(), err.Error())
headerToken, headerErr := utils.GetTokenFromHeader(r)
cookieToken, cookieErr := utils.GetTokenFromCookie(r)
if headerErr != nil && cookieErr != nil {
ctx := SetStatus(r.Context(), headerErr.Error())
next.ServeHTTP(rw, r.WithContext(ctx))
return
}
user, err := getUserFromToken(tokenStr)
token := headerToken
if token == "" {
token = cookieToken
}
user, err := utils.GetUserFromToken[models.UserJWT](token, Config.AccessSecret)
ctx := r.Context()
if err != nil {

24
app/redis-client.go Normal file
View File

@ -0,0 +1,24 @@
package app
import (
"context"
"github.com/redis/go-redis/v9"
)
var (
RedisClient redis.UniversalClient
)
// InitRedis initializes the Redis client and assigns it to RedisClient
func InitRedis(ctx context.Context) error {
RedisClient = redis.NewClient(&redis.Options{
Addr: Config.RedisHost + ":" + Config.RedisPort,
Password: Config.RedisPass,
DB: 0,
})
if _, err := RedisClient.Ping(ctx).Result(); err != nil {
return err
}
return nil
}

View File

@ -8,7 +8,7 @@ import (
"github.com/99designs/gqlgen/graphql"
)
func Auth(ctx context.Context, obj interface{}, next graphql.Resolver) (res interface{}, err error) {
func Auth(ctx context.Context, obj any, next graphql.Resolver) (res any, err error) {
if _, err := app.CurrentUser(ctx); err != nil {
return nil, fmt.Errorf("access denied, %s", err.Error())

20
directives/has-role.go Normal file
View File

@ -0,0 +1,20 @@
package directives
import (
"context"
"fmt"
"git.farahty.com/nimer/go-mongo/app"
"git.farahty.com/nimer/go-mongo/models"
"github.com/99designs/gqlgen/graphql"
)
func HasRole(ctx context.Context, obj any, next graphql.Resolver, role models.Role) (res any, err error) {
if _, err := app.CurrentUser(ctx); err != nil {
return nil, fmt.Errorf("access denied, %s", err.Error())
}
return next(ctx)
}

View File

@ -7,6 +7,7 @@ import (
"os"
"strings"
"git.farahty.com/nimer/go-mongo/utils"
"github.com/99designs/gqlgen/api"
"github.com/99designs/gqlgen/codegen/config"
"github.com/99designs/gqlgen/plugin/modelgen"
@ -29,7 +30,7 @@ func mutateHook(b *modelgen.ModelBuild) *modelgen.ModelBuild {
field.Tag = strings.TrimSuffix(field.Tag, `"`) + `,omitempty" bson:"-"`
}
} else {
field.Tag = strings.TrimSuffix(field.Tag, `"`) + `" bson:"` + field.Name + `,omitempty"`
field.Tag = strings.TrimSuffix(field.Tag, `"`) + `" bson:"` + utils.Camelize(field.Name) + `,omitempty"`
}
}
}

View File

@ -13,8 +13,8 @@ import (
"sync/atomic"
"time"
"git.farahty.com/nimer/go-mongo/app"
"git.farahty.com/nimer/go-mongo/models"
"git.farahty.com/nimer/go-mongo/utils"
"github.com/99designs/gqlgen/graphql"
"github.com/99designs/gqlgen/graphql/introspection"
gqlparser "github.com/vektah/gqlparser/v2"
@ -51,12 +51,14 @@ type ResolverRoot interface {
}
type DirectiveRoot struct {
Auth func(ctx context.Context, obj any, next graphql.Resolver) (res any, err error)
Auth func(ctx context.Context, obj any, next graphql.Resolver) (res any, err error)
HasRole func(ctx context.Context, obj any, next graphql.Resolver, role models.Role) (res any, err error)
}
type ComplexityRoot struct {
Category struct {
Body func(childComplexity int) int
Children func(childComplexity int) int
CreatedAt func(childComplexity int) int
CreatedBy func(childComplexity int) int
CreatedByID func(childComplexity int) int
@ -119,10 +121,8 @@ type ComplexityRoot struct {
User struct {
Email func(childComplexity int) int
ID func(childComplexity int) int
Password func(childComplexity int) int
Phone func(childComplexity int) int
Status func(childComplexity int) int
Token func(childComplexity int) int
Type func(childComplexity int) int
Verified func(childComplexity int) int
}
@ -131,6 +131,8 @@ type ComplexityRoot struct {
type CategoryResolver interface {
Parent(ctx context.Context, obj *models.Category) (*models.Category, error)
Children(ctx context.Context, obj *models.Category) ([]*models.Category, error)
CreatedBy(ctx context.Context, obj *models.Category) (*models.User, error)
UpdatedBy(ctx context.Context, obj *models.Category) (*models.User, error)
@ -190,6 +192,13 @@ func (e *executableSchema) Complexity(ctx context.Context, typeName, field strin
return e.complexity.Category.Body(childComplexity), true
case "Category.children":
if e.complexity.Category.Children == nil {
break
}
return e.complexity.Category.Children(childComplexity), true
case "Category.createdAt":
if e.complexity.Category.CreatedAt == nil {
break
@ -507,13 +516,6 @@ func (e *executableSchema) Complexity(ctx context.Context, typeName, field strin
return e.complexity.User.ID(childComplexity), true
case "User.password":
if e.complexity.User.Password == nil {
break
}
return e.complexity.User.Password(childComplexity), true
case "User.phone":
if e.complexity.User.Phone == nil {
break
@ -528,13 +530,6 @@ func (e *executableSchema) Complexity(ctx context.Context, typeName, field strin
return e.complexity.User.Status(childComplexity), true
case "User.token":
if e.complexity.User.Token == nil {
break
}
return e.complexity.User.Token(childComplexity), true
case "User.type":
if e.complexity.User.Type == nil {
break
@ -678,6 +673,13 @@ func (ec *executionContext) introspectType(name string) (*introspection.Type, er
var sources = []*ast.Source{
{Name: "../gql/auth.gql", Input: `directive @auth on FIELD_DEFINITION
directive @hasRole(role: Role!) on FIELD_DEFINITION
enum Role {
ADMIN
USER
}
type LoginResponse {
user: User!
accessToken: String!
@ -693,15 +695,31 @@ extend type Mutation {
login(input: LoginInput!): LoginResponse!
}
`, BuiltIn: false},
{Name: "../gql/base.gql", Input: `directive @goField(
{Name: "../gql/base.gql", Input: `directive @goModel(
model: String
models: [String!]
forceGenerate: Boolean
) on OBJECT | INPUT_OBJECT | SCALAR | ENUM | INTERFACE | UNION
directive @goField(
forceResolver: Boolean
name: String
omittable: Boolean
type: String
) on INPUT_FIELD_DEFINITION | FIELD_DEFINITION
directive @goTag(
key: String!
value: String
) on INPUT_FIELD_DEFINITION | FIELD_DEFINITION
directive @goExtraField(
name: String
type: String!
overrideTags: String
description: String
) repeatable on OBJECT | INPUT_OBJECT
scalar Time
interface Base {
@ -744,6 +762,9 @@ input TranslatedInput {
parent: Category @goField(forceResolver: true)
parentId: ID
"#bson:ignore"
children: [Category] @goField(forceResolver: true)
createdAt: Time!
updatedAt: Time!
@ -813,15 +834,15 @@ extend type Subscription {
onTodo: Todo
}
`, BuiltIn: false},
{Name: "../gql/user.gql", Input: `type User {
{Name: "../gql/user.gql", Input: `type User
@goExtraField(name: "Password", type: "*string")
@goExtraField(name: "Token", type: "*string") {
id: ID!
phone: String
email: String
type: String
status: String
verified: Boolean @goField(forceResolver: true)
password: String
token: String
}
input CreateUserInput {
@ -846,6 +867,34 @@ var parsedSchema = gqlparser.MustLoadSchema(sources...)
// region ***************************** args.gotpl *****************************
func (ec *executionContext) dir_hasRole_args(ctx context.Context, rawArgs map[string]any) (map[string]any, error) {
var err error
args := map[string]any{}
arg0, err := ec.dir_hasRole_argsRole(ctx, rawArgs)
if err != nil {
return nil, err
}
args["role"] = arg0
return args, nil
}
func (ec *executionContext) dir_hasRole_argsRole(
ctx context.Context,
rawArgs map[string]any,
) (models.Role, error) {
if _, ok := rawArgs["role"]; !ok {
var zeroVal models.Role
return zeroVal, nil
}
ctx = graphql.WithPathContext(ctx, graphql.NewPathWithField("role"))
if tmp, ok := rawArgs["role"]; ok {
return ec.unmarshalNRole2gitᚗfarahtyᚗcomᚋnimerᚋgoᚑmongoᚋmodelsᚐRole(ctx, tmp)
}
var zeroVal models.Role
return zeroVal, nil
}
func (ec *executionContext) field_Mutation_createCategory_args(ctx context.Context, rawArgs map[string]any) (map[string]any, error) {
var err error
args := map[string]any{}
@ -1353,6 +1402,8 @@ func (ec *executionContext) fieldContext_Category_parent(_ context.Context, fiel
return ec.fieldContext_Category_parent(ctx, field)
case "parentId":
return ec.fieldContext_Category_parentId(ctx, field)
case "children":
return ec.fieldContext_Category_children(ctx, field)
case "createdAt":
return ec.fieldContext_Category_createdAt(ctx, field)
case "updatedAt":
@ -1417,6 +1468,77 @@ func (ec *executionContext) fieldContext_Category_parentId(_ context.Context, fi
return fc, nil
}
func (ec *executionContext) _Category_children(ctx context.Context, field graphql.CollectedField, obj *models.Category) (ret graphql.Marshaler) {
fc, err := ec.fieldContext_Category_children(ctx, field)
if err != nil {
return graphql.Null
}
ctx = graphql.WithFieldContext(ctx, fc)
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = graphql.Null
}
}()
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (any, error) {
ctx = rctx // use context from middleware stack in children
return ec.resolvers.Category().Children(rctx, obj)
})
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
if resTmp == nil {
return graphql.Null
}
res := resTmp.([]*models.Category)
fc.Result = res
return ec.marshalOCategory2ᚕᚖgitᚗfarahtyᚗcomᚋnimerᚋgoᚑmongoᚋmodelsᚐCategory(ctx, field.Selections, res)
}
func (ec *executionContext) fieldContext_Category_children(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) {
fc = &graphql.FieldContext{
Object: "Category",
Field: field,
IsMethod: true,
IsResolver: true,
Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) {
switch field.Name {
case "id":
return ec.fieldContext_Category_id(ctx, field)
case "title":
return ec.fieldContext_Category_title(ctx, field)
case "body":
return ec.fieldContext_Category_body(ctx, field)
case "parent":
return ec.fieldContext_Category_parent(ctx, field)
case "parentId":
return ec.fieldContext_Category_parentId(ctx, field)
case "children":
return ec.fieldContext_Category_children(ctx, field)
case "createdAt":
return ec.fieldContext_Category_createdAt(ctx, field)
case "updatedAt":
return ec.fieldContext_Category_updatedAt(ctx, field)
case "createdBy":
return ec.fieldContext_Category_createdBy(ctx, field)
case "createdById":
return ec.fieldContext_Category_createdById(ctx, field)
case "updatedBy":
return ec.fieldContext_Category_updatedBy(ctx, field)
case "updatedById":
return ec.fieldContext_Category_updatedById(ctx, field)
case "owner":
return ec.fieldContext_Category_owner(ctx, field)
case "ownerId":
return ec.fieldContext_Category_ownerId(ctx, field)
}
return nil, fmt.Errorf("no field named %q was found under type Category", field.Name)
},
}
return fc, nil
}
func (ec *executionContext) _Category_createdAt(ctx context.Context, field graphql.CollectedField, obj *models.Category) (ret graphql.Marshaler) {
fc, err := ec.fieldContext_Category_createdAt(ctx, field)
if err != nil {
@ -1556,10 +1678,6 @@ func (ec *executionContext) fieldContext_Category_createdBy(_ context.Context, f
return ec.fieldContext_User_status(ctx, field)
case "verified":
return ec.fieldContext_User_verified(ctx, field)
case "password":
return ec.fieldContext_User_password(ctx, field)
case "token":
return ec.fieldContext_User_token(ctx, field)
}
return nil, fmt.Errorf("no field named %q was found under type User", field.Name)
},
@ -1662,10 +1780,6 @@ func (ec *executionContext) fieldContext_Category_updatedBy(_ context.Context, f
return ec.fieldContext_User_status(ctx, field)
case "verified":
return ec.fieldContext_User_verified(ctx, field)
case "password":
return ec.fieldContext_User_password(ctx, field)
case "token":
return ec.fieldContext_User_token(ctx, field)
}
return nil, fmt.Errorf("no field named %q was found under type User", field.Name)
},
@ -1765,10 +1879,6 @@ func (ec *executionContext) fieldContext_Category_owner(_ context.Context, field
return ec.fieldContext_User_status(ctx, field)
case "verified":
return ec.fieldContext_User_verified(ctx, field)
case "password":
return ec.fieldContext_User_password(ctx, field)
case "token":
return ec.fieldContext_User_token(ctx, field)
}
return nil, fmt.Errorf("no field named %q was found under type User", field.Name)
},
@ -1871,10 +1981,6 @@ func (ec *executionContext) fieldContext_LoginResponse_user(_ context.Context, f
return ec.fieldContext_User_status(ctx, field)
case "verified":
return ec.fieldContext_User_verified(ctx, field)
case "password":
return ec.fieldContext_User_password(ctx, field)
case "token":
return ec.fieldContext_User_token(ctx, field)
}
return nil, fmt.Errorf("no field named %q was found under type User", field.Name)
},
@ -2082,6 +2188,8 @@ func (ec *executionContext) fieldContext_Mutation_createCategory(ctx context.Con
return ec.fieldContext_Category_parent(ctx, field)
case "parentId":
return ec.fieldContext_Category_parentId(ctx, field)
case "children":
return ec.fieldContext_Category_children(ctx, field)
case "createdAt":
return ec.fieldContext_Category_createdAt(ctx, field)
case "updatedAt":
@ -2240,10 +2348,6 @@ func (ec *executionContext) fieldContext_Mutation_createUser(ctx context.Context
return ec.fieldContext_User_status(ctx, field)
case "verified":
return ec.fieldContext_User_verified(ctx, field)
case "password":
return ec.fieldContext_User_password(ctx, field)
case "token":
return ec.fieldContext_User_token(ctx, field)
}
return nil, fmt.Errorf("no field named %q was found under type User", field.Name)
},
@ -2311,6 +2415,8 @@ func (ec *executionContext) fieldContext_Query_categories(_ context.Context, fie
return ec.fieldContext_Category_parent(ctx, field)
case "parentId":
return ec.fieldContext_Category_parentId(ctx, field)
case "children":
return ec.fieldContext_Category_children(ctx, field)
case "createdAt":
return ec.fieldContext_Category_createdAt(ctx, field)
case "updatedAt":
@ -2383,6 +2489,8 @@ func (ec *executionContext) fieldContext_Query_category(ctx context.Context, fie
return ec.fieldContext_Category_parent(ctx, field)
case "parentId":
return ec.fieldContext_Category_parentId(ctx, field)
case "children":
return ec.fieldContext_Category_children(ctx, field)
case "createdAt":
return ec.fieldContext_Category_createdAt(ctx, field)
case "updatedAt":
@ -2606,10 +2714,6 @@ func (ec *executionContext) fieldContext_Query_users(_ context.Context, field gr
return ec.fieldContext_User_status(ctx, field)
case "verified":
return ec.fieldContext_User_verified(ctx, field)
case "password":
return ec.fieldContext_User_password(ctx, field)
case "token":
return ec.fieldContext_User_token(ctx, field)
}
return nil, fmt.Errorf("no field named %q was found under type User", field.Name)
},
@ -3092,10 +3196,6 @@ func (ec *executionContext) fieldContext_Todo_createdBy(_ context.Context, field
return ec.fieldContext_User_status(ctx, field)
case "verified":
return ec.fieldContext_User_verified(ctx, field)
case "password":
return ec.fieldContext_User_password(ctx, field)
case "token":
return ec.fieldContext_User_token(ctx, field)
}
return nil, fmt.Errorf("no field named %q was found under type User", field.Name)
},
@ -3198,10 +3298,6 @@ func (ec *executionContext) fieldContext_Todo_updatedBy(_ context.Context, field
return ec.fieldContext_User_status(ctx, field)
case "verified":
return ec.fieldContext_User_verified(ctx, field)
case "password":
return ec.fieldContext_User_password(ctx, field)
case "token":
return ec.fieldContext_User_token(ctx, field)
}
return nil, fmt.Errorf("no field named %q was found under type User", field.Name)
},
@ -3301,10 +3397,6 @@ func (ec *executionContext) fieldContext_Todo_owner(_ context.Context, field gra
return ec.fieldContext_User_status(ctx, field)
case "verified":
return ec.fieldContext_User_verified(ctx, field)
case "password":
return ec.fieldContext_User_password(ctx, field)
case "token":
return ec.fieldContext_User_token(ctx, field)
}
return nil, fmt.Errorf("no field named %q was found under type User", field.Name)
},
@ -3737,88 +3829,6 @@ func (ec *executionContext) fieldContext_User_verified(_ context.Context, field
return fc, nil
}
func (ec *executionContext) _User_password(ctx context.Context, field graphql.CollectedField, obj *models.User) (ret graphql.Marshaler) {
fc, err := ec.fieldContext_User_password(ctx, field)
if err != nil {
return graphql.Null
}
ctx = graphql.WithFieldContext(ctx, fc)
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = graphql.Null
}
}()
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (any, error) {
ctx = rctx // use context from middleware stack in children
return obj.Password, nil
})
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
if resTmp == nil {
return graphql.Null
}
res := resTmp.(*string)
fc.Result = res
return ec.marshalOString2ᚖstring(ctx, field.Selections, res)
}
func (ec *executionContext) fieldContext_User_password(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) {
fc = &graphql.FieldContext{
Object: "User",
Field: field,
IsMethod: false,
IsResolver: false,
Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) {
return nil, errors.New("field of type String does not have child fields")
},
}
return fc, nil
}
func (ec *executionContext) _User_token(ctx context.Context, field graphql.CollectedField, obj *models.User) (ret graphql.Marshaler) {
fc, err := ec.fieldContext_User_token(ctx, field)
if err != nil {
return graphql.Null
}
ctx = graphql.WithFieldContext(ctx, fc)
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = graphql.Null
}
}()
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (any, error) {
ctx = rctx // use context from middleware stack in children
return obj.Token, nil
})
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
if resTmp == nil {
return graphql.Null
}
res := resTmp.(*string)
fc.Result = res
return ec.marshalOString2ᚖstring(ctx, field.Selections, res)
}
func (ec *executionContext) fieldContext_User_token(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) {
fc = &graphql.FieldContext{
Object: "User",
Field: field,
IsMethod: false,
IsResolver: false,
Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) {
return nil, errors.New("field of type String does not have child fields")
},
}
return fc, nil
}
func (ec *executionContext) ___Directive_name(ctx context.Context, field graphql.CollectedField, obj *introspection.Directive) (ret graphql.Marshaler) {
fc, err := ec.fieldContext___Directive_name(ctx, field)
if err != nil {
@ -6061,6 +6071,39 @@ func (ec *executionContext) _Category(ctx context.Context, sel ast.SelectionSet,
out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) })
case "parentId":
out.Values[i] = ec._Category_parentId(ctx, field, obj)
case "children":
field := field
innerFunc := func(ctx context.Context, _ *graphql.FieldSet) (res graphql.Marshaler) {
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
}
}()
res = ec._Category_children(ctx, field, obj)
return res
}
if field.Deferrable != nil {
dfs, ok := deferred[field.Deferrable.Label]
di := 0
if ok {
dfs.AddField(field)
di = len(dfs.Values) - 1
} else {
dfs = graphql.NewFieldSet([]graphql.CollectedField{field})
deferred[field.Deferrable.Label] = dfs
}
dfs.Concurrently(di, func(ctx context.Context) graphql.Marshaler {
return innerFunc(ctx, dfs)
})
// don't run the out.Concurrently() call below
out.Values[i] = graphql.Null
continue
}
out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) })
case "createdAt":
out.Values[i] = ec._Category_createdAt(ctx, field, obj)
if out.Values[i] == graphql.Null {
@ -6777,10 +6820,6 @@ func (ec *executionContext) _User(ctx context.Context, sel ast.SelectionSet, obj
}
out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) })
case "password":
out.Values[i] = ec._User_password(ctx, field, obj)
case "token":
out.Values[i] = ec._User_token(ctx, field, obj)
default:
panic("unknown field " + strconv.Quote(field.Name))
}
@ -7223,13 +7262,13 @@ func (ec *executionContext) unmarshalNCreateUserInput2gitᚗfarahtyᚗcomᚋnime
}
func (ec *executionContext) unmarshalNID2goᚗmongodbᚗorgᚋmongoᚑdriverᚋbsonᚋprimitiveᚐObjectID(ctx context.Context, v any) (primitive.ObjectID, error) {
res, err := app.UnmarshalObjectID(v)
res, err := utils.UnmarshalObjectID(v)
return res, graphql.ErrorOnPath(ctx, err)
}
func (ec *executionContext) marshalNID2goᚗmongodbᚗorgᚋmongoᚑdriverᚋbsonᚋprimitiveᚐObjectID(ctx context.Context, sel ast.SelectionSet, v primitive.ObjectID) graphql.Marshaler {
_ = sel
res := app.MarshalObjectID(v)
res := utils.MarshalObjectID(v)
if res == graphql.Null {
if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) {
ec.Errorf(ctx, "the requested element is null which the schema does not allow")
@ -7257,6 +7296,16 @@ func (ec *executionContext) marshalNLoginResponse2ᚖgitᚗfarahtyᚗcomᚋnimer
return ec._LoginResponse(ctx, sel, v)
}
func (ec *executionContext) unmarshalNRole2gitᚗfarahtyᚗcomᚋnimerᚋgoᚑmongoᚋmodelsᚐRole(ctx context.Context, v any) (models.Role, error) {
var res models.Role
err := res.UnmarshalGQL(v)
return res, graphql.ErrorOnPath(ctx, err)
}
func (ec *executionContext) marshalNRole2gitᚗfarahtyᚗcomᚋnimerᚋgoᚑmongoᚋmodelsᚐRole(ctx context.Context, sel ast.SelectionSet, v models.Role) graphql.Marshaler {
return v
}
func (ec *executionContext) unmarshalNString2string(ctx context.Context, v any) (string, error) {
res, err := graphql.UnmarshalString(v)
return res, graphql.ErrorOnPath(ctx, err)
@ -7660,6 +7709,47 @@ func (ec *executionContext) marshalOBoolean2ᚖbool(ctx context.Context, sel ast
return res
}
func (ec *executionContext) marshalOCategory2ᚕᚖgitᚗfarahtyᚗcomᚋnimerᚋgoᚑmongoᚋmodelsᚐCategory(ctx context.Context, sel ast.SelectionSet, v []*models.Category) graphql.Marshaler {
if v == nil {
return graphql.Null
}
ret := make(graphql.Array, len(v))
var wg sync.WaitGroup
isLen1 := len(v) == 1
if !isLen1 {
wg.Add(len(v))
}
for i := range v {
i := i
fc := &graphql.FieldContext{
Index: &i,
Result: &v[i],
}
ctx := graphql.WithFieldContext(ctx, fc)
f := func(i int) {
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = nil
}
}()
if !isLen1 {
defer wg.Done()
}
ret[i] = ec.marshalOCategory2ᚖgitᚗfarahtyᚗcomᚋnimerᚋgoᚑmongoᚋmodelsᚐCategory(ctx, sel, v[i])
}
if isLen1 {
f(i)
} else {
go f(i)
}
}
wg.Wait()
return ret
}
func (ec *executionContext) marshalOCategory2ᚖgitᚗfarahtyᚗcomᚋnimerᚋgoᚑmongoᚋmodelsᚐCategory(ctx context.Context, sel ast.SelectionSet, v *models.Category) graphql.Marshaler {
if v == nil {
return graphql.Null
@ -7671,7 +7761,7 @@ func (ec *executionContext) unmarshalOID2ᚖgoᚗmongodbᚗorgᚋmongoᚑdriver
if v == nil {
return nil, nil
}
res, err := app.UnmarshalObjectID(v)
res, err := utils.UnmarshalObjectID(v)
return &res, graphql.ErrorOnPath(ctx, err)
}
@ -7681,7 +7771,7 @@ func (ec *executionContext) marshalOID2ᚖgoᚗmongodbᚗorgᚋmongoᚑdriverᚋ
}
_ = sel
_ = ctx
res := app.MarshalObjectID(*v)
res := utils.MarshalObjectID(*v)
return res
}
@ -7701,6 +7791,42 @@ func (ec *executionContext) marshalOStatus2ᚖgitᚗfarahtyᚗcomᚋnimerᚋgo
return v
}
func (ec *executionContext) unmarshalOString2ᚕstringᚄ(ctx context.Context, v any) ([]string, error) {
if v == nil {
return nil, nil
}
var vSlice []any
vSlice = graphql.CoerceList(v)
var err error
res := make([]string, len(vSlice))
for i := range vSlice {
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithIndex(i))
res[i], err = ec.unmarshalNString2string(ctx, vSlice[i])
if err != nil {
return nil, err
}
}
return res, nil
}
func (ec *executionContext) marshalOString2ᚕstringᚄ(ctx context.Context, sel ast.SelectionSet, v []string) graphql.Marshaler {
if v == nil {
return graphql.Null
}
ret := make(graphql.Array, len(v))
for i := range v {
ret[i] = ec.marshalNString2string(ctx, sel, v[i])
}
for _, e := range ret {
if e == graphql.Null {
return graphql.Null
}
}
return ret
}
func (ec *executionContext) unmarshalOString2ᚖstring(ctx context.Context, v any) (*string, error) {
if v == nil {
return nil, nil

View File

@ -1,5 +1,12 @@
directive @auth on FIELD_DEFINITION
directive @hasRole(role: Role!) on FIELD_DEFINITION
enum Role {
ADMIN
USER
}
type LoginResponse {
user: User!
accessToken: String!

View File

@ -1,12 +1,28 @@
directive @goModel(
model: String
models: [String!]
forceGenerate: Boolean
) on OBJECT | INPUT_OBJECT | SCALAR | ENUM | INTERFACE | UNION
directive @goField(
forceResolver: Boolean
name: String
omittable: Boolean
type: String
) on INPUT_FIELD_DEFINITION | FIELD_DEFINITION
directive @goTag(
key: String!
value: String
) on INPUT_FIELD_DEFINITION | FIELD_DEFINITION
directive @goExtraField(
name: String
type: String!
overrideTags: String
description: String
) repeatable on OBJECT | INPUT_OBJECT
scalar Time
interface Base {

View File

@ -8,6 +8,9 @@ type Category implements Base {
parent: Category @goField(forceResolver: true)
parentId: ID
"#bson:ignore"
children: [Category] @goField(forceResolver: true)
createdAt: Time!
updatedAt: Time!

View File

@ -1,12 +1,12 @@
type User {
type User
@goExtraField(name: "Password", type: "*string")
@goExtraField(name: "Token", type: "*string") {
id: ID!
phone: String
email: String
type: String
status: String
verified: Boolean @goField(forceResolver: true)
password: String
token: String
}
input CreateUserInput {

View File

@ -1,12 +1,10 @@
schema:
- gql/*.gql
exec:
filename: generated/generated.go
package: generated
model:
filename: models/models_gen.go
package: models
@ -26,11 +24,11 @@ autobind:
# your liking
models:
ID:
model:
- git.farahty.com/nimer/go-mongo/app.ObjectID
model:
- git.farahty.com/nimer/go-mongo/utils.ObjectID
- github.com/99designs/gqlgen/graphql.ID
Int:
model:
- github.com/99designs/gqlgen/graphql.Int
- github.com/99designs/gqlgen/graphql.Int64
- github.com/99designs/gqlgen/graphql.Int32
- github.com/99designs/gqlgen/graphql.Int32

125
main.go
View File

@ -2,98 +2,101 @@ package main
import (
"context"
"log"
"net/http"
"os"
"os/signal"
"time"
"log"
"net/http"
"git.farahty.com/nimer/go-mongo/app"
"github.com/fatih/color"
"github.com/joho/godotenv"
"github.com/redis/go-redis/v9"
)
func main() {
// Setup cancelable root context
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
color.Yellow("Starting farahty server ...\n")
if _, exists := os.LookupEnv("MONGO_URI"); !exists {
err := godotenv.Load()
if err != nil {
log.Fatal("Error loading .env file\n")
}
color.Green("✅ .env loaded\n")
}
port := os.Getenv("PORT")
if cancel, err := app.Connect(); err != nil {
cancel()
log.Fatal(err)
} else {
defer func() {
color.Red("❌ Database Connection Closed\n")
cancel()
}()
}
// Panic recovery
defer func() {
if err := app.Mongo.Disconnect(context.Background()); err != nil {
log.Fatal("MogoDB Errors" + err.Error())
if r := recover(); r != nil {
color.Red("🔴 Panic occurred: %v\n", r)
os.Exit(1)
}
}()
color.Green("✅ Connected to Database successfully\n")
if err := app.LoadAuthorizer(context.Background()); err != nil {
log.Fatal("Authorizer Errors : " + err.Error())
color.Yellow("🚀 Starting server ...\n")
// Load and validate config
err := app.LoadConfig()
if err != nil {
log.Fatal("🔴 Config error: ", err)
}
// Connect to Mongo
dbCancel, err := app.Connect()
if err != nil {
log.Fatalf("🔴 MongoDB connection error: %v", err)
}
defer func() {
color.Red("❌ Closing MongoDB connection...\n")
dbCancel()
if err := app.Mongo.Disconnect(ctx); err != nil {
log.Fatal("🔴 MongoDB disconnection error: ", err)
}
}()
color.Green("✅ Connected to MongoDB successfully\n")
// Load authorization policies using root context
if err := app.LoadAuthorizer(ctx); err != nil {
log.Fatal("🔴 Authorizer error: ", err)
}
color.Green("✅ Authorization policies loaded successfully\n")
redisClient := redis.NewClient(&redis.Options{
Addr: os.Getenv("REDIS_HOST") + ":" + os.Getenv("REDIS_PORT"),
Password: os.Getenv("REDIS_PASSWORD"), // no password set
})
if _, err := redisClient.Ping(context.Background()).Result(); err != nil {
log.Fatal("Redis Error : " + err.Error())
// Redis
if err := app.InitRedis(ctx); err != nil {
log.Fatal("🔴 Redis connection error: ", err)
}
defer redisClient.Close()
defer func() {
color.Red("❌ Closing Redis connection...\n")
_ = app.RedisClient.Close()
}()
color.Green("✅ Connected to Redis cache successfully\n")
graphqlServer := createGraphqlServer(redisClient)
color.Green("🚀 Server Started at http://localhost:" + port + "\n")
//http.ListenAndServe(":"+port, createRouter(graphqlServer))
// Create GraphQL server
graphqlServer := createGraphqlServer()
// Start HTTP server
server := &http.Server{
Addr: ":" + port,
WriteTimeout: time.Second * 30,
ReadTimeout: time.Second * 30,
IdleTimeout: time.Second * 30,
Addr: ":" + app.Config.Port,
Handler: createRouter(graphqlServer),
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 30 * time.Second,
}
go server.ListenAndServe()
go func() {
color.Green("🌐 Server listening at http://localhost:%s\n", app.Config.Port)
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("🔴 Server failed: %v", err)
}
}()
// Wait for interrupt signal to gracefully shut down the server with
// a timeout of 15 seconds.
// Graceful shutdown
quit := make(chan os.Signal, 1)
signal.Notify(quit, os.Interrupt)
<-quit
color.Yellow(" 🎬 Start Shutdown Signal ... ")
ctx, cancelShutdown := context.WithTimeout(context.Background(), 15*time.Second)
defer cancelShutdown()
if err := server.Shutdown(ctx); err != nil {
log.Fatal("Server Shutdown:", err)
color.Yellow("🟡 Shutdown signal received, initiating cleanup...")
// Cancel root context and wait for graceful shutdown
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 15*time.Second)
defer shutdownCancel()
if err := server.Shutdown(shutdownCtx); err != nil {
log.Fatalf("🔴 Server forced to shutdown: %v", err)
}
color.Red("❌ Server Exiting")
color.Green("✅ Server shutdown completed gracefully")
}

View File

@ -27,10 +27,12 @@ type Category struct {
Title []*Translated `json:"title" bson:"title,omitempty"`
Body []*Translated `json:"body,omitempty" bson:"body,omitempty"`
// #bson:ignore
Parent *Category `json:"parent,omitempty,omitempty" bson:"-"`
ParentID *primitive.ObjectID `json:"parentId,omitempty" bson:"parentId,omitempty"`
CreatedAt time.Time `json:"createdAt" bson:"createdAt,omitempty"`
UpdatedAt time.Time `json:"updatedAt" bson:"updatedAt,omitempty"`
Parent *Category `json:"parent,omitempty,omitempty" bson:"-"`
ParentID *primitive.ObjectID `json:"parentId,omitempty" bson:"parentId,omitempty"`
// #bson:ignore
Children []*Category `json:"children,omitempty,omitempty" bson:"-"`
CreatedAt time.Time `json:"createdAt" bson:"createdAt,omitempty"`
UpdatedAt time.Time `json:"updatedAt" bson:"updatedAt,omitempty"`
// #bson:ignore
CreatedBy *User `json:"createdBy,omitempty" bson:"-"`
CreatedByID primitive.ObjectID `json:"createdById" bson:"createdById,omitempty"`
@ -132,8 +134,63 @@ type User struct {
Type *string `json:"type,omitempty" bson:"type,omitempty"`
Status *string `json:"status,omitempty" bson:"status,omitempty"`
Verified *bool `json:"verified,omitempty" bson:"verified,omitempty"`
Password *string `json:"password,omitempty" bson:"password,omitempty"`
Token *string `json:"token,omitempty" bson:"token,omitempty"`
Password *string `json:"-" bson:"Password,omitempty"`
Token *string `json:"-" bson:"Token,omitempty"`
}
type Role string
const (
RoleAdmin Role = "ADMIN"
RoleUser Role = "USER"
)
var AllRole = []Role{
RoleAdmin,
RoleUser,
}
func (e Role) IsValid() bool {
switch e {
case RoleAdmin, RoleUser:
return true
}
return false
}
func (e Role) String() string {
return string(e)
}
func (e *Role) UnmarshalGQL(v any) error {
str, ok := v.(string)
if !ok {
return fmt.Errorf("enums must be strings")
}
*e = Role(str)
if !e.IsValid() {
return fmt.Errorf("%s is not a valid Role", str)
}
return nil
}
func (e Role) MarshalGQL(w io.Writer) {
fmt.Fprint(w, strconv.Quote(e.String()))
}
func (e *Role) UnmarshalJSON(b []byte) error {
s, err := strconv.Unquote(string(b))
if err != nil {
return err
}
return e.UnmarshalGQL(s)
}
func (e Role) MarshalJSON() ([]byte, error) {
var buf bytes.Buffer
e.MarshalGQL(&buf)
return buf.Bytes(), nil
}
type Status string

View File

@ -11,6 +11,10 @@ func MakeHash(password string) (string, error) {
}
func (u *User) CheckPassword(password string) bool {
if u.Password == nil {
return false
}
err := bcrypt.CompareHashAndPassword([]byte(*u.Password), []byte(password))
return err == nil
}

View File

@ -7,6 +7,7 @@ package resolvers
import (
"context"
"git.farahty.com/nimer/go-mongo/app"
"git.farahty.com/nimer/go-mongo/generated"
"git.farahty.com/nimer/go-mongo/models"
authService "git.farahty.com/nimer/go-mongo/services/auth"
@ -14,7 +15,7 @@ import (
// Login is the resolver for the login field.
func (r *mutationResolver) Login(ctx context.Context, input models.LoginInput) (*models.LoginResponse, error) {
return authService.Login(ctx, &input)
return authService.Login(ctx, &input, app.RedisClient)
}
// Mutation returns generated.MutationResolver implementation.

View File

@ -22,6 +22,11 @@ func (r *categoryResolver) Parent(ctx context.Context, obj *models.Category) (*m
return categoryService.FindByID(ctx, *obj.ParentID)
}
// Children is the resolver for the children field.
func (r *categoryResolver) Children(ctx context.Context, obj *models.Category) ([]*models.Category, error) {
return categoryService.FindChildren(ctx, obj.ID)
}
// CreatedBy is the resolver for the createdBy field.
func (r *categoryResolver) CreatedBy(ctx context.Context, obj *models.Category) (*models.User, error) {
return userService.FindById(ctx, obj.CreatedByID)

View File

@ -15,15 +15,19 @@ import (
// It serves as dependency injection for your app, add any dependencies you require here.
type Resolver struct {
Redis *redis.Client
Redis redis.UniversalClient
}
func Subscribe[T any](ctx context.Context, redis *redis.Client, event string) (<-chan *T, error) {
func Subscribe[T any](ctx context.Context, redisClient redis.UniversalClient, event string) (<-chan *T, error) {
client, ok := redisClient.(*redis.Client)
if !ok {
return nil, nil // or handle error
}
clientChannel := make(chan *T, 1)
go func() {
sub := redis.Subscribe(ctx, event)
sub := client.Subscribe(ctx, event)
if _, err := sub.Receive(ctx); err != nil {
return
@ -52,5 +56,4 @@ func Subscribe[T any](ctx context.Context, redis *redis.Client, event string) (<
}()
return clientChannel, nil
}

View File

@ -12,6 +12,7 @@ import (
"git.farahty.com/nimer/go-mongo/models"
todoService "git.farahty.com/nimer/go-mongo/services/todo"
userService "git.farahty.com/nimer/go-mongo/services/user"
redis "github.com/redis/go-redis/v9"
"go.mongodb.org/mongo-driver/bson/primitive"
)
@ -24,7 +25,9 @@ func (r *mutationResolver) CreateTodo(ctx context.Context, input models.CreateTo
}
if objJson, err := json.Marshal(obj); err == nil {
r.Redis.Publish(ctx, "NEW_TODO_EVENT", objJson)
if client, ok := r.Redis.(*redis.Client); ok {
client.Publish(ctx, "NEW_TODO_EVENT", objJson)
}
}
return obj, nil
@ -42,7 +45,10 @@ func (r *queryResolver) Todo(ctx context.Context, id primitive.ObjectID) (*model
// OnTodo is the resolver for the onTodo field.
func (r *subscriptionResolver) OnTodo(ctx context.Context) (<-chan *models.Todo, error) {
return Subscribe[models.Todo](ctx, r.Redis, "NEW_TODO_EVENT")
if client, ok := r.Redis.(*redis.Client); ok {
return Subscribe[models.Todo](ctx, client, "NEW_TODO_EVENT")
}
return nil, nil
}
// CreatedBy is the resolver for the createdBy field.

View File

@ -3,14 +3,13 @@ package main
import (
"log"
"net/http"
"os"
"time"
"git.farahty.com/nimer/go-mongo/app"
"git.farahty.com/nimer/go-mongo/controllers"
"git.farahty.com/nimer/go-mongo/directives"
"git.farahty.com/nimer/go-mongo/generated"
"git.farahty.com/nimer/go-mongo/helpers"
"git.farahty.com/nimer/go-mongo/utils"
"github.com/99designs/gqlgen/graphql/handler"
"github.com/99designs/gqlgen/graphql/handler/extension"
"github.com/99designs/gqlgen/graphql/handler/transport"
@ -22,14 +21,13 @@ import (
"github.com/go-chi/httprate"
"github.com/gorilla/websocket"
"github.com/redis/go-redis/v9"
)
func createRouter(graphqlServer http.Handler) chi.Router {
router := chi.NewRouter()
// Apply middleware
if os.Getenv("ENV") == "production" {
if app.Config.Env == "production" {
router.Use(httprate.LimitByIP(100, 1*time.Minute)) // 100 requests/minute/IP
}
@ -41,6 +39,9 @@ func createRouter(graphqlServer http.Handler) chi.Router {
// Custom middleware for Auth
router.Use(app.AuthMiddleware)
router.Use(app.WriterMiddleware)
router.Use(app.RequestMiddleware)
// REST routes
router.Mount("/users", controllers.UserRouter())
@ -51,10 +52,8 @@ func createRouter(graphqlServer http.Handler) chi.Router {
return router
}
func createGraphqlServer(redisClient *redis.Client) http.Handler {
cache, err := helpers.NewCache(redisClient, 24*time.Hour, "apq:")
func createGraphqlServer() http.Handler {
cache, err := utils.NewCache(app.RedisClient, 24*time.Hour, "apq:")
if err != nil {
log.Fatalf("cannot create APQ redis cache: %v", err)
}
@ -62,7 +61,7 @@ func createGraphqlServer(redisClient *redis.Client) http.Handler {
// Setup gqlgen with resolvers and Redis client
schema := generated.Config{
Resolvers: &resolvers.Resolver{
Redis: redisClient,
Redis: app.RedisClient,
},
}
@ -88,10 +87,11 @@ func createGraphqlServer(redisClient *redis.Client) http.Handler {
srv.AddTransport(transport.MultipartForm{})
srv.Use(extension.AutomaticPersistedQuery{Cache: cache})
srv.Use(extension.Introspection{})
// Apply global middleware
srv.AroundRootFields(app.RootFieldsAuthorizer) // Check for @auth at root fields
srv.AroundResponses(app.ExpiryMiddleware) // Token expiry validation
//srv.AroundResponses(app.ExpiryMiddleware) // Token expiry validation
// Inject DataLoaders into request context
return app.LoaderMiddleware(app.NewLoaders(), srv)
@ -99,4 +99,6 @@ func createGraphqlServer(redisClient *redis.Client) http.Handler {
func mapDirectives(config *generated.Config) {
config.Directives.Auth = directives.Auth
config.Directives.HasRole = directives.HasRole
}

View File

@ -3,15 +3,23 @@ package authService
import (
"context"
"fmt"
"time"
"git.farahty.com/nimer/go-mongo/app"
"git.farahty.com/nimer/go-mongo/models"
"github.com/redis/go-redis/v9"
"go.mongodb.org/mongo-driver/bson"
)
func Login(ctx context.Context, loginInput *models.LoginInput) (*models.LoginResponse, error) {
const maxAttempts = 5
const lockoutDuration = 15 * time.Minute
func Login(ctx context.Context, loginInput *models.LoginInput, redisClient redis.UniversalClient) (*models.LoginResponse, error) {
if locked, _ := isLockedOut(loginInput.Identity, redisClient); locked {
return nil, fmt.Errorf("too many failed attempts, please try again later")
}
// todo : fix the security threats here
filter := bson.D{
{
Key: "$or",
@ -25,12 +33,37 @@ func Login(ctx context.Context, loginInput *models.LoginInput) (*models.LoginRes
user, err := app.FindOne[models.User](ctx, "users", filter)
if err != nil {
logFailedLoginAttempt(ctx, loginInput.Identity, redisClient) // optional
return nil, err
}
if !user.CheckPassword(loginInput.Password) {
return nil, fmt.Errorf("incorrect password")
logFailedLoginAttempt(ctx, loginInput.Identity, redisClient) // optional
return nil, fmt.Errorf("invalid identity or password")
}
clearFailedAttempts(ctx, loginInput.Identity, redisClient)
return successLogin(ctx, user)
}
func isLockedOut(identity string, redisClient redis.UniversalClient) (bool, error) {
key := fmt.Sprintf("login:fail:%s", identity)
attempts, err := redisClient.Get(context.Background(), key).Int()
if err != nil && err != redis.Nil {
return false, err
}
return attempts >= maxAttempts, nil
}
func logFailedLoginAttempt(ctx context.Context, identity string, redisClient redis.UniversalClient) {
key := fmt.Sprintf("login:fail:%s", identity)
redisClient.Incr(ctx, key)
redisClient.Expire(ctx, key, lockoutDuration)
}
func clearFailedAttempts(ctx context.Context, identity string, redisClient redis.UniversalClient) {
key := fmt.Sprintf("login:fail:%s", identity)
redisClient.Del(ctx, key)
}

View File

@ -1,29 +0,0 @@
package authService
import (
"time"
"github.com/golang-jwt/jwt/v4"
)
func createToken(sub, secret, expiry string, payload interface{}) (*string, error) {
duration, err := time.ParseDuration(expiry)
if err != nil {
return nil, err
}
token := jwt.New(jwt.SigningMethodHS256)
claims := token.Claims.(jwt.MapClaims)
claims["sub"] = sub
claims["exp"] = time.Now().Add(duration).Unix()
claims["data"] = payload
signedToken, err := token.SignedString([]byte(secret))
if err != nil {
return nil, err
}
return &signedToken, nil
}

View File

@ -3,35 +3,41 @@ package authService
import (
"context"
"encoding/hex"
"os"
"fmt"
"net/http"
"git.farahty.com/nimer/go-mongo/app"
"git.farahty.com/nimer/go-mongo/models"
"git.farahty.com/nimer/go-mongo/utils"
"github.com/google/uuid"
"go.mongodb.org/mongo-driver/bson"
)
func successLogin(ctx context.Context, user *models.User) (*models.LoginResponse, error) {
refresh_secret := os.Getenv("REFRESH_SECRET")
refresh_expiry := os.Getenv("REFRESH_EXPIRY")
refresh_secret := app.Config.RefreshSecret
refresh_expiry := app.Config.RefreshExpiry
access_secret := os.Getenv("ACCESS_SECRET")
access_expiry := os.Getenv("ACCESS_EXPIRY")
access_secret := app.Config.AccessSecret
access_expiry := app.Config.AccessExpiry
identity := user.Email
if identity == nil {
identity = user.Phone
var identity string
if user.Email != nil {
identity = *user.Email
} else if user.Phone != nil {
identity = *user.Phone
} else {
return nil, fmt.Errorf("user identity not found")
}
refreshHandle := hex.EncodeToString([]byte(uuid.NewString()))
refreshToken, err := createToken(
refreshToken, err := utils.CreateToken(
refreshHandle,
refresh_secret,
refresh_expiry,
models.UserJWT{
ID: user.ID.Hex(),
Identity: *identity,
Identity: identity,
},
)
@ -39,13 +45,13 @@ func successLogin(ctx context.Context, user *models.User) (*models.LoginResponse
return nil, err
}
accessToken, err := createToken(
accessToken, err := utils.CreateToken(
user.ID.Hex(),
access_secret,
access_expiry,
models.UserJWT{
ID: user.ID.Hex(),
Identity: *identity,
Identity: identity,
},
)
@ -53,14 +59,27 @@ func successLogin(ctx context.Context, user *models.User) (*models.LoginResponse
return nil, err
}
user.Token = &refreshHandle
_, err = app.Collection("users").UpdateByID(ctx, user.ID, bson.D{{Key: "$set", Value: user}})
_, err = app.Collection("users").UpdateByID(ctx, user.ID, bson.D{
{Key: "$set", Value: bson.D{
{Key: "token", Value: refreshHandle},
}},
})
if err != nil {
return nil, err
}
w := app.WriterFor(ctx)
http.SetCookie(*w, &http.Cookie{
Name: "access_token",
Value: *accessToken,
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
})
return &models.LoginResponse{
AccessToken: *accessToken,
RefreshToken: *refreshToken,

View File

@ -16,6 +16,10 @@ func Find(ctx context.Context) ([]*models.Category, error) {
return app.Find[models.Category](ctx, coll, bson.D{})
}
func FindChildren(ctx context.Context, id primitive.ObjectID) ([]*models.Category, error) {
return app.Find[models.Category](ctx, coll, bson.M{"parentId": id})
}
func Create(ctx context.Context, input models.CreateCategoryInput) (*models.Category, error) {
return app.InsertOne[models.Category](ctx, coll, input)
}

View File

@ -1,4 +1,4 @@
package helpers
package utils
import (
"context"

37
utils/camelize.go Normal file
View File

@ -0,0 +1,37 @@
package utils
import (
"strings"
"unicode"
)
func Camelize(s string) string {
// If input has no delimiters, assume it's already in camelCase or PascalCase
if !strings.ContainsAny(s, "_- ") {
return s[:1] + s[1:] // No change (optionally lowercase first letter: strings.ToLower(s[:1]) + s[1:])
}
// Otherwise, split and camelCase it
parts := strings.FieldsFunc(s, func(r rune) bool {
return r == '_' || r == '-' || unicode.IsSpace(r)
})
if len(parts) == 0 {
return ""
}
var b strings.Builder
b.WriteString(strings.ToLower(parts[0]))
for _, part := range parts[1:] {
if len(part) == 0 {
continue
}
b.WriteString(strings.ToUpper(part[:1]))
if len(part) > 1 {
b.WriteString(part[1:])
}
}
return b.String()
}

View File

@ -1,4 +1,4 @@
package helpers
package utils
import (
"crypto/aes"

View File

@ -1,4 +1,4 @@
package app
package utils
import (
"errors"
@ -18,7 +18,7 @@ func MarshalObjectID(value primitive.ObjectID) graphql.Marshaler {
}
func UnmarshalObjectID(value interface{}) (primitive.ObjectID, error) {
func UnmarshalObjectID(value any) (primitive.ObjectID, error) {
if str, ok := value.(string); ok {

View File

@ -1,17 +1,16 @@
package app
package utils
import (
"fmt"
"net/http"
"os"
"strings"
"time"
"git.farahty.com/nimer/go-mongo/models"
"github.com/golang-jwt/jwt/v4"
"github.com/mitchellh/mapstructure"
)
func getTokenFromHeader(r *http.Request) (string, error) {
func GetTokenFromHeader(r *http.Request) (string, error) {
authHeader := r.Header.Get("Authorization")
@ -28,9 +27,9 @@ func getTokenFromHeader(r *http.Request) (string, error) {
return strings.TrimSpace(authSlice[1]), nil
}
func getTokenFromCookie(r *http.Request) (string, error) {
func GetTokenFromCookie(r *http.Request) (string, error) {
cookie, err := r.Cookie("app-access-token")
cookie, err := r.Cookie("access_token")
if err != nil {
return "", fmt.Errorf("there is no authorization cookie provided")
@ -39,13 +38,13 @@ func getTokenFromCookie(r *http.Request) (string, error) {
return strings.TrimSpace(cookie.Value), nil
}
func getUserFromToken(tokenString string) (*models.UserJWT, error) {
token, err := jwt.Parse(tokenString, func(t *jwt.Token) (interface{}, error) {
func GetUserFromToken[T any](tokenString string, secret string) (*T, error) {
token, err := jwt.Parse(tokenString, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("wrong token format ")
}
return []byte(os.Getenv("ACCESS_SECRET")), nil
return []byte(secret), nil
})
if err != nil {
@ -56,7 +55,7 @@ func getUserFromToken(tokenString string) (*models.UserJWT, error) {
return nil, fmt.Errorf("token is not valid")
}
var user *models.UserJWT
var user *T
claims := token.Claims.(jwt.MapClaims)
if err := mapstructure.Decode(claims["data"], &user); err != nil {
@ -65,3 +64,25 @@ func getUserFromToken(tokenString string) (*models.UserJWT, error) {
return user, nil
}
func CreateToken(sub, secret, expiry string, payload any) (*string, error) {
duration, err := time.ParseDuration(expiry)
if err != nil {
return nil, err
}
token := jwt.New(jwt.SigningMethodHS256)
claims := token.Claims.(jwt.MapClaims)
claims["sub"] = sub
claims["exp"] = time.Now().Add(duration).Unix()
claims["data"] = payload
signedToken, err := token.SignedString([]byte(secret))
if err != nil {
return nil, err
}
return &signedToken, nil
}