316 lines
5.9 KiB
Go
316 lines
5.9 KiB
Go
package tableDB
|
|
|
|
import (
|
|
"os"
|
|
"testing"
|
|
|
|
"apigo.cc/go/log"
|
|
)
|
|
|
|
func TestSQLInjection(t *testing.T) {
|
|
logger := log.DefaultLogger
|
|
dbFile := "test_injection.db"
|
|
os.Remove(dbFile)
|
|
defer os.Remove(dbFile)
|
|
|
|
dbInst := GetDB("sqlite://"+dbFile, logger)
|
|
|
|
schema := `
|
|
== InjectionGroup ==
|
|
_Table SD
|
|
id c10 PK
|
|
name v64 U
|
|
memo t
|
|
createTime bi
|
|
creator v64
|
|
|
|
_Field SD
|
|
id c10 PK
|
|
tableId c10 I
|
|
name v64
|
|
type v32
|
|
isIndex b
|
|
memo t
|
|
createTime bi
|
|
|
|
_Policy SD
|
|
subject v64 I
|
|
action v32 I
|
|
resource v128 I
|
|
effect v16
|
|
|
|
== Test ==
|
|
users_inj SD
|
|
id c10 PK
|
|
name v50 U
|
|
secret t
|
|
`
|
|
err := dbInst.SyncSchema(schema)
|
|
if err != nil {
|
|
t.Fatalf("Failed to sync schema: %v", err)
|
|
}
|
|
|
|
appAdmin := dbInst.Auth("admin")
|
|
table := appAdmin.Table("users_inj")
|
|
_ = table.Set(map[string]any{"name": "Alice", "secret": "top-secret-123"})
|
|
|
|
// Attempt SQL injection via Table name
|
|
req1 := QueryRequest{
|
|
Table: "users_inj` --",
|
|
}
|
|
_, _, err = dbInst.BuildQuery(req1)
|
|
if err == nil {
|
|
t.Errorf("Expected error for invalid table name with injection")
|
|
}
|
|
|
|
// Attempt SQL injection via Field name
|
|
req2 := QueryRequest{
|
|
Table: "users_inj",
|
|
Select: []string{"name`, secret AS name `"},
|
|
}
|
|
_, _, err = dbInst.BuildQuery(req2)
|
|
if err == nil {
|
|
t.Errorf("Expected error for invalid field name with injection")
|
|
}
|
|
|
|
// Attempt SQL injection via Join Table
|
|
req3 := QueryRequest{
|
|
Table: "users_inj",
|
|
Joins: []JoinConfig{
|
|
{Table: "users_inj` --", On: "1=1"},
|
|
},
|
|
}
|
|
_, _, err = dbInst.BuildQuery(req3)
|
|
if err == nil {
|
|
t.Errorf("Expected error for invalid join table name")
|
|
}
|
|
|
|
// Attempt SQL injection via OrderBy
|
|
req4 := QueryRequest{
|
|
Table: "users_inj",
|
|
OrderBy: "name; DROP TABLE users_inj; --",
|
|
}
|
|
_, _, err = dbInst.BuildQuery(req4)
|
|
if err == nil {
|
|
t.Errorf("Expected error for invalid order by with injection")
|
|
}
|
|
}
|
|
|
|
func TestTableOperationsAndHooks(t *testing.T) {
|
|
logger := log.DefaultLogger
|
|
logger.SetLevel(log.ERROR)
|
|
os.Remove("test_ops.db")
|
|
defer os.Remove("test_ops.db")
|
|
|
|
dbInst := GetDB("sqlite://test_ops.db", logger)
|
|
|
|
var hookUpdatedRowsCount int
|
|
var hookRemovedRowsCount int
|
|
var hookUpdatingRowCalled bool
|
|
|
|
dbInst.Hooks.OnUpdatingRow = func(tableName string, row map[string]any) error {
|
|
hookUpdatingRowCalled = true
|
|
if tableName == "users_ops" {
|
|
row["memo"] = "hooked"
|
|
}
|
|
return nil
|
|
}
|
|
dbInst.Hooks.OnUpdatedRows = func(tableName string, count int) {
|
|
hookUpdatedRowsCount += count
|
|
}
|
|
dbInst.Hooks.OnRemovedRows = func(tableName string, ids []string) {
|
|
hookRemovedRowsCount += len(ids)
|
|
}
|
|
|
|
schema := `
|
|
== TestGroup ==
|
|
_Table SD
|
|
id c10 PK
|
|
name v64 U
|
|
memo t
|
|
createTime bi
|
|
creator v64
|
|
|
|
_Field SD
|
|
id c10 PK
|
|
tableId c10 I
|
|
name v64
|
|
type v32
|
|
isIndex b
|
|
memo t
|
|
createTime bi
|
|
|
|
_Policy SD
|
|
subject v64 I
|
|
action v32 I
|
|
resource v128 I
|
|
effect v16
|
|
|
|
== Test ==
|
|
users_ops SD
|
|
id c10 PK
|
|
name v50 U
|
|
age i
|
|
status ti
|
|
memo t
|
|
`
|
|
err := dbInst.SyncSchema(schema)
|
|
if err != nil {
|
|
t.Fatalf("Failed to sync schema: %v", err)
|
|
}
|
|
|
|
appAdmin := dbInst.Auth("admin")
|
|
table := appAdmin.Table("users_ops")
|
|
|
|
// Test Set (Insert)
|
|
err = table.Set(map[string]any{
|
|
"name": "Alice",
|
|
"age": 30,
|
|
"status": 1,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Set failed: %v", err)
|
|
}
|
|
|
|
if !hookUpdatingRowCalled {
|
|
t.Errorf("Expected OnUpdatingRow to be called")
|
|
}
|
|
|
|
if hookUpdatedRowsCount != 1 {
|
|
t.Errorf("Expected OnUpdatedRows to be 1, got %d", hookUpdatedRowsCount)
|
|
}
|
|
|
|
// Test Set with explicit ID (Insert)
|
|
err = table.Set(map[string]any{
|
|
"id": "100",
|
|
"name": "Bob",
|
|
"age": 25,
|
|
"status": 0,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Set with ID failed: %v", err)
|
|
}
|
|
|
|
// Test Get
|
|
record, err := table.Get("100")
|
|
if err != nil || record == nil {
|
|
t.Fatalf("Get failed: %v", err)
|
|
}
|
|
if record["name"] != "Bob" {
|
|
t.Fatalf("Expected name Bob, got %v", record["name"])
|
|
}
|
|
if record["memo"] != "hooked" {
|
|
t.Fatalf("Expected memo hooked, got %v", record["memo"])
|
|
}
|
|
|
|
// Test QueryRequest
|
|
queryReq := QueryRequest{
|
|
Table: "users_ops",
|
|
Where: "age > ?",
|
|
Args: []any{20},
|
|
Limit: 10,
|
|
}
|
|
res, err := dbInst.Query(queryReq)
|
|
if err != nil {
|
|
t.Fatalf("QueryRequest failed: %v", err)
|
|
}
|
|
if len(res) != 2 {
|
|
t.Fatalf("Expected 2 results from QueryRequest, got %d", len(res))
|
|
}
|
|
|
|
// Test cache and _Field
|
|
fields, err := table.Fields()
|
|
if err != nil {
|
|
t.Fatalf("Fields() failed: %v", err)
|
|
}
|
|
if len(fields) == 0 {
|
|
t.Fatalf("Expected fields metadata, got empty")
|
|
}
|
|
hasAge := false
|
|
for _, f := range fields {
|
|
if f.Name == "age" {
|
|
hasAge = true
|
|
break
|
|
}
|
|
}
|
|
if !hasAge {
|
|
t.Fatalf("Field 'age' not found in metadata")
|
|
}
|
|
|
|
// Test List
|
|
list, err := table.List(map[string]any{"age >": 20})
|
|
if err != nil {
|
|
t.Fatalf("List failed: %v", err)
|
|
}
|
|
if len(list) != 2 {
|
|
t.Fatalf("Expected 2 results from List, got %d", len(list))
|
|
}
|
|
|
|
// Test Count
|
|
count, err := table.Count(map[string]any{"age >": 20})
|
|
if err != nil {
|
|
t.Fatalf("Count failed: %v", err)
|
|
}
|
|
if count != 2 {
|
|
t.Fatalf("Expected count 2, got %d", count)
|
|
}
|
|
|
|
// Test Remove
|
|
err = table.Remove("100")
|
|
if err != nil {
|
|
t.Fatalf("Remove failed: %v", err)
|
|
}
|
|
record, err = table.Get("100")
|
|
if record != nil {
|
|
t.Fatalf("Expected nil after removal, got %v", record)
|
|
}
|
|
|
|
if hookRemovedRowsCount != 1 {
|
|
t.Errorf("Expected hookRemovedRowsCount to be 1, got %d", hookRemovedRowsCount)
|
|
}
|
|
}
|
|
|
|
/*
|
|
func BenchmarkTableSet(b *testing.B) {
|
|
logger := log.DefaultLogger
|
|
logger.SetLevel(log.ERROR)
|
|
os.Remove("bench_ops.db")
|
|
defer os.Remove("bench_ops.db")
|
|
|
|
dbInst := GetDB("sqlite://bench_ops.db", logger)
|
|
schema := `
|
|
== TestGroup ==
|
|
_Table SD
|
|
id c10 PK
|
|
name v64 U
|
|
|
|
_Field SD
|
|
id c10 PK
|
|
tableId c10 I
|
|
name v64
|
|
type v32
|
|
|
|
_Policy SD
|
|
subject v64 I
|
|
|
|
== Test ==
|
|
bench_ops SD
|
|
id c10 PK
|
|
name v50 U
|
|
val i
|
|
`
|
|
_ = dbInst.SyncSchema(schema)
|
|
|
|
appAdmin := dbInst.Auth("admin")
|
|
table := appAdmin.Table("bench_ops")
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
table.Set(map[string]any{
|
|
"name": cast.String(i),
|
|
"val": i,
|
|
})
|
|
}
|
|
}
|
|
*/
|