diff --git a/cache/cache.go b/cache/cache.go index 63d21520..be3c4801 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -90,7 +90,7 @@ func (r *RowCache) RowByModel(m model.Model) model.Model { if reflect.TypeOf(m) != r.dataType { return nil } - info, _ := mapper.NewInfo(&r.schema, m) + info, _ := mapper.NewInfo(r.name, &r.schema, m) uuid, err := info.FieldByColumn("_uuid") if err != nil { return nil @@ -120,7 +120,7 @@ func (r *RowCache) Create(uuid string, m model.Model, checkIndexes bool) error { if reflect.TypeOf(m) != r.dataType { return fmt.Errorf("expected data of type %s, but got %s", r.dataType.String(), reflect.TypeOf(m).String()) } - info, err := mapper.NewInfo(&r.schema, m) + info, err := mapper.NewInfo(r.name, &r.schema, m) if err != nil { return err } @@ -156,11 +156,11 @@ func (r *RowCache) Update(uuid string, m model.Model, checkIndexes bool) error { return fmt.Errorf("row %s does not exist", uuid) } oldRow := model.Clone(r.cache[uuid]) - oldInfo, err := mapper.NewInfo(&r.schema, oldRow) + oldInfo, err := mapper.NewInfo(r.name, &r.schema, oldRow) if err != nil { return err } - newInfo, err := mapper.NewInfo(&r.schema, m) + newInfo, err := mapper.NewInfo(r.name, &r.schema, m) if err != nil { return err } @@ -218,7 +218,7 @@ func (r *RowCache) Update(uuid string, m model.Model, checkIndexes bool) error { } func (r *RowCache) IndexExists(row model.Model) error { - info, err := mapper.NewInfo(&r.schema, row) + info, err := mapper.NewInfo(r.name, &r.schema, row) if err != nil { return err } @@ -252,7 +252,7 @@ func (r *RowCache) Delete(uuid string) error { return fmt.Errorf("row %s does not exist", uuid) } oldRow := r.cache[uuid] - oldInfo, err := mapper.NewInfo(&r.schema, oldRow) + oldInfo, err := mapper.NewInfo(r.name, &r.schema, oldRow) if err != nil { return err } @@ -325,7 +325,7 @@ func (r *RowCache) RowsByCondition(conditions []ovsdb.Condition) ([]model.Model, } else { for _, uuid := range r.Rows() { row := r.Row(uuid) - info, err := mapper.NewInfo(&r.schema, row) + info, err := mapper.NewInfo(r.name, &r.schema, row) if err != nil { return nil, err } @@ -761,18 +761,17 @@ func (t *TableCache) CreateModel(tableName string, row *ovsdb.Row, uuid string) if err != nil { return nil, err } - - err = t.dbModel.Mapper().GetRowData(tableName, row, model) + info, err := mapper.NewInfo(tableName, table, model) + if err != nil { + return nil, err + } + err = t.dbModel.Mapper().GetRowData(row, info) if err != nil { return nil, err } if uuid != "" { - mapperInfo, err := mapper.NewInfo(table, model) - if err != nil { - return nil, err - } - if err := mapperInfo.SetField("_uuid", uuid); err != nil { + if err := info.SetField("_uuid", uuid); err != nil { return nil, err } } @@ -791,7 +790,7 @@ func (t *TableCache) ApplyModifications(tableName string, base model.Model, upda if schema == nil { return fmt.Errorf("no schema for table %s", tableName) } - info, err := mapper.NewInfo(schema, base) + info, err := mapper.NewInfo(tableName, schema, base) if err != nil { return err } diff --git a/cache/cache_test.go b/cache/cache_test.go index cc5f7e21..b2d1db4b 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -253,7 +253,7 @@ func TestRowCacheCreateMultiIndex(t *testing.T) { } } else { assert.Nil(t, err) - mapperInfo, err := mapper.NewInfo(tSchema, tt.model) + mapperInfo, err := mapper.NewInfo("Open_vSwitch", tSchema, tt.model) require.Nil(t, err) h, err := valueFromIndex(mapperInfo, newIndex("foo", "bar")) require.Nil(t, err) @@ -419,7 +419,7 @@ func TestRowCacheUpdateMultiIndex(t *testing.T) { assert.Error(t, err) } else { assert.Nil(t, err) - mapperInfo, err := mapper.NewInfo(tSchema, tt.model) + mapperInfo, err := mapper.NewInfo("Open_vSwitch", tSchema, tt.model) require.Nil(t, err) h, err := valueFromIndex(mapperInfo, newIndex("foo", "bar")) require.Nil(t, err) @@ -982,7 +982,7 @@ func TestIndex(t *testing.T) { t.Run("Index by single column", func(t *testing.T) { idx, err := table.Index("foo") assert.Nil(t, err) - info, err := mapper.NewInfo(schema.Table("Open_vSwitch"), obj) + info, err := mapper.NewInfo("Open_vSwitch", schema.Table("Open_vSwitch"), obj) assert.Nil(t, err) v, err := valueFromIndex(info, newIndex("foo")) assert.Nil(t, err) @@ -994,7 +994,7 @@ func TestIndex(t *testing.T) { obj2 := obj obj2.Foo = "wrong" assert.Nil(t, err) - info, err := mapper.NewInfo(schema.Table("Open_vSwitch"), obj2) + info, err := mapper.NewInfo("Open_vSwitch", schema.Table("Open_vSwitch"), obj2) assert.Nil(t, err) v, err := valueFromIndex(info, newIndex("foo")) assert.Nil(t, err) @@ -1012,7 +1012,7 @@ func TestIndex(t *testing.T) { t.Run("Index by multi-column", func(t *testing.T) { idx, err := table.Index("bar", "baz") assert.Nil(t, err) - info, err := mapper.NewInfo(schema.Table("Open_vSwitch"), obj) + info, err := mapper.NewInfo("Open_vSwitch", schema.Table("Open_vSwitch"), obj) assert.Nil(t, err) v, err := valueFromIndex(info, newIndex("bar", "baz")) assert.Nil(t, err) @@ -1023,7 +1023,7 @@ func TestIndex(t *testing.T) { assert.Nil(t, err) obj2 := obj obj2.Baz++ - info, err := mapper.NewInfo(schema.Table("Open_vSwitch"), obj) + info, err := mapper.NewInfo("Open_vSwitch", schema.Table("Open_vSwitch"), obj) assert.Nil(t, err) v, err := valueFromIndex(info, newIndex("bar", "baz")) assert.Nil(t, err) diff --git a/client/api.go b/client/api.go index 8a0323a8..b92933c3 100644 --- a/client/api.go +++ b/client/api.go @@ -246,7 +246,7 @@ func (a api) Create(models ...model.Model) ([]ovsdb.Operation, error) { table := a.cache.Mapper().Schema.Table(tableName) // Read _uuid field, and use it as named-uuid - info, err := mapper.NewInfo(table, model) + info, err := mapper.NewInfo(tableName, table, model) if err != nil { return nil, err } @@ -256,7 +256,7 @@ func (a api) Create(models ...model.Model) ([]ovsdb.Operation, error) { return nil, err } - row, err := a.cache.Mapper().NewRow(tableName, model) + row, err := a.cache.Mapper().NewRow(info) if err != nil { return nil, err } @@ -294,7 +294,7 @@ func (a api) Mutate(model model.Model, mutationObjs ...model.Mutation) ([]ovsdb. return nil, err } - info, err := mapper.NewInfo(table, model) + info, err := mapper.NewInfo(tableName, table, model) if err != nil { return nil, err } @@ -305,7 +305,7 @@ func (a api) Mutate(model model.Model, mutationObjs ...model.Mutation) ([]ovsdb. return nil, err } - mutation, err := a.cache.Mapper().NewMutation(tableName, model, col, mobj.Mutator, mobj.Value) + mutation, err := a.cache.Mapper().NewMutation(info, col, mobj.Mutator, mobj.Value) if err != nil { return nil, err } @@ -335,12 +335,12 @@ func (a api) Update(model model.Model, fields ...interface{}) ([]ovsdb.Operation return nil, err } tableSchema := a.cache.Mapper().Schema.Table(table) + info, err := mapper.NewInfo(table, tableSchema, model) + if err != nil { + return nil, err + } if len(fields) > 0 { - info, err := mapper.NewInfo(tableSchema, model) - if err != nil { - return nil, err - } for _, f := range fields { colName, err := info.ColumnByPtr(f) if err != nil { @@ -357,7 +357,7 @@ func (a api) Update(model model.Model, fields ...interface{}) ([]ovsdb.Operation return nil, err } - row, err := a.cache.Mapper().NewRow(table, model, fields...) + row, err := a.cache.Mapper().NewRow(info, fields...) if err != nil { return nil, err } diff --git a/client/client.go b/client/client.go index 96876374..d8d45e9f 100644 --- a/client/client.go +++ b/client/client.go @@ -17,6 +17,7 @@ import ( "github.com/cenkalti/rpc2" "github.com/cenkalti/rpc2/jsonrpc" "github.com/ovn-org/libovsdb/cache" + "github.com/ovn-org/libovsdb/mapper" "github.com/ovn-org/libovsdb/model" "github.com/ovn-org/libovsdb/ovsdb" "github.com/ovn-org/libovsdb/ovsdb/serverdb" @@ -134,7 +135,7 @@ func newOVSDBClient(databaseModelRequest *model.DatabaseModelRequest, opts ...Op monitors: make(map[string]*Monitor), } } - ovs.metrics.init(databaseModel.Name()) + ovs.metrics.init(databaseModelRequest.Name()) return ovs, nil } @@ -146,7 +147,7 @@ func newOVSDBClient(databaseModelRequest *model.DatabaseModelRequest, opts ...Op func (o *ovsdbClient) Connect(ctx context.Context) error { // add the "model" value to the structured logger // to make it easier to tell between different DBs (e.g. ovn nbdb vs. sbdb) - l := o.options.logger.WithValues("model", o.primaryDB().model.Name()) + l := o.options.logger.WithValues("model", o.primaryDB().model.Request().Name()) o.options.logger = &l o.registerMetrics() @@ -701,7 +702,8 @@ func (o *ovsdbClient) monitor(ctx context.Context, cookie MonitorCookie, reconne dbName := cookie.DatabaseName db := o.databases[dbName] db.schemaMutex.RLock() - mapper := db.model.Mapper() + mmapper := db.model.Mapper() + schema := db.model.Schema() db.schemaMutex.RUnlock() typeMap := o.databases[dbName].model.Types() requests := make(map[string]ovsdb.MonitorRequest) @@ -710,7 +712,11 @@ func (o *ovsdbClient) monitor(ctx context.Context, cookie MonitorCookie, reconne if !ok { return fmt.Errorf("type for table %s does not exist in model", o.Table) } - request, err := mapper.NewMonitorRequest(o.Table, m, o.Fields) + info, err := mapper.NewInfo(o.Table, schema.Table(o.Table), m) + if err != nil { + return err + } + request, err := mmapper.NewMonitorRequest(info, o.Fields) if err != nil { return err } diff --git a/client/condition.go b/client/condition.go index cba7a7f7..ec396dce 100644 --- a/client/condition.go +++ b/client/condition.go @@ -28,12 +28,16 @@ type Conditional interface { type equalityConditional struct { mapper *mapper.Mapper tableName string - model model.Model + info *mapper.Info singleOp bool } func (c *equalityConditional) Matches(m model.Model) (bool, error) { - return c.mapper.EqualFields(c.tableName, c.model, m) + info, err := mapper.NewInfo(c.tableName, c.mapper.Schema.Table(c.tableName), m) + if err != nil { + return false, err + } + return c.mapper.EqualFields(c.info, info) } func (c *equalityConditional) Table() string { @@ -44,7 +48,7 @@ func (c *equalityConditional) Table() string { func (c *equalityConditional) Generate() ([][]ovsdb.Condition, error) { var result [][]ovsdb.Condition - conds, err := c.mapper.NewEqualityCondition(c.tableName, c.model) + conds, err := c.mapper.NewEqualityCondition(c.info) if err != nil { return nil, err } @@ -59,11 +63,15 @@ func (c *equalityConditional) Generate() ([][]ovsdb.Condition, error) { } // NewEqualityCondition creates a new equalityConditional -func newEqualityConditional(mapper *mapper.Mapper, table string, all bool, model model.Model, fields ...interface{}) (Conditional, error) { +func newEqualityConditional(m *mapper.Mapper, table string, all bool, model model.Model, fields ...interface{}) (Conditional, error) { + info, err := mapper.NewInfo(table, m.Schema.Table(table), model) + if err != nil { + return nil, err + } return &equalityConditional{ - mapper: mapper, + mapper: m, tableName: table, - model: model, + info: info, singleOp: all, }, nil } @@ -72,7 +80,7 @@ func newEqualityConditional(mapper *mapper.Mapper, table string, all bool, model type explicitConditional struct { mapper *mapper.Mapper tableName string - model model.Model + info *mapper.Info conditions []model.Condition singleOp bool } @@ -91,7 +99,7 @@ func (c *explicitConditional) Generate() ([][]ovsdb.Condition, error) { var conds []ovsdb.Condition for _, cond := range c.conditions { - ovsdbCond, err := c.mapper.NewCondition(c.tableName, c.model, cond.Field, cond.Function, cond.Value) + ovsdbCond, err := c.mapper.NewCondition(c.info, cond.Field, cond.Function, cond.Value) if err != nil { return nil, err } @@ -109,11 +117,15 @@ func (c *explicitConditional) Generate() ([][]ovsdb.Condition, error) { } // newIndexCondition creates a new equalityConditional -func newExplicitConditional(mapper *mapper.Mapper, table string, all bool, model model.Model, cond ...model.Condition) (Conditional, error) { +func newExplicitConditional(m *mapper.Mapper, table string, all bool, model model.Model, cond ...model.Condition) (Conditional, error) { + info, err := mapper.NewInfo(table, m.Schema.Table(table), model) + if err != nil { + return nil, err + } return &explicitConditional{ - mapper: mapper, + mapper: m, tableName: table, - model: model, + info: info, conditions: cond, singleOp: all, }, nil @@ -153,7 +165,11 @@ func (c *predicateConditional) Generate() ([][]ovsdb.Condition, error) { return nil, err } if match { - elemCond, err := c.cache.Mapper().NewEqualityCondition(c.tableName, elem) + info, err := mapper.NewInfo(c.tableName, c.cache.Mapper().Schema.Table(c.tableName), elem) + if err != nil { + return nil, err + } + elemCond, err := c.cache.Mapper().NewEqualityCondition(info) if err != nil { return nil, err } diff --git a/mapper/info.go b/mapper/info.go index 1ad981b6..e364ddd2 100644 --- a/mapper/info.go +++ b/mapper/info.go @@ -7,37 +7,42 @@ import ( "github.com/ovn-org/libovsdb/ovsdb" ) -// Info is a struct that handles the type map of an object -// The object must have exported tagged fields with the 'ovs' +// Info is a struct that wraps an object with its metadata type Info struct { // FieldName indexed by column - fields map[string]string - obj interface{} - table *ovsdb.TableSchema + Obj interface{} + Metadata *Metadata +} + +// Metadata represents the information needed to know how to map OVSDB columns into an objet's fields +type Metadata struct { + Fields map[string]string // Map of ColumnName -> FieldName + TableSchema *ovsdb.TableSchema // TableSchema associated + TableName string // Table name } // FieldByColumn returns the field value that corresponds to a column func (i *Info) FieldByColumn(column string) (interface{}, error) { - fieldName, ok := i.fields[column] + fieldName, ok := i.Metadata.Fields[column] if !ok { return nil, fmt.Errorf("FieldByColumn: column %s not found in orm info", column) } - return reflect.ValueOf(i.obj).Elem().FieldByName(fieldName).Interface(), nil + return reflect.ValueOf(i.Obj).Elem().FieldByName(fieldName).Interface(), nil } // FieldByColumn returns the field value that corresponds to a column func (i *Info) hasColumn(column string) bool { - _, ok := i.fields[column] + _, ok := i.Metadata.Fields[column] return ok } // SetField sets the field in the column to the specified value func (i *Info) SetField(column string, value interface{}) error { - fieldName, ok := i.fields[column] + fieldName, ok := i.Metadata.Fields[column] if !ok { return fmt.Errorf("SetField: column %s not found in orm info", column) } - fieldValue := reflect.ValueOf(i.obj).Elem().FieldByName(fieldName) + fieldValue := reflect.ValueOf(i.Obj).Elem().FieldByName(fieldName) if !fieldValue.Type().AssignableTo(reflect.TypeOf(value)) { return fmt.Errorf("column %s: native value %v (%s) is not assignable to field %s (%s)", @@ -53,12 +58,12 @@ func (i *Info) ColumnByPtr(fieldPtr interface{}) (string, error) { if fieldPtrVal.Kind() != reflect.Ptr { return "", ovsdb.NewErrWrongType("ColumnByPointer", "pointer to a field in the struct", fieldPtr) } - offset := fieldPtrVal.Pointer() - reflect.ValueOf(i.obj).Pointer() - objType := reflect.TypeOf(i.obj).Elem() + offset := fieldPtrVal.Pointer() - reflect.ValueOf(i.Obj).Pointer() + objType := reflect.TypeOf(i.Obj).Elem() for j := 0; j < objType.NumField(); j++ { if objType.Field(j).Offset == offset { column := objType.Field(j).Tag.Get("ovsdb") - if _, ok := i.fields[column]; !ok { + if _, ok := i.Metadata.Fields[column]; !ok { return "", fmt.Errorf("field does not have orm column information") } return column, nil @@ -74,7 +79,7 @@ func (i *Info) getValidIndexes() ([][]string, error) { var possibleIndexes [][]string possibleIndexes = append(possibleIndexes, []string{"_uuid"}) - possibleIndexes = append(possibleIndexes, i.table.Indexes...) + possibleIndexes = append(possibleIndexes, i.Metadata.TableSchema.Indexes...) // Iterate through indexes and validate them OUTER: @@ -83,7 +88,7 @@ OUTER: if !i.hasColumn(col) { continue OUTER } - columnSchema := i.table.Column(col) + columnSchema := i.Metadata.TableSchema.Column(col) if columnSchema == nil { continue OUTER } @@ -101,7 +106,7 @@ OUTER: } // NewInfo creates a MapperInfo structure around an object based on a given table schema -func NewInfo(table *ovsdb.TableSchema, obj interface{}) (*Info, error) { +func NewInfo(tableName string, table *ovsdb.TableSchema, obj interface{}) (*Info, error) { objPtrVal := reflect.ValueOf(obj) if objPtrVal.Type().Kind() != reflect.Ptr { return nil, ovsdb.NewErrWrongType("NewMapperInfo", "pointer to a struct", obj) @@ -146,8 +151,11 @@ func NewInfo(table *ovsdb.TableSchema, obj interface{}) (*Info, error) { } return &Info{ - fields: fields, - obj: obj, - table: table, + Obj: obj, + Metadata: &Metadata{ + Fields: fields, + TableSchema: table, + TableName: tableName, + }, }, nil } diff --git a/mapper/info_test.go b/mapper/info_test.go index b50633f3..cf5ad960 100644 --- a/mapper/info_test.go +++ b/mapper/info_test.go @@ -58,7 +58,7 @@ func TestNewMapperInfo(t *testing.T) { err := json.Unmarshal(tt.table, &table) assert.Nil(t, err) - info, err := NewInfo(&table, tt.obj) + info, err := NewInfo("Test", &table, tt.obj) if tt.err { assert.NotNil(t, err) } else { @@ -67,6 +67,7 @@ func TestNewMapperInfo(t *testing.T) { for _, col := range tt.expectedCols { assert.Truef(t, info.hasColumn(col), "Expected column should be present in Mapper Info") } + assert.Equal(t, "Test", info.Metadata.TableName) }) } @@ -141,7 +142,7 @@ func TestMapperInfoSet(t *testing.T) { err := json.Unmarshal(tt.table, &table) assert.Nil(t, err) - info, err := NewInfo(&table, tt.obj) + info, err := NewInfo("Test", &table, tt.obj) assert.Nil(t, err) err = info.SetField(tt.column, tt.field) @@ -222,7 +223,7 @@ func TestMapperInfoColByPtr(t *testing.T) { err := json.Unmarshal(tt.table, &table) assert.Nil(t, err) - info, err := NewInfo(&table, tt.obj) + info, err := NewInfo("Test", &table, tt.obj) assert.Nil(t, err) col, err := info.ColumnByPtr(tt.field) @@ -354,7 +355,7 @@ func TestOrmGetIndex(t *testing.T) { } for _, tt := range tests { t.Run(fmt.Sprintf("GetValidIndexes_%s", tt.name), func(t *testing.T) { - info, err := NewInfo(&table, tt.obj) + info, err := NewInfo("Test", &table, tt.obj) assert.Nil(t, err) indexes, err := info.getValidIndexes() diff --git a/mapper/mapper.go b/mapper/mapper.go index 2a66680c..ae01e4b2 100644 --- a/mapper/mapper.go +++ b/mapper/mapper.go @@ -36,20 +36,20 @@ func (e *ErrMapper) Error() string { e.objType, e.field, e.fieldType, e.fieldTag, e.reason) } -// ErrNoTable describes a error in the provided table information -type ErrNoTable struct { - table string -} +//// ErrNoTable describes a error in the provided table information +//type ErrNoTable struct { +// table string +//} +// +//func (e *ErrNoTable) Error() string { +// return fmt.Sprintf("Table not found: %s", e.table) +//} -func (e *ErrNoTable) Error() string { - return fmt.Sprintf("Table not found: %s", e.table) -} - -func newErrNoTable(table string) error { - return &ErrNoTable{ - table: table, - } -} +//func newErrNoTable(table string) error { +// return &ErrNoTable{ +// table: table, +// } +//} // NewMapper returns a new mapper func NewMapper(schema *ovsdb.DatabaseSchema) *Mapper { @@ -60,29 +60,19 @@ func NewMapper(schema *ovsdb.DatabaseSchema) *Mapper { // GetRowData transforms a Row to a struct based on its tags // The result object must be given as pointer to an object with the right tags -func (m Mapper) GetRowData(tableName string, row *ovsdb.Row, result interface{}) error { +func (m Mapper) GetRowData(row *ovsdb.Row, result *Info) error { if row == nil { return nil } - return m.getData(tableName, *row, result) + return m.getData(*row, result) } // getData transforms a map[string]interface{} containing OvS types (e.g: a ResultRow // has this format) to orm struct // The result object must be given as pointer to an object with the right tags -func (m Mapper) getData(tableName string, ovsData ovsdb.Row, result interface{}) error { - table := m.Schema.Table(tableName) - if table == nil { - return newErrNoTable(tableName) - } - - mapperInfo, err := NewInfo(table, result) - if err != nil { - return err - } - - for name, column := range table.Columns { - if !mapperInfo.hasColumn(name) { +func (m Mapper) getData(ovsData ovsdb.Row, result *Info) error { + for name, column := range result.Metadata.TableSchema.Columns { + if !result.hasColumn(name) { // If provided struct does not have a field to hold this value, skip it continue } @@ -96,10 +86,10 @@ func (m Mapper) getData(tableName string, ovsData ovsdb.Row, result interface{}) nativeElem, err := ovsdb.OvsToNative(column, ovsElem) if err != nil { return fmt.Errorf("table %s, column %s: failed to extract native element: %s", - tableName, name, err.Error()) + result.Metadata.TableName, name, err.Error()) } - if err := mapperInfo.SetField(name, nativeElem); err != nil { + if err := result.SetField(name, nativeElem); err != nil { return err } } @@ -109,24 +99,15 @@ func (m Mapper) getData(tableName string, ovsData ovsdb.Row, result interface{}) // NewRow transforms an orm struct to a map[string] interface{} that can be used as libovsdb.Row // By default, default or null values are skipped. This behavior can be modified by specifying // a list of fields (pointers to fields in the struct) to be added to the row -func (m Mapper) NewRow(tableName string, data interface{}, fields ...interface{}) (ovsdb.Row, error) { - table := m.Schema.Table(tableName) - if table == nil { - return nil, newErrNoTable(tableName) - } - mapperInfo, err := NewInfo(table, data) - if err != nil { - return nil, err - } - +func (m Mapper) NewRow(data *Info, fields ...interface{}) (ovsdb.Row, error) { columns := make(map[string]*ovsdb.ColumnSchema) - for k, v := range table.Columns { + for k, v := range data.Metadata.TableSchema.Columns { columns[k] = v } columns["_uuid"] = &ovsdb.UUIDColumn ovsRow := make(map[string]interface{}, len(columns)) for name, column := range columns { - nativeElem, err := mapperInfo.FieldByColumn(name) + nativeElem, err := data.FieldByColumn(name) if err != nil { // If provided struct does not have a field to hold this value, skip it continue @@ -136,7 +117,7 @@ func (m Mapper) NewRow(tableName string, data interface{}, fields ...interface{} if len(fields) > 0 { found := false for _, f := range fields { - col, err := mapperInfo.ColumnByPtr(f) + col, err := data.ColumnByPtr(f) if err != nil { return nil, err } @@ -154,7 +135,7 @@ func (m Mapper) NewRow(tableName string, data interface{}, fields ...interface{} } ovsElem, err := ovsdb.NativeToOvs(column, nativeElem) if err != nil { - return nil, fmt.Errorf("table %s, column %s: failed to generate ovs element. %s", tableName, name, err.Error()) + return nil, fmt.Errorf("table %s, column %s: failed to generate ovs element. %s", data.Metadata.TableName, name, err.Error()) } ovsRow[name] = ovsElem } @@ -169,25 +150,15 @@ func (m Mapper) NewRow(tableName string, data interface{}, fields ...interface{} // object has valid data. The order in which they are traversed matches the order defined // in the schema. // By `valid data` we mean non-default data. -func (m Mapper) NewEqualityCondition(tableName string, data interface{}, fields ...interface{}) ([]ovsdb.Condition, error) { +func (m Mapper) NewEqualityCondition(data *Info, fields ...interface{}) ([]ovsdb.Condition, error) { var conditions []ovsdb.Condition var condIndex [][]string - table := m.Schema.Table(tableName) - if table == nil { - return nil, newErrNoTable(tableName) - } - - mapperInfo, err := NewInfo(table, data) - if err != nil { - return nil, err - } - // If index is provided, use it. If not, obtain the valid indexes from the mapper info if len(fields) > 0 { providedIndex := []string{} for i := range fields { - if col, err := mapperInfo.ColumnByPtr(fields[i]); err == nil { + if col, err := data.ColumnByPtr(fields[i]); err == nil { providedIndex = append(providedIndex, col) } else { return nil, err @@ -196,7 +167,7 @@ func (m Mapper) NewEqualityCondition(tableName string, data interface{}, fields condIndex = append(condIndex, providedIndex) } else { var err error - condIndex, err = mapperInfo.getValidIndexes() + condIndex, err = data.getValidIndexes() if err != nil { return nil, err } @@ -208,12 +179,12 @@ func (m Mapper) NewEqualityCondition(tableName string, data interface{}, fields // Pick the first valid index for _, col := range condIndex[0] { - field, err := mapperInfo.FieldByColumn(col) + field, err := data.FieldByColumn(col) if err != nil { return nil, err } - column := table.Column(col) + column := data.Metadata.TableSchema.Column(col) if column == nil { return nil, fmt.Errorf("column %s not found", col) } @@ -229,47 +200,27 @@ func (m Mapper) NewEqualityCondition(tableName string, data interface{}, fields // EqualFields compares two mapped objects. // The indexes to use for comparison are, the _uuid, the table indexes and the columns that correspond // to the mapped fields pointed to by 'fields'. They must be pointers to fields on the first mapped element (i.e: one) -func (m Mapper) EqualFields(tableName string, one, other interface{}, fields ...interface{}) (bool, error) { +func (m Mapper) EqualFields(one, other *Info, fields ...interface{}) (bool, error) { indexes := []string{} - - table := m.Schema.Table(tableName) - if table == nil { - return false, newErrNoTable(tableName) - } - - info, err := NewInfo(table, one) - if err != nil { - return false, err - } for _, f := range fields { - col, err := info.ColumnByPtr(f) + col, err := one.ColumnByPtr(f) if err != nil { return false, err } indexes = append(indexes, col) } - return m.equalIndexes(table, one, other, indexes...) + return m.equalIndexes(one, other, indexes...) } // NewCondition returns a ovsdb.Condition based on the model -func (m Mapper) NewCondition(tableName string, data interface{}, field interface{}, function ovsdb.ConditionFunction, value interface{}) (*ovsdb.Condition, error) { - table := m.Schema.Table(tableName) - if table == nil { - return nil, newErrNoTable(tableName) - } - - info, err := NewInfo(table, data) - if err != nil { - return nil, err - } - - column, err := info.ColumnByPtr(field) +func (m Mapper) NewCondition(data *Info, field interface{}, function ovsdb.ConditionFunction, value interface{}) (*ovsdb.Condition, error) { + column, err := data.ColumnByPtr(field) if err != nil { return nil, err } // Check that the condition is valid - columnSchema := table.Column(column) + columnSchema := data.Metadata.TableSchema.Column(column) if columnSchema == nil { return nil, fmt.Errorf("column %s not found", column) } @@ -290,23 +241,13 @@ func (m Mapper) NewCondition(tableName string, data interface{}, field interface // NewMutation creates a RFC7047 mutation object based on an ORM object and the mutation fields (in native format) // It takes care of field validation against the column type -func (m Mapper) NewMutation(tableName string, data interface{}, column string, mutator ovsdb.Mutator, value interface{}) (*ovsdb.Mutation, error) { - table := m.Schema.Table(tableName) - if table == nil { - return nil, newErrNoTable(tableName) - } - - mapperInfo, err := NewInfo(table, data) - if err != nil { - return nil, err - } - +func (m Mapper) NewMutation(data *Info, column string, mutator ovsdb.Mutator, value interface{}) (*ovsdb.Mutation, error) { // Check the column exists in the object - if !mapperInfo.hasColumn(column) { + if !data.hasColumn(column) { return nil, fmt.Errorf("mutation contains column %s that does not exist in object %v", column, data) } // Check that the mutation is valid - columnSchema := table.Column(column) + columnSchema := data.Metadata.TableSchema.Column(column) if columnSchema == nil { return nil, fmt.Errorf("column %s not found", column) } @@ -315,6 +256,7 @@ func (m Mapper) NewMutation(tableName string, data interface{}, column string, m } var ovsValue interface{} + var err error // Usually a mutation value is of the same type of the value being mutated // except for delete mutation of maps where it can also be a list of same type of // keys (rfc7047 5.1). Handle this special case here. @@ -341,24 +283,15 @@ func (m Mapper) NewMutation(tableName string, data interface{}, column string, m // For any of the indexes defined in the Table Schema, the values all of its columns are simultaneously equal // (as per RFC7047) // The values of all of the optional indexes passed as variadic parameter to this function are equal. -func (m Mapper) equalIndexes(table *ovsdb.TableSchema, one, other interface{}, indexes ...string) (bool, error) { +func (m Mapper) equalIndexes(one, other *Info, indexes ...string) (bool, error) { match := false - oneMapperInfo, err := NewInfo(table, one) - if err != nil { - return false, err - } - otherMapperInfo, err := NewInfo(table, other) - if err != nil { - return false, err - } - - oneIndexes, err := oneMapperInfo.getValidIndexes() + oneIndexes, err := one.getValidIndexes() if err != nil { return false, err } - otherIndexes, err := otherMapperInfo.getValidIndexes() + otherIndexes, err := other.getValidIndexes() if err != nil { return false, err } @@ -371,14 +304,14 @@ func (m Mapper) equalIndexes(table *ovsdb.TableSchema, one, other interface{}, i if reflect.DeepEqual(ridx, lidx) { // All columns in an index must be simultaneously equal for _, col := range lidx { - if !oneMapperInfo.hasColumn(col) || !otherMapperInfo.hasColumn(col) { + if !one.hasColumn(col) || !other.hasColumn(col) { break } - lfield, err := oneMapperInfo.FieldByColumn(col) + lfield, err := one.FieldByColumn(col) if err != nil { return false, err } - rfield, err := otherMapperInfo.FieldByColumn(col) + rfield, err := other.FieldByColumn(col) if err != nil { return false, err } @@ -401,23 +334,18 @@ func (m Mapper) equalIndexes(table *ovsdb.TableSchema, one, other interface{}, i // NewMonitorRequest returns a monitor request for the provided tableName // If fields is provided, the request will be constrained to the provided columns // If no fields are provided, all columns will be used -func (m *Mapper) NewMonitorRequest(tableName string, data interface{}, fields []interface{}) (*ovsdb.MonitorRequest, error) { +func (m *Mapper) NewMonitorRequest(data *Info, fields []interface{}) (*ovsdb.MonitorRequest, error) { var columns []string - schema := m.Schema.Tables[tableName] - info, err := NewInfo(&schema, data) - if err != nil { - return nil, err - } if len(fields) > 0 { for _, f := range fields { - column, err := info.ColumnByPtr(f) + column, err := data.ColumnByPtr(f) if err != nil { return nil, err } columns = append(columns, column) } } else { - for c := range info.table.Columns { + for c := range data.Metadata.TableSchema.Columns { columns = append(columns, c) } } diff --git a/mapper/mapper_test.go b/mapper/mapper_test.go index 5cb50e8b..7cb3aa80 100644 --- a/mapper/mapper_test.go +++ b/mapper/mapper_test.go @@ -226,7 +226,11 @@ func TestMapperGetData(t *testing.T) { test := ormTestType{ NonTagged: "something", } - err := mapper.GetRowData("TestTable", &ovsRow, &test) + testInfo, err := NewInfo("TestTable", schema.Table("TestTable"), &test) + assert.NoError(t, err) + + err = mapper.GetRowData(&ovsRow, testInfo) + assert.NoError(t, err) /*End code under test*/ if err != nil { @@ -341,7 +345,9 @@ func TestMapperNewRow(t *testing.T) { for _, test := range tests { t.Run(fmt.Sprintf("NewRow: %s", test.name), func(t *testing.T) { mapper := NewMapper(&schema) - row, err := mapper.NewRow("TestTable", test.objInput) + info, err := NewInfo("TestTable", schema.Table("TestTable"), test.objInput) + assert.NoError(t, err) + row, err := mapper.NewRow(info) if test.shoulderr { assert.NotNil(t, err) } else { @@ -432,7 +438,9 @@ func TestMapperNewRowFields(t *testing.T) { testObj.MyFloat = 0 test.prepare(&testObj) - row, err := mapper.NewRow("TestTable", &testObj, test.fields...) + info, err := NewInfo("TestTable", schema.Table("TestTable"), &testObj) + assert.NoError(t, err) + row, err := mapper.NewRow(info, test.fields...) if test.err { assert.NotNil(t, err) } else { @@ -584,7 +592,10 @@ func TestMapperCondition(t *testing.T) { for _, tt := range tests { t.Run(fmt.Sprintf("newEqualityCondition_%s", tt.name), func(t *testing.T) { tt.prepare(&testObj) - conds, err := mapper.NewEqualityCondition("TestTable", &testObj, tt.index...) + info, err := NewInfo("TestTable", schema.Table("TestTable"), &testObj) + assert.NoError(t, err) + + conds, err := mapper.NewEqualityCondition(info, tt.index...) if tt.err { if err == nil { t.Errorf("expected an error but got none") @@ -835,7 +846,11 @@ func TestMapperEqualIndexes(t *testing.T) { } for _, test := range tests { t.Run(fmt.Sprintf("Equal %s", test.name), func(t *testing.T) { - eq, err := mapper.equalIndexes(mapper.Schema.Table("TestTable"), &test.obj1, &test.obj2, test.indexes...) + info1, err := NewInfo("TestTable", schema.Table("TestTable"), &test.obj1) + assert.NoError(t, err) + info2, err := NewInfo("TestTable", schema.Table("TestTable"), &test.obj2) + assert.NoError(t, err) + eq, err := mapper.equalIndexes(info1, info2, test.indexes...) assert.Nil(t, err) assert.Equalf(t, test.expected, eq, "equal value should match expected") }) @@ -858,11 +873,15 @@ func TestMapperEqualIndexes(t *testing.T) { Int1: 42, Int2: 25, } - eq, err := mapper.EqualFields("TestTable", &obj1, &obj2, &obj1.Int1, &obj1.Int2) + info1, err := NewInfo("TestTable", schema.Table("TestTable"), &obj1) + assert.NoError(t, err) + info2, err := NewInfo("TestTable", schema.Table("TestTable"), &obj2) + assert.NoError(t, err) + eq, err := mapper.EqualFields(info1, info2, &obj1.Int1, &obj1.Int2) assert.Nil(t, err) assert.True(t, eq) // Using pointers to second value is not supported - _, err = mapper.EqualFields("TestTable", &obj1, &obj2, &obj2.Int1, &obj2.Int2) + _, err = mapper.EqualFields(info1, info2, &obj2.Int1, &obj2.Int2) assert.NotNil(t, err) } @@ -1012,7 +1031,10 @@ func TestMapperMutation(t *testing.T) { } for _, test := range tests { t.Run(fmt.Sprintf("newMutation%s", test.name), func(t *testing.T) { - mutation, err := mapper.NewMutation("TestTable", &test.obj, test.column, test.mutator, test.value) + info, err := NewInfo("TestTable", schema.Table("TestTable"), &test.obj) + assert.NoError(t, err) + + mutation, err := mapper.NewMutation(info, test.column, test.mutator, test.value) if test.err { if err == nil { t.Errorf("expected an error but got none") @@ -1097,10 +1119,12 @@ func TestNewMonitorRequest(t *testing.T) { require.NoError(t, err) mapper := NewMapper(&schema) testTable := &testType{} - mr, err := mapper.NewMonitorRequest("TestTable", testTable, nil) + info, err := NewInfo("TestTable", schema.Table("TestTable"), testTable) + assert.NoError(t, err) + mr, err := mapper.NewMonitorRequest(info, nil) require.NoError(t, err) assert.ElementsMatch(t, mr.Columns, []string{"name", "config", "composed_1", "composed_2", "int1", "int2"}) - mr2, err := mapper.NewMonitorRequest("TestTable", testTable, []interface{}{&testTable.Int1, &testTable.MyName}) + mr2, err := mapper.NewMonitorRequest(info, []interface{}{&testTable.Int1, &testTable.MyName}) require.NoError(t, err) assert.ElementsMatch(t, mr2.Columns, []string{"int1", "name"}) } diff --git a/model/request.go b/model/request.go index e1e2ce44..e8b82501 100644 --- a/model/request.go +++ b/model/request.go @@ -49,7 +49,7 @@ func (db DatabaseModelRequest) validate(schema *ovsdb.DatabaseSchema) []error { errors = append(errors, err) continue } - if _, err := mapper.NewInfo(tableSchema, model); err != nil { + if _, err := mapper.NewInfo(tableName, tableSchema, model); err != nil { errors = append(errors, err) } } diff --git a/server/server.go b/server/server.go index 3a198c6c..6cf292df 100644 --- a/server/server.go +++ b/server/server.go @@ -141,8 +141,8 @@ type Transaction struct { Cache *cache.TableCache } -func NewTransaction(schema *ovsdb.DatabaseSchema, model *model.DatabaseModelRequest) Transaction { - cache, err := cache.NewTableCache(schema, model, nil) +func NewTransaction(schema *ovsdb.DatabaseSchema, model *model.DatabaseModel) Transaction { + cache, err := cache.NewTableCache(model, nil) if err != nil { panic(err) } diff --git a/server/transact.go b/server/transact.go index d64d1f50..b0ba5341 100644 --- a/server/transact.go +++ b/server/transact.go @@ -89,7 +89,13 @@ func (o *OvsdbServer) Insert(database string, table string, rowUUID string, row }, nil } - err = m.GetRowData(table, &row, model) + mapperInfo, err := mapper.NewInfo(table, tSchema, model) + if err != nil { + return ovsdb.OperationResult{ + Error: err.Error(), + }, nil + } + err = m.GetRowData(&row, mapperInfo) if err != nil { return ovsdb.OperationResult{ Error: err.Error(), @@ -97,12 +103,6 @@ func (o *OvsdbServer) Insert(database string, table string, rowUUID string, row } if rowUUID != "" { - mapperInfo, err := mapper.NewInfo(tSchema, model) - if err != nil { - return ovsdb.OperationResult{ - Error: err.Error(), - }, nil - } if err := mapperInfo.SetField("_uuid", rowUUID); err != nil { return ovsdb.OperationResult{ Error: err.Error(), @@ -110,7 +110,7 @@ func (o *OvsdbServer) Insert(database string, table string, rowUUID string, row } } - resultRow, err := m.NewRow(table, model) + resultRow, err := m.NewRow(mapperInfo) if err != nil { return ovsdb.OperationResult{ Error: err.Error(), @@ -163,7 +163,11 @@ func (o *OvsdbServer) Select(database string, table string, where []ovsdb.Condit panic(err) } for _, row := range rows { - resultRow, err := m.NewRow(table, row) + info, err := mapper.NewInfo(table, dbModel.Schema().Table(table), row) + if err != nil { + panic(err) + } + resultRow, err := m.NewRow(info) if err != nil { panic(err) } @@ -194,10 +198,10 @@ func (o *OvsdbServer) Update(database, table string, where []ovsdb.Condition, ro }, nil } for _, old := range rows { - info, _ := mapper.NewInfo(schema, old) - uuid, _ := info.FieldByColumn("_uuid") + oldInfo, _ := mapper.NewInfo(table, schema, old) + uuid, _ := oldInfo.FieldByColumn("_uuid") - oldRow, err := m.NewRow(table, old) + oldRow, err := m.NewRow(oldInfo) if err != nil { panic(err) } @@ -205,15 +209,15 @@ func (o *OvsdbServer) Update(database, table string, where []ovsdb.Condition, ro if err != nil { panic(err) } - err = m.GetRowData(table, &oldRow, new) + newInfo, err := mapper.NewInfo(table, schema, new) if err != nil { panic(err) } - info, err = mapper.NewInfo(schema, new) + err = m.GetRowData(&oldRow, newInfo) if err != nil { panic(err) } - err = info.SetField("_uuid", uuid) + err = newInfo.SetField("_uuid", uuid) if err != nil { panic(err) } @@ -235,7 +239,7 @@ func (o *OvsdbServer) Update(database, table string, where []ovsdb.Condition, ro Details: fmt.Sprintf("column %s is of table %s not mutable", column, table), }, nil } - old, err := info.FieldByColumn(column) + old, err := newInfo.FieldByColumn(column) if err != nil { panic(err) } @@ -254,7 +258,7 @@ func (o *OvsdbServer) Update(database, table string, where []ovsdb.Condition, ro continue } - err = info.SetField(column, native) + err = newInfo.SetField(column, native) if err != nil { panic(err) } @@ -270,7 +274,7 @@ func (o *OvsdbServer) Update(database, table string, where []ovsdb.Condition, ro } } - newRow, err := m.NewRow(table, new) + newRow, err := m.NewRow(newInfo) if err != nil { panic(err) } @@ -324,12 +328,12 @@ func (o *OvsdbServer) Mutate(database, table string, where []ovsdb.Condition, mu } for _, old := range rows { - oldInfo, err := mapper.NewInfo(schema, old) + oldInfo, err := mapper.NewInfo(table, schema, old) if err != nil { panic(err) } uuid, _ := oldInfo.FieldByColumn("_uuid") - oldRow, err := m.NewRow(table, old) + oldRow, err := m.NewRow(oldInfo) if err != nil { panic(err) } @@ -337,11 +341,11 @@ func (o *OvsdbServer) Mutate(database, table string, where []ovsdb.Condition, mu if err != nil { panic(err) } - err = m.GetRowData(table, &oldRow, new) + newInfo, err := mapper.NewInfo(table, schema, new) if err != nil { panic(err) } - newInfo, err := mapper.NewInfo(schema, new) + err = m.GetRowData(&oldRow, newInfo) if err != nil { panic(err) } @@ -424,7 +428,7 @@ func (o *OvsdbServer) Mutate(database, table string, where []ovsdb.Condition, mu }, nil } - newRow, err := m.NewRow(table, new) + newRow, err := m.NewRow(newInfo) if err != nil { panic(err) } @@ -460,9 +464,9 @@ func (o *OvsdbServer) Delete(database, table string, where []ovsdb.Condition) (o panic(err) } for _, row := range rows { - info, _ := mapper.NewInfo(schema, row) + info, _ := mapper.NewInfo(table, schema, row) uuid, _ := info.FieldByColumn("_uuid") - oldRow, err := m.NewRow(table, row) + oldRow, err := m.NewRow(info) if err != nil { panic(err) } diff --git a/server/transact_test.go b/server/transact_test.go index fbfa8032..c2849654 100644 --- a/server/transact_test.go +++ b/server/transact_test.go @@ -32,7 +32,9 @@ func TestMutateOp(t *testing.T) { m := mapper.NewMapper(schema) ovs := ovsType{} - ovsRow, err := m.NewRow("Open_vSwitch", &ovs) + info, err := mapper.NewInfo("Open_vSwitch", schema.Table("Open_vSwitch"), &ovs) + require.NoError(t, err) + ovsRow, err := m.NewRow(info) require.Nil(t, err) bridge := bridgeType{ @@ -43,7 +45,9 @@ func TestMutateOp(t *testing.T) { "waldo": "fred", }, } - bridgeRow, err := m.NewRow("Bridge", &bridge) + bridgeInfo, err := mapper.NewInfo("Bridge", schema.Table("Bridge"), &bridge) + require.NoError(t, err) + bridgeRow, err := m.NewRow(bridgeInfo) require.Nil(t, err) res, updates := o.Insert("Open_vSwitch", "Open_vSwitch", ovsUUID, ovsRow) @@ -214,8 +218,7 @@ func TestOvsdbServerInsert(t *testing.T) { t.Fatal(err) } ovsDB := NewInMemoryDatabase(map[string]*model.DatabaseModelRequest{"Open_vSwitch": defDB}) - o, err := NewOvsdbServer(ovsDB, DatabaseModel{ - Model: defDB, Schema: schema}) + o, err := NewOvsdbServer(ovsDB, *model.NewDatabaseModel(schema, defDB)) require.Nil(t, err) m := mapper.NewMapper(schema) @@ -231,7 +234,9 @@ func TestOvsdbServerInsert(t *testing.T) { }, } bridgeUUID := uuid.NewString() - bridgeRow, err := m.NewRow("Bridge", &bridge) + bridgeInfo, err := mapper.NewInfo("Bridge", schema.Table("Bridge"), &bridge) + require.NoError(t, err) + bridgeRow, err := m.NewRow(bridgeInfo) require.Nil(t, err) res, updates := o.Insert("Open_vSwitch", "Bridge", bridgeUUID, bridgeRow) @@ -267,8 +272,7 @@ func TestOvsdbServerUpdate(t *testing.T) { t.Fatal(err) } ovsDB := NewInMemoryDatabase(map[string]*model.DatabaseModelRequest{"Open_vSwitch": defDB}) - o, err := NewOvsdbServer(ovsDB, DatabaseModel{ - Model: defDB, Schema: schema}) + o, err := NewOvsdbServer(ovsDB, *model.NewDatabaseModel(schema, defDB)) require.Nil(t, err) m := mapper.NewMapper(schema) @@ -281,7 +285,9 @@ func TestOvsdbServerUpdate(t *testing.T) { }, } bridgeUUID := uuid.NewString() - bridgeRow, err := m.NewRow("Bridge", &bridge) + bridgeInfo, err := mapper.NewInfo("Bridge", schema.Table("Bridge"), &bridge) + require.NoError(t, err) + bridgeRow, err := m.NewRow(bridgeInfo) require.Nil(t, err) res, updates := o.Insert("Open_vSwitch", "Bridge", bridgeUUID, bridgeRow)