Initial Commit

This commit is contained in:
2021-02-11 15:47:42 -05:00
commit fec590b16e
249 changed files with 42571 additions and 0 deletions

27
internal/api/api.go Normal file
View File

@@ -0,0 +1,27 @@
package api
import (
"net/http"
"reichard.io/imagini/internal/db"
"reichard.io/imagini/internal/auth"
"reichard.io/imagini/internal/config"
)
type API struct {
Router *http.ServeMux
Config *config.Config
Auth *auth.AuthManager
DB *db.DBManager
}
func NewApi(db *db.DBManager, c *config.Config, auth *auth.AuthManager) *API {
api := &API{
Router: http.NewServeMux(),
Config: c,
Auth: auth,
DB: db,
}
api.registerRoutes()
return api
}

103
internal/api/auth.go Normal file
View File

@@ -0,0 +1,103 @@
package api
import (
"errors"
"fmt"
"net/http"
"github.com/google/uuid"
"github.com/lestrrat-go/jwx/jwt"
"reichard.io/imagini/graph/model"
)
func (api *API) refreshTokens(refreshToken jwt.Token) (string, string, error) {
// Acquire User & Device
did, ok := refreshToken.Get("did")
if !ok {
return "", "", errors.New("Missing DID")
}
uid, ok := refreshToken.Get(jwt.SubjectKey)
if !ok {
return "", "", errors.New("Missing UID")
}
deviceUUID, err := uuid.Parse(fmt.Sprintf("%v", did))
if err != nil {
return "", "", errors.New("Invalid DID")
}
userUUID, err := uuid.Parse(fmt.Sprintf("%v", uid))
if err != nil {
return "", "", errors.New("Invalid UID")
}
// Device & User Skeleton
user := model.User{ID: userUUID.String()}
device := model.Device{ID: deviceUUID.String()}
// Find User
_, err = api.DB.User(&user)
if err != nil {
return "", "", err
}
// Update Access Token
accessTokenCookie, err := api.Auth.CreateJWTAccessToken(user, device)
if err != nil {
return "", "", err
}
return accessTokenCookie, "", err
}
func (api *API) validateTokens(w *http.ResponseWriter, r *http.Request) (jwt.Token, error) {
// TODO: Check from X-Imagini-AccessToken
// TODO: Check from X-Imagini-RefreshToken
// Validate Access Token
accessCookie, _ := r.Cookie("AccessToken")
if accessCookie != nil {
accessToken, err := api.Auth.ValidateJWTAccessToken(accessCookie.Value)
if err == nil {
return accessToken, nil
}
}
// Validate Refresh Cookie Exists
refreshCookie, _ := r.Cookie("RefreshToken")
if refreshCookie == nil {
return nil, errors.New("Tokens Invalid")
}
// Validate Refresh Token
refreshToken, err := api.Auth.ValidateJWTRefreshToken(refreshCookie.Value)
if err != nil {
return nil, errors.New("Tokens Invalid")
}
// Refresh Access Token & Generate New Refresh Token
newAccessToken, newRefreshToken, err := api.refreshTokens(refreshToken)
if err != nil {
return nil, err
}
// TODO: Actually Refresh Refresh Token
newRefreshToken = refreshCookie.Value
// Set appropriate cookies (TODO: Only for web!)
// Update Access & Refresh Cookies
http.SetCookie(*w, &http.Cookie{
Name: "AccessToken",
Value: newAccessToken,
})
http.SetCookie(*w, &http.Cookie{
Name: "RefreshToken",
Value: newRefreshToken,
})
// Only for iOS & Android (TODO: Remove for web! Only cause affected by CORS during development)
(*w).Header().Set("X-Imagini-AccessToken", newAccessToken)
(*w).Header().Set("X-Imagini-RefreshToken", newRefreshToken)
return jwt.ParseBytes([]byte(newAccessToken))
}

View File

@@ -0,0 +1,50 @@
package api
import (
"context"
"errors"
"github.com/99designs/gqlgen/graphql"
"reichard.io/imagini/graph/model"
)
/**
* This is used to validate whether the users role is adequate for the requested resource.
**/
func (api *API) hasMinRoleDirective(ctx context.Context, obj interface{}, next graphql.Resolver, role model.Role) (res interface{}, err error) {
authContext := ctx.Value("auth").(*model.AuthContext)
accessToken, err := api.validateTokens(authContext.AuthResponse, authContext.AuthRequest)
if err != nil {
return nil, errors.New("Access Denied")
}
authContext.AccessToken = &accessToken
userRole, ok := accessToken.Get("role")
if !ok {
return nil, errors.New("Access Denied")
}
if userRole == model.RoleAdmin.String() {
return next(ctx)
}
if userRole == role.String() {
return next(ctx)
}
return nil, errors.New("Role Not Authenticated")
}
/**
* This is needed but not used. Meta is used for Gorm.
**/
func (api *API) metaDirective(ctx context.Context, obj interface{}, next graphql.Resolver, gorm *string) (res interface{}, err error) {
return next(ctx)
}
/**
* This overrides the response so fields with an @isPrivate directive are always nil.
**/
func (api *API) isPrivateDirective(ctx context.Context, obj interface{}, next graphql.Resolver) (res interface{}, err error) {
return nil, errors.New("Private Field")
}

55
internal/api/media.go Normal file
View File

@@ -0,0 +1,55 @@
package api
import (
"net/http"
"os"
"path"
"reichard.io/imagini/graph/model"
)
/**
* Responsible for serving up static images / videos
**/
func (api *API) mediaHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
if path.Dir(r.URL.Path) != "/media" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
// Acquire Width & Height Parameters
query := r.URL.Query()
width := query["width"]
height := query["height"]
_ = width
_ = height
// TODO: Caching & Resizing
// - If both, force resize with new scale
// - If one, scale resize proportionally
// Pull out userID
authContext := r.Context().Value("auth").(*model.AuthContext)
rawUserID, _ := (*authContext.AccessToken).Get("sub")
userID := rawUserID.(string)
// Derive Path
fileName := path.Base(r.URL.Path)
folderPath := path.Join("/" + api.Config.DataPath + "/media/" + userID)
mediaPath := path.Join(folderPath + "/" + fileName)
// Check if File Exists
_, err := os.Stat(mediaPath)
if os.IsNotExist(err) {
// TODO: Different HTTP Response Code?
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
http.ServeFile(w, r, mediaPath)
}

View File

@@ -0,0 +1 @@
package api

View File

@@ -0,0 +1,71 @@
package api
import (
"context"
"net/http"
"reichard.io/imagini/graph/model"
)
type Middleware func(http.Handler) http.HandlerFunc
func multipleMiddleware(h http.HandlerFunc, m ...Middleware) http.HandlerFunc {
if len(m) < 1 {
return h
}
wrapped := h
for i := len(m) - 1; i >= 0; i-- {
wrapped = m[i](wrapped)
}
return wrapped
}
/**
* This is used for the graphQL endpoints that may require access to the
* Request and ResponseWriter variables. These are used to get / set cookies.
**/
func (api *API) queryMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// TODO: REMOVE (SOME OF) THIS!! Only for developement due to CORS
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Expose-Headers", "*")
w.Header().Set("Access-Control-Allow-Headers", "*")
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
authContext := &model.AuthContext{
AuthResponse: &w,
AuthRequest: r,
}
// Add context
ctx := context.WithValue(r.Context(), "auth", authContext)
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
})
}
/**
* This is used for non graphQL endpoints that require authentication.
**/
func (api *API) authMiddleware(next http.Handler) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Validate Tokens
accessToken, err := api.validateTokens(&w, r)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
// Create Context
authContext := &model.AuthContext{
AccessToken: &accessToken,
}
ctx := context.WithValue(r.Context(), "auth", authContext)
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
})
}

36
internal/api/routes.go Normal file
View File

@@ -0,0 +1,36 @@
package api
import (
"github.com/99designs/gqlgen/graphql/handler"
"github.com/99designs/gqlgen/graphql/playground"
"reichard.io/imagini/graph"
"reichard.io/imagini/graph/generated"
)
func (api *API) registerRoutes() {
// Set up Directives
graphConfig := generated.Config{
Resolvers: &graph.Resolver{
DB: api.DB,
Auth: api.Auth,
Config: api.Config,
},
Directives: generated.DirectiveRoot{
Meta: api.metaDirective,
IsPrivate: api.isPrivateDirective,
HasMinRole: api.hasMinRoleDirective,
},
}
srv := handler.NewDefaultServer(generated.NewExecutableSchema(graphConfig))
// Handle GraphQL
api.Router.Handle("/playground", playground.Handler("GraphQL playground", "/query"))
api.Router.Handle("/query", api.queryMiddleware(srv))
// Handle Resource Route
api.Router.HandleFunc("/media/", multipleMiddleware(
api.mediaHandler,
api.authMiddleware,
))
}

57
internal/auth/auth.go Normal file
View File

@@ -0,0 +1,57 @@
package auth
import (
"errors"
log "github.com/sirupsen/logrus"
"gorm.io/gorm"
"reichard.io/imagini/graph/model"
"reichard.io/imagini/internal/config"
"reichard.io/imagini/internal/db"
)
type AuthManager struct {
DB *db.DBManager
Config *config.Config
}
func NewMgr(db *db.DBManager, c *config.Config) *AuthManager {
return &AuthManager{
DB: db,
Config: c,
}
}
func (auth *AuthManager) AuthenticateUser(user, password string) (model.User, bool) {
// Find User by Username / Email
foundUser := &model.User{Username: user}
_, err := auth.DB.User(foundUser)
// By Username
if errors.Is(err, gorm.ErrRecordNotFound) {
foundUser = &model.User{Email: user}
_, err = auth.DB.User(foundUser)
}
// By Email
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn("[auth] User not found: ", user)
return *foundUser, false
} else if err != nil {
log.Error(err)
return *foundUser, false
}
log.Info("[auth] Authenticating user: ", foundUser.Username)
// Determine Type
switch foundUser.AuthType {
case "Local":
return *foundUser, authenticateLocalUser(*foundUser, password)
case "LDAP":
return *foundUser, authenticateLDAPUser(*foundUser, password)
default:
return *foundUser, false
}
}

126
internal/auth/jwt.go Normal file
View File

@@ -0,0 +1,126 @@
package auth
import (
"encoding/json"
"errors"
"fmt"
"time"
"github.com/google/uuid"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwt"
log "github.com/sirupsen/logrus"
"reichard.io/imagini/graph/model"
)
func (auth *AuthManager) ValidateJWTRefreshToken(refreshJWT string) (jwt.Token, error) {
byteRefreshJWT := []byte(refreshJWT)
// Acquire Relevant Device
unverifiedToken, err := jwt.ParseBytes(byteRefreshJWT)
did, ok := unverifiedToken.Get("did")
if !ok {
return nil, errors.New("did does not exist")
}
deviceID, err := uuid.Parse(fmt.Sprintf("%v", did))
if err != nil {
return nil, errors.New("did does not parse")
}
device := &model.Device{ID: deviceID.String()}
_, err = auth.DB.Device(device)
if err != nil {
return nil, err
}
// Verify & Validate Token
verifiedToken, err := jwt.ParseBytes(byteRefreshJWT,
jwt.WithValidate(true),
jwt.WithVerify(jwa.HS256, []byte(*device.RefreshKey)),
)
if err != nil {
fmt.Println("failed to parse payload: ", err)
return nil, err
}
return verifiedToken, nil
}
func (auth *AuthManager) ValidateJWTAccessToken(accessJWT string) (jwt.Token, error) {
byteAccessJWT := []byte(accessJWT)
verifiedToken, err := jwt.ParseBytes(byteAccessJWT,
jwt.WithValidate(true),
jwt.WithVerify(jwa.HS256, []byte(auth.Config.JWTSecret)),
)
if err != nil {
return nil, err
}
return verifiedToken, nil
}
func (auth *AuthManager) CreateJWTRefreshToken(user model.User, device model.Device) (string, error) {
// Acquire Refresh Key
byteKey := []byte(*device.RefreshKey)
// Create New Token
tm := time.Now()
t := jwt.New()
t.Set(`did`, device.ID) // Device ID
t.Set(jwt.SubjectKey, user.ID) // User ID
t.Set(jwt.AudienceKey, `imagini`) // App ID
t.Set(jwt.IssuedAtKey, tm) // Issued At
// iOS & Android = Never Expiring Refresh Token
if device.Type != "iOS" && device.Type != "Android" {
t.Set(jwt.ExpirationKey, tm.Add(time.Hour*24)) // 1 Day Access Key
}
// Validate Token Creation
_, err := json.MarshalIndent(t, "", " ")
if err != nil {
fmt.Printf("failed to generate JSON: %s\n", err)
return "", err
}
// Sign Token
signed, err := jwt.Sign(t, jwa.HS256, byteKey)
if err != nil {
log.Printf("failed to sign token: %s", err)
return "", err
}
// Return Token
return string(signed), nil
}
func (auth *AuthManager) CreateJWTAccessToken(user model.User, device model.Device) (string, error) {
// Create New Token
tm := time.Now()
t := jwt.New()
t.Set(`did`, device.ID) // Device ID
t.Set(`role`, user.Role.String()) // User Role (Admin / User)
t.Set(jwt.SubjectKey, user.ID) // User ID
t.Set(jwt.AudienceKey, `imagini`) // App ID
t.Set(jwt.IssuedAtKey, tm) // Issued At
t.Set(jwt.ExpirationKey, tm.Add(time.Hour*2)) // 2 Hour Access Key
// Validate Token Creation
_, err := json.MarshalIndent(t, "", " ")
if err != nil {
fmt.Printf("failed to generate JSON: %s\n", err)
return "", err
}
// Use Server Key
byteKey := []byte(auth.Config.JWTSecret)
// Sign Token
signed, err := jwt.Sign(t, jwa.HS256, byteKey)
if err != nil {
log.Printf("failed to sign token: %s", err)
return "", err
}
// Return Token
return string(signed), nil
}

9
internal/auth/ldap.go Normal file
View File

@@ -0,0 +1,9 @@
package auth
import (
"reichard.io/imagini/graph/model"
)
func authenticateLDAPUser(user model.User, pw string) bool {
return false
}

18
internal/auth/local.go Normal file
View File

@@ -0,0 +1,18 @@
package auth
import (
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt"
"reichard.io/imagini/graph/model"
)
func authenticateLocalUser(user model.User, pw string) bool {
bPassword := []byte(pw)
err := bcrypt.CompareHashAndPassword([]byte(*user.Password), bPassword)
if err == nil {
log.Info("[auth] Authentication successfull: ", user.Username)
return true
}
log.Warn("[auth] Authentication failed: ", user.Username)
return false
}

34
internal/config/config.go Normal file
View File

@@ -0,0 +1,34 @@
package config
import (
"os"
)
type Config struct {
DBType string
DBName string
DBPassword string
DataPath string
ConfigPath string
JWTSecret string
ListenPort string
}
func Load() *Config {
return &Config{
DBType: getEnv("DATABASE_TYPE", "SQLite"),
DBName: getEnv("DATABASE_NAME", "imagini"),
DBPassword: getEnv("DATABASE_PASSWORD", ""),
ConfigPath: getEnv("CONFIG_PATH", "/config"),
DataPath: getEnv("DATA_PATH", "/data"),
JWTSecret: getEnv("JWT_SECRET", "58b9340c0472cf045db226bc445966524e780cd38bc3dd707afce80c95d4de6f"),
ListenPort: getEnv("LISTEN_PORT", "8484"),
}
}
func getEnv(key, fallback string) string {
if value, ok := os.LookupEnv(key); ok {
return value
}
return fallback
}

37
internal/db/albums.go Normal file
View File

@@ -0,0 +1,37 @@
package db
import (
log "github.com/sirupsen/logrus"
"gorm.io/gorm"
"reichard.io/imagini/graph/model"
)
func (dbm *DBManager) CreateAlbum(album *model.Album) error {
log.Debug("[db] Creating album: ", album.Name)
err := dbm.db.Create(album).Error
return err
}
func (dbm *DBManager) Album(album *model.Album) (int64, error) {
var count int64
err := dbm.db.Where(album).First(album).Count(&count).Error
return count, err
}
func (dbm *DBManager) Albums(userID string, filters *model.AlbumFilter, page *model.Page, order *model.Order) ([]*model.Album, model.PageResponse, error) {
// Initial User Filter
tx := dbm.db.Session(&gorm.Session{}).Model(&model.Album{}).Where("user_id == ?", userID)
// Dynamically Generate Base Query
tx, pageResponse := dbm.generateBaseQuery(tx, filters, page, order)
// Acquire Results
var foundAlbums []*model.Album
err := tx.Find(&foundAlbums).Error
return foundAlbums, pageResponse, err
}
func (dbm *DBManager) DeleteAlbum(album *model.Album) error {
return nil
}

155
internal/db/db.go Normal file
View File

@@ -0,0 +1,155 @@
package db
import (
"fmt"
"path"
"reflect"
"github.com/iancoleman/strcase"
log "github.com/sirupsen/logrus"
// "gorm.io/gorm/logger"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"reichard.io/imagini/graph/model"
"reichard.io/imagini/internal/config"
)
type DBManager struct {
db *gorm.DB
}
func NewMgr(c *config.Config) *DBManager {
gormConfig := &gorm.Config{
PrepareStmt: true,
// Logger: logger.Default.LogMode(logger.Silent),
}
// Create manager
dbm := &DBManager{}
if c.DBType == "SQLite" {
dbLocation := path.Join(c.ConfigPath, "imagini.db")
dbm.db, _ = gorm.Open(sqlite.Open(dbLocation), gormConfig)
dbm.db = dbm.db.Debug()
} else {
log.Fatal("Unsupported Database")
}
// Initialize database
dbm.db.AutoMigrate(&model.Device{})
dbm.db.AutoMigrate(&model.User{})
dbm.db.AutoMigrate(&model.MediaItem{})
dbm.db.AutoMigrate(&model.Tag{})
dbm.db.AutoMigrate(&model.Album{})
// Determine whether to bootstrap
var count int64
dbm.db.Model(&model.User{}).Count(&count)
if count == 0 {
dbm.bootstrapDatabase()
}
return dbm
}
func (dbm *DBManager) bootstrapDatabase() {
log.Info("[query] Bootstrapping database.")
password := "admin"
user := &model.User{
Username: "admin",
AuthType: "Local",
Password: &password,
Role: model.RoleAdmin,
}
err := dbm.CreateUser(user)
if err != nil {
log.Fatal("[query] Unable to bootstrap database.")
}
}
func (dbm *DBManager) generateBaseQuery(tx *gorm.DB, filter interface{}, page *model.Page, order *model.Order) (*gorm.DB, model.PageResponse) {
tx = dbm.generateFilter(tx, filter)
tx = dbm.generateOrder(tx, order, filter)
tx, pageResponse := dbm.generatePage(tx, page)
return tx, pageResponse
}
func (dbm *DBManager) generateOrder(tx *gorm.DB, order *model.Order, filter interface{}) *gorm.DB {
// Set Defaults
orderBy := "created_at"
orderDirection := model.OrderDirectionDesc
if order == nil {
order = &model.Order{
By: &orderBy,
Direction: &orderDirection,
}
}
if order.By == nil {
order.By = &orderBy
}
if order.Direction == nil {
order.Direction = &orderDirection
}
// Get Possible Values
ptr := reflect.New(reflect.TypeOf(filter).Elem())
v := reflect.Indirect(ptr)
isValid := false
for i := 0; i < v.NumField(); i++ {
fieldName := v.Type().Field(i).Name
if strcase.ToSnake(*order.By) == strcase.ToSnake(fieldName) {
isValid = true
break
}
}
if isValid {
tx = tx.Order(fmt.Sprintf("%s %s", strcase.ToSnake(*order.By), order.Direction.String()))
}
return tx
}
func (dbm *DBManager) generatePage(tx *gorm.DB, page *model.Page) (*gorm.DB, model.PageResponse) {
// Set Defaults
var count int64
pageSize := 50
pageNum := 1
if page == nil {
page = &model.Page{
Size: &pageSize,
Page: &pageNum,
}
}
if page.Size == nil {
page.Size = &pageSize
}
if page.Page == nil {
page.Page = &pageNum
}
// Acquire Counts Before Pagination
tx.Count(&count)
// Calculate Offset
calculatedOffset := (*page.Page - 1) * *page.Size
tx = tx.Limit(*page.Size).Offset(calculatedOffset)
return tx, model.PageResponse{
Page: *page.Page,
Size: *page.Size,
Total: int(count),
}
}

44
internal/db/devices.go Normal file
View File

@@ -0,0 +1,44 @@
package db
import (
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"gorm.io/gorm"
"reichard.io/imagini/graph/model"
)
func (dbm *DBManager) CreateDevice(device *model.Device) error {
log.Debug("[db] Creating device: ", device.Name)
refreshKey := uuid.New().String()
device.RefreshKey = &refreshKey
err := dbm.db.Create(device).Error
return err
}
func (dbm *DBManager) Device(device *model.Device) (int64, error) {
var count int64
err := dbm.db.Where(device).First(device).Count(&count).Error
return count, err
}
func (dbm *DBManager) Devices(userID string, filters *model.DeviceFilter, page *model.Page, order *model.Order) ([]*model.Device, model.PageResponse, error) {
// Initial User Filter
tx := dbm.db.Session(&gorm.Session{}).Model(&model.Device{}).Where("user_id == ?", userID)
// Dynamically Generate Base Query
tx, pageResponse := dbm.generateBaseQuery(tx, filters, page, order)
// Acquire Results
var foundDevices []*model.Device
err := tx.Find(&foundDevices).Error
return foundDevices, pageResponse, err
}
func (dbm *DBManager) DeleteDevice(user *model.Device) error {
return nil
}
func (dbm *DBManager) UpdateRefreshToken(device *model.Device, refreshToken string) error {
return nil
}

196
internal/db/filters.go Normal file
View File

@@ -0,0 +1,196 @@
package db
import (
"fmt"
"reflect"
"github.com/iancoleman/strcase"
"gorm.io/gorm"
"reichard.io/imagini/graph/model"
)
// Generic function used to generate filters for the DB
func (dbm *DBManager) generateFilter(tx *gorm.DB, filter interface{}) *gorm.DB {
ptr := reflect.ValueOf(filter)
v := reflect.Indirect(ptr)
if v == reflect.ValueOf(nil) {
return tx
}
for i := 0; i < v.NumField(); i++ {
fieldName := strcase.ToSnake(v.Type().Field(i).Name)
fieldVal := v.Field(i)
if fieldVal.IsNil() {
continue
}
switch valType := fieldVal.Type(); valType {
case reflect.TypeOf(&model.StringFilter{}):
tx = generateStringFilter(tx, fieldName, fieldVal.Interface().(*model.StringFilter))
case reflect.TypeOf(&model.BooleanFilter{}):
tx = generateBooleanFilter(tx, fieldName, fieldVal.Interface().(*model.BooleanFilter))
case reflect.TypeOf(&model.FloatFilter{}):
tx = generateFloatFilter(tx, fieldName, fieldVal.Interface().(*model.FloatFilter))
case reflect.TypeOf(&model.IntFilter{}):
tx = generateIntFilter(tx, fieldName, fieldVal.Interface().(*model.IntFilter))
case reflect.TypeOf(&model.IDFilter{}):
tx = generateIDFilter(tx, fieldName, fieldVal.Interface().(*model.IDFilter))
case reflect.TypeOf(&model.TimeFilter{}):
tx = generateTimeFilter(tx, fieldName, fieldVal.Interface().(*model.TimeFilter))
case reflect.TypeOf(&model.RoleFilter{}):
tx = generateRoleFilter(tx, fieldName, fieldVal.Interface().(*model.RoleFilter))
case reflect.TypeOf(&model.DeviceTypeFilter{}):
tx = generateDeviceTypeFilter(tx, fieldName, fieldVal.Interface().(*model.DeviceTypeFilter))
case reflect.TypeOf(&model.AuthTypeFilter{}):
tx = generateAuthTypeFilter(tx, fieldName, fieldVal.Interface().(*model.AuthTypeFilter))
}
}
return tx
}
func generateStringFilter(tx *gorm.DB, fieldName string, filter *model.StringFilter) *gorm.DB {
if filter.EqualTo != nil {
tx = tx.Where(fmt.Sprintf("%s == ?", fieldName), *filter.EqualTo)
}
if filter.NotEqualTo != nil {
tx = tx.Where(fmt.Sprintf("%s != ?", fieldName), *filter.NotEqualTo)
}
if filter.StartsWith != nil {
tx = tx.Where(fmt.Sprintf("%s LIKE ?", fieldName), fmt.Sprintf("%s%%", *filter.StartsWith))
}
if filter.NotStartsWith != nil {
tx = tx.Where(fmt.Sprintf("%s NOT LIKE ?", fieldName), fmt.Sprintf("%s%%", *filter.NotStartsWith))
}
if filter.EndsWith != nil {
tx = tx.Where(fmt.Sprintf("%s LIKE ?", fieldName), fmt.Sprintf("%%%s", *filter.EndsWith))
}
if filter.NotEndsWith != nil {
tx = tx.Where(fmt.Sprintf("%s NOT LIKE ?", fieldName), fmt.Sprintf("%%%s", *filter.NotEndsWith))
}
if filter.Contains != nil {
tx = tx.Where(fmt.Sprintf("%s LIKE ?", fieldName), fmt.Sprintf("%%%s%%", *filter.Contains))
}
if filter.NotContains != nil {
tx = tx.Where(fmt.Sprintf("%s NOT LIKE ?", fieldName), fmt.Sprintf("%%%s%%", *filter.NotContains))
}
return tx
}
func generateBooleanFilter(tx *gorm.DB, fieldName string, filter *model.BooleanFilter) *gorm.DB {
if filter.EqualTo != nil {
tx = tx.Where(fmt.Sprintf("%s == ?", fieldName), *filter.EqualTo)
}
if filter.NotEqualTo != nil {
tx = tx.Where(fmt.Sprintf("%s != ?", fieldName), *filter.NotEqualTo)
}
return tx
}
func generateFloatFilter(tx *gorm.DB, fieldName string, filter *model.FloatFilter) *gorm.DB {
if filter.EqualTo != nil {
tx = tx.Where(fmt.Sprintf("%s == ?", fieldName), *filter.EqualTo)
}
if filter.NotEqualTo != nil {
tx = tx.Where(fmt.Sprintf("%s != ?", fieldName), *filter.NotEqualTo)
}
if filter.GreaterThan != nil {
tx = tx.Where(fmt.Sprintf("%s > ?", fieldName), *filter.GreaterThan)
}
if filter.GreaterThanOrEqualTo != nil {
tx = tx.Where(fmt.Sprintf("%s >= ?", fieldName), *filter.GreaterThanOrEqualTo)
}
if filter.LessThan != nil {
tx = tx.Where(fmt.Sprintf("%s < ?", fieldName), *filter.LessThan)
}
if filter.LessThanOrEqualTo != nil {
tx = tx.Where(fmt.Sprintf("%s <= ?", fieldName), *filter.LessThanOrEqualTo)
}
return tx
}
func generateIntFilter(tx *gorm.DB, fieldName string, filter *model.IntFilter) *gorm.DB {
if filter.EqualTo != nil {
tx = tx.Where(fmt.Sprintf("%s == ?", fieldName), *filter.EqualTo)
}
if filter.NotEqualTo != nil {
tx = tx.Where(fmt.Sprintf("%s != ?", fieldName), *filter.NotEqualTo)
}
if filter.GreaterThan != nil {
tx = tx.Where(fmt.Sprintf("%s > ?", fieldName), *filter.GreaterThan)
}
if filter.GreaterThanOrEqualTo != nil {
tx = tx.Where(fmt.Sprintf("%s >= ?", fieldName), *filter.GreaterThanOrEqualTo)
}
if filter.LessThan != nil {
tx = tx.Where(fmt.Sprintf("%s < ?", fieldName), *filter.LessThan)
}
if filter.LessThanOrEqualTo != nil {
tx = tx.Where(fmt.Sprintf("%s <= ?", fieldName), *filter.LessThanOrEqualTo)
}
return tx
}
func generateIDFilter(tx *gorm.DB, fieldName string, filter *model.IDFilter) *gorm.DB {
if filter.EqualTo != nil {
tx = tx.Where(fmt.Sprintf("%s == ?", fieldName), *filter.EqualTo)
}
if filter.NotEqualTo != nil {
tx = tx.Where(fmt.Sprintf("%s != ?", fieldName), *filter.NotEqualTo)
}
return tx
}
func generateTimeFilter(tx *gorm.DB, fieldName string, filter *model.TimeFilter) *gorm.DB {
if filter.EqualTo != nil {
tx = tx.Where(fmt.Sprintf("%s == ?", fieldName), *filter.EqualTo)
}
if filter.NotEqualTo != nil {
tx = tx.Where(fmt.Sprintf("%s != ?", fieldName), *filter.NotEqualTo)
}
if filter.GreaterThan != nil {
tx = tx.Where(fmt.Sprintf("%s > ?", fieldName), *filter.GreaterThan)
}
if filter.GreaterThanOrEqualTo != nil {
tx = tx.Where(fmt.Sprintf("%s >= ?", fieldName), *filter.GreaterThanOrEqualTo)
}
if filter.LessThan != nil {
tx = tx.Where(fmt.Sprintf("%s < ?", fieldName), *filter.LessThan)
}
if filter.LessThanOrEqualTo != nil {
tx = tx.Where(fmt.Sprintf("%s <= ?", fieldName), *filter.LessThanOrEqualTo)
}
return tx
}
func generateRoleFilter(tx *gorm.DB, fieldName string, filter *model.RoleFilter) *gorm.DB {
if filter.EqualTo != nil {
tx = tx.Where(fmt.Sprintf("%s == ?", fieldName), *filter.EqualTo)
}
if filter.NotEqualTo != nil {
tx = tx.Where(fmt.Sprintf("%s != ?", fieldName), *filter.NotEqualTo)
}
return tx
}
func generateDeviceTypeFilter(tx *gorm.DB, fieldName string, filter *model.DeviceTypeFilter) *gorm.DB {
if filter.EqualTo != nil {
tx = tx.Where(fmt.Sprintf("%s == ?", fieldName), *filter.EqualTo)
}
if filter.NotEqualTo != nil {
tx = tx.Where(fmt.Sprintf("%s != ?", fieldName), *filter.NotEqualTo)
}
return tx
}
func generateAuthTypeFilter(tx *gorm.DB, fieldName string, filter *model.AuthTypeFilter) *gorm.DB {
if filter.EqualTo != nil {
tx = tx.Where(fmt.Sprintf("%s == ?", fieldName), *filter.EqualTo)
}
if filter.NotEqualTo != nil {
tx = tx.Where(fmt.Sprintf("%s != ?", fieldName), *filter.NotEqualTo)
}
return tx
}

View File

@@ -0,0 +1,34 @@
package db
import (
log "github.com/sirupsen/logrus"
"gorm.io/gorm"
"reichard.io/imagini/graph/model"
)
func (dbm *DBManager) CreateMediaItem(mediaItem *model.MediaItem) error {
log.Debug("[db] Creating media item: ", mediaItem.FileName)
err := dbm.db.Create(mediaItem).Error
return err
}
func (dbm *DBManager) MediaItem(mediaItem *model.MediaItem) (int64, error) {
var count int64
err := dbm.db.Where(mediaItem).First(mediaItem).Count(&count).Error
return count, err
}
// UserID, Filters, Sort, Page, Delete
func (dbm *DBManager) MediaItems(userID string, filters *model.MediaItemFilter, page *model.Page, order *model.Order) ([]*model.MediaItem, model.PageResponse, error) {
// Initial User Filter
tx := dbm.db.Session(&gorm.Session{}).Model(&model.MediaItem{}).Where("user_id == ?", userID)
// Dynamically Generate Base Query
tx, pageResponse := dbm.generateBaseQuery(tx, filters, page, order)
// Acquire Results
var mediaItems []*model.MediaItem
err := tx.Find(&mediaItems).Error
return mediaItems, pageResponse, err
}

37
internal/db/tags.go Normal file
View File

@@ -0,0 +1,37 @@
package db
import (
log "github.com/sirupsen/logrus"
"gorm.io/gorm"
"reichard.io/imagini/graph/model"
)
func (dbm *DBManager) CreateTag(tag *model.Tag) error {
log.Debug("[db] Creating tag: ", tag.Name)
err := dbm.db.Create(tag).Error
return err
}
func (dbm *DBManager) Tag(tag *model.Tag) (int64, error) {
var count int64
err := dbm.db.Where(tag).First(tag).Count(&count).Error
return count, err
}
func (dbm *DBManager) Tags(userID string, filters *model.TagFilter, page *model.Page, order *model.Order) ([]*model.Tag, model.PageResponse, error) {
// Initial User Filter
tx := dbm.db.Session(&gorm.Session{}).Model(&model.Tag{}).Where("user_id == ?", userID)
// Dynamically Generate Base Query
tx, pageResponse := dbm.generateBaseQuery(tx, filters, page, order)
// Acquire Results
var foundTags []*model.Tag
err := tx.Find(&foundTags).Error
return foundTags, pageResponse, err
}
func (dbm *DBManager) DeleteTag(tag *model.Tag) error {
return nil
}

48
internal/db/users.go Normal file
View File

@@ -0,0 +1,48 @@
package db
import (
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
"reichard.io/imagini/graph/model"
)
func (dbm *DBManager) CreateUser(user *model.User) error {
log.Info("[db] Creating user: ", user.Username)
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(*user.Password), bcrypt.DefaultCost)
if err != nil {
log.Error(err)
return err
}
stringHashedPassword := string(hashedPassword)
user.Password = &stringHashedPassword
return dbm.db.Create(user).Error
}
func (dbm *DBManager) User(user *model.User) (int64, error) {
var count int64
err := dbm.db.Where(user).First(user).Count(&count).Error
return count, err
}
func (dbm *DBManager) Users(filters *model.UserFilter, page *model.Page, order *model.Order) ([]*model.User, model.PageResponse, error) {
// Initial User Filter
tx := dbm.db.Session(&gorm.Session{}).Model(&model.Tag{})
// Dynamically Generate Base Query
tx, pageResponse := dbm.generateBaseQuery(tx, filters, page, order)
// Acquire Results
var foundUsers []*model.User
err := tx.Find(&foundUsers).Error
return foundUsers, pageResponse, err
}
func (dbm *DBManager) DeleteUser(user model.User) error {
return nil
}
func (dbm *DBManager) UpdatePassword(user model.User, pw string) {
}

View File

@@ -0,0 +1,40 @@
package session
import (
"sync"
)
// Used to maintain a cache of user specific jwt secrets
// This will prevent DB lookups on every request
// May not actually be needed. Refresh Token is the only
// token that will require proactive DB lookups.
type SessionManager struct {
mutex sync.Mutex
values map[string]string
}
func NewMgr() *SessionManager {
return &SessionManager{}
}
func (sm *SessionManager) Set(key, value string) {
sm.mutex.Lock()
sm.values[key] = value
sm.mutex.Unlock()
}
func (sm *SessionManager) Get(key string) string {
sm.mutex.Lock()
defer sm.mutex.Unlock()
return sm.values[key]
}
func (sm *SessionManager) Delete(key string) {
sm.mutex.Lock()
defer sm.mutex.Unlock()
_, exists := sm.values[key]
if !exists {
return
}
delete(sm.values, key)
}