Initial Commit
This commit is contained in:
27
internal/api/api.go
Normal file
27
internal/api/api.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"reichard.io/imagini/internal/db"
|
||||
"reichard.io/imagini/internal/auth"
|
||||
"reichard.io/imagini/internal/config"
|
||||
)
|
||||
|
||||
type API struct {
|
||||
Router *http.ServeMux
|
||||
Config *config.Config
|
||||
Auth *auth.AuthManager
|
||||
DB *db.DBManager
|
||||
}
|
||||
|
||||
func NewApi(db *db.DBManager, c *config.Config, auth *auth.AuthManager) *API {
|
||||
api := &API{
|
||||
Router: http.NewServeMux(),
|
||||
Config: c,
|
||||
Auth: auth,
|
||||
DB: db,
|
||||
}
|
||||
api.registerRoutes()
|
||||
return api
|
||||
}
|
||||
103
internal/api/auth.go
Normal file
103
internal/api/auth.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lestrrat-go/jwx/jwt"
|
||||
|
||||
"reichard.io/imagini/graph/model"
|
||||
)
|
||||
|
||||
func (api *API) refreshTokens(refreshToken jwt.Token) (string, string, error) {
|
||||
// Acquire User & Device
|
||||
did, ok := refreshToken.Get("did")
|
||||
if !ok {
|
||||
return "", "", errors.New("Missing DID")
|
||||
}
|
||||
uid, ok := refreshToken.Get(jwt.SubjectKey)
|
||||
if !ok {
|
||||
return "", "", errors.New("Missing UID")
|
||||
}
|
||||
deviceUUID, err := uuid.Parse(fmt.Sprintf("%v", did))
|
||||
if err != nil {
|
||||
return "", "", errors.New("Invalid DID")
|
||||
}
|
||||
userUUID, err := uuid.Parse(fmt.Sprintf("%v", uid))
|
||||
if err != nil {
|
||||
return "", "", errors.New("Invalid UID")
|
||||
}
|
||||
|
||||
// Device & User Skeleton
|
||||
user := model.User{ID: userUUID.String()}
|
||||
device := model.Device{ID: deviceUUID.String()}
|
||||
|
||||
// Find User
|
||||
_, err = api.DB.User(&user)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// Update Access Token
|
||||
accessTokenCookie, err := api.Auth.CreateJWTAccessToken(user, device)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return accessTokenCookie, "", err
|
||||
}
|
||||
|
||||
func (api *API) validateTokens(w *http.ResponseWriter, r *http.Request) (jwt.Token, error) {
|
||||
// TODO: Check from X-Imagini-AccessToken
|
||||
// TODO: Check from X-Imagini-RefreshToken
|
||||
|
||||
// Validate Access Token
|
||||
accessCookie, _ := r.Cookie("AccessToken")
|
||||
if accessCookie != nil {
|
||||
accessToken, err := api.Auth.ValidateJWTAccessToken(accessCookie.Value)
|
||||
if err == nil {
|
||||
return accessToken, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Validate Refresh Cookie Exists
|
||||
refreshCookie, _ := r.Cookie("RefreshToken")
|
||||
if refreshCookie == nil {
|
||||
return nil, errors.New("Tokens Invalid")
|
||||
}
|
||||
|
||||
// Validate Refresh Token
|
||||
refreshToken, err := api.Auth.ValidateJWTRefreshToken(refreshCookie.Value)
|
||||
if err != nil {
|
||||
return nil, errors.New("Tokens Invalid")
|
||||
}
|
||||
|
||||
// Refresh Access Token & Generate New Refresh Token
|
||||
newAccessToken, newRefreshToken, err := api.refreshTokens(refreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: Actually Refresh Refresh Token
|
||||
newRefreshToken = refreshCookie.Value
|
||||
|
||||
// Set appropriate cookies (TODO: Only for web!)
|
||||
|
||||
// Update Access & Refresh Cookies
|
||||
http.SetCookie(*w, &http.Cookie{
|
||||
Name: "AccessToken",
|
||||
Value: newAccessToken,
|
||||
})
|
||||
http.SetCookie(*w, &http.Cookie{
|
||||
Name: "RefreshToken",
|
||||
Value: newRefreshToken,
|
||||
})
|
||||
|
||||
// Only for iOS & Android (TODO: Remove for web! Only cause affected by CORS during development)
|
||||
(*w).Header().Set("X-Imagini-AccessToken", newAccessToken)
|
||||
(*w).Header().Set("X-Imagini-RefreshToken", newRefreshToken)
|
||||
|
||||
return jwt.ParseBytes([]byte(newAccessToken))
|
||||
}
|
||||
50
internal/api/directives.go
Normal file
50
internal/api/directives.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/99designs/gqlgen/graphql"
|
||||
"reichard.io/imagini/graph/model"
|
||||
)
|
||||
|
||||
/**
|
||||
* This is used to validate whether the users role is adequate for the requested resource.
|
||||
**/
|
||||
func (api *API) hasMinRoleDirective(ctx context.Context, obj interface{}, next graphql.Resolver, role model.Role) (res interface{}, err error) {
|
||||
authContext := ctx.Value("auth").(*model.AuthContext)
|
||||
accessToken, err := api.validateTokens(authContext.AuthResponse, authContext.AuthRequest)
|
||||
if err != nil {
|
||||
return nil, errors.New("Access Denied")
|
||||
}
|
||||
authContext.AccessToken = &accessToken
|
||||
|
||||
userRole, ok := accessToken.Get("role")
|
||||
if !ok {
|
||||
return nil, errors.New("Access Denied")
|
||||
}
|
||||
|
||||
if userRole == model.RoleAdmin.String() {
|
||||
return next(ctx)
|
||||
}
|
||||
|
||||
if userRole == role.String() {
|
||||
return next(ctx)
|
||||
}
|
||||
|
||||
return nil, errors.New("Role Not Authenticated")
|
||||
}
|
||||
|
||||
/**
|
||||
* This is needed but not used. Meta is used for Gorm.
|
||||
**/
|
||||
func (api *API) metaDirective(ctx context.Context, obj interface{}, next graphql.Resolver, gorm *string) (res interface{}, err error) {
|
||||
return next(ctx)
|
||||
}
|
||||
|
||||
/**
|
||||
* This overrides the response so fields with an @isPrivate directive are always nil.
|
||||
**/
|
||||
func (api *API) isPrivateDirective(ctx context.Context, obj interface{}, next graphql.Resolver) (res interface{}, err error) {
|
||||
return nil, errors.New("Private Field")
|
||||
}
|
||||
55
internal/api/media.go
Normal file
55
internal/api/media.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
|
||||
"reichard.io/imagini/graph/model"
|
||||
)
|
||||
|
||||
/**
|
||||
* Responsible for serving up static images / videos
|
||||
**/
|
||||
func (api *API) mediaHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
if path.Dir(r.URL.Path) != "/media" {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// Acquire Width & Height Parameters
|
||||
query := r.URL.Query()
|
||||
width := query["width"]
|
||||
height := query["height"]
|
||||
_ = width
|
||||
_ = height
|
||||
|
||||
// TODO: Caching & Resizing
|
||||
// - If both, force resize with new scale
|
||||
// - If one, scale resize proportionally
|
||||
|
||||
// Pull out userID
|
||||
authContext := r.Context().Value("auth").(*model.AuthContext)
|
||||
rawUserID, _ := (*authContext.AccessToken).Get("sub")
|
||||
userID := rawUserID.(string)
|
||||
|
||||
// Derive Path
|
||||
fileName := path.Base(r.URL.Path)
|
||||
folderPath := path.Join("/" + api.Config.DataPath + "/media/" + userID)
|
||||
mediaPath := path.Join(folderPath + "/" + fileName)
|
||||
|
||||
// Check if File Exists
|
||||
_, err := os.Stat(mediaPath)
|
||||
if os.IsNotExist(err) {
|
||||
// TODO: Different HTTP Response Code?
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
http.ServeFile(w, r, mediaPath)
|
||||
}
|
||||
1
internal/api/media_item.go
Normal file
1
internal/api/media_item.go
Normal file
@@ -0,0 +1 @@
|
||||
package api
|
||||
71
internal/api/middlewares.go
Normal file
71
internal/api/middlewares.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"reichard.io/imagini/graph/model"
|
||||
)
|
||||
|
||||
type Middleware func(http.Handler) http.HandlerFunc
|
||||
|
||||
func multipleMiddleware(h http.HandlerFunc, m ...Middleware) http.HandlerFunc {
|
||||
if len(m) < 1 {
|
||||
return h
|
||||
}
|
||||
wrapped := h
|
||||
for i := len(m) - 1; i >= 0; i-- {
|
||||
wrapped = m[i](wrapped)
|
||||
}
|
||||
return wrapped
|
||||
}
|
||||
|
||||
/**
|
||||
* This is used for the graphQL endpoints that may require access to the
|
||||
* Request and ResponseWriter variables. These are used to get / set cookies.
|
||||
**/
|
||||
func (api *API) queryMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// TODO: REMOVE (SOME OF) THIS!! Only for developement due to CORS
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Expose-Headers", "*")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
|
||||
|
||||
authContext := &model.AuthContext{
|
||||
AuthResponse: &w,
|
||||
AuthRequest: r,
|
||||
}
|
||||
|
||||
// Add context
|
||||
ctx := context.WithValue(r.Context(), "auth", authContext)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* This is used for non graphQL endpoints that require authentication.
|
||||
**/
|
||||
func (api *API) authMiddleware(next http.Handler) http.HandlerFunc {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Validate Tokens
|
||||
accessToken, err := api.validateTokens(&w, r)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Create Context
|
||||
authContext := &model.AuthContext{
|
||||
AccessToken: &accessToken,
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), "auth", authContext)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
36
internal/api/routes.go
Normal file
36
internal/api/routes.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"github.com/99designs/gqlgen/graphql/handler"
|
||||
"github.com/99designs/gqlgen/graphql/playground"
|
||||
|
||||
"reichard.io/imagini/graph"
|
||||
"reichard.io/imagini/graph/generated"
|
||||
)
|
||||
|
||||
func (api *API) registerRoutes() {
|
||||
// Set up Directives
|
||||
graphConfig := generated.Config{
|
||||
Resolvers: &graph.Resolver{
|
||||
DB: api.DB,
|
||||
Auth: api.Auth,
|
||||
Config: api.Config,
|
||||
},
|
||||
Directives: generated.DirectiveRoot{
|
||||
Meta: api.metaDirective,
|
||||
IsPrivate: api.isPrivateDirective,
|
||||
HasMinRole: api.hasMinRoleDirective,
|
||||
},
|
||||
}
|
||||
srv := handler.NewDefaultServer(generated.NewExecutableSchema(graphConfig))
|
||||
|
||||
// Handle GraphQL
|
||||
api.Router.Handle("/playground", playground.Handler("GraphQL playground", "/query"))
|
||||
api.Router.Handle("/query", api.queryMiddleware(srv))
|
||||
|
||||
// Handle Resource Route
|
||||
api.Router.HandleFunc("/media/", multipleMiddleware(
|
||||
api.mediaHandler,
|
||||
api.authMiddleware,
|
||||
))
|
||||
}
|
||||
57
internal/auth/auth.go
Normal file
57
internal/auth/auth.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"reichard.io/imagini/graph/model"
|
||||
"reichard.io/imagini/internal/config"
|
||||
"reichard.io/imagini/internal/db"
|
||||
)
|
||||
|
||||
type AuthManager struct {
|
||||
DB *db.DBManager
|
||||
Config *config.Config
|
||||
}
|
||||
|
||||
func NewMgr(db *db.DBManager, c *config.Config) *AuthManager {
|
||||
return &AuthManager{
|
||||
DB: db,
|
||||
Config: c,
|
||||
}
|
||||
}
|
||||
|
||||
func (auth *AuthManager) AuthenticateUser(user, password string) (model.User, bool) {
|
||||
// Find User by Username / Email
|
||||
foundUser := &model.User{Username: user}
|
||||
_, err := auth.DB.User(foundUser)
|
||||
|
||||
// By Username
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
foundUser = &model.User{Email: user}
|
||||
_, err = auth.DB.User(foundUser)
|
||||
}
|
||||
|
||||
// By Email
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
log.Warn("[auth] User not found: ", user)
|
||||
return *foundUser, false
|
||||
} else if err != nil {
|
||||
log.Error(err)
|
||||
return *foundUser, false
|
||||
}
|
||||
|
||||
log.Info("[auth] Authenticating user: ", foundUser.Username)
|
||||
|
||||
// Determine Type
|
||||
switch foundUser.AuthType {
|
||||
case "Local":
|
||||
return *foundUser, authenticateLocalUser(*foundUser, password)
|
||||
case "LDAP":
|
||||
return *foundUser, authenticateLDAPUser(*foundUser, password)
|
||||
default:
|
||||
return *foundUser, false
|
||||
}
|
||||
}
|
||||
126
internal/auth/jwt.go
Normal file
126
internal/auth/jwt.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lestrrat-go/jwx/jwa"
|
||||
"github.com/lestrrat-go/jwx/jwt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"reichard.io/imagini/graph/model"
|
||||
)
|
||||
|
||||
func (auth *AuthManager) ValidateJWTRefreshToken(refreshJWT string) (jwt.Token, error) {
|
||||
byteRefreshJWT := []byte(refreshJWT)
|
||||
|
||||
// Acquire Relevant Device
|
||||
unverifiedToken, err := jwt.ParseBytes(byteRefreshJWT)
|
||||
did, ok := unverifiedToken.Get("did")
|
||||
if !ok {
|
||||
return nil, errors.New("did does not exist")
|
||||
}
|
||||
deviceID, err := uuid.Parse(fmt.Sprintf("%v", did))
|
||||
if err != nil {
|
||||
return nil, errors.New("did does not parse")
|
||||
}
|
||||
device := &model.Device{ID: deviceID.String()}
|
||||
_, err = auth.DB.Device(device)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Verify & Validate Token
|
||||
verifiedToken, err := jwt.ParseBytes(byteRefreshJWT,
|
||||
jwt.WithValidate(true),
|
||||
jwt.WithVerify(jwa.HS256, []byte(*device.RefreshKey)),
|
||||
)
|
||||
if err != nil {
|
||||
fmt.Println("failed to parse payload: ", err)
|
||||
return nil, err
|
||||
}
|
||||
return verifiedToken, nil
|
||||
}
|
||||
|
||||
func (auth *AuthManager) ValidateJWTAccessToken(accessJWT string) (jwt.Token, error) {
|
||||
byteAccessJWT := []byte(accessJWT)
|
||||
verifiedToken, err := jwt.ParseBytes(byteAccessJWT,
|
||||
jwt.WithValidate(true),
|
||||
jwt.WithVerify(jwa.HS256, []byte(auth.Config.JWTSecret)),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return verifiedToken, nil
|
||||
}
|
||||
|
||||
func (auth *AuthManager) CreateJWTRefreshToken(user model.User, device model.Device) (string, error) {
|
||||
// Acquire Refresh Key
|
||||
byteKey := []byte(*device.RefreshKey)
|
||||
|
||||
// Create New Token
|
||||
tm := time.Now()
|
||||
t := jwt.New()
|
||||
t.Set(`did`, device.ID) // Device ID
|
||||
t.Set(jwt.SubjectKey, user.ID) // User ID
|
||||
t.Set(jwt.AudienceKey, `imagini`) // App ID
|
||||
t.Set(jwt.IssuedAtKey, tm) // Issued At
|
||||
|
||||
// iOS & Android = Never Expiring Refresh Token
|
||||
if device.Type != "iOS" && device.Type != "Android" {
|
||||
t.Set(jwt.ExpirationKey, tm.Add(time.Hour*24)) // 1 Day Access Key
|
||||
}
|
||||
|
||||
// Validate Token Creation
|
||||
_, err := json.MarshalIndent(t, "", " ")
|
||||
if err != nil {
|
||||
fmt.Printf("failed to generate JSON: %s\n", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Sign Token
|
||||
signed, err := jwt.Sign(t, jwa.HS256, byteKey)
|
||||
if err != nil {
|
||||
log.Printf("failed to sign token: %s", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Return Token
|
||||
return string(signed), nil
|
||||
}
|
||||
|
||||
func (auth *AuthManager) CreateJWTAccessToken(user model.User, device model.Device) (string, error) {
|
||||
// Create New Token
|
||||
tm := time.Now()
|
||||
t := jwt.New()
|
||||
t.Set(`did`, device.ID) // Device ID
|
||||
t.Set(`role`, user.Role.String()) // User Role (Admin / User)
|
||||
t.Set(jwt.SubjectKey, user.ID) // User ID
|
||||
t.Set(jwt.AudienceKey, `imagini`) // App ID
|
||||
t.Set(jwt.IssuedAtKey, tm) // Issued At
|
||||
t.Set(jwt.ExpirationKey, tm.Add(time.Hour*2)) // 2 Hour Access Key
|
||||
|
||||
// Validate Token Creation
|
||||
_, err := json.MarshalIndent(t, "", " ")
|
||||
if err != nil {
|
||||
fmt.Printf("failed to generate JSON: %s\n", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Use Server Key
|
||||
byteKey := []byte(auth.Config.JWTSecret)
|
||||
|
||||
// Sign Token
|
||||
signed, err := jwt.Sign(t, jwa.HS256, byteKey)
|
||||
if err != nil {
|
||||
log.Printf("failed to sign token: %s", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Return Token
|
||||
return string(signed), nil
|
||||
}
|
||||
9
internal/auth/ldap.go
Normal file
9
internal/auth/ldap.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"reichard.io/imagini/graph/model"
|
||||
)
|
||||
|
||||
func authenticateLDAPUser(user model.User, pw string) bool {
|
||||
return false
|
||||
}
|
||||
18
internal/auth/local.go
Normal file
18
internal/auth/local.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"reichard.io/imagini/graph/model"
|
||||
)
|
||||
|
||||
func authenticateLocalUser(user model.User, pw string) bool {
|
||||
bPassword := []byte(pw)
|
||||
err := bcrypt.CompareHashAndPassword([]byte(*user.Password), bPassword)
|
||||
if err == nil {
|
||||
log.Info("[auth] Authentication successfull: ", user.Username)
|
||||
return true
|
||||
}
|
||||
log.Warn("[auth] Authentication failed: ", user.Username)
|
||||
return false
|
||||
}
|
||||
34
internal/config/config.go
Normal file
34
internal/config/config.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
DBType string
|
||||
DBName string
|
||||
DBPassword string
|
||||
DataPath string
|
||||
ConfigPath string
|
||||
JWTSecret string
|
||||
ListenPort string
|
||||
}
|
||||
|
||||
func Load() *Config {
|
||||
return &Config{
|
||||
DBType: getEnv("DATABASE_TYPE", "SQLite"),
|
||||
DBName: getEnv("DATABASE_NAME", "imagini"),
|
||||
DBPassword: getEnv("DATABASE_PASSWORD", ""),
|
||||
ConfigPath: getEnv("CONFIG_PATH", "/config"),
|
||||
DataPath: getEnv("DATA_PATH", "/data"),
|
||||
JWTSecret: getEnv("JWT_SECRET", "58b9340c0472cf045db226bc445966524e780cd38bc3dd707afce80c95d4de6f"),
|
||||
ListenPort: getEnv("LISTEN_PORT", "8484"),
|
||||
}
|
||||
}
|
||||
|
||||
func getEnv(key, fallback string) string {
|
||||
if value, ok := os.LookupEnv(key); ok {
|
||||
return value
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
37
internal/db/albums.go
Normal file
37
internal/db/albums.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"reichard.io/imagini/graph/model"
|
||||
)
|
||||
|
||||
func (dbm *DBManager) CreateAlbum(album *model.Album) error {
|
||||
log.Debug("[db] Creating album: ", album.Name)
|
||||
err := dbm.db.Create(album).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func (dbm *DBManager) Album(album *model.Album) (int64, error) {
|
||||
var count int64
|
||||
err := dbm.db.Where(album).First(album).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (dbm *DBManager) Albums(userID string, filters *model.AlbumFilter, page *model.Page, order *model.Order) ([]*model.Album, model.PageResponse, error) {
|
||||
// Initial User Filter
|
||||
tx := dbm.db.Session(&gorm.Session{}).Model(&model.Album{}).Where("user_id == ?", userID)
|
||||
|
||||
// Dynamically Generate Base Query
|
||||
tx, pageResponse := dbm.generateBaseQuery(tx, filters, page, order)
|
||||
|
||||
// Acquire Results
|
||||
var foundAlbums []*model.Album
|
||||
err := tx.Find(&foundAlbums).Error
|
||||
return foundAlbums, pageResponse, err
|
||||
}
|
||||
|
||||
func (dbm *DBManager) DeleteAlbum(album *model.Album) error {
|
||||
return nil
|
||||
}
|
||||
155
internal/db/db.go
Normal file
155
internal/db/db.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path"
|
||||
"reflect"
|
||||
|
||||
"github.com/iancoleman/strcase"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
// "gorm.io/gorm/logger"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"reichard.io/imagini/graph/model"
|
||||
"reichard.io/imagini/internal/config"
|
||||
)
|
||||
|
||||
type DBManager struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewMgr(c *config.Config) *DBManager {
|
||||
gormConfig := &gorm.Config{
|
||||
PrepareStmt: true,
|
||||
// Logger: logger.Default.LogMode(logger.Silent),
|
||||
}
|
||||
|
||||
// Create manager
|
||||
dbm := &DBManager{}
|
||||
|
||||
if c.DBType == "SQLite" {
|
||||
dbLocation := path.Join(c.ConfigPath, "imagini.db")
|
||||
dbm.db, _ = gorm.Open(sqlite.Open(dbLocation), gormConfig)
|
||||
dbm.db = dbm.db.Debug()
|
||||
} else {
|
||||
log.Fatal("Unsupported Database")
|
||||
}
|
||||
|
||||
// Initialize database
|
||||
dbm.db.AutoMigrate(&model.Device{})
|
||||
dbm.db.AutoMigrate(&model.User{})
|
||||
dbm.db.AutoMigrate(&model.MediaItem{})
|
||||
dbm.db.AutoMigrate(&model.Tag{})
|
||||
dbm.db.AutoMigrate(&model.Album{})
|
||||
|
||||
// Determine whether to bootstrap
|
||||
var count int64
|
||||
dbm.db.Model(&model.User{}).Count(&count)
|
||||
if count == 0 {
|
||||
dbm.bootstrapDatabase()
|
||||
}
|
||||
|
||||
return dbm
|
||||
}
|
||||
|
||||
func (dbm *DBManager) bootstrapDatabase() {
|
||||
log.Info("[query] Bootstrapping database.")
|
||||
|
||||
password := "admin"
|
||||
user := &model.User{
|
||||
Username: "admin",
|
||||
AuthType: "Local",
|
||||
Password: &password,
|
||||
Role: model.RoleAdmin,
|
||||
}
|
||||
|
||||
err := dbm.CreateUser(user)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal("[query] Unable to bootstrap database.")
|
||||
}
|
||||
}
|
||||
|
||||
func (dbm *DBManager) generateBaseQuery(tx *gorm.DB, filter interface{}, page *model.Page, order *model.Order) (*gorm.DB, model.PageResponse) {
|
||||
tx = dbm.generateFilter(tx, filter)
|
||||
tx = dbm.generateOrder(tx, order, filter)
|
||||
tx, pageResponse := dbm.generatePage(tx, page)
|
||||
return tx, pageResponse
|
||||
}
|
||||
|
||||
func (dbm *DBManager) generateOrder(tx *gorm.DB, order *model.Order, filter interface{}) *gorm.DB {
|
||||
// Set Defaults
|
||||
orderBy := "created_at"
|
||||
orderDirection := model.OrderDirectionDesc
|
||||
|
||||
if order == nil {
|
||||
order = &model.Order{
|
||||
By: &orderBy,
|
||||
Direction: &orderDirection,
|
||||
}
|
||||
}
|
||||
|
||||
if order.By == nil {
|
||||
order.By = &orderBy
|
||||
}
|
||||
|
||||
if order.Direction == nil {
|
||||
order.Direction = &orderDirection
|
||||
}
|
||||
|
||||
// Get Possible Values
|
||||
ptr := reflect.New(reflect.TypeOf(filter).Elem())
|
||||
v := reflect.Indirect(ptr)
|
||||
|
||||
isValid := false
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
fieldName := v.Type().Field(i).Name
|
||||
if strcase.ToSnake(*order.By) == strcase.ToSnake(fieldName) {
|
||||
isValid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isValid {
|
||||
tx = tx.Order(fmt.Sprintf("%s %s", strcase.ToSnake(*order.By), order.Direction.String()))
|
||||
}
|
||||
|
||||
return tx
|
||||
}
|
||||
|
||||
func (dbm *DBManager) generatePage(tx *gorm.DB, page *model.Page) (*gorm.DB, model.PageResponse) {
|
||||
// Set Defaults
|
||||
var count int64
|
||||
pageSize := 50
|
||||
pageNum := 1
|
||||
|
||||
if page == nil {
|
||||
page = &model.Page{
|
||||
Size: &pageSize,
|
||||
Page: &pageNum,
|
||||
}
|
||||
}
|
||||
|
||||
if page.Size == nil {
|
||||
page.Size = &pageSize
|
||||
}
|
||||
|
||||
if page.Page == nil {
|
||||
page.Page = &pageNum
|
||||
}
|
||||
|
||||
// Acquire Counts Before Pagination
|
||||
tx.Count(&count)
|
||||
|
||||
// Calculate Offset
|
||||
calculatedOffset := (*page.Page - 1) * *page.Size
|
||||
tx = tx.Limit(*page.Size).Offset(calculatedOffset)
|
||||
|
||||
return tx, model.PageResponse{
|
||||
Page: *page.Page,
|
||||
Size: *page.Size,
|
||||
Total: int(count),
|
||||
}
|
||||
}
|
||||
44
internal/db/devices.go
Normal file
44
internal/db/devices.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"reichard.io/imagini/graph/model"
|
||||
)
|
||||
|
||||
func (dbm *DBManager) CreateDevice(device *model.Device) error {
|
||||
log.Debug("[db] Creating device: ", device.Name)
|
||||
refreshKey := uuid.New().String()
|
||||
device.RefreshKey = &refreshKey
|
||||
err := dbm.db.Create(device).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func (dbm *DBManager) Device(device *model.Device) (int64, error) {
|
||||
var count int64
|
||||
err := dbm.db.Where(device).First(device).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (dbm *DBManager) Devices(userID string, filters *model.DeviceFilter, page *model.Page, order *model.Order) ([]*model.Device, model.PageResponse, error) {
|
||||
// Initial User Filter
|
||||
tx := dbm.db.Session(&gorm.Session{}).Model(&model.Device{}).Where("user_id == ?", userID)
|
||||
|
||||
// Dynamically Generate Base Query
|
||||
tx, pageResponse := dbm.generateBaseQuery(tx, filters, page, order)
|
||||
|
||||
// Acquire Results
|
||||
var foundDevices []*model.Device
|
||||
err := tx.Find(&foundDevices).Error
|
||||
return foundDevices, pageResponse, err
|
||||
}
|
||||
|
||||
func (dbm *DBManager) DeleteDevice(user *model.Device) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dbm *DBManager) UpdateRefreshToken(device *model.Device, refreshToken string) error {
|
||||
return nil
|
||||
}
|
||||
196
internal/db/filters.go
Normal file
196
internal/db/filters.go
Normal file
@@ -0,0 +1,196 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/iancoleman/strcase"
|
||||
"gorm.io/gorm"
|
||||
"reichard.io/imagini/graph/model"
|
||||
)
|
||||
|
||||
// Generic function used to generate filters for the DB
|
||||
func (dbm *DBManager) generateFilter(tx *gorm.DB, filter interface{}) *gorm.DB {
|
||||
ptr := reflect.ValueOf(filter)
|
||||
v := reflect.Indirect(ptr)
|
||||
|
||||
if v == reflect.ValueOf(nil) {
|
||||
return tx
|
||||
}
|
||||
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
fieldName := strcase.ToSnake(v.Type().Field(i).Name)
|
||||
fieldVal := v.Field(i)
|
||||
|
||||
if fieldVal.IsNil() {
|
||||
continue
|
||||
}
|
||||
|
||||
switch valType := fieldVal.Type(); valType {
|
||||
case reflect.TypeOf(&model.StringFilter{}):
|
||||
tx = generateStringFilter(tx, fieldName, fieldVal.Interface().(*model.StringFilter))
|
||||
case reflect.TypeOf(&model.BooleanFilter{}):
|
||||
tx = generateBooleanFilter(tx, fieldName, fieldVal.Interface().(*model.BooleanFilter))
|
||||
case reflect.TypeOf(&model.FloatFilter{}):
|
||||
tx = generateFloatFilter(tx, fieldName, fieldVal.Interface().(*model.FloatFilter))
|
||||
case reflect.TypeOf(&model.IntFilter{}):
|
||||
tx = generateIntFilter(tx, fieldName, fieldVal.Interface().(*model.IntFilter))
|
||||
case reflect.TypeOf(&model.IDFilter{}):
|
||||
tx = generateIDFilter(tx, fieldName, fieldVal.Interface().(*model.IDFilter))
|
||||
case reflect.TypeOf(&model.TimeFilter{}):
|
||||
tx = generateTimeFilter(tx, fieldName, fieldVal.Interface().(*model.TimeFilter))
|
||||
case reflect.TypeOf(&model.RoleFilter{}):
|
||||
tx = generateRoleFilter(tx, fieldName, fieldVal.Interface().(*model.RoleFilter))
|
||||
case reflect.TypeOf(&model.DeviceTypeFilter{}):
|
||||
tx = generateDeviceTypeFilter(tx, fieldName, fieldVal.Interface().(*model.DeviceTypeFilter))
|
||||
case reflect.TypeOf(&model.AuthTypeFilter{}):
|
||||
tx = generateAuthTypeFilter(tx, fieldName, fieldVal.Interface().(*model.AuthTypeFilter))
|
||||
}
|
||||
}
|
||||
|
||||
return tx
|
||||
}
|
||||
|
||||
func generateStringFilter(tx *gorm.DB, fieldName string, filter *model.StringFilter) *gorm.DB {
|
||||
if filter.EqualTo != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s == ?", fieldName), *filter.EqualTo)
|
||||
}
|
||||
if filter.NotEqualTo != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s != ?", fieldName), *filter.NotEqualTo)
|
||||
}
|
||||
if filter.StartsWith != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s LIKE ?", fieldName), fmt.Sprintf("%s%%", *filter.StartsWith))
|
||||
}
|
||||
if filter.NotStartsWith != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s NOT LIKE ?", fieldName), fmt.Sprintf("%s%%", *filter.NotStartsWith))
|
||||
}
|
||||
if filter.EndsWith != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s LIKE ?", fieldName), fmt.Sprintf("%%%s", *filter.EndsWith))
|
||||
}
|
||||
if filter.NotEndsWith != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s NOT LIKE ?", fieldName), fmt.Sprintf("%%%s", *filter.NotEndsWith))
|
||||
}
|
||||
if filter.Contains != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s LIKE ?", fieldName), fmt.Sprintf("%%%s%%", *filter.Contains))
|
||||
}
|
||||
if filter.NotContains != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s NOT LIKE ?", fieldName), fmt.Sprintf("%%%s%%", *filter.NotContains))
|
||||
}
|
||||
return tx
|
||||
}
|
||||
|
||||
func generateBooleanFilter(tx *gorm.DB, fieldName string, filter *model.BooleanFilter) *gorm.DB {
|
||||
if filter.EqualTo != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s == ?", fieldName), *filter.EqualTo)
|
||||
}
|
||||
if filter.NotEqualTo != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s != ?", fieldName), *filter.NotEqualTo)
|
||||
}
|
||||
return tx
|
||||
}
|
||||
|
||||
func generateFloatFilter(tx *gorm.DB, fieldName string, filter *model.FloatFilter) *gorm.DB {
|
||||
if filter.EqualTo != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s == ?", fieldName), *filter.EqualTo)
|
||||
}
|
||||
if filter.NotEqualTo != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s != ?", fieldName), *filter.NotEqualTo)
|
||||
}
|
||||
if filter.GreaterThan != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s > ?", fieldName), *filter.GreaterThan)
|
||||
}
|
||||
if filter.GreaterThanOrEqualTo != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s >= ?", fieldName), *filter.GreaterThanOrEqualTo)
|
||||
}
|
||||
if filter.LessThan != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s < ?", fieldName), *filter.LessThan)
|
||||
}
|
||||
if filter.LessThanOrEqualTo != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s <= ?", fieldName), *filter.LessThanOrEqualTo)
|
||||
}
|
||||
return tx
|
||||
}
|
||||
|
||||
func generateIntFilter(tx *gorm.DB, fieldName string, filter *model.IntFilter) *gorm.DB {
|
||||
if filter.EqualTo != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s == ?", fieldName), *filter.EqualTo)
|
||||
}
|
||||
if filter.NotEqualTo != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s != ?", fieldName), *filter.NotEqualTo)
|
||||
}
|
||||
if filter.GreaterThan != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s > ?", fieldName), *filter.GreaterThan)
|
||||
}
|
||||
if filter.GreaterThanOrEqualTo != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s >= ?", fieldName), *filter.GreaterThanOrEqualTo)
|
||||
}
|
||||
if filter.LessThan != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s < ?", fieldName), *filter.LessThan)
|
||||
}
|
||||
if filter.LessThanOrEqualTo != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s <= ?", fieldName), *filter.LessThanOrEqualTo)
|
||||
}
|
||||
return tx
|
||||
}
|
||||
|
||||
func generateIDFilter(tx *gorm.DB, fieldName string, filter *model.IDFilter) *gorm.DB {
|
||||
if filter.EqualTo != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s == ?", fieldName), *filter.EqualTo)
|
||||
}
|
||||
if filter.NotEqualTo != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s != ?", fieldName), *filter.NotEqualTo)
|
||||
}
|
||||
return tx
|
||||
}
|
||||
|
||||
func generateTimeFilter(tx *gorm.DB, fieldName string, filter *model.TimeFilter) *gorm.DB {
|
||||
if filter.EqualTo != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s == ?", fieldName), *filter.EqualTo)
|
||||
}
|
||||
if filter.NotEqualTo != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s != ?", fieldName), *filter.NotEqualTo)
|
||||
}
|
||||
if filter.GreaterThan != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s > ?", fieldName), *filter.GreaterThan)
|
||||
}
|
||||
if filter.GreaterThanOrEqualTo != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s >= ?", fieldName), *filter.GreaterThanOrEqualTo)
|
||||
}
|
||||
if filter.LessThan != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s < ?", fieldName), *filter.LessThan)
|
||||
}
|
||||
if filter.LessThanOrEqualTo != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s <= ?", fieldName), *filter.LessThanOrEqualTo)
|
||||
}
|
||||
return tx
|
||||
}
|
||||
|
||||
func generateRoleFilter(tx *gorm.DB, fieldName string, filter *model.RoleFilter) *gorm.DB {
|
||||
if filter.EqualTo != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s == ?", fieldName), *filter.EqualTo)
|
||||
}
|
||||
if filter.NotEqualTo != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s != ?", fieldName), *filter.NotEqualTo)
|
||||
}
|
||||
return tx
|
||||
}
|
||||
|
||||
func generateDeviceTypeFilter(tx *gorm.DB, fieldName string, filter *model.DeviceTypeFilter) *gorm.DB {
|
||||
if filter.EqualTo != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s == ?", fieldName), *filter.EqualTo)
|
||||
}
|
||||
if filter.NotEqualTo != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s != ?", fieldName), *filter.NotEqualTo)
|
||||
}
|
||||
return tx
|
||||
}
|
||||
|
||||
func generateAuthTypeFilter(tx *gorm.DB, fieldName string, filter *model.AuthTypeFilter) *gorm.DB {
|
||||
if filter.EqualTo != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s == ?", fieldName), *filter.EqualTo)
|
||||
}
|
||||
if filter.NotEqualTo != nil {
|
||||
tx = tx.Where(fmt.Sprintf("%s != ?", fieldName), *filter.NotEqualTo)
|
||||
}
|
||||
return tx
|
||||
}
|
||||
34
internal/db/media_items.go
Normal file
34
internal/db/media_items.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"reichard.io/imagini/graph/model"
|
||||
)
|
||||
|
||||
func (dbm *DBManager) CreateMediaItem(mediaItem *model.MediaItem) error {
|
||||
log.Debug("[db] Creating media item: ", mediaItem.FileName)
|
||||
err := dbm.db.Create(mediaItem).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func (dbm *DBManager) MediaItem(mediaItem *model.MediaItem) (int64, error) {
|
||||
var count int64
|
||||
err := dbm.db.Where(mediaItem).First(mediaItem).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// UserID, Filters, Sort, Page, Delete
|
||||
func (dbm *DBManager) MediaItems(userID string, filters *model.MediaItemFilter, page *model.Page, order *model.Order) ([]*model.MediaItem, model.PageResponse, error) {
|
||||
// Initial User Filter
|
||||
tx := dbm.db.Session(&gorm.Session{}).Model(&model.MediaItem{}).Where("user_id == ?", userID)
|
||||
|
||||
// Dynamically Generate Base Query
|
||||
tx, pageResponse := dbm.generateBaseQuery(tx, filters, page, order)
|
||||
|
||||
// Acquire Results
|
||||
var mediaItems []*model.MediaItem
|
||||
err := tx.Find(&mediaItems).Error
|
||||
return mediaItems, pageResponse, err
|
||||
}
|
||||
37
internal/db/tags.go
Normal file
37
internal/db/tags.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"reichard.io/imagini/graph/model"
|
||||
)
|
||||
|
||||
func (dbm *DBManager) CreateTag(tag *model.Tag) error {
|
||||
log.Debug("[db] Creating tag: ", tag.Name)
|
||||
err := dbm.db.Create(tag).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func (dbm *DBManager) Tag(tag *model.Tag) (int64, error) {
|
||||
var count int64
|
||||
err := dbm.db.Where(tag).First(tag).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (dbm *DBManager) Tags(userID string, filters *model.TagFilter, page *model.Page, order *model.Order) ([]*model.Tag, model.PageResponse, error) {
|
||||
// Initial User Filter
|
||||
tx := dbm.db.Session(&gorm.Session{}).Model(&model.Tag{}).Where("user_id == ?", userID)
|
||||
|
||||
// Dynamically Generate Base Query
|
||||
tx, pageResponse := dbm.generateBaseQuery(tx, filters, page, order)
|
||||
|
||||
// Acquire Results
|
||||
var foundTags []*model.Tag
|
||||
err := tx.Find(&foundTags).Error
|
||||
return foundTags, pageResponse, err
|
||||
}
|
||||
|
||||
func (dbm *DBManager) DeleteTag(tag *model.Tag) error {
|
||||
return nil
|
||||
}
|
||||
48
internal/db/users.go
Normal file
48
internal/db/users.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"reichard.io/imagini/graph/model"
|
||||
)
|
||||
|
||||
func (dbm *DBManager) CreateUser(user *model.User) error {
|
||||
log.Info("[db] Creating user: ", user.Username)
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(*user.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
stringHashedPassword := string(hashedPassword)
|
||||
user.Password = &stringHashedPassword
|
||||
return dbm.db.Create(user).Error
|
||||
}
|
||||
|
||||
func (dbm *DBManager) User(user *model.User) (int64, error) {
|
||||
var count int64
|
||||
err := dbm.db.Where(user).First(user).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (dbm *DBManager) Users(filters *model.UserFilter, page *model.Page, order *model.Order) ([]*model.User, model.PageResponse, error) {
|
||||
// Initial User Filter
|
||||
tx := dbm.db.Session(&gorm.Session{}).Model(&model.Tag{})
|
||||
|
||||
// Dynamically Generate Base Query
|
||||
tx, pageResponse := dbm.generateBaseQuery(tx, filters, page, order)
|
||||
|
||||
// Acquire Results
|
||||
var foundUsers []*model.User
|
||||
err := tx.Find(&foundUsers).Error
|
||||
return foundUsers, pageResponse, err
|
||||
}
|
||||
|
||||
func (dbm *DBManager) DeleteUser(user model.User) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dbm *DBManager) UpdatePassword(user model.User, pw string) {
|
||||
|
||||
}
|
||||
40
internal/session/session.go
Normal file
40
internal/session/session.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Used to maintain a cache of user specific jwt secrets
|
||||
// This will prevent DB lookups on every request
|
||||
// May not actually be needed. Refresh Token is the only
|
||||
// token that will require proactive DB lookups.
|
||||
type SessionManager struct {
|
||||
mutex sync.Mutex
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func NewMgr() *SessionManager {
|
||||
return &SessionManager{}
|
||||
}
|
||||
|
||||
func (sm *SessionManager) Set(key, value string) {
|
||||
sm.mutex.Lock()
|
||||
sm.values[key] = value
|
||||
sm.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (sm *SessionManager) Get(key string) string {
|
||||
sm.mutex.Lock()
|
||||
defer sm.mutex.Unlock()
|
||||
return sm.values[key]
|
||||
}
|
||||
|
||||
func (sm *SessionManager) Delete(key string) {
|
||||
sm.mutex.Lock()
|
||||
defer sm.mutex.Unlock()
|
||||
_, exists := sm.values[key]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
delete(sm.values, key)
|
||||
}
|
||||
Reference in New Issue
Block a user