package db import ( "fmt" "path" "reflect" "github.com/iancoleman/strcase" log "github.com/sirupsen/logrus" // "gorm.io/gorm/logger" "gorm.io/driver/sqlite" "gorm.io/gorm" "reichard.io/imagini/graph/model" "reichard.io/imagini/internal/config" ) type DBManager struct { db *gorm.DB } func NewMgr(c *config.Config) *DBManager { gormConfig := &gorm.Config{ PrepareStmt: true, // Logger: logger.Default.LogMode(logger.Silent), } // Create manager dbm := &DBManager{} if c.DBType == "SQLite" { dbLocation := path.Join(c.ConfigPath, "imagini.db") dbm.db, _ = gorm.Open(sqlite.Open(dbLocation), gormConfig) dbm.db = dbm.db.Debug() } else { log.Fatal("Unsupported Database") } // Initialize database dbm.db.AutoMigrate(&model.Device{}) dbm.db.AutoMigrate(&model.User{}) dbm.db.AutoMigrate(&model.MediaItem{}) dbm.db.AutoMigrate(&model.Tag{}) dbm.db.AutoMigrate(&model.Album{}) // Determine whether to bootstrap var count int64 dbm.db.Model(&model.User{}).Count(&count) if count == 0 { dbm.bootstrapDatabase() } return dbm } func (dbm *DBManager) bootstrapDatabase() { log.Info("[query] Bootstrapping database.") password := "admin" user := &model.User{ Username: "admin", AuthType: "Local", Password: &password, Role: model.RoleAdmin, } err := dbm.CreateUser(user) if err != nil { log.Fatal("[query] Unable to bootstrap database.") } } func (dbm *DBManager) generateBaseQuery(tx *gorm.DB, filter interface{}, page *model.Page, order *model.Order) (*gorm.DB, model.PageResponse) { tx = dbm.generateFilter(tx, filter) tx = dbm.generateOrder(tx, order, filter) tx, pageResponse := dbm.generatePage(tx, page) return tx, pageResponse } func (dbm *DBManager) generateOrder(tx *gorm.DB, order *model.Order, filter interface{}) *gorm.DB { // Set Defaults orderBy := "created_at" orderDirection := model.OrderDirectionDesc if order == nil { order = &model.Order{ By: &orderBy, Direction: &orderDirection, } } if order.By == nil { order.By = &orderBy } if order.Direction == nil { order.Direction = &orderDirection } // Get Possible Values ptr := reflect.New(reflect.TypeOf(filter).Elem()) v := reflect.Indirect(ptr) isValid := false for i := 0; i < v.NumField(); i++ { fieldName := v.Type().Field(i).Name if strcase.ToSnake(*order.By) == strcase.ToSnake(fieldName) { isValid = true break } } if isValid { tx = tx.Order(fmt.Sprintf("%s %s", strcase.ToSnake(*order.By), order.Direction.String())) } return tx } func (dbm *DBManager) generatePage(tx *gorm.DB, page *model.Page) (*gorm.DB, model.PageResponse) { // Set Defaults var count int64 pageSize := 50 pageNum := 1 if page == nil { page = &model.Page{ Size: &pageSize, Page: &pageNum, } } if page.Size == nil { page.Size = &pageSize } if page.Page == nil { page.Page = &pageNum } // Acquire Counts Before Pagination tx.Count(&count) // Calculate Offset calculatedOffset := (*page.Page - 1) * *page.Size tx = tx.Limit(*page.Size).Offset(calculatedOffset) return tx, model.PageResponse{ Page: *page.Page, Size: *page.Size, Total: int(count), } }