gva-pms/server/plugin/monitor/initialize/gorm.go

173 lines
4.3 KiB
Go

package initialize
import (
"context"
"fmt"
"reflect"
"strings"
"github.com/flipped-aurora/gin-vue-admin/server/global"
"github.com/flipped-aurora/gin-vue-admin/server/plugin/monitor/model"
"github.com/flipped-aurora/gin-vue-admin/server/plugin/monitor/service"
"github.com/pkg/errors"
"go.uber.org/zap"
"gorm.io/gorm"
)
var monitorService service.MC
func Gorm(ctx context.Context) {
err := global.GVA_DB.WithContext(ctx).AutoMigrate(model.MonitorConfig{}, model.ChangeLog{})
if err != nil {
err = errors.Wrap(err, "注册表失败!")
zap.L().Error(fmt.Sprintf("%+v", err))
}
RegisterHooks(ctx, global.GVA_DB)
for _, db := range global.GVA_DBList {
RegisterHooks(ctx, db)
}
monitorService.InitMonitor()
}
func RegisterHooks(ctx context.Context, db *gorm.DB) {
err := db.Callback().Update().Before("gorm:update").Register("monitor:before_update", func(tx *gorm.DB) {
fmt.Println("钩子函数被调用============")
if tx.Statement == nil || tx.Statement.Schema == nil {
return
}
// 获取表名(兼容自定义表名)
tableName := tx.Statement.Schema.Table
fmt.Printf("tableName:%s \n", tableName)
// 获取时间过期缓存
ExpireCaChe := monitorService.RefreshExipre()
// 判断是否过期
if !ExpireCaChe[tableName] {
fmt.Printf("tableName:%s 不在有效期 \n", tableName)
return
}
// 跳过系统表监控
if strings.HasPrefix(tableName, "sys_") || strings.HasPrefix(tableName, "exa_") {
return
}
// 获取主键名
pkField := tx.Statement.Schema.PrioritizedPrimaryField
// 获取主键值
var pkValue any
// 从模型实例中获取
if tx.Statement.Dest != nil {
dest := reflect.ValueOf(tx.Statement.Dest)
if dest.Kind() == reflect.Ptr {
dest = dest.Elem()
}
if dest.Kind() == reflect.Map {
pkValue = dest.MapIndex(reflect.ValueOf(pkField.Name))
}
if dest.Kind() == reflect.Struct {
pkValue = dest.FieldByName(pkField.Name).Interface()
}
}
if pkValue == nil {
global.GVA_LOG.Warn("无法获取主键值,跳过记录变更", zap.String("table", tableName))
return
}
// 查询旧记录
oldModel := reflect.New(tx.Statement.Schema.ModelType).Interface()
if err := tx.Session(&gorm.Session{}).First(oldModel, pkValue).Error; err != nil {
return
}
// 遍历所有字段,检查是否被监控
for _, field := range tx.Statement.Schema.Fields {
fieldName := field.DBName
if !monitorService.IsFieldMonitored(tableName, fieldName) {
continue // 跳过未监控字段
}
// 对比新旧值
oldVal := getFieldValue(oldModel, field.Name)
newVal := getFieldValue(tx.Statement.Dest, field.Name)
fmt.Printf("oldVal:%v;newVal:%v\n", oldVal, newVal)
if !reflect.DeepEqual(oldVal, newVal) {
// 记录日志
logEntry := model.ChangeLog{
Table: &tableName,
Column: &fieldName,
OldValue: toStringPtr(oldVal),
NewValue: toStringPtr(newVal),
}
fmt.Println("logEntry:", logEntry)
global.GVA_DB.Debug().Model(model.ChangeLog{}).Create(&logEntry)
}
}
})
if err != nil {
fmt.Println("register hook failed%%%%%%%%%%%%%%", zap.Error(err))
return
}
}
// 辅助函数:获取字段值(处理指针)
func getFieldValue(model interface{}, field string) interface{} {
val := reflect.ValueOf(model)
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
if val.Kind() == reflect.Map {
mapVal := val.Interface().(map[string]interface{})
if v, ok := mapVal[field]; ok {
if v != nil {
return val
}
return nil
}
// return val.MapIndex(reflect.ValueOf(field))
}
fieldVal := val.FieldByName(field)
if !fieldVal.IsValid() {
return nil
}
if fieldVal.Kind() == reflect.Ptr || fieldVal.Kind() == reflect.Interface {
if fieldVal.IsNil() {
return nil
}
return fieldVal.Elem().Interface()
}
return fieldVal.Interface()
}
// 辅助函数:安全转换为字符串指针
func toStringPtr(v interface{}) (strPrt *string) {
if v == nil {
return nil
}
val := reflect.ValueOf(v)
if val.Kind() == reflect.Ptr {
if val.IsNil() {
return nil
}
return val.Elem().Interface().(*string)
}
if val.Kind() == reflect.String {
str := val.String()
return &str
}
if val.Kind() == reflect.Uint {
str := fmt.Sprintf("%d", val.Uint())
return &str
}
str := fmt.Sprintf("%v", v)
return &str
}