diff --git a/consts/consts.go b/consts/consts.go index bf5cc9f1..76bdb82e 100644 --- a/consts/consts.go +++ b/consts/consts.go @@ -20,4 +20,5 @@ const ( ComposeProjectName = "COMPOSE_PROJECT_NAME" ComposePathSeparator = "COMPOSE_PATH_SEPARATOR" ComposeFilePath = "COMPOSE_FILE" + ComposeProfiles = "COMPOSE_PROFILES" ) diff --git a/loader/loader.go b/loader/loader.go index 93b3c93c..eb09b313 100644 --- a/loader/loader.go +++ b/loader/loader.go @@ -246,9 +246,10 @@ func Load(configDetails types.ConfigDetails, options ...func(*Options)) (*types. } } - if len(opts.Profiles) > 0 { - project.ApplyProfiles(opts.Profiles) + if profiles, ok := project.Environment[consts.ComposeProfiles]; ok && len(opts.Profiles) == 0 { + opts.Profiles = strings.Split(profiles, ",") } + project.ApplyProfiles(opts.Profiles) err = project.ResolveServicesEnvironment(opts.discardEnvFiles) diff --git a/loader/merge_test.go b/loader/merge_test.go index 9b919224..96301d40 100644 --- a/loader/merge_test.go +++ b/loader/merge_test.go @@ -1223,7 +1223,7 @@ func TestMergeTopLevelExtensions(t *testing.T) { assert.DeepEqual(t, &types.Project{ Name: "", WorkingDir: "", - Services: types.Services{}, + Services: nil, Networks: types.Networks{}, Volumes: types.Volumes{}, Secrets: types.Secrets{}, diff --git a/types/project.go b/types/project.go index cf7526eb..59c06b41 100644 --- a/types/project.go +++ b/types/project.go @@ -45,6 +45,7 @@ type Project struct { // DisabledServices track services which have been disable as profile is not active DisabledServices Services `yaml:"-" json:"-"` + Profiles []string `yaml:"-" json:"-"` } // ServiceNames return names for all services in this Compose config @@ -119,6 +120,16 @@ func (p *Project) GetServices(names ...string) (Services, error) { return services, nil } +// GetDisabledService retrieve disabled service by name +func (p Project) GetDisabledService(name string) (ServiceConfig, error) { + for _, config := range p.DisabledServices { + if config.Name == name { + return config, nil + } + } + return ServiceConfig{}, fmt.Errorf("no such service: %s", name) +} + // GetService retrieve a specific service by name func (p *Project) GetService(name string) (ServiceConfig, error) { services, err := p.GetServices(name) @@ -247,7 +258,7 @@ func (p *Project) ApplyProfiles(profiles []string) { } } var enabled, disabled Services - for _, service := range p.Services { + for _, service := range p.AllServices() { if service.HasProfile(profiles) { enabled = append(enabled, service) } else { @@ -256,6 +267,41 @@ func (p *Project) ApplyProfiles(profiles []string) { } p.Services = enabled p.DisabledServices = disabled + p.Profiles = profiles +} + +// EnableServices ensure services are enabled and activate profiles accordingly +func (p *Project) EnableServices(names ...string) error { + if len(names) == 0 { + return nil + } + var enabled []string + for _, name := range names { + _, err := p.GetService(name) + if err == nil { + // already enabled + continue + } + def, err := p.GetDisabledService(name) + if err != nil { + return err + } + enabled = append(enabled, def.Profiles...) + } + + profiles := p.Profiles +PROFILES: + for _, profile := range enabled { + for _, p := range profiles { + if p == profile { + continue PROFILES + } + } + profiles = append(profiles, profile) + } + p.ApplyProfiles(profiles) + + return p.ResolveServicesEnvironment(true) } // WithoutUnnecessaryResources drops networks/volumes/secrets/configs that are not referenced by active services diff --git a/types/project_test.go b/types/project_test.go index 05b8e459..b8421bc4 100644 --- a/types/project_test.go +++ b/types/project_test.go @@ -31,8 +31,22 @@ func Test_ApplyProfiles(t *testing.T) { assert.Equal(t, len(p.Services), 2) assert.Equal(t, p.Services[0].Name, "service_1") assert.Equal(t, p.Services[1].Name, "service_2") + assert.Equal(t, len(p.DisabledServices), 3) + assert.Equal(t, p.DisabledServices[0].Name, "service_3") + assert.Equal(t, p.DisabledServices[1].Name, "service_4") + assert.Equal(t, p.DisabledServices[2].Name, "service_5") + + err := p.EnableServices("service_4") + assert.NilError(t, err) + + assert.Equal(t, len(p.Services), 4) + assert.Equal(t, p.Services[0].Name, "service_1") + assert.Equal(t, p.Services[1].Name, "service_2") + assert.Equal(t, p.Services[2].Name, "service_4") + assert.Equal(t, p.Services[3].Name, "service_5") assert.Equal(t, len(p.DisabledServices), 1) assert.Equal(t, p.DisabledServices[0].Name, "service_3") + } func Test_WithoutUnnecessaryResources(t *testing.T) { @@ -60,7 +74,7 @@ func Test_NoProfiles(t *testing.T) { p := makeProject() p.ApplyProfiles(nil) assert.Equal(t, len(p.Services), 1) - assert.Equal(t, len(p.DisabledServices), 2) + assert.Equal(t, len(p.DisabledServices), 4) assert.Equal(t, p.Services[0].Name, "service_1") } @@ -79,8 +93,10 @@ func Test_ForServices(t *testing.T) { err := p.ForServices([]string{"service_2"}) assert.NilError(t, err) - assert.Equal(t, len(p.DisabledServices), 1) + assert.Equal(t, len(p.DisabledServices), 3) assert.Equal(t, p.DisabledServices[0].Name, "service_3") + assert.Equal(t, p.DisabledServices[1].Name, "service_4") + assert.Equal(t, p.DisabledServices[2].Name, "service_5") } func Test_ForServicesCycle(t *testing.T) { @@ -103,6 +119,12 @@ func makeProject() Project { Name: "service_3", Profiles: []string{"bar"}, DependsOn: map[string]ServiceDependency{"service_2": {}}, + }, ServiceConfig{ + Name: "service_4", + Profiles: []string{"zot"}, + }, ServiceConfig{ + Name: "service_5", + Profiles: []string{"zot"}, }), Networks: Networks{}, Volumes: Volumes{},