diff --git a/cache/cache_test.go b/cache/cache_test.go index 38088ea6..b8c7ee16 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -1,6 +1,7 @@ package cache import ( + "sync" "testing" "encoding/json" @@ -49,6 +50,7 @@ func TestRowCache_Row(t *testing.T) { t.Run(tt.name, func(t *testing.T) { r := &RowCache{ cache: tt.fields.cache, + mutex: sync.RWMutex{}, } got := r.Row(tt.args.uuid) assert.Equal(t, tt.want, got) @@ -75,6 +77,7 @@ func TestRowCache_Rows(t *testing.T) { t.Run(tt.name, func(t *testing.T) { r := &RowCache{ cache: tt.fields.cache, + mutex: sync.RWMutex{}, } got := r.Rows() assert.ElementsMatch(t, tt.want, got) diff --git a/client/api.go b/client/api.go index 7a957608..5321cd9d 100644 --- a/client/api.go +++ b/client/api.go @@ -304,6 +304,9 @@ func (a api) Mutate(model model.Model, mutationObjs ...model.Mutation) ([]ovsdb. } tableName := a.cache.DBModel().FindTable(reflect.ValueOf(model).Type()) + if tableName == "" { + return nil, fmt.Errorf("table not found for object") + } table := a.cache.Mapper().Schema.Table(tableName) if table == nil { return nil, fmt.Errorf("schema error: table not found in Database Model for type %s", reflect.TypeOf(model)) diff --git a/client/client.go b/client/client.go index 5ceba616..8d5e792e 100644 --- a/client/client.go +++ b/client/client.go @@ -14,6 +14,7 @@ import ( "github.com/cenkalti/rpc2" "github.com/cenkalti/rpc2/jsonrpc" "github.com/ovn-org/libovsdb/cache" + "github.com/ovn-org/libovsdb/mapper" "github.com/ovn-org/libovsdb/model" "github.com/ovn-org/libovsdb/ovsdb" ) @@ -21,6 +22,7 @@ import ( // OvsdbClient is an OVSDB client type OvsdbClient struct { rpcClient *rpc2.Client + dbModel *model.DBModel Schema ovsdb.DatabaseSchema handlers []ovsdb.NotificationHandler handlersMutex *sync.Mutex @@ -29,9 +31,10 @@ type OvsdbClient struct { api API } -func newOvsdbClient() *OvsdbClient { +func newOvsdbClient(dbModel *model.DBModel) *OvsdbClient { // Cache initialization is delayed because we first need to obtain the schema ovs := &OvsdbClient{ + dbModel: dbModel, handlersMutex: &sync.Mutex{}, stopCh: make(chan struct{}), } @@ -85,7 +88,7 @@ func Connect(ctx context.Context, database *model.DBModel, opts ...Option) (*Ovs } func newRPC2Client(conn net.Conn, database *model.DBModel) (*OvsdbClient, error) { - ovs := newOvsdbClient() + ovs := newOvsdbClient(database) ovs.rpcClient = rpc2.NewClientWithCodec(jsonrpc.NewJSONCodec(conn)) ovs.rpcClient.SetBlocking(true) ovs.rpcClient.Handle("echo", func(_ *rpc2.Client, args []interface{}, reply *[]interface{}) error { @@ -254,18 +257,11 @@ func (ovs OvsdbClient) Transact(operation ...ovsdb.Operation) ([]ovsdb.Operation // MonitorAll is a convenience method to monitor every table/column func (ovs OvsdbClient) MonitorAll(jsonContext interface{}) error { - requests := make(map[string]ovsdb.MonitorRequest) - for table, tableSchema := range ovs.Schema.Tables { - var columns []string - for column := range tableSchema.Columns { - columns = append(columns, column) - } - requests[table] = ovsdb.MonitorRequest{ - Columns: columns, - Select: ovsdb.NewDefaultMonitorSelect(), - } + var options []TableMonitor + for name := range ovs.dbModel.Types() { + options = append(options, TableMonitor{Table: name}) } - return ovs.Monitor(jsonContext, requests) + return ovs.Monitor(jsonContext, options...) } // MonitorCancel will request cancel a previously issued monitor request @@ -285,13 +281,55 @@ func (ovs OvsdbClient) MonitorCancel(jsonContext interface{}) error { return nil } +// TableMonitor is a table to be monitored +type TableMonitor struct { + // Table is the table to be monitored + Table string + // Fields are the fields in the model to monitor + // If none are supplied, all fields will be used + Fields []interface{} +} + +func (o *OvsdbClient) NewTableMonitor(m model.Model, fields ...interface{}) TableMonitor { + tableName := o.dbModel.FindTable(reflect.TypeOf(m)) + if tableName == "" { + panic(fmt.Sprintf("Object of type %s is not part of the DBModel", reflect.TypeOf(m))) + } + return TableMonitor{ + Table: tableName, + Fields: fields, + } +} + // Monitor will provide updates for a given table/column // and populate the cache with them. Subsequent updates will be processed // by the Update Notifications // RFC 7047 : monitor -func (ovs OvsdbClient) Monitor(jsonContext interface{}, requests map[string]ovsdb.MonitorRequest) error { +func (ovs OvsdbClient) Monitor(jsonContext interface{}, options ...TableMonitor) error { var reply ovsdb.TableUpdates - + mapper := mapper.NewMapper(&ovs.Schema) + typeMap := ovs.dbModel.Types() + requests := make(map[string]ovsdb.MonitorRequest) + if len(options) == 0 { + return fmt.Errorf("no monitor options provided") + } + for _, o := range options { + var fields []interface{} + if len(o.Fields) > 0 { + fields = o.Fields + } else { + fields = nil + } + m, ok := typeMap[o.Table] + if !ok { + return fmt.Errorf("type for table %s does not exist in dbModel", o.Table) + } + request, err := mapper.NewMonitorRequest(o.Table, m, fields) + if err != nil { + return err + } + requests[o.Table] = *request + } args := ovsdb.NewMonitorArgs(ovs.Schema.Name, jsonContext, requests) err := ovs.rpcClient.Call("monitor", args, &reply) if err != nil { diff --git a/client/ovs_integration_test.go b/client/ovs_integration_test.go index e81c0764..c6ee311e 100644 --- a/client/ovs_integration_test.go +++ b/client/ovs_integration_test.go @@ -486,7 +486,10 @@ func TestMonitorCancelIntegration(t *testing.T) { Select: ovsdb.NewDefaultMonitorSelect(), } - err = ovs.Monitor(monitorID, requests) + err = ovs.Monitor(monitorID, + ovs.NewTableMonitor(&ovsType{}), + ovs.NewTableMonitor(&bridgeType{}), + ) if err != nil { t.Fatal(err) } diff --git a/example/play_with_ovs/play_with_ovs.go b/example/play_with_ovs/play_with_ovs.go index 273764eb..a1e9868a 100644 --- a/example/play_with_ovs/play_with_ovs.go +++ b/example/play_with_ovs/play_with_ovs.go @@ -114,8 +114,10 @@ func main() { } }, }) - - err = ovs.MonitorAll("") + err = ovs.Monitor("play_with_ovs", + ovs.NewTableMonitor(&OpenvSwitch{}), + ovs.NewTableMonitor(&Bridge{}), + ) if err != nil { log.Fatal(err) } diff --git a/mapper/info.go b/mapper/info.go index 934963d3..fd1ceb1e 100644 --- a/mapper/info.go +++ b/mapper/info.go @@ -47,11 +47,11 @@ func (i *Info) SetField(column string, value interface{}) error { return nil } -// ColumnByPtr returns the column name that corresponds to the field by the field's pminter +// ColumnByPtr returns the column name that corresponds to the field by the field's pointer func (i *Info) ColumnByPtr(fieldPtr interface{}) (string, error) { fieldPtrVal := reflect.ValueOf(fieldPtr) if fieldPtrVal.Kind() != reflect.Ptr { - return "", ovsdb.NewErrWrongType("ColumnByPminter", "pminter to a field in the struct", fieldPtr) + return "", ovsdb.NewErrWrongType("ColumnByPointer", "pointer to a field in the struct", fieldPtr) } offset := fieldPtrVal.Pointer() - reflect.ValueOf(i.obj).Pointer() objType := reflect.TypeOf(i.obj).Elem() @@ -64,7 +64,7 @@ func (i *Info) ColumnByPtr(fieldPtr interface{}) (string, error) { return column, nil } } - return "", fmt.Errorf("field pminter does not correspond to orm struct") + return "", fmt.Errorf("field pointer does not correspond to orm struct") } // getValidIndexes inspects the object and returns the a list of indexes (set of columns) for witch @@ -104,11 +104,11 @@ OUTER: func NewInfo(table *ovsdb.TableSchema, obj interface{}) (*Info, error) { objPtrVal := reflect.ValueOf(obj) if objPtrVal.Type().Kind() != reflect.Ptr { - return nil, ovsdb.NewErrWrongType("NewMapperInfo", "pminter to a struct", obj) + return nil, ovsdb.NewErrWrongType("NewMapperInfo", "pointer to a struct", obj) } objVal := reflect.Indirect(objPtrVal) if objVal.Kind() != reflect.Struct { - return nil, ovsdb.NewErrWrongType("NewMapperInfo", "pminter to a struct", obj) + return nil, ovsdb.NewErrWrongType("NewMapperInfo", "pointer to a struct", obj) } objType := objVal.Type() diff --git a/mapper/mapper.go b/mapper/mapper.go index 5578d0fa..05e54253 100644 --- a/mapper/mapper.go +++ b/mapper/mapper.go @@ -394,3 +394,30 @@ func (m Mapper) equalIndexes(table *ovsdb.TableSchema, one, other interface{}, i } return false, nil } + +// NewMonitorRequest returns a monitor request for the provided tableName +// If fields is provided, the request will be constrained to the provided columns +// If no fields are provided, all columns will be used +func (m *Mapper) NewMonitorRequest(tableName string, data interface{}, fields []interface{}) (*ovsdb.MonitorRequest, error) { + var columns []string + schema := m.Schema.Tables[tableName] + info, err := NewInfo(&schema, data) + if err != nil { + return nil, err + } + if len(fields) > 0 { + for _, f := range fields { + column, err := info.ColumnByPtr(f) + if err != nil { + return nil, err + } + columns = append(columns, column) + } + } else { + columns = append(columns, "_uuid") + for c := range info.table.Columns { + columns = append(columns, c) + } + } + return &ovsdb.MonitorRequest{Columns: columns, Select: ovsdb.NewDefaultMonitorSelect()}, nil +} diff --git a/mapper/mapper_test.go b/mapper/mapper_test.go index c8839925..e3a45f6d 100644 --- a/mapper/mapper_test.go +++ b/mapper/mapper_test.go @@ -7,6 +7,7 @@ import ( "github.com/ovn-org/libovsdb/ovsdb" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var ( @@ -1027,3 +1028,68 @@ func testOvsMap(t *testing.T, set interface{}) *ovsdb.OvsMap { assert.Nil(t, err) return oMap } + +func TestNewMonitorRequest(t *testing.T) { + var testSchema = []byte(`{ + "cksum": "223619766 22548", + "name": "TestSchema", + "tables": { + "TestTable": { + "indexes": [["name"],["composed_1","composed_2"]], + "columns": { + "name": { + "type": "string" + }, + "composed_1": { + "type": { + "key": "string" + } + }, + "composed_2": { + "type": { + "key": "string" + } + }, + "int1": { + "type": { + "key": "integer" + } + }, + "int2": { + "type": { + "key": "integer" + } + }, + "config": { + "type": { + "key": "string", + "max": "unlimited", + "min": 0, + "value": "string" + } + } + } + } + } +}`) + type testType struct { + ID string `ovsdb:"_uuid"` + MyName string `ovsdb:"name"` + Config map[string]string `ovsdb:"config"` + Comp1 string `ovsdb:"composed_1"` + Comp2 string `ovsdb:"composed_2"` + Int1 int `ovsdb:"int1"` + Int2 int `ovsdb:"int2"` + } + var schema ovsdb.DatabaseSchema + err := json.Unmarshal(testSchema, &schema) + require.NoError(t, err) + mapper := NewMapper(&schema) + testTable := &testType{} + mr, err := mapper.NewMonitorRequest("TestTable", testTable, nil) + require.NoError(t, err) + assert.ElementsMatch(t, mr.Columns, []string{"_uuid", "name", "config", "composed_1", "composed_2", "int1", "int2"}) + mr2, err := mapper.NewMonitorRequest("TestTable", testTable, []interface{}{&testTable.ID, &testTable.MyName}) + require.NoError(t, err) + assert.ElementsMatch(t, mr2.Columns, []string{"_uuid", "name"}) +}