From b7de71f9ce74e99dde61ee138608df8edc5486bd Mon Sep 17 00:00:00 2001 From: John-Paul Sassine Date: Tue, 17 Jun 2025 20:49:12 +0000 Subject: [PATCH] feat(kubelet): Add ResourceHealthStatus for DRA pods This change introduces the ability for the Kubelet to monitor and report the health of devices allocated via Dynamic Resource Allocation (DRA). This addresses a key part of KEP-4680 by providing visibility into device failures, which helps users and controllers diagnose pod failures. The implementation includes: - A new `v1alpha1.NodeHealth` gRPC service with a `WatchResources` stream that DRA plugins can optionally implement. - A health information cache within the Kubelet's DRA manager to track the last known health of each device and handle plugin disconnections. - An asynchronous update mechanism that triggers a pod sync when a device's health changes. - A new `allocatedResourcesStatus` field in `v1.ContainerStatus` to expose the device health information to users via the Pod API. Update vendor KEP-4680: Fix lint, boilerplate, and codegen issues Add another e2e test, add TODO for KEP4680 & update test infra helpers Add Feature Gate e2e test Fixing presubmits Fix var names, feature gating, and nits Fix DRA Health gRPC API according to review feedback --- hack/update-codegen.sh | 1 + pkg/kubelet/cm/container_manager_linux.go | 45 +- pkg/kubelet/cm/dra/healthinfo.go | 212 +++++++ pkg/kubelet/cm/dra/healthinfo_test.go | 453 ++++++++++++++ pkg/kubelet/cm/dra/manager.go | 239 +++++++- pkg/kubelet/cm/dra/manager_test.go | 574 ++++++++++++++++-- pkg/kubelet/cm/dra/plugin/dra_plugin.go | 109 ++++ .../cm/dra/plugin/dra_plugin_manager.go | 43 +- .../cm/dra/plugin/dra_plugin_manager_test.go | 11 +- pkg/kubelet/cm/dra/plugin/dra_plugin_test.go | 92 ++- .../cm/dra/plugin/registration_test.go | 6 +- .../cm/dra/plugin/testing_helpers_test.go | 34 ++ pkg/kubelet/cm/dra/plugin/types.go | 31 + pkg/kubelet/cm/dra/state/state.go | 39 ++ pkg/kubelet/kubelet.go | 2 - .../kubeletplugin/draplugin.go | 18 +- .../pkg/apis/dra-health/v1alpha1/api.pb.go | 414 +++++++++++++ .../pkg/apis/dra-health/v1alpha1/api.proto | 67 ++ .../apis/dra-health/v1alpha1/api_grpc.pb.go | 160 +++++ test/e2e/dra/test-driver/app/kubeletplugin.go | 146 ++++- test/e2e/dra/test-driver/app/server.go | 1 + test/e2e_node/dra_test.go | 335 +++++++++- 22 files changed, 2932 insertions(+), 100 deletions(-) create mode 100644 pkg/kubelet/cm/dra/healthinfo.go create mode 100644 pkg/kubelet/cm/dra/healthinfo_test.go create mode 100644 pkg/kubelet/cm/dra/plugin/testing_helpers_test.go create mode 100644 pkg/kubelet/cm/dra/plugin/types.go create mode 100644 staging/src/k8s.io/kubelet/pkg/apis/dra-health/v1alpha1/api.pb.go create mode 100644 staging/src/k8s.io/kubelet/pkg/apis/dra-health/v1alpha1/api.proto create mode 100644 staging/src/k8s.io/kubelet/pkg/apis/dra-health/v1alpha1/api_grpc.pb.go diff --git a/hack/update-codegen.sh b/hack/update-codegen.sh index de31c785cc4..ea899266851 100755 --- a/hack/update-codegen.sh +++ b/hack/update-codegen.sh @@ -1032,6 +1032,7 @@ function codegen::protobindings() { "pkg/kubelet/pluginmanager/pluginwatcher/example_plugin_apis" "staging/src/k8s.io/cri-api/pkg/apis/runtime" "staging/src/k8s.io/externaljwt/apis" + "staging/src/k8s.io/kubelet/pkg/apis/dra-health" ) local apis=("${apis_using_gogo[@]}" "${apis_using_protoc[@]}") diff --git a/pkg/kubelet/cm/container_manager_linux.go b/pkg/kubelet/cm/container_manager_linux.go index 74636398701..c08aa56f244 100644 --- a/pkg/kubelet/cm/container_manager_linux.go +++ b/pkg/kubelet/cm/container_manager_linux.go @@ -136,6 +136,8 @@ type containerManagerImpl struct { draManager *dra.Manager // kubeClient is the interface to the Kubernetes API server. May be nil if the kubelet is running in standalone mode. kubeClient clientset.Interface + // resourceUpdates is a channel that provides resource updates. + resourceUpdates chan resourceupdates.Update } type features struct { @@ -351,6 +353,39 @@ func NewContainerManager(mountUtil mount.Interface, cadvisorInterface cadvisor.I } cm.topologyManager.AddHintProvider(cm.memoryManager) + // Create a single channel for all resource updates. This channel is consumed + // by the Kubelet's main sync loop. + cm.resourceUpdates = make(chan resourceupdates.Update, 10) + + // Start goroutines to fan-in updates from the various sub-managers + // (e.g., device manager, DRA manager) into the single updates channel. + var wg sync.WaitGroup + sources := map[string]<-chan resourceupdates.Update{} + if cm.deviceManager != nil { + sources["deviceManager"] = cm.deviceManager.Updates() + } + if utilfeature.DefaultFeatureGate.Enabled(kubefeatures.DynamicResourceAllocation) && cm.draManager != nil { + if utilfeature.DefaultFeatureGate.Enabled(kubefeatures.ResourceHealthStatus) { + sources["draManager"] = cm.draManager.Updates() + } + } + + for name, ch := range sources { + wg.Add(1) + go func(name string, c <-chan resourceupdates.Update) { + defer wg.Done() + for v := range c { + klog.V(4).InfoS("Container Manager: forwarding resource update", "source", name, "pods", v.PodUIDs) + cm.resourceUpdates <- v + } + }(name, ch) + } + + go func() { + wg.Wait() + close(cm.resourceUpdates) + }() + return cm, nil } @@ -1055,10 +1090,14 @@ func (cm *containerManagerImpl) UpdateAllocatedResourcesStatus(pod *v1.Pod, stat // For now we only support Device Plugin cm.deviceManager.UpdateAllocatedResourcesStatus(pod, status) - // TODO(SergeyKanzhelev, https://kep.k8s.io/4680): add support for DRA resources which is planned for the next iteration of a KEP. + // Update DRA resources if the feature is enabled and the manager exists + if utilfeature.DefaultFeatureGate.Enabled(kubefeatures.DynamicResourceAllocation) && cm.draManager != nil { + if utilfeature.DefaultFeatureGate.Enabled(kubefeatures.ResourceHealthStatus) { + cm.draManager.UpdateAllocatedResourcesStatus(pod, status) + } + } } func (cm *containerManagerImpl) Updates() <-chan resourceupdates.Update { - // TODO(SergeyKanzhelev, https://kep.k8s.io/4680): add support for DRA resources, for now only use device plugin updates. DRA support is planned for the next iteration of a KEP. - return cm.deviceManager.Updates() + return cm.resourceUpdates } diff --git a/pkg/kubelet/cm/dra/healthinfo.go b/pkg/kubelet/cm/dra/healthinfo.go new file mode 100644 index 00000000000..b389192c3b0 --- /dev/null +++ b/pkg/kubelet/cm/dra/healthinfo.go @@ -0,0 +1,212 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package dra + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "sync" + "time" + + "k8s.io/klog/v2" + "k8s.io/kubernetes/pkg/kubelet/cm/dra/state" +) + +// TODO(#133118): Make health timeout configurable. +const ( + healthTimeout = 30 * time.Second +) + +// healthInfoCache is a cache of known device health. +type healthInfoCache struct { + sync.RWMutex + HealthInfo *state.DevicesHealthMap + stateFile string +} + +// newHealthInfoCache creates a new cache, loading from a checkpoint if present. +func newHealthInfoCache(stateFile string) (*healthInfoCache, error) { + cache := &healthInfoCache{ + HealthInfo: &state.DevicesHealthMap{}, + stateFile: stateFile, + } + if err := cache.loadFromCheckpoint(); err != nil { + klog.Background().Error(err, "Failed to load health checkpoint, proceeding with empty cache") + } + return cache, nil +} + +// loadFromCheckpoint loads the cache from the state file. +func (cache *healthInfoCache) loadFromCheckpoint() error { + if cache.stateFile == "" { + return nil + } + data, err := os.ReadFile(cache.stateFile) + if err != nil { + if os.IsNotExist(err) { + cache.HealthInfo = &state.DevicesHealthMap{} + return nil + } + return err + } + return json.Unmarshal(data, cache.HealthInfo) +} + +// withLock runs a function while holding the healthInfoCache lock. +func (cache *healthInfoCache) withLock(f func() error) error { + cache.Lock() + defer cache.Unlock() + return f() +} + +// withRLock runs a function while holding the healthInfoCache rlock. +func (cache *healthInfoCache) withRLock(f func() error) error { + cache.RLock() + defer cache.RUnlock() + return f() +} + +// saveToCheckpointInternal does the actual saving without locking. +// Assumes the caller holds the necessary lock. +func (cache *healthInfoCache) saveToCheckpointInternal() error { + if cache.stateFile == "" { + return nil + } + data, err := json.Marshal(cache.HealthInfo) + if err != nil { + return fmt.Errorf("failed to marshal health info: %w", err) + } + + tempFile, err := os.CreateTemp(filepath.Dir(cache.stateFile), filepath.Base(cache.stateFile)+".tmp") + if err != nil { + return fmt.Errorf("failed to create temp checkpoint file: %w", err) + } + + defer func() { + if err := os.Remove(tempFile.Name()); err != nil && !os.IsNotExist(err) { + klog.Background().Error(err, "Failed to remove temporary checkpoint file", "path", tempFile.Name()) + } + }() + + if _, err := tempFile.Write(data); err != nil { + _ = tempFile.Close() + return fmt.Errorf("failed to write to temporary file: %w", err) + } + + if err := tempFile.Close(); err != nil { + return fmt.Errorf("failed to close temporary file: %w", err) + } + + if err := os.Rename(tempFile.Name(), cache.stateFile); err != nil { + return fmt.Errorf("failed to rename temporary file to state file: %w", err) + } + + return nil +} + +// getHealthInfo returns the current health info, adjusting for timeouts. +func (cache *healthInfoCache) getHealthInfo(driverName, poolName, deviceName string) state.DeviceHealthStatus { + res := state.DeviceHealthStatusUnknown + + _ = cache.withRLock(func() error { + now := time.Now() + if driver, ok := (*cache.HealthInfo)[driverName]; ok { + key := poolName + "/" + deviceName + if device, ok := driver.Devices[key]; ok { + if now.Sub(device.LastUpdated) > healthTimeout { + res = state.DeviceHealthStatusUnknown + } else { + res = device.Health + } + } + } + return nil + }) + return res +} + +// updateHealthInfo reconciles the cache with a fresh list of device health states +// from a plugin. It identifies which devices have changed state and handles devices +// that are no longer being reported by the plugin. +func (cache *healthInfoCache) updateHealthInfo(driverName string, devices []state.DeviceHealth) ([]state.DeviceHealth, error) { + changedDevices := []state.DeviceHealth{} + err := cache.withLock(func() error { + now := time.Now() + currentDriver, exists := (*cache.HealthInfo)[driverName] + if !exists { + currentDriver = state.DriverHealthState{Devices: make(map[string]state.DeviceHealth)} + (*cache.HealthInfo)[driverName] = currentDriver + } + + reportedKeys := make(map[string]struct{}) + + // Phase 1: Process the incoming report from the plugin. + // Update existing devices, add new ones, and record all devices + // present in this report. + for _, reportedDevice := range devices { + reportedDevice.LastUpdated = now + key := reportedDevice.PoolName + "/" + reportedDevice.DeviceName + reportedKeys[key] = struct{}{} + + existingDevice, ok := currentDriver.Devices[key] + + if !ok || existingDevice.Health != reportedDevice.Health { + changedDevices = append(changedDevices, reportedDevice) + } + + currentDriver.Devices[key] = reportedDevice + } + + // Phase 2: Handle devices that are in the cache but were not in the report. + // These devices may have been removed or the plugin may have stopped monitoring + // them. Mark them as "Unknown" if their status has timed out. + for key, existingDevice := range currentDriver.Devices { + if _, wasReported := reportedKeys[key]; !wasReported { + if existingDevice.Health != state.DeviceHealthStatusUnknown && now.Sub(existingDevice.LastUpdated) > healthTimeout { + existingDevice.Health = state.DeviceHealthStatusUnknown + existingDevice.LastUpdated = now + currentDriver.Devices[key] = existingDevice + + changedDevices = append(changedDevices, existingDevice) + } + } + } + + // Phase 3: Persist changes to the checkpoint file if any state changed. + if len(changedDevices) > 0 { + if err := cache.saveToCheckpointInternal(); err != nil { + klog.Background().Error(err, "Failed to save health checkpoint after update. Kubelet restart may lose the device health information.") + } + } + return nil + }) + + if err != nil { + return nil, err + } + return changedDevices, nil +} + +// clearDriver clears all health data for a specific driver. +func (cache *healthInfoCache) clearDriver(driverName string) error { + return cache.withLock(func() error { + delete(*cache.HealthInfo, driverName) + return cache.saveToCheckpointInternal() + }) +} diff --git a/pkg/kubelet/cm/dra/healthinfo_test.go b/pkg/kubelet/cm/dra/healthinfo_test.go new file mode 100644 index 00000000000..3100b642d1b --- /dev/null +++ b/pkg/kubelet/cm/dra/healthinfo_test.go @@ -0,0 +1,453 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package dra + +import ( + "errors" + "os" + "path" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "k8s.io/kubernetes/pkg/kubelet/cm/dra/state" +) + +const ( + testDriver = "test-driver" + testPool = "test-pool" + testDevice = "test-device" + testNamespace = "test-namespace" + testClaim = "test-claim" +) + +var ( + testDeviceHealth = state.DeviceHealth{ + PoolName: testPool, + DeviceName: testDevice, + Health: state.DeviceHealthStatusHealthy, + } +) + +// `TestNewHealthInfoCache tests cache creation and checkpoint loading. +func TestNewHealthInfoCache(t *testing.T) { + tests := []struct { + description string + stateFile string + wantErr bool + }{ + { + description: "successfully created cache", + stateFile: path.Join(t.TempDir(), "health_checkpoint"), + }, + { + description: "empty state file", + stateFile: "", + }, + } + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + if test.stateFile != "" { + f, err := os.Create(test.stateFile) + require.NoError(t, err) + require.NoError(t, f.Close()) + } + cache, err := newHealthInfoCache(test.stateFile) + if test.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.NotNil(t, cache) + if test.stateFile != "" { + require.NoError(t, os.Remove(test.stateFile)) + } + }) + } +} + +// Helper function to compare DeviceHealth slices ignoring LastUpdated time +func assertDeviceHealthElementsMatchIgnoreTime(t *testing.T, expected, actual []state.DeviceHealth) { + require.Len(t, actual, len(expected), "Number of changed devices mismatch") + + // Create comparable versions without LastUpdated + normalize := func(dh state.DeviceHealth) state.DeviceHealth { + // Zero out time for comparison + dh.LastUpdated = time.Time{} + return dh + } + + expectedNormalized := make([]state.DeviceHealth, len(expected)) + actualNormalized := make([]state.DeviceHealth, len(actual)) + + for i := range expected { + expectedNormalized[i] = normalize(expected[i]) + } + for i := range actual { + actualNormalized[i] = normalize(actual[i]) + } + + assert.ElementsMatch(t, expectedNormalized, actualNormalized, "Changed device elements mismatch (ignoring time)") +} + +// TestWithLock tests the withLock method’s behavior. +func TestWithLock(t *testing.T) { + cache, err := newHealthInfoCache("") + require.NoError(t, err) + tests := []struct { + description string + f func() error + wantErr bool + }{ + { + description: "lock prevents concurrent lock", + f: func() error { + if cache.TryLock() { + defer cache.Unlock() + return errors.New("Lock succeeded") + } + return nil + }, + }, + { + description: "erroring function", + f: func() error { + return errors.New("test error") + }, + wantErr: true, + }, + } + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + err := cache.withLock(test.f) + if test.wantErr { + assert.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +// TestWithRLock tests the withRLock method’s behavior. +func TestWithRLock(t *testing.T) { + cache, err := newHealthInfoCache("") + require.NoError(t, err) + tests := []struct { + description string + f func() error + wantErr bool + }{ + { + description: "rlock allows concurrent rlock", + f: func() error { + if !cache.TryRLock() { + return errors.New("Concurrent RLock failed") + } + defer cache.RUnlock() + return nil + }, + wantErr: false, + }, + { + description: "rlock prevents lock", + f: func() error { + if cache.TryLock() { + defer cache.Unlock() + return errors.New("Write Lock succeeded: Bad") + } + return nil + }, + wantErr: false, + }, + } + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + err := cache.withRLock(test.f) + if test.wantErr { + assert.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +// TestGetHealthInfo tests retrieving health status. +func TestGetHealthInfo(t *testing.T) { + cache, err := newHealthInfoCache("") + require.NoError(t, err) + + // Initial state + assert.Equal(t, state.DeviceHealthStatusUnknown, cache.getHealthInfo(testDriver, testPool, testDevice)) + + // Add a device + _, err = cache.updateHealthInfo(testDriver, []state.DeviceHealth{testDeviceHealth}) + require.NoError(t, err) + assert.Equal(t, state.DeviceHealthStatusHealthy, cache.getHealthInfo(testDriver, testPool, testDevice)) + + // Test timeout (simulated with old LastUpdated) + err = cache.withLock(func() error { + driverState := (*cache.HealthInfo)[testDriver] + deviceKey := testPool + "/" + testDevice + device := driverState.Devices[deviceKey] + device.LastUpdated = time.Now().Add((-healthTimeout) - time.Second) + driverState.Devices[deviceKey] = device + (*cache.HealthInfo)[testDriver] = driverState + return nil + }) + require.NoError(t, err) + assert.Equal(t, state.DeviceHealthStatusUnknown, cache.getHealthInfo(testDriver, testPool, testDevice)) +} + +// TestGetHealthInfoRobust tests retrieving health status logic solely & against many cases. +func TestGetHealthInfoRobust(t *testing.T) { + tests := []struct { + name string + initialState *state.DevicesHealthMap + driverName string + poolName string + deviceName string + expectedHealth state.DeviceHealthStatus + }{ + { + name: "empty cache", + initialState: &state.DevicesHealthMap{}, + driverName: testDriver, + poolName: testPool, + deviceName: testDevice, + expectedHealth: state.DeviceHealthStatusUnknown, + }, + { + name: "device exists and is healthy", + initialState: &state.DevicesHealthMap{ + testDriver: {Devices: map[string]state.DeviceHealth{ + testPool + "/" + testDevice: {PoolName: testPool, DeviceName: testDevice, Health: state.DeviceHealthStatusHealthy, LastUpdated: time.Now()}, + }}, + }, + driverName: testDriver, + poolName: testPool, + deviceName: testDevice, + expectedHealth: "Healthy", + }, + { + name: "device exists and is unhealthy", + initialState: &state.DevicesHealthMap{ + testDriver: {Devices: map[string]state.DeviceHealth{ + testPool + "/" + testDevice: {PoolName: testPool, DeviceName: testDevice, Health: state.DeviceHealthStatusUnhealthy, LastUpdated: time.Now()}, + }}, + }, + driverName: testDriver, + poolName: testPool, + deviceName: testDevice, + expectedHealth: "Unhealthy", + }, + { + name: "device exists but timed out", + initialState: &state.DevicesHealthMap{ + testDriver: {Devices: map[string]state.DeviceHealth{ + testPool + "/" + testDevice: {PoolName: testPool, DeviceName: testDevice, Health: state.DeviceHealthStatusHealthy, LastUpdated: time.Now().Add((-1 * healthTimeout) - time.Second)}, + }}, + }, + driverName: testDriver, + poolName: testPool, + deviceName: testDevice, + expectedHealth: "Unknown", + }, + { + name: "device exists, just within timeout", + initialState: &state.DevicesHealthMap{ + testDriver: {Devices: map[string]state.DeviceHealth{ + testPool + "/" + testDevice: {PoolName: testPool, DeviceName: testDevice, Health: state.DeviceHealthStatusHealthy, LastUpdated: time.Now().Add((-1 * healthTimeout) + time.Second)}, + }}, + }, + driverName: testDriver, + poolName: testPool, + deviceName: testDevice, + expectedHealth: "Healthy", + }, + { + name: "device does not exist, just outside of timeout", + initialState: &state.DevicesHealthMap{ + testDriver: {Devices: map[string]state.DeviceHealth{ + testPool + "/" + testDevice: {PoolName: testPool, DeviceName: testDevice, Health: state.DeviceHealthStatusHealthy, LastUpdated: time.Now().Add((-1 * healthTimeout) - time.Second)}, + }}, + }, + driverName: testDriver, + poolName: testPool, + deviceName: "device2", + expectedHealth: "Unknown", + }, + { + name: "device does not exist", + initialState: &state.DevicesHealthMap{ + testDriver: {Devices: map[string]state.DeviceHealth{ + testPool + "/" + testDevice: {PoolName: testPool, DeviceName: testDevice, Health: state.DeviceHealthStatusHealthy, LastUpdated: time.Now()}, + }}, + }, + driverName: testDriver, + poolName: testPool, + deviceName: "device2", + expectedHealth: "Unknown", + }, + { + name: "driver does not exist", + initialState: &state.DevicesHealthMap{ + testDriver: {Devices: map[string]state.DeviceHealth{ + testPool + "/" + testDevice: {PoolName: testPool, DeviceName: testDevice, Health: state.DeviceHealthStatusHealthy, LastUpdated: time.Now()}, + }}, + }, + driverName: "driver2", + poolName: testPool, + deviceName: testDevice, + expectedHealth: "Unknown", + }, + { + name: "pool does not exist", + initialState: &state.DevicesHealthMap{ + testDriver: {Devices: map[string]state.DeviceHealth{ + testPool + "/" + testDevice: {PoolName: testPool, DeviceName: testDevice, Health: state.DeviceHealthStatusHealthy, LastUpdated: time.Now()}, + }}, + }, + driverName: testDriver, + poolName: "pool2", + deviceName: testDevice, + expectedHealth: "Unknown", + }, + { + name: "multiple devices", + initialState: &state.DevicesHealthMap{ + testDriver: {Devices: map[string]state.DeviceHealth{ + testPool + "/" + testDevice: {PoolName: testPool, DeviceName: testDevice, Health: state.DeviceHealthStatusHealthy, LastUpdated: time.Now()}, + testPool + "/device2": {PoolName: testPool, DeviceName: "device2", Health: state.DeviceHealthStatusUnhealthy, LastUpdated: time.Now()}, + }}, + }, + driverName: testDriver, + poolName: testPool, + deviceName: "device2", + expectedHealth: "Unhealthy", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cache := &healthInfoCache{HealthInfo: tt.initialState} + health := cache.getHealthInfo(tt.driverName, tt.poolName, tt.deviceName) + assert.Equal(t, tt.expectedHealth, health) + }) + } +} + +// TestUpdateHealthInfo tests adding, updating, and reconciling device health. +func TestUpdateHealthInfo(t *testing.T) { + tmpFile := path.Join(t.TempDir(), "health_checkpoint_test") + cache, err := newHealthInfoCache(tmpFile) + require.NoError(t, err) + + // 1 -- Add new device + deviceToAdd := testDeviceHealth + expectedChanged1 := []state.DeviceHealth{deviceToAdd} + changedDevices, err := cache.updateHealthInfo(testDriver, []state.DeviceHealth{testDeviceHealth}) + require.NoError(t, err) + assertDeviceHealthElementsMatchIgnoreTime(t, expectedChanged1, changedDevices) + assert.Equal(t, state.DeviceHealthStatusHealthy, cache.getHealthInfo(testDriver, testPool, testDevice)) + + // 2 -- Update with no change + changedDevices, err = cache.updateHealthInfo(testDriver, []state.DeviceHealth{testDeviceHealth}) + require.NoError(t, err) + assert.Empty(t, changedDevices, "Scenario 2: Changed devices list should be empty") + + // 3 -- Update with new health + newHealth := testDeviceHealth + newHealth.Health = state.DeviceHealthStatusUnhealthy + expectedChanged3 := []state.DeviceHealth{newHealth} + changedDevices, err = cache.updateHealthInfo(testDriver, []state.DeviceHealth{newHealth}) + require.NoError(t, err) + assertDeviceHealthElementsMatchIgnoreTime(t, expectedChanged3, changedDevices) + assert.Equal(t, state.DeviceHealthStatusUnhealthy, cache.getHealthInfo(testDriver, testPool, testDevice)) + + // 4 -- Add second device, omit first + secondDevice := state.DeviceHealth{PoolName: testPool, DeviceName: "device2", Health: state.DeviceHealthStatusHealthy} + // When the first device is omitted, it should be marked as "Unknown" after a timeout. + // For this test, we simulate the timeout by not reporting it. + firstDeviceAsUnknown := newHealth + firstDeviceAsUnknown.Health = state.DeviceHealthStatusUnknown + expectedChanged4 := []state.DeviceHealth{secondDevice, firstDeviceAsUnknown} + // Manually set the time of the first device to be outside the timeout window + err = cache.withLock(func() error { + deviceKey := testPool + "/" + testDevice + device := (*cache.HealthInfo)[testDriver].Devices[deviceKey] + device.LastUpdated = time.Now().Add(-healthTimeout * 2) + (*cache.HealthInfo)[testDriver].Devices[deviceKey] = device + return nil + }) + require.NoError(t, err) + + changedDevices, err = cache.updateHealthInfo(testDriver, []state.DeviceHealth{secondDevice}) + require.NoError(t, err) + assertDeviceHealthElementsMatchIgnoreTime(t, expectedChanged4, changedDevices) + assert.Equal(t, state.DeviceHealthStatusHealthy, cache.getHealthInfo(testDriver, testPool, "device2")) + assert.Equal(t, state.DeviceHealthStatusUnknown, cache.getHealthInfo(testDriver, testPool, testDevice)) + + // 5 -- Test persistence + cache2, err := newHealthInfoCache(tmpFile) + require.NoError(t, err) + assert.Equal(t, state.DeviceHealthStatusHealthy, cache2.getHealthInfo(testDriver, testPool, "device2")) + assert.Equal(t, state.DeviceHealthStatusUnknown, cache2.getHealthInfo(testDriver, testPool, testDevice)) + + // 6 -- Test how updateHealthInfo handles device timeouts + timeoutDevice := state.DeviceHealth{PoolName: testPool, DeviceName: "timeoutDevice", Health: "Unhealthy"} + _, err = cache.updateHealthInfo(testDriver, []state.DeviceHealth{timeoutDevice}) + require.NoError(t, err) + + // Manually manipulate the last updated time of timeoutDevice to seem like it surpassed healthtimeout. + err = cache.withLock(func() error { + driverState := (*cache.HealthInfo)[testDriver] + deviceKey := testPool + "/timeoutDevice" + device := driverState.Devices[deviceKey] + device.LastUpdated = time.Now().Add((-healthTimeout) - time.Second) + driverState.Devices[deviceKey] = device + (*cache.HealthInfo)[testDriver] = driverState + return nil + }) + require.NoError(t, err) + + expectedTimeoutDeviceUnknown := state.DeviceHealth{PoolName: testPool, DeviceName: "timeoutDevice", Health: state.DeviceHealthStatusUnknown} + expectedChanged6 := []state.DeviceHealth{expectedTimeoutDeviceUnknown} + changedDevices, err = cache.updateHealthInfo(testDriver, []state.DeviceHealth{}) + require.NoError(t, err) + assertDeviceHealthElementsMatchIgnoreTime(t, expectedChanged6, changedDevices) + + driverState := (*cache.HealthInfo)[testDriver] + device := driverState.Devices[testPool+"/timeoutDevice"] + assert.Equal(t, state.DeviceHealthStatusUnknown, device.Health, "Health status should be Unknown after timeout in updateHealthInfo") +} + +// TestClearDriver tests clearing a driver’s health data. +func TestClearDriver(t *testing.T) { + cache, err := newHealthInfoCache("") + require.NoError(t, err) + + _, err = cache.updateHealthInfo(testDriver, []state.DeviceHealth{testDeviceHealth}) + require.NoError(t, err) + assert.Equal(t, state.DeviceHealthStatusHealthy, cache.getHealthInfo(testDriver, testPool, testDevice)) + + err = cache.clearDriver(testDriver) + require.NoError(t, err) + assert.Equal(t, state.DeviceHealthStatusUnknown, cache.getHealthInfo(testDriver, testPool, testDevice)) +} diff --git a/pkg/kubelet/cm/dra/manager.go b/pkg/kubelet/cm/dra/manager.go index e7112b49783..3a57cc1d238 100644 --- a/pkg/kubelet/cm/dra/manager.go +++ b/pkg/kubelet/cm/dra/manager.go @@ -18,7 +18,10 @@ package dra import ( "context" + "errors" "fmt" + "io" + "path/filepath" "strconv" "time" @@ -32,9 +35,12 @@ import ( "k8s.io/component-base/metrics" "k8s.io/dynamic-resource-allocation/resourceclaim" "k8s.io/klog/v2" + + drahealthv1alpha1 "k8s.io/kubelet/pkg/apis/dra-health/v1alpha1" drapb "k8s.io/kubelet/pkg/apis/dra/v1" draplugin "k8s.io/kubernetes/pkg/kubelet/cm/dra/plugin" "k8s.io/kubernetes/pkg/kubelet/cm/dra/state" + "k8s.io/kubernetes/pkg/kubelet/cm/resourceupdates" "k8s.io/kubernetes/pkg/kubelet/config" kubecontainer "k8s.io/kubernetes/pkg/kubelet/container" kubeletmetrics "k8s.io/kubernetes/pkg/kubelet/metrics" @@ -91,6 +97,12 @@ type Manager struct { // KubeClient reference kubeClient clientset.Interface + + // healthInfoCache contains cached health info + healthInfoCache *healthInfoCache + + // update channel for resource updates + update chan resourceupdates.Update } // NewManager creates a new DRA manager. @@ -108,6 +120,11 @@ func NewManager(logger klog.Logger, kubeClient clientset.Interface, stateFileDir return nil, fmt.Errorf("create ResourceClaim cache: %w", err) } + healthInfoCache, err := newHealthInfoCache(filepath.Join(stateFileDirectory, "dra_health_state")) + if err != nil { + return nil, fmt.Errorf("failed to create healthInfo cache: %w", err) + } + // TODO: for now the reconcile period is not configurable. // We should consider making it configurable in the future. reconcilePeriod := defaultReconcilePeriod @@ -118,6 +135,8 @@ func NewManager(logger klog.Logger, kubeClient clientset.Interface, stateFileDir reconcilePeriod: reconcilePeriod, activePods: nil, sourcesReady: nil, + healthInfoCache: healthInfoCache, + update: make(chan resourceupdates.Update, 100), } return manager, nil @@ -145,7 +164,7 @@ func (m *Manager) Start(ctx context.Context, activePods ActivePodsFunc, getNode // initPluginManager can be used instead of Start to make the manager useable // for calls to prepare/unprepare. It exists primarily for testing purposes. func (m *Manager) initDRAPluginManager(ctx context.Context, getNode GetNodeFunc, wipingDelay time.Duration) { - m.draPlugins = draplugin.NewDRAPluginManager(ctx, m.kubeClient, getNode, wipingDelay) + m.draPlugins = draplugin.NewDRAPluginManager(ctx, m.kubeClient, getNode, m, wipingDelay) } // reconcileLoop ensures that any stale state in the manager's claimInfoCache gets periodically reconciled. @@ -595,6 +614,12 @@ func (m *Manager) unprepareResources(ctx context.Context, podUID types.UID, name // Atomically perform some operations on the claimInfo cache. err := m.cache.withLock(func() error { + // TODO(#132978): Re-evaluate this logic to support post-mortem health updates. + // As of the initial implementation, we immediately delete the claim info upon + // unprepare. This means a late-arriving health update for a terminated pod + // will be missed. A future enhancement could be to "tombstone" this entry for + // a grace period instead of deleting it. + // Delete all claimInfos from the cache that have just been unprepared. for _, claimName := range claimNamesMap { claimInfo, _ := m.cache.get(claimName, namespace) @@ -663,3 +688,215 @@ func (m *Manager) GetContainerClaimInfos(pod *v1.Pod, container *v1.Container) ( } return claimInfos, nil } + +// UpdateAllocatedResourcesStatus updates the health status of allocated DRA resources in the pod's container statuses. +func (m *Manager) UpdateAllocatedResourcesStatus(pod *v1.Pod, status *v1.PodStatus) { + logger := klog.FromContext(context.Background()) + for _, container := range pod.Spec.Containers { + // Get all the DRA claim details associated with this specific container. + claimInfos, err := m.GetContainerClaimInfos(pod, &container) + if err != nil { + logger.Error(err, "Failed to get claim infos for container", "pod", klog.KObj(pod), "container", container.Name) + continue + } + + // Find the corresponding container status + for i, containerStatus := range status.ContainerStatuses { + if containerStatus.Name != container.Name { + continue + } + + // Ensure the slice exists. Use a map for efficient updates by resource name. + resourceStatusMap := make(map[v1.ResourceName]*v1.ResourceStatus) + if status.ContainerStatuses[i].AllocatedResourcesStatus != nil { + for idx := range status.ContainerStatuses[i].AllocatedResourcesStatus { + // Store pointers to modify in place + resourceStatusMap[status.ContainerStatuses[i].AllocatedResourcesStatus[idx].Name] = &status.ContainerStatuses[i].AllocatedResourcesStatus[idx] + } + } else { + status.ContainerStatuses[i].AllocatedResourcesStatus = []v1.ResourceStatus{} + } + + // Loop through each claim associated with the container + for _, claimInfo := range claimInfos { + var resourceName v1.ResourceName + foundClaimInSpec := false + for _, cClaim := range container.Resources.Claims { + if cClaim.Name == claimInfo.ClaimName { + if cClaim.Request == "" { + resourceName = v1.ResourceName(fmt.Sprintf("claim:%s", cClaim.Name)) + } else { + resourceName = v1.ResourceName(fmt.Sprintf("claim:%s/%s", cClaim.Name, cClaim.Request)) + } + foundClaimInSpec = true + break + } + } + if !foundClaimInSpec { + logger.V(4).Info("Could not find matching resource claim in container spec", "pod", klog.KObj(pod), "container", container.Name, "claimName", claimInfo.ClaimName) + continue + } + + // Get or create the ResourceStatus entry for this claim + resStatus, ok := resourceStatusMap[resourceName] + + if !ok { + // Create a new entry and add it to the map and the slice + newStatus := v1.ResourceStatus{ + Name: resourceName, + Resources: []v1.ResourceHealth{}, + } + status.ContainerStatuses[i].AllocatedResourcesStatus = append(status.ContainerStatuses[i].AllocatedResourcesStatus, newStatus) + // Get pointer to the newly added element *after* appending + resStatus = &status.ContainerStatuses[i].AllocatedResourcesStatus[len(status.ContainerStatuses[i].AllocatedResourcesStatus)-1] + resourceStatusMap[resourceName] = resStatus + } + + // Clear previous health entries for this resource before adding current ones + // Ensures we only report current health for allocated devices. + resStatus.Resources = []v1.ResourceHealth{} + + // Iterate through the map holding the state specific to each driver + for driverName, driverState := range claimInfo.DriverState { + // Iterate through each specific device allocated by this driver + for _, device := range driverState.Devices { + + healthStr := m.healthInfoCache.getHealthInfo(driverName, device.PoolName, device.DeviceName) + + // Convert internal health string to API type + var health v1.ResourceHealthStatus + switch healthStr { + case "Healthy": + health = v1.ResourceHealthStatusHealthy + case "Unhealthy": + health = v1.ResourceHealthStatusUnhealthy + default: // Catches "Unknown" or any other case + health = v1.ResourceHealthStatusUnknown + } + + // Create the ResourceHealth entry + resourceHealth := v1.ResourceHealth{ + Health: health, + } + + // Use first CDI device ID as ResourceID, with fallback + if len(device.CDIDeviceIDs) > 0 { + resourceHealth.ResourceID = v1.ResourceID(device.CDIDeviceIDs[0]) + } else { + // Fallback ID if no CDI ID is present + resourceHealth.ResourceID = v1.ResourceID(fmt.Sprintf("%s/%s/%s", driverName, device.PoolName, device.DeviceName)) + } + + // Append the health status for this specific device/resource ID + resStatus.Resources = append(resStatus.Resources, resourceHealth) + } + } + } + // Rebuild the slice from the map values to ensure correctness + finalStatuses := make([]v1.ResourceStatus, 0, len(resourceStatusMap)) + for _, rs := range resourceStatusMap { + // Only add if it actually has resource health entries populated + if len(rs.Resources) > 0 { + finalStatuses = append(finalStatuses, *rs) + } + } + status.ContainerStatuses[i].AllocatedResourcesStatus = finalStatuses + } + } +} + +// HandleWatchResourcesStream processes health updates from the DRA plugin. +func (m *Manager) HandleWatchResourcesStream(ctx context.Context, stream drahealthv1alpha1.DRAResourceHealth_NodeWatchResourcesClient, pluginName string) error { + logger := klog.FromContext(ctx) + + defer func() { + logger.V(4).Info("Clearing health cache for driver upon stream exit", "pluginName", pluginName) + // Use a separate context for clearDriver if needed, though background should be fine. + if err := m.healthInfoCache.clearDriver(pluginName); err != nil { + logger.Error(err, "Failed to clear health info cache for driver", "pluginName", pluginName) + } + }() + + for { + resp, err := stream.Recv() + if err != nil { + // Context canceled, normal shutdown. + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + logger.V(4).Info("Stopping health monitoring due to context cancellation", "pluginName", pluginName, "reason", err) + return err + } + // Stream closed cleanly by the server, get normal EOF. + if errors.Is(err, io.EOF) { + logger.V(4).Info("Stream ended with EOF", "pluginName", pluginName) + return nil + } + // Other errors are unexpected, log & return. + logger.Error(err, "Error receiving from WatchResources stream", "pluginName", pluginName) + return err + } + + // Convert drahealthv1alpha1.DeviceHealth to state.DeviceHealth + devices := make([]state.DeviceHealth, len(resp.GetDevices())) + for i, d := range resp.GetDevices() { + var health state.DeviceHealthStatus + switch d.GetHealth() { + case drahealthv1alpha1.HealthStatus_HEALTHY: + health = state.DeviceHealthStatusHealthy + case drahealthv1alpha1.HealthStatus_UNHEALTHY: + health = state.DeviceHealthStatusUnhealthy + default: + health = state.DeviceHealthStatusUnknown + } + devices[i] = state.DeviceHealth{ + PoolName: d.GetDevice().GetPoolName(), + DeviceName: d.GetDevice().GetDeviceName(), + Health: health, + LastUpdated: time.Unix(d.GetLastUpdatedTime(), 0), + } + } + + changedDevices, updateErr := m.healthInfoCache.updateHealthInfo(pluginName, devices) + if updateErr != nil { + logger.Error(updateErr, "Failed to update health info cache", "pluginName", pluginName) + } + if len(changedDevices) > 0 { + logger.V(4).Info("Health info changed, checking affected pods", "pluginName", pluginName, "changedDevicesCount", len(changedDevices)) + + podsToUpdate := sets.New[string]() + + m.cache.RLock() + for _, dev := range changedDevices { + for _, cInfo := range m.cache.claimInfo { + if driverState, ok := cInfo.DriverState[pluginName]; ok { + for _, allocatedDevice := range driverState.Devices { + if allocatedDevice.PoolName == dev.PoolName && allocatedDevice.DeviceName == dev.DeviceName { + podsToUpdate.Insert(cInfo.PodUIDs.UnsortedList()...) + break + } + } + } + } + } + m.cache.RUnlock() + + if podsToUpdate.Len() > 0 { + podUIDs := podsToUpdate.UnsortedList() + logger.Info("Sending health update notification for pods", "pluginName", pluginName, "pods", podUIDs) + select { + case m.update <- resourceupdates.Update{PodUIDs: podUIDs}: + default: + logger.Error(nil, "DRA health update channel is full, discarding pod update notification", "pluginName", pluginName, "pods", podUIDs) + } + } else { + logger.V(4).Info("Health info changed, but no active pods found using the affected devices", "pluginName", pluginName) + } + } + + } +} + +// Updates returns the channel that provides resource updates. +func (m *Manager) Updates() <-chan resourceupdates.Update { + // Return the internal channel that HandleWatchResourcesStream writes to. + return m.update +} diff --git a/pkg/kubelet/cm/dra/manager_test.go b/pkg/kubelet/cm/dra/manager_test.go index bb8efacfb26..c1a60a9d5d4 100644 --- a/pkg/kubelet/cm/dra/manager_test.go +++ b/pkg/kubelet/cm/dra/manager_test.go @@ -18,7 +18,9 @@ package dra import ( "context" + "errors" "fmt" + "io" "net" "os" "path/filepath" @@ -28,8 +30,10 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "google.golang.org/grpc" + "google.golang.org/grpc/metadata" v1 "k8s.io/api/core/v1" resourceapi "k8s.io/api/resource/v1" @@ -39,8 +43,11 @@ import ( "k8s.io/client-go/kubernetes/fake" "k8s.io/dynamic-resource-allocation/resourceclaim" "k8s.io/klog/v2" + + drahealthv1alpha1 "k8s.io/kubelet/pkg/apis/dra-health/v1alpha1" drapb "k8s.io/kubelet/pkg/apis/dra/v1beta1" "k8s.io/kubernetes/pkg/kubelet/cm/dra/state" + "k8s.io/kubernetes/pkg/kubelet/cm/resourceupdates" "k8s.io/kubernetes/test/utils/ktesting" ) @@ -52,12 +59,16 @@ const ( type fakeDRADriverGRPCServer struct { drapb.UnimplementedDRAPluginServer + drahealthv1alpha1.UnimplementedDRAResourceHealthServer driverName string timeout *time.Duration prepareResourceCalls atomic.Uint32 unprepareResourceCalls atomic.Uint32 + watchResourcesCalls atomic.Uint32 prepareResourcesResponse *drapb.NodePrepareResourcesResponse unprepareResourcesResponse *drapb.NodeUnprepareResourcesResponse + watchResourcesResponses chan *drahealthv1alpha1.NodeWatchResourcesResponse + watchResourcesError error } func (s *fakeDRADriverGRPCServer) NodePrepareResources(ctx context.Context, req *drapb.NodePrepareResourcesRequest) (*drapb.NodePrepareResourcesResponse, error) { @@ -107,6 +118,79 @@ func (s *fakeDRADriverGRPCServer) NodeUnprepareResources(ctx context.Context, re return s.unprepareResourcesResponse, nil } +func (s *fakeDRADriverGRPCServer) NodeWatchResources(req *drahealthv1alpha1.NodeWatchResourcesRequest, stream drahealthv1alpha1.DRAResourceHealth_NodeWatchResourcesServer) error { + s.watchResourcesCalls.Add(1) + logger := klog.FromContext(stream.Context()) + logger.V(4).Info("Fake Server: WatchResources stream started") + + if s.watchResourcesError != nil { + logger.Error(s.watchResourcesError, "Fake Server: Returning predefined stream error") + return s.watchResourcesError + } + + go func() { + for { + select { + case <-stream.Context().Done(): + logger.Info("Fake Server: WatchResources stream context canceled") + return + case resp, ok := <-s.watchResourcesResponses: + if !ok { + logger.Info("Fake Server: WatchResources response channel closed") + return + } + logger.V(5).Info("Fake Server: Sending health response", "response", resp) + // Use the stream argument to send + if err := stream.Send(resp); err != nil { + logger.Error(err, "Fake Server: Error sending response on stream") + return + } + } + } + }() + + logger.V(4).Info("Fake Server: WatchResources RPC call returning control to client.") + return nil +} + +type mockWatchResourcesClient struct { + mock.Mock + RecvChan chan struct { + Resp *drahealthv1alpha1.NodeWatchResourcesResponse + Err error + } + Ctx context.Context +} + +func (m *mockWatchResourcesClient) Recv() (*drahealthv1alpha1.NodeWatchResourcesResponse, error) { + logger := klog.FromContext(m.Ctx) + select { + case <-m.Ctx.Done(): + logger.V(6).Info("mockWatchClient.Recv: Context done", "err", m.Ctx.Err()) + return nil, m.Ctx.Err() + case item, ok := <-m.RecvChan: + if !ok { + logger.V(6).Info("mockWatchClient.Recv: RecvChan closed, returning io.EOF") + return nil, io.EOF + } + return item.Resp, item.Err + } +} + +func (m *mockWatchResourcesClient) Context() context.Context { + return m.Ctx +} + +func (m *mockWatchResourcesClient) Header() (metadata.MD, error) { return nil, nil } +func (m *mockWatchResourcesClient) Trailer() metadata.MD { return nil } +func (m *mockWatchResourcesClient) CloseSend() error { return nil } +func (m *mockWatchResourcesClient) RecvMsg(v interface{}) error { + return fmt.Errorf("RecvMsg not implemented") +} +func (m *mockWatchResourcesClient) SendMsg(v interface{}) error { + return fmt.Errorf("SendMsg not implemented") +} + type tearDown func() type fakeDRAServerInfo struct { @@ -118,7 +202,7 @@ type fakeDRAServerInfo struct { teardownFn tearDown } -func setupFakeDRADriverGRPCServer(ctx context.Context, shouldTimeout bool, pluginClientTimeout *time.Duration, prepareResourcesResponse *drapb.NodePrepareResourcesResponse, unprepareResourcesResponse *drapb.NodeUnprepareResourcesResponse) (fakeDRAServerInfo, error) { +func setupFakeDRADriverGRPCServer(ctx context.Context, shouldTimeout bool, pluginClientTimeout *time.Duration, prepareResourcesResponse *drapb.NodePrepareResourcesResponse, unprepareResourcesResponse *drapb.NodeUnprepareResourcesResponse, watchResourcesError error) (fakeDRAServerInfo, error) { socketDir, err := os.MkdirTemp("", "dra") if err != nil { return fakeDRAServerInfo{ @@ -154,24 +238,32 @@ func setupFakeDRADriverGRPCServer(ctx context.Context, shouldTimeout bool, plugi driverName: driverName, prepareResourcesResponse: prepareResourcesResponse, unprepareResourcesResponse: unprepareResourcesResponse, + watchResourcesResponses: make(chan *drahealthv1alpha1.NodeWatchResourcesResponse, 10), + watchResourcesError: watchResourcesError, } if shouldTimeout { timeout := *pluginClientTimeout * 2 fakeDRADriverGRPCServer.timeout = &timeout } + drahealthv1alpha1.RegisterDRAResourceHealthServer(s, fakeDRADriverGRPCServer) drapb.RegisterDRAPluginServer(s, fakeDRADriverGRPCServer) - go func(ctx context.Context) { + go func() { go func() { - if err := s.Serve(l); err != nil { - logger := klog.FromContext(ctx) + logger := klog.FromContext(ctx) + logger.V(4).Info("Starting fake gRPC server", "address", socketName) + if err := s.Serve(l); err != nil && !errors.Is(err, grpc.ErrServerStopped) { logger.Error(err, "failed to serve gRPC") } + logger.V(4).Info("Fake gRPC server stopped serving", "address", socketName) }() <-stopCh + logger := klog.FromContext(ctx) + logger.V(4).Info("Stopping fake gRPC server", "address", socketName) s.GracefulStop() - }(ctx) + logger.V(4).Info("Fake gRPC server stopped", "address", socketName) + }() return fakeDRAServerInfo{ server: fakeDRADriverGRPCServer, @@ -243,7 +335,7 @@ func genTestPod() *v1.Pod { } } -// getTestClaim generates resource claim object +// genTestClaim generates resource claim object func genTestClaim(name, driver, device, podUID string) *resourceapi.ResourceClaim { return &resourceapi.ResourceClaim{ ObjectMeta: metav1.ObjectMeta{ @@ -416,8 +508,8 @@ func TestPrepareResources(t *testing.T) { { description: "unknown driver", pod: genTestPod(), - claim: genTestClaim(claimName, "unknown.driver", deviceName, podUID), - expectedErrMsg: "DRA driver unknown.driver is not registered", + claim: genTestClaim(claimName, "unknown driver", deviceName, podUID), + expectedErrMsg: "prepare dynamic resources: DRA driver unknown driver is not registered", }, { description: "should prepare resources, driver returns nil value", @@ -560,11 +652,15 @@ func TestPrepareResources(t *testing.T) { }, } { t.Run(test.description, func(t *testing.T) { + backgroundCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + tCtx := ktesting.Init(t) + backgroundCtx = klog.NewContext(backgroundCtx, tCtx.Logger()) manager, err := NewManager(tCtx.Logger(), fakeKubeClient, t.TempDir()) require.NoError(t, err, "create DRA manager") - manager.initDRAPluginManager(tCtx, getFakeNode, time.Second /* very short wiping delay for testing */) + manager.initDRAPluginManager(backgroundCtx, getFakeNode, time.Second /* very short wiping delay for testing */) if test.claim != nil { if _, err := fakeKubeClient.ResourceV1().ResourceClaims(test.pod.Namespace).Create(tCtx, test.claim, metav1.CreateOptions{}); err != nil { @@ -581,7 +677,7 @@ func TestPrepareResources(t *testing.T) { pluginClientTimeout = &timeout } - draServerInfo, err := setupFakeDRADriverGRPCServer(tCtx, test.wantTimeout, pluginClientTimeout, test.resp, nil) + draServerInfo, err := setupFakeDRADriverGRPCServer(backgroundCtx, test.wantTimeout, pluginClientTimeout, test.resp, nil, nil) if err != nil { t.Fatal(err) } @@ -595,7 +691,7 @@ func TestPrepareResources(t *testing.T) { manager.cache.add(test.claimInfo) } - err = manager.PrepareResources(tCtx, test.pod) + err = manager.PrepareResources(backgroundCtx, test.pod) assert.Equal(t, test.expectedPrepareCalls, draServerInfo.server.prepareResourceCalls.Load()) @@ -614,19 +710,16 @@ func TestPrepareResources(t *testing.T) { } // check the cache contains the expected claim info - claimName, _, err := resourceclaim.Name(test.pod, &test.pod.Spec.ResourceClaims[0]) - if err != nil { - t.Fatal(err) - } - claimInfo, ok := manager.cache.get(*claimName, test.pod.Namespace) - if !ok { - t.Fatalf("claimInfo not found in cache for claim %s", *claimName) - } - if len(claimInfo.PodUIDs) != 1 || !claimInfo.PodUIDs.Has(string(test.pod.UID)) { - t.Fatalf("podUIDs mismatch: expected [%s], got %v", test.pod.UID, claimInfo.PodUIDs) - } - - assert.Equal(t, test.expectedClaimInfoState, claimInfo.ClaimInfoState) + podClaimName, _, err := resourceclaim.Name(test.pod, &test.pod.Spec.ResourceClaims[0]) + require.NoError(t, err) + claimInfoResult, ok := manager.cache.get(*podClaimName, test.pod.Namespace) + require.True(t, ok, "claimInfo not found in cache") + require.True(t, claimInfoResult.PodUIDs.Has(string(test.pod.UID)), "podUIDs mismatch") + assert.Equal(t, test.expectedClaimInfoState.ClaimUID, claimInfoResult.ClaimUID) + assert.Equal(t, test.expectedClaimInfoState.ClaimName, claimInfoResult.ClaimName) + assert.Equal(t, test.expectedClaimInfoState.Namespace, claimInfoResult.Namespace) + assert.Equal(t, test.expectedClaimInfoState.DriverState, claimInfoResult.DriverState) + assert.True(t, claimInfoResult.prepared, "ClaimInfo should be marked as prepared") }) } } @@ -647,11 +740,30 @@ func TestUnprepareResources(t *testing.T) { expectedErrMsg string }{ { - description: "unknown driver", - pod: genTestPod(), - claim: genTestClaim(claimName, "unknown driver", deviceName, podUID), - claimInfo: genTestClaimInfo(claimUID, []string{podUID}, true), - expectedErrMsg: "DRA driver test-driver is not registered", + description: "unknown driver", + driverName: driverName, + pod: genTestPod(), + claimInfo: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + ClaimUID: claimUID, + ClaimName: claimName, + Namespace: namespace, + PodUIDs: sets.New[string](string(podUID)), + DriverState: map[string]state.DriverState{ + "unknown-driver": { + Devices: []state.Device{{ + PoolName: poolName, + DeviceName: deviceName, + RequestNames: []string{requestName}, + CDIDeviceIDs: []string{"random-cdi-id"}, + }}, + }, + }, + }, + prepared: true, + }, + expectedErrMsg: "unprepare dynamic resources: DRA driver unknown-driver is not registered", + expectedUnprepareCalls: 0, }, { description: "resource claim referenced by other pod(s)", @@ -678,14 +790,6 @@ func TestUnprepareResources(t *testing.T) { expectedUnprepareCalls: 1, expectedErrMsg: "NodeUnprepareResources skipped 1 ResourceClaims", }, - { - description: "should unprepare resource", - driverName: driverName, - pod: genTestPod(), - claim: genTestClaim(claimName, driverName, deviceName, podUID), - claimInfo: genTestClaimInfo(claimUID, []string{podUID}, false), - expectedUnprepareCalls: 1, - }, { description: "should unprepare already prepared resource", driverName: driverName, @@ -704,7 +808,11 @@ func TestUnprepareResources(t *testing.T) { }, } { t.Run(test.description, func(t *testing.T) { + backgroundCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + tCtx := ktesting.Init(t) + backgroundCtx = klog.NewContext(backgroundCtx, tCtx.Logger()) var pluginClientTimeout *time.Duration if test.wantTimeout { @@ -712,7 +820,7 @@ func TestUnprepareResources(t *testing.T) { pluginClientTimeout = &timeout } - draServerInfo, err := setupFakeDRADriverGRPCServer(tCtx, test.wantTimeout, pluginClientTimeout, nil, test.resp) + draServerInfo, err := setupFakeDRADriverGRPCServer(backgroundCtx, test.wantTimeout, pluginClientTimeout, nil, test.resp, nil) if err != nil { t.Fatal(err) } @@ -720,7 +828,7 @@ func TestUnprepareResources(t *testing.T) { manager, err := NewManager(tCtx.Logger(), fakeKubeClient, t.TempDir()) require.NoError(t, err, "create DRA manager") - manager.initDRAPluginManager(tCtx, getFakeNode, time.Second /* very short wiping delay for testing */) + manager.initDRAPluginManager(backgroundCtx, getFakeNode, time.Second /* very short wiping delay for testing */) plg := manager.GetWatcherHandler() if err := plg.RegisterPlugin(test.driverName, draServerInfo.socketName, []string{drapb.DRAPluginService}, pluginClientTimeout); err != nil { @@ -731,7 +839,7 @@ func TestUnprepareResources(t *testing.T) { manager.cache.add(test.claimInfo) } - err = manager.UnprepareResources(tCtx, test.pod) + err = manager.UnprepareResources(backgroundCtx, test.pod) assert.Equal(t, test.expectedUnprepareCalls, draServerInfo.server.unprepareResourceCalls.Load()) @@ -746,17 +854,18 @@ func TestUnprepareResources(t *testing.T) { require.NoError(t, err) if test.wantResourceSkipped { + if test.claimInfo != nil && len(test.claimInfo.PodUIDs) > 1 { + cachedClaim, exists := manager.cache.get(test.claimInfo.ClaimName, test.claimInfo.Namespace) + require.True(t, exists, "ClaimInfo should still exist if skipped") + assert.False(t, cachedClaim.PodUIDs.Has(string(test.pod.UID)), "Pod UID should be removed from skipped claim") + } return // resource skipped so no need to continue } - // Check that the cache has been updated correctly - claimName, _, err := resourceclaim.Name(test.pod, &test.pod.Spec.ResourceClaims[0]) - if err != nil { - t.Fatal(err) - } - if manager.cache.contains(*claimName, test.pod.Namespace) { - t.Fatalf("claimInfo still found in cache after calling UnprepareResources") - } + // Check cache was cleared only on successful unprepare + podClaimName, _, err := resourceclaim.Name(test.pod, &test.pod.Spec.ResourceClaims[0]) + require.NoError(t, err) + assert.False(t, manager.cache.contains(*podClaimName, test.pod.Namespace), "claimInfo should not be found after successful unprepare") }) } } @@ -869,7 +978,7 @@ func TestParallelPrepareUnprepareResources(t *testing.T) { tCtx := ktesting.Init(t) // Setup and register fake DRA driver - draServerInfo, err := setupFakeDRADriverGRPCServer(tCtx, false, nil, nil, nil) + draServerInfo, err := setupFakeDRADriverGRPCServer(tCtx, false, nil, nil, nil, nil) if err != nil { t.Fatal(err) } @@ -917,6 +1026,7 @@ func TestParallelPrepareUnprepareResources(t *testing.T) { }, Containers: []v1.Container{ { + Name: fmt.Sprintf("container-%d", goRoutineNum), Resources: v1.ResourceRequirements{ Claims: []v1.ResourceClaim{ { @@ -935,13 +1045,17 @@ func TestParallelPrepareUnprepareResources(t *testing.T) { return } + defer func() { + _ = fakeKubeClient.ResourceV1().ResourceClaims(pod.Namespace).Delete(tCtx, claim.Name, metav1.DeleteOptions{}) + }() + if err = manager.PrepareResources(tCtx, pod); err != nil { - t.Errorf("pod: %s: PrepareResources failed: %+v", pod.Name, err) + t.Errorf("GoRoutine %d: pod: %s: PrepareResources failed: %+v", goRoutineNum, pod.Name, err) return } if err = manager.UnprepareResources(tCtx, pod); err != nil { - t.Errorf("pod: %s: UnprepareResources failed: %+v", pod.Name, err) + t.Errorf("GoRoutine %d: pod: %s: UnprepareResources failed: %+v", goRoutineNum, pod.Name, err) return } @@ -950,3 +1064,361 @@ func TestParallelPrepareUnprepareResources(t *testing.T) { wgStart.Done() // Start executing goroutines wgSync.Wait() // Wait for all goroutines to finish } + +// TestHandleWatchResourcesStream verifies the manager's ability to process health updates +// received from a DRA plugin's WatchResources stream. It checks if the internal health cache +// is updated correctly, if affected pods are identified, and if update notifications are sent +// through the manager's update channel. It covers various scenarios including health changes, stream errors, and context cancellation. +func TestHandleWatchResourcesStream(t *testing.T) { + overallTestCtx, overallTestCancel := context.WithCancel(ktesting.Init(t)) + defer overallTestCancel() + + // Helper to create and setup a new manager for each sub-test + setupNewManagerAndRunStreamTest := func( + st *testing.T, + testSpecificCtx context.Context, + initialClaimInfos ...*ClaimInfo, + ) ( + managerInstance *Manager, + runTestStreamFunc func(context.Context, chan struct { + Resp *drahealthv1alpha1.NodeWatchResourcesResponse + Err error + }) (<-chan resourceupdates.Update, chan struct{}, chan error), + ) { + tCtx := ktesting.Init(t) + // Fresh manager for each sub-test + manager, err := NewManager(tCtx.Logger(), nil, st.TempDir()) + require.NoError(st, err) + + for _, ci := range initialClaimInfos { + manager.cache.add(ci) + } + + managerInstance = manager + + runTestStreamFunc = func( + streamCtx context.Context, + responses chan struct { + Resp *drahealthv1alpha1.NodeWatchResourcesResponse + Err error + }, + ) (<-chan resourceupdates.Update, chan struct{}, chan error) { + mockStream := &mockWatchResourcesClient{ + RecvChan: responses, + Ctx: streamCtx, + } + done := make(chan struct{}) + errChan := make(chan error, 1) + go func() { + defer close(done) + // Use a logger that includes sub-test name for clarity + logger := klog.FromContext(streamCtx).WithName(st.Name()) + hdlCtx := klog.NewContext(streamCtx, logger) + + err := managerInstance.HandleWatchResourcesStream(hdlCtx, mockStream, driverName) + if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || errors.Is(err, io.EOF) { + logger.V(4).Info("HandleWatchResourcesStream (test goroutine) exited as expected", "error", err) + } else { + // This is an application/stream error, not a standard exit. + // The sub-test ("StreamError") will assert this specific error. + logger.V(2).Info("HandleWatchResourcesStream (test goroutine) exited with application/stream error", "error", err) + } + } else { + logger.V(4).Info("HandleWatchResourcesStream (test goroutine) exited cleanly (nil error, likely from EOF)") + } + errChan <- err + close(errChan) + }() + return managerInstance.update, done, errChan + } + return managerInstance, runTestStreamFunc + } + + // Test Case 1: Health change for an allocated device + t.Run("HealthChangeForAllocatedDevice", func(t *testing.T) { + stCtx, stCancel := context.WithCancel(overallTestCtx) + defer stCancel() + + // Setup: Create a manager with a relevant claim already in its cache. + initialClaim := genTestClaimInfo(claimUID, []string{string(podUID)}, true) + manager, runStreamTest := setupNewManagerAndRunStreamTest(t, stCtx, initialClaim) + + t.Log("HealthChangeForAllocatedDevice: Test Case Started") + + responses := make(chan struct { + Resp *drahealthv1alpha1.NodeWatchResourcesResponse + Err error + }, 1) + updateChan, done, streamErrChan := runStreamTest(stCtx, responses) + + // Send the health update message + unhealthyDeviceMsg := &drahealthv1alpha1.DeviceHealth{ + Device: &drahealthv1alpha1.DeviceIdentifier{ + PoolName: poolName, + DeviceName: deviceName, + }, + Health: drahealthv1alpha1.HealthStatus_UNHEALTHY, + LastUpdatedTime: time.Now().Unix(), + } + t.Logf("HealthChangeForAllocatedDevice: Sending health update: %+v", unhealthyDeviceMsg) + responses <- struct { + Resp *drahealthv1alpha1.NodeWatchResourcesResponse + Err error + }{ + Resp: &drahealthv1alpha1.NodeWatchResourcesResponse{Devices: []*drahealthv1alpha1.DeviceHealth{unhealthyDeviceMsg}}, + } + + t.Log("HealthChangeForAllocatedDevice: Waiting for update on manager channel") + select { + case upd := <-updateChan: + t.Logf("HealthChangeForAllocatedDevice: Received update: %+v", upd) + assert.ElementsMatch(t, []string{string(podUID)}, upd.PodUIDs, "Expected pod UID in update") + case <-time.After(2 * time.Second): + t.Fatal("HealthChangeForAllocatedDevice: Timeout waiting for pod update on manager.update channel") + } + + // Check cache state + cachedHealth := manager.healthInfoCache.getHealthInfo(driverName, poolName, deviceName) + assert.Equal(t, state.DeviceHealthStatus("Unhealthy"), cachedHealth, "Cache update check failed") + + t.Log("HealthChangeForAllocatedDevice: Closing responses channel to signal EOF") + close(responses) + + t.Log("HealthChangeForAllocatedDevice: Waiting on done channel") + var finalErr error + select { + case <-done: + finalErr = <-streamErrChan + t.Log("HealthChangeForAllocatedDevice: done channel closed, stream goroutine finished.") + case <-time.After(1 * time.Second): + t.Fatal("HealthChangeForAllocatedDevice: Timed out waiting for HandleWatchResourcesStream to finish after EOF signal") + } + // Expect nil (if HandleWatchResourcesStream returns nil on EOF) or io.EOF + assert.True(t, finalErr == nil || errors.Is(finalErr, io.EOF), "Expected nil or io.EOF, got %v", finalErr) + }) + + // Test Case 2: Health change for a non-allocated device + t.Run("NonAllocatedDeviceChange", func(t *testing.T) { + stCtx, stCancel := context.WithCancel(overallTestCtx) + defer stCancel() + + // Setup: Manager with no specific claims, or claims that don't use "other-device" + manager, runStreamTest := setupNewManagerAndRunStreamTest(t, stCtx) + + t.Log("NonAllocatedDeviceChange: Test Case Started") + responses := make(chan struct { + Resp *drahealthv1alpha1.NodeWatchResourcesResponse + Err error + }, 1) + updateChan, done, streamErrChan := runStreamTest(stCtx, responses) + + otherDeviceMsg := &drahealthv1alpha1.DeviceHealth{ + Device: &drahealthv1alpha1.DeviceIdentifier{ + PoolName: poolName, + DeviceName: "other-device", + }, + Health: drahealthv1alpha1.HealthStatus_UNHEALTHY, + LastUpdatedTime: time.Now().Unix(), + } + responses <- struct { + Resp *drahealthv1alpha1.NodeWatchResourcesResponse + Err error + }{ + Resp: &drahealthv1alpha1.NodeWatchResourcesResponse{Devices: []*drahealthv1alpha1.DeviceHealth{otherDeviceMsg}}, + } + + select { + case upd := <-updateChan: + t.Fatalf("NonAllocatedDeviceChange: Unexpected update on manager.update channel: %+v", upd) + // OK, no update expected on manager.update for this device + case <-time.After(200 * time.Millisecond): + t.Log("NonAllocatedDeviceChange: Correctly received no update on manager channel.") + } + + // Check health cache for the "other-device" + cachedHealthOther := manager.healthInfoCache.getHealthInfo(driverName, poolName, "other-device") + assert.Equal(t, state.DeviceHealthStatus("Unhealthy"), cachedHealthOther, "Cache update for other-device failed") + + close(responses) + var finalErr error + select { + case <-done: + finalErr = <-streamErrChan + t.Log("NonAllocatedDeviceChange: Stream handler goroutine finished.") + case <-time.After(1 * time.Second): + t.Fatal("NonAllocatedDeviceChange: Timeout waiting for stream handler to finish after EOF") + } + assert.True(t, finalErr == nil || errors.Is(finalErr, io.EOF), "Expected nil or io.EOF, got %v", finalErr) + }) + + // Test Case 3: No actual health state change (idempotency) + t.Run("NoActualStateChange", func(t *testing.T) { + stCtx, stCancel := context.WithCancel(overallTestCtx) + defer stCancel() + + // Setup: Manager with a claim and the device already marked Unhealthy in health cache + initialClaim := genTestClaimInfo(claimUID, []string{string(podUID)}, true) + manager, runStreamTest := setupNewManagerAndRunStreamTest(t, stCtx, initialClaim) + + // Pre-populate health cache + initialHealth := state.DeviceHealth{PoolName: poolName, DeviceName: deviceName, Health: "Unhealthy", LastUpdated: time.Now().Add(-5 * time.Millisecond)} // Ensure LastUpdated is slightly in past + _, err := manager.healthInfoCache.updateHealthInfo(driverName, []state.DeviceHealth{initialHealth}) + require.NoError(t, err, "Failed to pre-populate health cache") + + t.Log("NoActualStateChange: Test Case Started") + responses := make(chan struct { + Resp *drahealthv1alpha1.NodeWatchResourcesResponse + Err error + }, 1) + updateChan, done, streamErrChan := runStreamTest(stCtx, responses) + + // Send the same "Unhealthy" state again + unhealthyDeviceMsg := &drahealthv1alpha1.DeviceHealth{ + Device: &drahealthv1alpha1.DeviceIdentifier{ + PoolName: poolName, + DeviceName: deviceName, + }, + Health: drahealthv1alpha1.HealthStatus_UNHEALTHY, + LastUpdatedTime: time.Now().Unix(), + } + responses <- struct { + Resp *drahealthv1alpha1.NodeWatchResourcesResponse + Err error + }{ + Resp: &drahealthv1alpha1.NodeWatchResourcesResponse{Devices: []*drahealthv1alpha1.DeviceHealth{unhealthyDeviceMsg}}, + } + + select { + case upd := <-updateChan: + t.Fatalf("NoActualStateChange: Unexpected update on manager.update channel: %+v", upd) + case <-time.After(200 * time.Millisecond): + t.Log("NoActualStateChange: Correctly received no update on manager channel.") + } + + close(responses) + var finalErr error + select { + case <-done: + finalErr = <-streamErrChan + t.Log("NoActualStateChange: Stream handler goroutine finished.") + case <-time.After(1 * time.Second): + t.Fatal("NoActualStateChange: Timeout waiting for stream handler to finish after EOF") + } + assert.True(t, finalErr == nil || errors.Is(finalErr, io.EOF), "Expected nil or io.EOF, got %v", finalErr) + }) + + // Test Case 4: Stream error + t.Run("StreamError", func(t *testing.T) { + stCtx, stCancel := context.WithCancel(overallTestCtx) + defer stCancel() + + // Get a new manager and the scoped runStreamTest helper + _, runStreamTest := setupNewManagerAndRunStreamTest(t, stCtx) + t.Log("StreamError: Test Case Started") + + responses := make(chan struct { + Resp *drahealthv1alpha1.NodeWatchResourcesResponse + Err error + }, 1) + _, done, streamErrChan := runStreamTest(stCtx, responses) + + expectedStreamErr := errors.New("simulated mock stream error") + responses <- struct { + Resp *drahealthv1alpha1.NodeWatchResourcesResponse + Err error + }{Err: expectedStreamErr} + + t.Log("StreamError: Waiting on done channel") + var actualErr error + select { + case <-done: + // Read the error propagated from the HandleWatchResourcesStream goroutine + actualErr = <-streamErrChan + t.Logf("StreamError: done channel closed. Stream handler returned: %v", actualErr) + case <-time.After(2 * time.Second): + t.Fatal("StreamError: Timeout waiting for stream handler to finish after error signal") + } + + require.Error(t, actualErr, "HandleWatchResourcesStream should have returned an error") + assert.ErrorIs(t, actualErr, expectedStreamErr) + }) + + // Test Case 5: Context cancellation + t.Run("ContextCanceled", func(t *testing.T) { + stCtx, stCancel := context.WithCancel(overallTestCtx) + // Deliberately do not `defer stCancel()` for this specific test case + + _, runStreamTest := setupNewManagerAndRunStreamTest(t, stCtx) + t.Log("ContextCanceled: Test Case Started") + + responses := make(chan struct { + Resp *drahealthv1alpha1.NodeWatchResourcesResponse + Err error + }) + _, done, streamErrChan := runStreamTest(stCtx, responses) + + t.Log("ContextCanceled: Intentionally canceling context for stream handler after a short delay.") + time.Sleep(50 * time.Millisecond) + stCancel() + + t.Log("ContextCanceled: Waiting on done channel") + var finalErr error + select { + case <-done: + finalErr = <-streamErrChan + t.Log("ContextCanceled: done channel closed. Stream handler finished after context cancellation.") + case <-time.After(1 * time.Second): + t.Fatal("ContextCanceled: Timeout waiting for stream handler to finish after context cancellation") + } + require.Error(t, finalErr) + assert.True(t, errors.Is(finalErr, context.Canceled) || errors.Is(finalErr, context.DeadlineExceeded)) + }) +} + +// TestUpdateAllocatedResourcesStatus checks if the manager correctly updates the +// PodStatus with the health information of allocated DRA resources. It populates +// the caches with known claim and health data, then calls the function and verifies the resulting PodStatus. +func TestUpdateAllocatedResourcesStatus(t *testing.T) { + tCtx := ktesting.Init(t) + + // Setup Manager with caches + manager, err := NewManager(tCtx.Logger(), nil, t.TempDir()) + require.NoError(t, err) + + // Populate claimInfoCache + claimInfo := genTestClaimInfo(claimUID, []string{podUID}, true) + manager.cache.add(claimInfo) + + // Populate healthInfoCache + healthyDevice := state.DeviceHealth{PoolName: poolName, DeviceName: deviceName, Health: "Healthy", LastUpdated: time.Now()} + _, err = manager.healthInfoCache.updateHealthInfo(driverName, []state.DeviceHealth{healthyDevice}) + require.NoError(t, err) + + // Create Pod and Status objects + pod := genTestPod() + require.NotEmpty(t, pod.Spec.Containers, "genTestPod should create at least one container") + // Ensure the container has a name for matching + pod.Spec.Containers[0].Name = containerName + podStatus := &v1.PodStatus{ + ContainerStatuses: []v1.ContainerStatus{ + {Name: containerName}, + }, + } + + // Call the function under test + manager.UpdateAllocatedResourcesStatus(pod, podStatus) + + require.Len(t, podStatus.ContainerStatuses, 1) + contStatus := podStatus.ContainerStatuses[0] + require.NotNil(t, contStatus.AllocatedResourcesStatus) + require.Len(t, contStatus.AllocatedResourcesStatus, 1, "Should have status for one resource claim") + + resourceStatus := contStatus.AllocatedResourcesStatus[0] + assert.Equal(t, v1.ResourceName("claim:"+claimName), resourceStatus.Name, "ResourceStatus Name mismatch") + // Check the Resources slice + require.Len(t, resourceStatus.Resources, 1, "Should have health info for one device") + resourceHealth := resourceStatus.Resources[0] + assert.Equal(t, v1.ResourceID(cdiID), resourceHealth.ResourceID, "ResourceHealth ResourceID mismatch") + assert.Equal(t, v1.ResourceHealthStatusHealthy, resourceHealth.Health, "ResourceHealth Health status mismatch") +} diff --git a/pkg/kubelet/cm/dra/plugin/dra_plugin.go b/pkg/kubelet/cm/dra/plugin/dra_plugin.go index b1f4a1f659f..1529bdcc3af 100644 --- a/pkg/kubelet/cm/dra/plugin/dra_plugin.go +++ b/pkg/kubelet/cm/dra/plugin/dra_plugin.go @@ -18,13 +18,19 @@ package plugin import ( "context" + "errors" "fmt" + "net" + "sync" "time" "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/status" "k8s.io/klog/v2" + drahealthv1alpha1 "k8s.io/kubelet/pkg/apis/dra-health/v1alpha1" drapbv1 "k8s.io/kubelet/pkg/apis/dra/v1" drapbv1beta1 "k8s.io/kubelet/pkg/apis/dra/v1beta1" "k8s.io/kubernetes/pkg/kubelet/metrics" @@ -55,6 +61,72 @@ type DRAPlugin struct { endpoint string chosenService string // e.g. drapbv1.DRAPluginService clientCallTimeout time.Duration + + mutex sync.Mutex + backgroundCtx context.Context + + healthClient drahealthv1alpha1.DRAResourceHealthClient + healthStreamCtx context.Context + healthStreamCancel context.CancelFunc +} + +func (p *DRAPlugin) getOrCreateGRPCConn() (*grpc.ClientConn, error) { + p.mutex.Lock() + defer p.mutex.Unlock() + + // If connection exists and is ready, return it. + if p.conn != nil && p.conn.GetState() != connectivity.Shutdown { + // Initialize health client if connection exists but client is nil + // This allows lazy init if connection was established before health was added. + if p.healthClient == nil { + p.healthClient = drahealthv1alpha1.NewDRAResourceHealthClient(p.conn) + klog.FromContext(p.backgroundCtx).V(4).Info("Initialized DRAResourceHealthClient lazily") + } + return p.conn, nil + } + + // If the connection is dead, clean it up before creating a new one. + if p.conn != nil { + if err := p.conn.Close(); err != nil { + return nil, fmt.Errorf("failed to close stale gRPC connection to %s: %w", p.endpoint, err) + } + p.conn = nil + p.healthClient = nil + } + + ctx := p.backgroundCtx + logger := klog.FromContext(ctx) + + network := "unix" + logger.V(4).Info("Creating new gRPC connection", "protocol", network, "endpoint", p.endpoint) + // grpc.Dial is deprecated. grpc.NewClient should be used instead. + // For now this gets ignored because this function is meant to establish + // the connection, with the one second timeout below. Perhaps that + // approach should be reconsidered? + //nolint:staticcheck + conn, err := grpc.Dial( + p.endpoint, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithContextDialer(func(ctx context.Context, target string) (net.Conn, error) { + return (&net.Dialer{}).DialContext(ctx, network, target) + }), + grpc.WithChainUnaryInterceptor(newMetricsInterceptor(p.driverName)), + ) + if err != nil { + return nil, err + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + if ok := conn.WaitForStateChange(ctx, connectivity.Connecting); !ok { + return nil, errors.New("timed out waiting for gRPC connection to be ready") + } + + p.conn = conn + p.healthClient = drahealthv1alpha1.NewDRAResourceHealthClient(p.conn) + + return p.conn, nil } func (p *DRAPlugin) DriverName() string { @@ -131,3 +203,40 @@ func newMetricsInterceptor(driverName string) grpc.UnaryClientInterceptor { return err } } + +// SetHealthStream stores the context and cancel function for the active health stream. +func (p *DRAPlugin) SetHealthStream(ctx context.Context, cancel context.CancelFunc) { + p.mutex.Lock() + defer p.mutex.Unlock() + p.healthStreamCtx = ctx + p.healthStreamCancel = cancel +} + +// HealthStreamCancel returns the cancel function for the current health stream, if any. +func (p *DRAPlugin) HealthStreamCancel() context.CancelFunc { + p.mutex.Lock() + defer p.mutex.Unlock() + return p.healthStreamCancel +} + +// NodeWatchResources establishes a stream to receive health updates from the DRA plugin. +func (p *DRAPlugin) NodeWatchResources(ctx context.Context) (drahealthv1alpha1.DRAResourceHealth_NodeWatchResourcesClient, error) { + // Ensure a connection and the health client exist before proceeding. + // This call is idempotent and will create them if they don't exist. + _, err := p.getOrCreateGRPCConn() + if err != nil { + klog.FromContext(p.backgroundCtx).Error(err, "Failed to get gRPC connection for health client") + return nil, err + } + + logger := klog.FromContext(ctx).WithValues("pluginName", p.driverName) + logger.V(4).Info("Starting WatchResources stream") + stream, err := p.healthClient.NodeWatchResources(ctx, &drahealthv1alpha1.NodeWatchResourcesRequest{}) + if err != nil { + logger.Error(err, "NodeWatchResources RPC call failed") + return nil, err + } + + logger.V(4).Info("NodeWatchResources stream initiated successfully") + return stream, nil +} diff --git a/pkg/kubelet/cm/dra/plugin/dra_plugin_manager.go b/pkg/kubelet/cm/dra/plugin/dra_plugin_manager.go index e4f7a5932a0..2b76d3bd348 100644 --- a/pkg/kubelet/cm/dra/plugin/dra_plugin_manager.go +++ b/pkg/kubelet/cm/dra/plugin/dra_plugin_manager.go @@ -34,9 +34,11 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/fields" "k8s.io/apimachinery/pkg/util/wait" + utilfeature "k8s.io/apiserver/pkg/util/feature" "k8s.io/client-go/kubernetes" "k8s.io/klog/v2" timedworkers "k8s.io/kubernetes/pkg/controller/tainteviction" // TODO (?): move this common helper somewhere else? + "k8s.io/kubernetes/pkg/features" "k8s.io/kubernetes/pkg/kubelet/pluginmanager/cache" "k8s.io/utils/ptr" ) @@ -58,6 +60,7 @@ type DRAPluginManager struct { kubeClient kubernetes.Interface getNode func() (*v1.Node, error) wipingDelay time.Duration + streamHandler StreamHandler wg sync.WaitGroup mutex sync.RWMutex @@ -134,7 +137,7 @@ func (m *monitoredPlugin) HandleConn(_ context.Context, stats grpcstats.ConnStat // The context can be used to cancel all background activities. // If desired, Stop can be called in addition or instead of canceling // the context. It then also waits for background activities to stop. -func NewDRAPluginManager(ctx context.Context, kubeClient kubernetes.Interface, getNode func() (*v1.Node, error), wipingDelay time.Duration) *DRAPluginManager { +func NewDRAPluginManager(ctx context.Context, kubeClient kubernetes.Interface, getNode func() (*v1.Node, error), streamHandler StreamHandler, wipingDelay time.Duration) *DRAPluginManager { ctx, cancel := context.WithCancelCause(ctx) pm := &DRAPluginManager{ backgroundCtx: klog.NewContext(ctx, klog.LoggerWithName(klog.FromContext(ctx), "DRA registration handler")), @@ -142,6 +145,7 @@ func NewDRAPluginManager(ctx context.Context, kubeClient kubernetes.Interface, g kubeClient: kubeClient, getNode: getNode, wipingDelay: wipingDelay, + streamHandler: streamHandler, } pm.pendingWipes = timedworkers.CreateWorkerQueue(func(ctx context.Context, fireAt time.Time, args *timedworkers.WorkArgs) error { pm.wipeResourceSlices(ctx, args.Object.Name) @@ -238,6 +242,9 @@ func (pm *DRAPluginManager) wipeResourceSlices(ctx context.Context, driver strin // its credentials. logger.V(5).Info("Deleting ResourceSlice failed, retrying", "fieldSelector", fieldSelector, "err", err) return false, nil + case apierrors.IsNotFound(err): + logger.V(5).Info("ResourceSlices not found, nothing to delete.", "fieldSelector", fieldSelector) + return true, nil default: // Log and retry for other errors. logger.V(3).Info("Deleting ResourceSlice failed, retrying", "fieldSelector", fieldSelector, "err", err) @@ -332,6 +339,7 @@ func (pm *DRAPluginManager) add(driverName string, endpoint string, chosenServic endpoint: endpoint, chosenService: chosenService, clientCallTimeout: clientCallTimeout, + backgroundCtx: pm.backgroundCtx, } if pm.store == nil { pm.store = make(map[string][]*monitoredPlugin) @@ -364,6 +372,30 @@ func (pm *DRAPluginManager) add(driverName string, endpoint string, chosenServic } p.conn = conn + if utilfeature.DefaultFeatureGate.Enabled(features.ResourceHealthStatus) { + pm.wg.Add(1) + go func() { + defer pm.wg.Done() + streamCtx, streamCancel := context.WithCancel(p.backgroundCtx) + p.SetHealthStream(streamCtx, streamCancel) + + wait.UntilWithContext(streamCtx, func(ctx context.Context) { + logger.V(4).Info("Attempting to start WatchResources health stream") + stream, err := p.NodeWatchResources(ctx) + if err != nil { + logger.V(3).Error(err, "Failed to establish WatchResources stream, will retry") + return + } + + logger.V(2).Info("Successfully started WatchResources health stream") + + err = pm.streamHandler.HandleWatchResourcesStream(ctx, stream, driverName) + logger.V(2).Info("WatchResources health stream has ended", "error", err) + + }, 5*time.Second) + }() + } + // Ensure that gRPC tries to connect even if we don't call any gRPC method. // This is necessary to detect early whether a plugin is really available. // This is currently an experimental gRPC method. Should it be removed we @@ -418,7 +450,14 @@ func (pm *DRAPluginManager) remove(driverName, endpoint string) { pm.store[driverName] = slices.Delete(plugins, i, i+1) } - logger.V(3).Info("Unregistered DRA plugin", "driverName", p.driverName, "endpoint", p.endpoint, "numPlugins", len(pm.store[driverName])) + // Cancel the plugin's health stream if it was active. + healthCancel := p.HealthStreamCancel() + if healthCancel != nil { + logger.V(4).Info("Canceling health stream during deregistration") + healthCancel() + } + + logger.V(3).Info("Unregistered DRA plugin", "driverName", driverName, "endpoint", endpoint, "numPlugins", len(pm.store[driverName])) pm.sync(driverName) } diff --git a/pkg/kubelet/cm/dra/plugin/dra_plugin_manager_test.go b/pkg/kubelet/cm/dra/plugin/dra_plugin_manager_test.go index c104c9a88fb..207c479e4a9 100644 --- a/pkg/kubelet/cm/dra/plugin/dra_plugin_manager_test.go +++ b/pkg/kubelet/cm/dra/plugin/dra_plugin_manager_test.go @@ -31,7 +31,7 @@ func TestAddSameName(t *testing.T) { driverName := fmt.Sprintf("dummy-driver-%d", rand.IntN(10000)) // ensure the plugin we are using is registered - draPlugins := NewDRAPluginManager(tCtx, nil, nil, 0) + draPlugins := NewDRAPluginManager(tCtx, nil, nil, nil, 0) tCtx.ExpectNoError(draPlugins.add(driverName, "old.sock", "", defaultClientCallTimeout), "add first plugin") p, err := draPlugins.GetPlugin(driverName) tCtx.ExpectNoError(err, "get first plugin") @@ -60,9 +60,14 @@ func TestAddSameName(t *testing.T) { func TestDelete(t *testing.T) { tCtx := ktesting.Init(t) driverName := fmt.Sprintf("dummy-driver-%d", rand.IntN(10000)) + socketFile := "dra.sock" // ensure the plugin we are using is registered - draPlugins := NewDRAPluginManager(tCtx, nil, nil, 0) + draPlugins := NewDRAPluginManager(tCtx, nil, nil, &mockStreamHandler{}, 0) tCtx.ExpectNoError(draPlugins.add(driverName, "dra.sock", "", defaultClientCallTimeout), "add plugin") - draPlugins.remove(driverName, "") + + draPlugins.remove(driverName, socketFile) + + _, err := draPlugins.GetPlugin(driverName) + require.Error(t, err, "plugin should not exist after being removed") } diff --git a/pkg/kubelet/cm/dra/plugin/dra_plugin_test.go b/pkg/kubelet/cm/dra/plugin/dra_plugin_test.go index f76ae74a3f5..b3af8c75b32 100644 --- a/pkg/kubelet/cm/dra/plugin/dra_plugin_test.go +++ b/pkg/kubelet/cm/dra/plugin/dra_plugin_test.go @@ -19,23 +19,27 @@ package plugin import ( "context" "errors" - "fmt" + "io" "net" "path" "strings" "sync" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" + drahealthv1alpha1 "k8s.io/kubelet/pkg/apis/dra-health/v1alpha1" drapbv1 "k8s.io/kubelet/pkg/apis/dra/v1" drapbv1beta1 "k8s.io/kubelet/pkg/apis/dra/v1beta1" "k8s.io/kubernetes/test/utils/ktesting" ) type fakeGRPCServer struct { + drapbv1beta1.UnimplementedDRAPluginServer + drahealthv1alpha1.UnimplementedDRAResourceHealthServer } var _ drapbv1.DRAPluginServer = &fakeGRPCServer{} @@ -56,6 +60,24 @@ func (f *fakeGRPCServer) NodeUnprepareResources(ctx context.Context, in *drapbv1 return &drapbv1.NodeUnprepareResourcesResponse{}, nil } +func (f *fakeGRPCServer) NodeWatchResources(in *drahealthv1alpha1.NodeWatchResourcesRequest, srv drahealthv1alpha1.DRAResourceHealth_NodeWatchResourcesServer) error { + resp := &drahealthv1alpha1.NodeWatchResourcesResponse{ + Devices: []*drahealthv1alpha1.DeviceHealth{ + { + Device: &drahealthv1alpha1.DeviceIdentifier{ + PoolName: "pool1", + DeviceName: "dev1", + }, + Health: drahealthv1alpha1.HealthStatus_HEALTHY, + }, + }, + } + if err := srv.Send(resp); err != nil { + return err + } + return nil +} + // tearDown is an idempotent cleanup function. type tearDown func() @@ -73,13 +95,22 @@ func setupFakeGRPCServer(service, addr string) (tearDown, error) { s := grpc.NewServer() fakeGRPCServer := &fakeGRPCServer{} + switch service { case drapbv1.DRAPluginService: drapbv1.RegisterDRAPluginServer(s, fakeGRPCServer) case drapbv1beta1.DRAPluginService: drapbv1beta1.RegisterDRAPluginServer(s, drapbv1beta1.V1ServerWrapper{DRAPluginServer: fakeGRPCServer}) + case drahealthv1alpha1.DRAResourceHealth_ServiceDesc.ServiceName: + drahealthv1alpha1.RegisterDRAResourceHealthServer(s, fakeGRPCServer) default: - return nil, fmt.Errorf("unsupported gRPC service: %s", service) + if service == "" { + drapbv1.RegisterDRAPluginServer(s, fakeGRPCServer) + drapbv1beta1.RegisterDRAPluginServer(s, drapbv1beta1.V1ServerWrapper{DRAPluginServer: fakeGRPCServer}) + drahealthv1alpha1.RegisterDRAResourceHealthServer(s, fakeGRPCServer) + } else { + return nil, err + } } go func() { @@ -100,9 +131,7 @@ func TestGRPCConnIsReused(t *testing.T) { service := drapbv1.DRAPluginService addr := path.Join(t.TempDir(), "dra.sock") teardown, err := setupFakeGRPCServer(service, addr) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer teardown() reusedConns := make(map[*grpc.ClientConn]int) @@ -112,7 +141,7 @@ func TestGRPCConnIsReused(t *testing.T) { driverName := "dummy-driver" // ensure the plugin we are using is registered - draPlugins := NewDRAPluginManager(tCtx, nil, nil, 0) + draPlugins := NewDRAPluginManager(tCtx, nil, nil, &mockStreamHandler{}, 0) tCtx.ExpectNoError(draPlugins.add(driverName, addr, service, defaultClientCallTimeout), "add plugin") plugin, err := draPlugins.GetPlugin(driverName) tCtx.ExpectNoError(err, "get plugin") @@ -152,12 +181,8 @@ func TestGRPCConnIsReused(t *testing.T) { wg.Wait() // We should have only one entry otherwise it means another gRPC connection has been created - if len(reusedConns) != 1 { - t.Errorf("expected length to be 1 but got %d", len(reusedConns)) - } - if counter, ok := reusedConns[conn]; ok && counter != 2 { - t.Errorf("expected counter to be 2 but got %d", counter) - } + require.Len(t, reusedConns, 1, "expected length to be 1 but got %d", len(reusedConns)) + require.Equal(t, 2, reusedConns[conn], "expected counter to be 2 but got %d", reusedConns[conn]) } func TestGetDRAPlugin(t *testing.T) { @@ -186,7 +211,7 @@ func TestGetDRAPlugin(t *testing.T) { } { t.Run(test.description, func(t *testing.T) { tCtx := ktesting.Init(t) - draPlugins := NewDRAPluginManager(tCtx, nil, nil, 0) + draPlugins := NewDRAPluginManager(tCtx, nil, nil, &mockStreamHandler{}, 0) if test.setup != nil { require.NoError(t, test.setup(draPlugins), "setup plugin") } @@ -244,8 +269,9 @@ func TestGRPCMethods(t *testing.T) { defer teardown() driverName := "dummy-driver" - draPlugins := NewDRAPluginManager(tCtx, nil, nil, 0) + draPlugins := NewDRAPluginManager(tCtx, nil, nil, &mockStreamHandler{}, 0) tCtx.ExpectNoError(draPlugins.add(driverName, addr, test.chosenService, defaultClientCallTimeout)) + plugin, err := draPlugins.GetPlugin(driverName) if err != nil { t.Fatal(err) @@ -271,3 +297,41 @@ func assertError(t *testing.T, expectError string, err error) { t.Errorf("Expected error %q, got: %v", expectError, err) } } + +func TestPlugin_WatchResources(t *testing.T) { + tCtx := ktesting.Init(t) + ctx, cancel := context.WithCancel(tCtx) + defer cancel() + + driverName := "test-driver" + addr := path.Join(t.TempDir(), "dra.sock") + + teardown, err := setupFakeGRPCServer("", addr) + require.NoError(t, err) + defer teardown() + + draPlugins := NewDRAPluginManager(tCtx, nil, nil, &mockStreamHandler{}, 0) + err = draPlugins.add(driverName, addr, drapbv1beta1.DRAPluginService, 5*time.Second) + require.NoError(t, err) + defer draPlugins.remove(driverName, addr) + + p, err := draPlugins.GetPlugin(driverName) + require.NoError(t, err) + + stream, err := p.NodeWatchResources(ctx) + require.NoError(t, err) + require.NotNil(t, stream) + + // 1. Receive the first message that our fake server sends. + resp, err := stream.Recv() + require.NoError(t, err, "The first Recv() should succeed with the message from the server") + require.NotNil(t, resp) + require.Len(t, resp.Devices, 1) + assert.Equal(t, "pool1", resp.Devices[0].GetDevice().GetPoolName()) + assert.Equal(t, drahealthv1alpha1.HealthStatus_HEALTHY, resp.Devices[0].GetHealth()) + + // 2. The second receive should fail with io.EOF because the server + // closed the stream by returning nil. This confirms the stream ended cleanly. + _, err = stream.Recv() + require.ErrorIs(t, err, io.EOF, "The second Recv() should return an io.EOF error to signal a clean stream closure") +} diff --git a/pkg/kubelet/cm/dra/plugin/registration_test.go b/pkg/kubelet/cm/dra/plugin/registration_test.go index 70a33c7a276..acf674162ef 100644 --- a/pkg/kubelet/cm/dra/plugin/registration_test.go +++ b/pkg/kubelet/cm/dra/plugin/registration_test.go @@ -151,7 +151,7 @@ func TestRegistrationHandler(t *testing.T) { description: "two-services", driverName: pluginB, socketFile: socketFileB, - supportedServices: []string{drapb.DRAPluginService /* TODO: add v1 here once we have it */}, + supportedServices: []string{drapb.DRAPluginService, "v1alpha1.NodeHealth" /* TODO: add v1 here once we have it */}, chosenService: drapb.DRAPluginService, }, // TODO: use v1beta1 here once we have v1 @@ -221,7 +221,7 @@ func TestRegistrationHandler(t *testing.T) { } // The DRAPluginManager wipes all slices at startup. - draPlugins := NewDRAPluginManager(tCtx, client, getFakeNode, time.Second /* very short wiping delay for testing */) + draPlugins := NewDRAPluginManager(tCtx, client, getFakeNode, &mockStreamHandler{}, time.Second /* very short wiping delay for testing */) tCtx.Cleanup(draPlugins.Stop) if test.withClient { requireNoSlices(tCtx) @@ -309,7 +309,7 @@ func TestConnectionHandling(t *testing.T) { tCtx = ktesting.WithClients(tCtx, nil, nil, client, nil, nil) // The handler wipes all slices at startup. - draPlugins := NewDRAPluginManager(tCtx, client, getFakeNode, test.delay) + draPlugins := NewDRAPluginManager(tCtx, client, getFakeNode, &mockStreamHandler{}, test.delay) tCtx.Cleanup(draPlugins.Stop) requireNoSlices(tCtx) diff --git a/pkg/kubelet/cm/dra/plugin/testing_helpers_test.go b/pkg/kubelet/cm/dra/plugin/testing_helpers_test.go new file mode 100644 index 00000000000..63d5b1f4c98 --- /dev/null +++ b/pkg/kubelet/cm/dra/plugin/testing_helpers_test.go @@ -0,0 +1,34 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package plugin + +import ( + "context" + + drahealthv1alpha1 "k8s.io/kubelet/pkg/apis/dra-health/v1alpha1" +) + +// mockStreamHandler is a mock implementation of the StreamHandler interface, +// shared across all tests in this package. +type mockStreamHandler struct{} + +var _ StreamHandler = &mockStreamHandler{} + +func (m *mockStreamHandler) HandleWatchResourcesStream(ctx context.Context, stream drahealthv1alpha1.DRAResourceHealth_NodeWatchResourcesClient, resourceName string) error { + <-ctx.Done() + return nil +} diff --git a/pkg/kubelet/cm/dra/plugin/types.go b/pkg/kubelet/cm/dra/plugin/types.go new file mode 100644 index 00000000000..8d8acabd0c2 --- /dev/null +++ b/pkg/kubelet/cm/dra/plugin/types.go @@ -0,0 +1,31 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package plugin + +import ( + "context" + + drahealthv1alpha1 "k8s.io/kubelet/pkg/apis/dra-health/v1alpha1" +) + +// StreamHandler defines the interface for handling DRA health streams. +// This interface is implemented by the DRA Manager to decouple the plugin +// package from the manager package, breaking the import cycle. +type StreamHandler interface { + // HandleWatchResourcesStream processes health updates from a specific DRA plugin stream. + HandleWatchResourcesStream(ctx context.Context, stream drahealthv1alpha1.DRAResourceHealth_NodeWatchResourcesClient, resourceName string) error +} diff --git a/pkg/kubelet/cm/dra/state/state.go b/pkg/kubelet/cm/dra/state/state.go index 045a9b6b2f4..3f15d3c38e9 100644 --- a/pkg/kubelet/cm/dra/state/state.go +++ b/pkg/kubelet/cm/dra/state/state.go @@ -17,6 +17,8 @@ limitations under the License. package state import ( + "time" + "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/sets" ) @@ -57,3 +59,40 @@ type Device struct { RequestNames []string CDIDeviceIDs []string } + +// DevicesHealthMap is a map between driver names and the list of the device's health. +type DevicesHealthMap map[string]DriverHealthState + +// DriverHealthState is used to store health information of all devices of a driver. +type DriverHealthState struct { + // Devices maps a device's unique key ("/") to its health state. + Devices map[string]DeviceHealth +} + +type DeviceHealthStatus string + +const ( + // DeviceHealthStatusHealthy represents a healthy device. + DeviceHealthStatusHealthy DeviceHealthStatus = "Healthy" + // DeviceHealthStatusUnhealthy represents an unhealthy device. + DeviceHealthStatusUnhealthy DeviceHealthStatus = "Unhealthy" + // DeviceHealthStatusUnknown represents a device with unknown health status. + DeviceHealthStatusUnknown DeviceHealthStatus = "Unknown" +) + +// DeviceHealth is used to store health information of a device. +type DeviceHealth struct { + // PoolName is the name of the pool where the device is allocated. + PoolName string + + // DeviceName is the name of the device. + // The full identifier is '//' across the system. + DeviceName string + + // Health is the health status of the device. + // Statuses: "Healthy", "Unhealthy", "Unknown". + Health DeviceHealthStatus + + // LastUpdated keeps track of the last health status update of this device. + LastUpdated time.Time +} diff --git a/pkg/kubelet/kubelet.go b/pkg/kubelet/kubelet.go index 6c6c230db40..b507f679519 100644 --- a/pkg/kubelet/kubelet.go +++ b/pkg/kubelet/kubelet.go @@ -2522,7 +2522,6 @@ func (kl *Kubelet) syncLoopIteration(ctx context.Context, configCh <-chan kubety } kl.sourcesReady.AddSource(u.Source) - case e := <-plegCh: if isSyncPodWorthy(e) { // PLEG event for a pod; sync it. @@ -2586,7 +2585,6 @@ func (kl *Kubelet) syncLoopIteration(ctx context.Context, configCh <-chan kubety // We do not apply the optimization by updating the status directly, but can do it later handler.HandlePodSyncs(pods) } - case <-housekeepingCh: if !kl.sourcesReady.AllReady() { // If the sources aren't ready or volume manager has not yet synced the states, diff --git a/staging/src/k8s.io/dynamic-resource-allocation/kubeletplugin/draplugin.go b/staging/src/k8s.io/dynamic-resource-allocation/kubeletplugin/draplugin.go index b273c82d0c1..33f31aaa834 100644 --- a/staging/src/k8s.io/dynamic-resource-allocation/kubeletplugin/draplugin.go +++ b/staging/src/k8s.io/dynamic-resource-allocation/kubeletplugin/draplugin.go @@ -37,6 +37,7 @@ import ( draclient "k8s.io/dynamic-resource-allocation/client" "k8s.io/dynamic-resource-allocation/resourceclaim" "k8s.io/dynamic-resource-allocation/resourceslice" + drahealthv1alpha1 "k8s.io/kubelet/pkg/apis/dra-health/v1alpha1" drapbv1 "k8s.io/kubelet/pkg/apis/dra/v1" drapbv1beta1 "k8s.io/kubelet/pkg/apis/dra/v1beta1" registerapi "k8s.io/kubelet/pkg/apis/pluginregistration/v1" @@ -487,6 +488,7 @@ type options struct { nodeV1 bool registrationService bool draService bool + healthService *bool } // Helper combines the kubelet registration service and the DRA node plugin @@ -618,6 +620,12 @@ func Start(ctx context.Context, plugin DRAPlugin, opts ...Option) (result *Helpe if o.nodeV1beta1 { supportedServices = append(supportedServices, drapbv1beta1.DRAPluginService) } + // Check if the plugin implements the DRAResourceHealth service. + if _, ok := plugin.(drahealthv1alpha1.DRAResourceHealthServer); ok { + // If it does, add it to the list of services this plugin supports. + logger.V(5).Info("detected v1alpha1.DRAResourceHealth gRPC service") + supportedServices = append(supportedServices, drahealthv1alpha1.DRAResourceHealth_ServiceDesc.ServiceName) + } if len(supportedServices) == 0 { return nil, errors.New("no supported DRA gRPC API is implemented and enabled") } @@ -635,7 +643,7 @@ func Start(ctx context.Context, plugin DRAPlugin, opts ...Option) (result *Helpe o.unaryInterceptors, o.streamInterceptors, draEndpoint, - func(ctx context.Context, err error) { + func(ctx context.Context, err error) { // This error handler is REQUIRED plugin.HandleError(ctx, err, "DRA gRPC server failed") }, func(grpcServer *grpc.Server) { @@ -643,10 +651,18 @@ func Start(ctx context.Context, plugin DRAPlugin, opts ...Option) (result *Helpe logger.V(5).Info("registering v1.DRAPlugin gRPC service") drapbv1.RegisterDRAPluginServer(grpcServer, &nodePluginImplementation{Helper: d}) } + if o.nodeV1beta1 { logger.V(5).Info("registering v1beta1.DRAPlugin gRPC service") drapbv1beta1.RegisterDRAPluginServer(grpcServer, drapbv1beta1.V1ServerWrapper{DRAPluginServer: &nodePluginImplementation{Helper: d}}) } + + if heatlhServer, ok := d.plugin.(drahealthv1alpha1.DRAResourceHealthServer); ok { + if o.healthService == nil || *o.healthService { + logger.V(5).Info("registering v1alpha1.DRAResourceHealth gRPC service") + drahealthv1alpha1.RegisterDRAResourceHealthServer(grpcServer, heatlhServer) + } + } }, ) if err != nil { diff --git a/staging/src/k8s.io/kubelet/pkg/apis/dra-health/v1alpha1/api.pb.go b/staging/src/k8s.io/kubelet/pkg/apis/dra-health/v1alpha1/api.pb.go new file mode 100644 index 00000000000..dc78d8a0ade --- /dev/null +++ b/staging/src/k8s.io/kubelet/pkg/apis/dra-health/v1alpha1/api.pb.go @@ -0,0 +1,414 @@ +/* +Copyright The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Copyright 2025 The Kubernetes Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.4 +// protoc v4.23.4 +// source: staging/src/k8s.io/kubelet/pkg/apis/dra-health/v1alpha1/api.proto + +package v1alpha1 + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// HealthStatus defines the possible health states of a device. +type HealthStatus int32 + +const ( + // UNKNOWN indicates that the health of the device cannot be determined. + HealthStatus_UNKNOWN HealthStatus = 0 + // HEALTHY indicates that the device is operating normally. + HealthStatus_HEALTHY HealthStatus = 1 + // UNHEALTHY indicates that the device has reported a problem. + HealthStatus_UNHEALTHY HealthStatus = 2 +) + +// Enum value maps for HealthStatus. +var ( + HealthStatus_name = map[int32]string{ + 0: "UNKNOWN", + 1: "HEALTHY", + 2: "UNHEALTHY", + } + HealthStatus_value = map[string]int32{ + "UNKNOWN": 0, + "HEALTHY": 1, + "UNHEALTHY": 2, + } +) + +func (x HealthStatus) Enum() *HealthStatus { + p := new(HealthStatus) + *p = x + return p +} + +func (x HealthStatus) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (HealthStatus) Descriptor() protoreflect.EnumDescriptor { + return file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_enumTypes[0].Descriptor() +} + +func (HealthStatus) Type() protoreflect.EnumType { + return &file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_enumTypes[0] +} + +func (x HealthStatus) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use HealthStatus.Descriptor instead. +func (HealthStatus) EnumDescriptor() ([]byte, []int) { + return file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_rawDescGZIP(), []int{0} +} + +type NodeWatchResourcesRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *NodeWatchResourcesRequest) Reset() { + *x = NodeWatchResourcesRequest{} + mi := &file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *NodeWatchResourcesRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*NodeWatchResourcesRequest) ProtoMessage() {} + +func (x *NodeWatchResourcesRequest) ProtoReflect() protoreflect.Message { + mi := &file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use NodeWatchResourcesRequest.ProtoReflect.Descriptor instead. +func (*NodeWatchResourcesRequest) Descriptor() ([]byte, []int) { + return file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_rawDescGZIP(), []int{0} +} + +// DeviceIdentifier uniquely identifies a device within the scope of a driver. +type DeviceIdentifier struct { + state protoimpl.MessageState `protogen:"open.v1"` + // The pool which contains the device. + PoolName string `protobuf:"bytes,1,opt,name=pool_name,json=poolName,proto3" json:"pool_name,omitempty"` + // The unique name of the device within the pool. + DeviceName string `protobuf:"bytes,2,opt,name=device_name,json=deviceName,proto3" json:"device_name,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *DeviceIdentifier) Reset() { + *x = DeviceIdentifier{} + mi := &file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *DeviceIdentifier) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DeviceIdentifier) ProtoMessage() {} + +func (x *DeviceIdentifier) ProtoReflect() protoreflect.Message { + mi := &file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DeviceIdentifier.ProtoReflect.Descriptor instead. +func (*DeviceIdentifier) Descriptor() ([]byte, []int) { + return file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_rawDescGZIP(), []int{1} +} + +func (x *DeviceIdentifier) GetPoolName() string { + if x != nil { + return x.PoolName + } + return "" +} + +func (x *DeviceIdentifier) GetDeviceName() string { + if x != nil { + return x.DeviceName + } + return "" +} + +// DeviceHealth represents the health of a single device. +type DeviceHealth struct { + state protoimpl.MessageState `protogen:"open.v1"` + // The identifier for the device. + Device *DeviceIdentifier `protobuf:"bytes,1,opt,name=device,proto3" json:"device,omitempty"` + // The health status of the device. + Health HealthStatus `protobuf:"varint,2,opt,name=health,proto3,enum=v1alpha1.HealthStatus" json:"health,omitempty"` + // The Unix time (in seconds) of when this health status was last determined by the plugin. + LastUpdatedTime int64 `protobuf:"varint,3,opt,name=last_updated_time,json=lastUpdatedTime,proto3" json:"last_updated_time,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *DeviceHealth) Reset() { + *x = DeviceHealth{} + mi := &file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *DeviceHealth) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DeviceHealth) ProtoMessage() {} + +func (x *DeviceHealth) ProtoReflect() protoreflect.Message { + mi := &file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DeviceHealth.ProtoReflect.Descriptor instead. +func (*DeviceHealth) Descriptor() ([]byte, []int) { + return file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_rawDescGZIP(), []int{2} +} + +func (x *DeviceHealth) GetDevice() *DeviceIdentifier { + if x != nil { + return x.Device + } + return nil +} + +func (x *DeviceHealth) GetHealth() HealthStatus { + if x != nil { + return x.Health + } + return HealthStatus_UNKNOWN +} + +func (x *DeviceHealth) GetLastUpdatedTime() int64 { + if x != nil { + return x.LastUpdatedTime + } + return 0 +} + +// NodeWatchResourcesResponse contains a list of devices and their current health. +// This should be a complete list for the driver; Kubelet will reconcile this +// state with its internal cache. Any devices managed by the driver that are +// not in this list will be considered to have an "Unknown" health status after a timeout. +type NodeWatchResourcesResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Devices []*DeviceHealth `protobuf:"bytes,1,rep,name=devices,proto3" json:"devices,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *NodeWatchResourcesResponse) Reset() { + *x = NodeWatchResourcesResponse{} + mi := &file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *NodeWatchResourcesResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*NodeWatchResourcesResponse) ProtoMessage() {} + +func (x *NodeWatchResourcesResponse) ProtoReflect() protoreflect.Message { + mi := &file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use NodeWatchResourcesResponse.ProtoReflect.Descriptor instead. +func (*NodeWatchResourcesResponse) Descriptor() ([]byte, []int) { + return file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_rawDescGZIP(), []int{3} +} + +func (x *NodeWatchResourcesResponse) GetDevices() []*DeviceHealth { + if x != nil { + return x.Devices + } + return nil +} + +var File_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto protoreflect.FileDescriptor + +var file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_rawDesc = string([]byte{ + 0x0a, 0x41, 0x73, 0x74, 0x61, 0x67, 0x69, 0x6e, 0x67, 0x2f, 0x73, 0x72, 0x63, 0x2f, 0x6b, 0x38, + 0x73, 0x2e, 0x69, 0x6f, 0x2f, 0x6b, 0x75, 0x62, 0x65, 0x6c, 0x65, 0x74, 0x2f, 0x70, 0x6b, 0x67, + 0x2f, 0x61, 0x70, 0x69, 0x73, 0x2f, 0x64, 0x72, 0x61, 0x2d, 0x68, 0x65, 0x61, 0x6c, 0x74, 0x68, + 0x2f, 0x76, 0x31, 0x61, 0x6c, 0x70, 0x68, 0x61, 0x31, 0x2f, 0x61, 0x70, 0x69, 0x2e, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x12, 0x08, 0x76, 0x31, 0x61, 0x6c, 0x70, 0x68, 0x61, 0x31, 0x22, 0x1b, 0x0a, + 0x19, 0x4e, 0x6f, 0x64, 0x65, 0x57, 0x61, 0x74, 0x63, 0x68, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, + 0x63, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x50, 0x0a, 0x10, 0x44, 0x65, + 0x76, 0x69, 0x63, 0x65, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x66, 0x69, 0x65, 0x72, 0x12, 0x1b, + 0x0a, 0x09, 0x70, 0x6f, 0x6f, 0x6c, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x08, 0x70, 0x6f, 0x6f, 0x6c, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x64, + 0x65, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0a, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x22, 0x9e, 0x01, 0x0a, + 0x0c, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x12, 0x32, 0x0a, + 0x06, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, + 0x76, 0x31, 0x61, 0x6c, 0x70, 0x68, 0x61, 0x31, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x49, + 0x64, 0x65, 0x6e, 0x74, 0x69, 0x66, 0x69, 0x65, 0x72, 0x52, 0x06, 0x64, 0x65, 0x76, 0x69, 0x63, + 0x65, 0x12, 0x2e, 0x0a, 0x06, 0x68, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x0e, 0x32, 0x16, 0x2e, 0x76, 0x31, 0x61, 0x6c, 0x70, 0x68, 0x61, 0x31, 0x2e, 0x48, 0x65, 0x61, + 0x6c, 0x74, 0x68, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x06, 0x68, 0x65, 0x61, 0x6c, 0x74, + 0x68, 0x12, 0x2a, 0x0a, 0x11, 0x6c, 0x61, 0x73, 0x74, 0x5f, 0x75, 0x70, 0x64, 0x61, 0x74, 0x65, + 0x64, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0f, 0x6c, 0x61, + 0x73, 0x74, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x64, 0x54, 0x69, 0x6d, 0x65, 0x22, 0x4e, 0x0a, + 0x1a, 0x4e, 0x6f, 0x64, 0x65, 0x57, 0x61, 0x74, 0x63, 0x68, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, + 0x63, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x30, 0x0a, 0x07, 0x64, + 0x65, 0x76, 0x69, 0x63, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x76, + 0x31, 0x61, 0x6c, 0x70, 0x68, 0x61, 0x31, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x48, 0x65, + 0x61, 0x6c, 0x74, 0x68, 0x52, 0x07, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, 0x73, 0x2a, 0x37, 0x0a, + 0x0c, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x0b, 0x0a, + 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x0b, 0x0a, 0x07, 0x48, 0x45, + 0x41, 0x4c, 0x54, 0x48, 0x59, 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x55, 0x4e, 0x48, 0x45, 0x41, + 0x4c, 0x54, 0x48, 0x59, 0x10, 0x02, 0x32, 0x78, 0x0a, 0x11, 0x44, 0x52, 0x41, 0x52, 0x65, 0x73, + 0x6f, 0x75, 0x72, 0x63, 0x65, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x12, 0x63, 0x0a, 0x12, 0x4e, + 0x6f, 0x64, 0x65, 0x57, 0x61, 0x74, 0x63, 0x68, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, + 0x73, 0x12, 0x23, 0x2e, 0x76, 0x31, 0x61, 0x6c, 0x70, 0x68, 0x61, 0x31, 0x2e, 0x4e, 0x6f, 0x64, + 0x65, 0x57, 0x61, 0x74, 0x63, 0x68, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x73, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x24, 0x2e, 0x76, 0x31, 0x61, 0x6c, 0x70, 0x68, 0x61, + 0x31, 0x2e, 0x4e, 0x6f, 0x64, 0x65, 0x57, 0x61, 0x74, 0x63, 0x68, 0x52, 0x65, 0x73, 0x6f, 0x75, + 0x72, 0x63, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x30, 0x01, + 0x42, 0x2d, 0x5a, 0x2b, 0x6b, 0x38, 0x73, 0x2e, 0x69, 0x6f, 0x2f, 0x6b, 0x75, 0x62, 0x65, 0x6c, + 0x65, 0x74, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x61, 0x70, 0x69, 0x73, 0x2f, 0x64, 0x72, 0x61, 0x2d, + 0x68, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x2f, 0x76, 0x31, 0x61, 0x6c, 0x70, 0x68, 0x61, 0x31, 0x62, + 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +}) + +var ( + file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_rawDescOnce sync.Once + file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_rawDescData []byte +) + +func file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_rawDescGZIP() []byte { + file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_rawDescOnce.Do(func() { + file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_rawDesc), len(file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_rawDesc))) + }) + return file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_rawDescData +} + +var file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_msgTypes = make([]protoimpl.MessageInfo, 4) +var file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_goTypes = []any{ + (HealthStatus)(0), // 0: v1alpha1.HealthStatus + (*NodeWatchResourcesRequest)(nil), // 1: v1alpha1.NodeWatchResourcesRequest + (*DeviceIdentifier)(nil), // 2: v1alpha1.DeviceIdentifier + (*DeviceHealth)(nil), // 3: v1alpha1.DeviceHealth + (*NodeWatchResourcesResponse)(nil), // 4: v1alpha1.NodeWatchResourcesResponse +} +var file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_depIdxs = []int32{ + 2, // 0: v1alpha1.DeviceHealth.device:type_name -> v1alpha1.DeviceIdentifier + 0, // 1: v1alpha1.DeviceHealth.health:type_name -> v1alpha1.HealthStatus + 3, // 2: v1alpha1.NodeWatchResourcesResponse.devices:type_name -> v1alpha1.DeviceHealth + 1, // 3: v1alpha1.DRAResourceHealth.NodeWatchResources:input_type -> v1alpha1.NodeWatchResourcesRequest + 4, // 4: v1alpha1.DRAResourceHealth.NodeWatchResources:output_type -> v1alpha1.NodeWatchResourcesResponse + 4, // [4:5] is the sub-list for method output_type + 3, // [3:4] is the sub-list for method input_type + 3, // [3:3] is the sub-list for extension type_name + 3, // [3:3] is the sub-list for extension extendee + 0, // [0:3] is the sub-list for field type_name +} + +func init() { file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_init() } +func file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_init() { + if File_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_rawDesc), len(file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_rawDesc)), + NumEnums: 1, + NumMessages: 4, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_goTypes, + DependencyIndexes: file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_depIdxs, + EnumInfos: file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_enumTypes, + MessageInfos: file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_msgTypes, + }.Build() + File_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto = out.File + file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_goTypes = nil + file_staging_src_k8s_io_kubelet_pkg_apis_dra_health_v1alpha1_api_proto_depIdxs = nil +} diff --git a/staging/src/k8s.io/kubelet/pkg/apis/dra-health/v1alpha1/api.proto b/staging/src/k8s.io/kubelet/pkg/apis/dra-health/v1alpha1/api.proto new file mode 100644 index 00000000000..08d83d89ea1 --- /dev/null +++ b/staging/src/k8s.io/kubelet/pkg/apis/dra-health/v1alpha1/api.proto @@ -0,0 +1,67 @@ +// Copyright 2025 The Kubernetes Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; +package v1alpha1; +option go_package = "k8s.io/kubelet/pkg/apis/dra-health/v1alpha1"; + + +// DRAResourceHealth service is implemented by DRA plugins and called by Kubelet. +service DRAResourceHealth { + // NodeWatchResources allows a DRA plugin to stream health updates for its devices to Kubelet. + rpc NodeWatchResources(NodeWatchResourcesRequest) returns (stream NodeWatchResourcesResponse) {} +} + +message NodeWatchResourcesRequest { + // Reserved for future use. +} + +// HealthStatus defines the possible health states of a device. +enum HealthStatus { + // UNKNOWN indicates that the health of the device cannot be determined. + UNKNOWN = 0; + // HEALTHY indicates that the device is operating normally. + HEALTHY = 1; + // UNHEALTHY indicates that the device has reported a problem. + UNHEALTHY = 2; +} + +// DeviceIdentifier uniquely identifies a device within the scope of a driver. +message DeviceIdentifier { + // The pool which contains the device. + string pool_name = 1; + + // The unique name of the device within the pool. + string device_name = 2; +} + +// DeviceHealth represents the health of a single device. +message DeviceHealth { + // The identifier for the device. + DeviceIdentifier device = 1; + + // The health status of the device. + HealthStatus health = 2; + + // The Unix time (in seconds) of when this health status was last determined by the plugin. + int64 last_updated_time = 3; +} + +// NodeWatchResourcesResponse contains a list of devices and their current health. +// This should be a complete list for the driver; Kubelet will reconcile this +// state with its internal cache. Any devices managed by the driver that are +// not in this list will be considered to have an "Unknown" health status after a timeout. +message NodeWatchResourcesResponse { + repeated DeviceHealth devices = 1; +} \ No newline at end of file diff --git a/staging/src/k8s.io/kubelet/pkg/apis/dra-health/v1alpha1/api_grpc.pb.go b/staging/src/k8s.io/kubelet/pkg/apis/dra-health/v1alpha1/api_grpc.pb.go new file mode 100644 index 00000000000..bc38e04203c --- /dev/null +++ b/staging/src/k8s.io/kubelet/pkg/apis/dra-health/v1alpha1/api_grpc.pb.go @@ -0,0 +1,160 @@ +/* +Copyright The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Copyright 2025 The Kubernetes Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.5.1 +// - protoc v4.23.4 +// source: staging/src/k8s.io/kubelet/pkg/apis/dra-health/v1alpha1/api.proto + +package v1alpha1 + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + DRAResourceHealth_NodeWatchResources_FullMethodName = "/v1alpha1.DRAResourceHealth/NodeWatchResources" +) + +// DRAResourceHealthClient is the client API for DRAResourceHealth service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +// +// DRAResourceHealth service is implemented by DRA plugins and called by Kubelet. +type DRAResourceHealthClient interface { + // NodeWatchResources allows a DRA plugin to stream health updates for its devices to Kubelet. + NodeWatchResources(ctx context.Context, in *NodeWatchResourcesRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[NodeWatchResourcesResponse], error) +} + +type dRAResourceHealthClient struct { + cc grpc.ClientConnInterface +} + +func NewDRAResourceHealthClient(cc grpc.ClientConnInterface) DRAResourceHealthClient { + return &dRAResourceHealthClient{cc} +} + +func (c *dRAResourceHealthClient) NodeWatchResources(ctx context.Context, in *NodeWatchResourcesRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[NodeWatchResourcesResponse], error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + stream, err := c.cc.NewStream(ctx, &DRAResourceHealth_ServiceDesc.Streams[0], DRAResourceHealth_NodeWatchResources_FullMethodName, cOpts...) + if err != nil { + return nil, err + } + x := &grpc.GenericClientStream[NodeWatchResourcesRequest, NodeWatchResourcesResponse]{ClientStream: stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type DRAResourceHealth_NodeWatchResourcesClient = grpc.ServerStreamingClient[NodeWatchResourcesResponse] + +// DRAResourceHealthServer is the server API for DRAResourceHealth service. +// All implementations must embed UnimplementedDRAResourceHealthServer +// for forward compatibility. +// +// DRAResourceHealth service is implemented by DRA plugins and called by Kubelet. +type DRAResourceHealthServer interface { + // NodeWatchResources allows a DRA plugin to stream health updates for its devices to Kubelet. + NodeWatchResources(*NodeWatchResourcesRequest, grpc.ServerStreamingServer[NodeWatchResourcesResponse]) error + mustEmbedUnimplementedDRAResourceHealthServer() +} + +// UnimplementedDRAResourceHealthServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedDRAResourceHealthServer struct{} + +func (UnimplementedDRAResourceHealthServer) NodeWatchResources(*NodeWatchResourcesRequest, grpc.ServerStreamingServer[NodeWatchResourcesResponse]) error { + return status.Errorf(codes.Unimplemented, "method NodeWatchResources not implemented") +} +func (UnimplementedDRAResourceHealthServer) mustEmbedUnimplementedDRAResourceHealthServer() {} +func (UnimplementedDRAResourceHealthServer) testEmbeddedByValue() {} + +// UnsafeDRAResourceHealthServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to DRAResourceHealthServer will +// result in compilation errors. +type UnsafeDRAResourceHealthServer interface { + mustEmbedUnimplementedDRAResourceHealthServer() +} + +func RegisterDRAResourceHealthServer(s grpc.ServiceRegistrar, srv DRAResourceHealthServer) { + // If the following call pancis, it indicates UnimplementedDRAResourceHealthServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&DRAResourceHealth_ServiceDesc, srv) +} + +func _DRAResourceHealth_NodeWatchResources_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(NodeWatchResourcesRequest) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(DRAResourceHealthServer).NodeWatchResources(m, &grpc.GenericServerStream[NodeWatchResourcesRequest, NodeWatchResourcesResponse]{ServerStream: stream}) +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type DRAResourceHealth_NodeWatchResourcesServer = grpc.ServerStreamingServer[NodeWatchResourcesResponse] + +// DRAResourceHealth_ServiceDesc is the grpc.ServiceDesc for DRAResourceHealth service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var DRAResourceHealth_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "v1alpha1.DRAResourceHealth", + HandlerType: (*DRAResourceHealthServer)(nil), + Methods: []grpc.MethodDesc{}, + Streams: []grpc.StreamDesc{ + { + StreamName: "NodeWatchResources", + Handler: _DRAResourceHealth_NodeWatchResources_Handler, + ServerStreams: true, + }, + }, + Metadata: "staging/src/k8s.io/kubelet/pkg/apis/dra-health/v1alpha1/api.proto", +} diff --git a/test/e2e/dra/test-driver/app/kubeletplugin.go b/test/e2e/dra/test-driver/app/kubeletplugin.go index 9327596916d..59eb71397ab 100644 --- a/test/e2e/dra/test-driver/app/kubeletplugin.go +++ b/test/e2e/dra/test-driver/app/kubeletplugin.go @@ -27,6 +27,7 @@ import ( "sort" "strings" "sync" + "time" "google.golang.org/grpc" @@ -42,9 +43,21 @@ import ( "k8s.io/dynamic-resource-allocation/resourceclaim" "k8s.io/dynamic-resource-allocation/resourceslice" "k8s.io/klog/v2" + drahealthv1alpha1 "k8s.io/kubelet/pkg/apis/dra-health/v1alpha1" ) +type Options struct { + EnableHealthService bool +} + +type DeviceHealthUpdate struct { + PoolName string + DeviceName string + Health string +} + type ExamplePlugin struct { + drahealthv1alpha1.UnimplementedDRAResourceHealthServer stopCh <-chan struct{} logger klog.Logger resourceClient cgoresource.ResourceV1Interface @@ -62,6 +75,10 @@ type ExamplePlugin struct { prepared map[ClaimID][]kubeletplugin.Device // prepared claims -> result of nodePrepareResource gRPCCalls []GRPCCall + healthMutex sync.Mutex + deviceHealth map[string]string + HealthControlChan chan DeviceHealthUpdate + blockPrepareResourcesMutex sync.Mutex blockUnprepareResourcesMutex sync.Mutex @@ -76,6 +93,12 @@ type ExamplePlugin struct { cancelMainContext context.CancelCauseFunc } +var _ kubeletplugin.DRAPlugin = &ExamplePlugin{} +var _ drahealthv1alpha1.DRAResourceHealthServer = &ExamplePlugin{} + +//nolint:unused +func (ex *ExamplePlugin) mustEmbedUnimplementedDRAResourceHealthServer() {} + type GRPCCall struct { // FullMethod is the fully qualified, e.g. /package.service/method. FullMethod string @@ -156,8 +179,11 @@ func StartPlugin(ctx context.Context, cdiDir, driverName string, kubeClient kube } testOpts := &options{} + pluginOpts := &Options{} for _, opt := range opts { switch typedOpt := opt.(type) { + case Options: + *pluginOpts = typedOpt case TestOption: if err := typedOpt(testOpts); err != nil { return nil, fmt.Errorf("apply test option: %w", err) @@ -179,6 +205,8 @@ func StartPlugin(ctx context.Context, cdiDir, driverName string, kubeClient kube nodeName: nodeName, prepared: make(map[ClaimID][]kubeletplugin.Device), cancelMainContext: testOpts.cancelMainContext, + deviceHealth: make(map[string]string), + HealthControlChan: make(chan DeviceHealthUpdate, 10), } publicOpts = append(publicOpts, @@ -500,23 +528,24 @@ func (ex *ExamplePlugin) recordGRPCCall(ctx context.Context, req interface{}, in return call.Response, call.Err } -func (ex *ExamplePlugin) recordGRPCStream(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { - call := GRPCCall{ - FullMethod: info.FullMethod, - } +func (ex *ExamplePlugin) recordGRPCStream(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) { ex.mutex.Lock() - ex.gRPCCalls = append(ex.gRPCCalls, call) - index := len(ex.gRPCCalls) - 1 + // Append a new empty GRPCCall struct to get its index. + ex.gRPCCalls = append(ex.gRPCCalls, GRPCCall{}) + + pCall := &ex.gRPCCalls[len(ex.gRPCCalls)-1] + + pCall.FullMethod = info.FullMethod ex.mutex.Unlock() - // We don't hold the mutex here to allow concurrent calls. - call.Err = handler(srv, stream) + defer func() { + ex.mutex.Lock() + defer ex.mutex.Unlock() + pCall.Err = err + }() - ex.mutex.Lock() - ex.gRPCCalls[index] = call - ex.mutex.Unlock() - - return call.Err + err = handler(srv, stream) + return err } func (ex *ExamplePlugin) GetGRPCCalls() []GRPCCall { @@ -564,3 +593,94 @@ func (ex *ExamplePlugin) UpdateStatus(ctx context.Context, resourceClaim *resour func (ex *ExamplePlugin) SetGetInfoError(err error) { ex.d.SetGetInfoError(err) } + +func (ex *ExamplePlugin) NodeWatchResources(req *drahealthv1alpha1.NodeWatchResourcesRequest, srv drahealthv1alpha1.DRAResourceHealth_NodeWatchResourcesServer) error { + logger := klog.FromContext(srv.Context()) + logger.V(3).Info("Starting dynamic NodeWatchResources stream") + + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + + // Send an initial update immediately to report on pre-configured devices. + if err := ex.sendHealthUpdate(srv); err != nil { + logger.Error(err, "Failed to send initial health update") + } + + for { + select { + case <-srv.Context().Done(): + logger.V(3).Info("NodeWatchResources stream canceled by kubelet") + return nil + case update, ok := <-ex.HealthControlChan: + if !ok { + logger.V(3).Info("HealthControlChan closed, exiting NodeWatchResources stream.") + return nil + } + logger.V(3).Info("Received health update from control channel", "update", update) + ex.healthMutex.Lock() + key := update.PoolName + "/" + update.DeviceName + ex.deviceHealth[key] = update.Health + ex.healthMutex.Unlock() + + if err := ex.sendHealthUpdate(srv); err != nil { + logger.Error(err, "Failed to send health update after control message") + } + case <-ticker.C: + if err := ex.sendHealthUpdate(srv); err != nil { + if srv.Context().Err() != nil { + logger.V(3).Info("NodeWatchResources stream closed during periodic update, exiting.") + return nil + } + logger.Error(err, "Failed to send periodic health update") + } + } + } +} + +// sendHealthUpdate dynamically builds the health report from the current state of the deviceHealth map. +func (ex *ExamplePlugin) sendHealthUpdate(srv drahealthv1alpha1.DRAResourceHealth_NodeWatchResourcesServer) error { + logger := klog.FromContext(srv.Context()) + healthUpdates := []*drahealthv1alpha1.DeviceHealth{} + + ex.healthMutex.Lock() + for key, health := range ex.deviceHealth { + parts := strings.SplitN(key, "/", 2) + if len(parts) != 2 { + continue + } + poolName := parts[0] + deviceName := parts[1] + + var healthEnum drahealthv1alpha1.HealthStatus + switch health { + case "Healthy": + healthEnum = drahealthv1alpha1.HealthStatus_HEALTHY + case "Unhealthy": + healthEnum = drahealthv1alpha1.HealthStatus_UNHEALTHY + default: + healthEnum = drahealthv1alpha1.HealthStatus_UNKNOWN + } + + healthUpdates = append(healthUpdates, &drahealthv1alpha1.DeviceHealth{ + Device: &drahealthv1alpha1.DeviceIdentifier{ + PoolName: poolName, + DeviceName: deviceName, + }, + Health: healthEnum, + LastUpdatedTime: time.Now().Unix(), + }) + } + ex.healthMutex.Unlock() + + // Sorting slice to ensure consistent ordering in tests. + sort.Slice(healthUpdates, func(i, j int) bool { + if healthUpdates[i].GetDevice().GetPoolName() != healthUpdates[j].GetDevice().GetPoolName() { + return healthUpdates[i].GetDevice().GetPoolName() < healthUpdates[j].GetDevice().GetPoolName() + } + return healthUpdates[i].GetDevice().GetDeviceName() < healthUpdates[j].GetDevice().GetDeviceName() + }) + + resp := &drahealthv1alpha1.NodeWatchResourcesResponse{Devices: healthUpdates} + logger.V(5).Info("Test driver sending health update", "response", resp) + return srv.Send(resp) +} diff --git a/test/e2e/dra/test-driver/app/server.go b/test/e2e/dra/test-driver/app/server.go index 62c4bbcfaff..44b4483df86 100644 --- a/test/e2e/dra/test-driver/app/server.go +++ b/test/e2e/dra/test-driver/app/server.go @@ -223,6 +223,7 @@ func NewCommand() *cobra.Command { } plugin, err := StartPlugin(cmd.Context(), *cdiDir, *driverName, clientset, *nodeName, FileOperations{DriverResources: &driverResources}, + Options{EnableHealthService: true}, kubeletplugin.PluginDataDirectoryPath(datadir), kubeletplugin.RegistrarDirectoryPath(*kubeletRegistryDir), ) diff --git a/test/e2e_node/dra_test.go b/test/e2e_node/dra_test.go index f10f2af1849..ce47b9daca7 100644 --- a/test/e2e_node/dra_test.go +++ b/test/e2e_node/dra_test.go @@ -18,8 +18,8 @@ limitations under the License. E2E Node test for DRA (Dynamic Resource Allocation) This test covers node-specific aspects of DRA The test can be run locally on Linux this way: - make test-e2e-node FOCUS='\[Feature:DynamicResourceAllocation\]' SKIP='\[Flaky\]' PARALLELISM=1 \ - TEST_ARGS='--feature-gates="DynamicResourceAllocation=true" --service-feature-gates="DynamicResourceAllocation=true" --runtime-config=api/all=true' + make test-e2e-node FOCUS='\[Feature:DynamicResourceAllocation\]' SKIP='\[Flaky\]' PARALLELISM=1 \ + TEST_ARGS='--feature-gates="DynamicResourceAllocation=true,ResourceHealthStatus=true" --service-feature-gates="DynamicResourceAllocation=true,ResourceHealthStatus=true" --runtime-config=api/all=true' */ package e2enode @@ -819,10 +819,169 @@ var _ = framework.SIGDescribe("node")(framework.WithLabel("DRA"), feature.Dynami ), ) }) + + f.Context("Resource Health", framework.WithFeatureGate(features.ResourceHealthStatus), f.WithSerial(), func() { + + // Verifies that device health transitions (Healthy -> Unhealthy -> Healthy) + // reported by a DRA plugin are correctly reflected in the Pod's status. + ginkgo.It("should reflect device health changes in the Pod's status", func(ctx context.Context) { + ginkgo.By("Starting the test driver with channel-based control") + kubeletPlugin := newKubeletPlugin(ctx, f.ClientSet, getNodeName(ctx, f), driverName) + + className := "health-test-class" + claimName := "health-test-claim" + podName := "health-test-pod" + poolNameForTest := "pool-a" + deviceNameForTest := "dev-0" + + pod := createHealthTestPodAndClaim(ctx, f, driverName, podName, claimName, className, poolNameForTest, deviceNameForTest) + + ginkgo.By("Waiting for the pod to be running") + framework.ExpectNoError(e2epod.WaitForPodRunningInNamespace(ctx, f.ClientSet, pod)) + + ginkgo.By("Forcing a 'Healthy' status update to establish a baseline") + kubeletPlugin.HealthControlChan <- testdriver.DeviceHealthUpdate{ + PoolName: poolNameForTest, + DeviceName: deviceNameForTest, + Health: "Healthy", + } + + ginkgo.By("Verifying device health is now Healthy in the pod status") + gomega.Eventually(ctx, func(ctx context.Context) (string, error) { + return getDeviceHealthFromAPIServer(f, pod.Namespace, pod.Name, driverName, claimName, poolNameForTest, deviceNameForTest) + }).WithTimeout(30*time.Second).WithPolling(1*time.Second).Should(gomega.Equal("Healthy"), "Device health should be Healthy after explicit update") + + ginkgo.By("Setting device health to Unhealthy via control channel") + kubeletPlugin.HealthControlChan <- testdriver.DeviceHealthUpdate{ + PoolName: poolNameForTest, + DeviceName: deviceNameForTest, + Health: "Unhealthy", + } + + ginkgo.By("Verifying device health is now Unhealthy") + gomega.Eventually(ctx, func(ctx context.Context) (string, error) { + return getDeviceHealthFromAPIServer(f, pod.Namespace, pod.Name, driverName, claimName, poolNameForTest, deviceNameForTest) + }).WithTimeout(60*time.Second).WithPolling(2*time.Second).Should(gomega.Equal("Unhealthy"), "Device health should update to Unhealthy") + + ginkgo.By("Setting device health back to Healthy via control channel") + kubeletPlugin.HealthControlChan <- testdriver.DeviceHealthUpdate{ + PoolName: poolNameForTest, + DeviceName: deviceNameForTest, + Health: "Healthy", + } + + ginkgo.By("Verifying device health has recovered to Healthy") + gomega.Eventually(ctx, func(ctx context.Context) (string, error) { + return getDeviceHealthFromAPIServer(f, pod.Namespace, pod.Name, driverName, claimName, poolNameForTest, deviceNameForTest) + }).WithTimeout(60*time.Second).WithPolling(2*time.Second).Should(gomega.Equal("Healthy"), "Device health should recover and update to Healthy") + }) + + // Verifies that device health transitions to "Unknown" when a DRA plugin + // stops and recovers to "Healthy" upon plugin restart. + ginkgo.It("should update health to Unknown when plugin stops and recover upon restart", func(ctx context.Context) { + ginkgo.By("Starting the test driver") + kubeletPlugin := newKubeletPlugin(ctx, f.ClientSet, getNodeName(ctx, f), driverName) + + className := "unknown-test-class" + claimName := "unknown-test-claim" + podName := "unknown-test-pod" + poolNameForTest := "pool-b" + deviceNameForTest := "dev-1" + + pod := createHealthTestPodAndClaim(ctx, f, driverName, podName, claimName, className, poolNameForTest, deviceNameForTest) + + ginkgo.By("Waiting for the pod to be running") + framework.ExpectNoError(e2epod.WaitForPodRunningInNamespace(ctx, f.ClientSet, pod)) + + ginkgo.By("Establishing a baseline 'Healthy' status") + kubeletPlugin.HealthControlChan <- testdriver.DeviceHealthUpdate{ + PoolName: poolNameForTest, + DeviceName: deviceNameForTest, + Health: "Healthy", + } + gomega.Eventually(ctx, func(ctx context.Context) (string, error) { + return getDeviceHealthFromAPIServer(f, pod.Namespace, pod.Name, driverName, claimName, poolNameForTest, deviceNameForTest) + }).WithTimeout(30*time.Second).WithPolling(1*time.Second).Should(gomega.Equal("Healthy"), "Device health should be Healthy initially") + + ginkgo.By("Stopping the DRA plugin to simulate a crash") + kubeletPlugin.Stop() + + ginkgo.By("Verifying device health transitions to 'Unknown'") + gomega.Eventually(ctx, func(ctx context.Context) (string, error) { + return getDeviceHealthFromAPIServer(f, pod.Namespace, pod.Name, driverName, claimName, poolNameForTest, deviceNameForTest) + }).WithTimeout(2*time.Minute).WithPolling(5*time.Second).Should(gomega.Equal("Unknown"), "Device health should become Unknown after plugin stops") + + ginkgo.By("Restarting the DRA plugin to simulate recovery") + // Re-initialize the plugin, which will re-register with the Kubelet. + kubeletPlugin = newKubeletPlugin(ctx, f.ClientSet, getNodeName(ctx, f), driverName) + + ginkgo.By("Forcing a 'Healthy' status update after restart") + kubeletPlugin.HealthControlChan <- testdriver.DeviceHealthUpdate{ + PoolName: poolNameForTest, + DeviceName: deviceNameForTest, + Health: "Healthy", + } + + ginkgo.By("Verifying device health recovers to 'Healthy'") + gomega.Eventually(ctx, func(ctx context.Context) (string, error) { + return getDeviceHealthFromAPIServer(f, pod.Namespace, pod.Name, driverName, claimName, poolNameForTest, deviceNameForTest) + }).WithTimeout(60*time.Second).WithPolling(2*time.Second).Should(gomega.Equal("Healthy"), "Device health should recover to Healthy after plugin restarts") + }) + + }) + + f.Context("Resource Health with Feature Gate Disabled", framework.WithLabel("[FeatureGate:ResourceHealthStatus:Disabled]"), f.WithSerial(), func() { + + // Verifies that the Kubelet adds no health status to the Pod when the + // ResourceHealthStatus feature gate is disabled. + ginkgo.It("should not add health status to Pod when feature gate is disabled", func(ctx context.Context) { + + ginkgo.By("Starting a test driver") + newKubeletPlugin(ctx, f.ClientSet, getNodeName(ctx, f), driverName, withHealthService(false)) + + className := "gate-disabled-class" + claimName := "gate-disabled-claim" + podName := "gate-disabled-pod" + poolNameForTest := "pool-d" + deviceNameForTest := "dev-3" + + pod := createHealthTestPodAndClaim(ctx, f, driverName, podName, claimName, className, poolNameForTest, deviceNameForTest) + + ginkgo.By("Waiting for the pod to be running") + framework.ExpectNoError(e2epod.WaitForPodRunningInNamespace(ctx, f.ClientSet, pod)) + + ginkgo.By("Consistently verifying that the allocatedResourcesStatus field remains absent") + gomega.Consistently(func(ctx context.Context) error { + p, err := f.ClientSet.CoreV1().Pods(pod.Namespace).Get(ctx, pod.Name, metav1.GetOptions{}) + if err != nil { + return err + } + for _, containerStatus := range p.Status.ContainerStatuses { + if containerStatus.Name == "testcontainer" { + if len(containerStatus.AllocatedResourcesStatus) != 0 { + return fmt.Errorf("expected allocatedResourcesStatus to be absent, but found %d entries", len(containerStatus.AllocatedResourcesStatus)) + } + return nil + } + } + return fmt.Errorf("could not find container 'testcontainer' in pod status") + }).WithContext(ctx).WithTimeout(30*time.Second).WithPolling(2*time.Second).Should(gomega.Succeed(), "The allocatedResourcesStatus field should be absent when the feature gate is disabled") + }) + }) }) +// pluginOption defines a functional option for configuring the test driver. +type pluginOption func(*testdriver.Options) + +// withHealthService is a pluginOption to explicitly enable or disable the health service. +func withHealthService(enabled bool) pluginOption { + return func(o *testdriver.Options) { + o.EnableHealthService = enabled + } +} + // Run Kubelet plugin and wait until it's registered -func newKubeletPlugin(ctx context.Context, clientSet kubernetes.Interface, nodeName, driverName string) *testdriver.ExamplePlugin { +func newKubeletPlugin(ctx context.Context, clientSet kubernetes.Interface, nodeName, driverName string, options ...pluginOption) *testdriver.ExamplePlugin { ginkgo.By("start Kubelet plugin") logger := klog.LoggerWithValues(klog.LoggerWithName(klog.Background(), "DRA kubelet plugin "+driverName), "node", nodeName) ctx = klog.NewContext(ctx, logger) @@ -836,6 +995,13 @@ func newKubeletPlugin(ctx context.Context, clientSet kubernetes.Interface, nodeN err = os.MkdirAll(datadir, 0750) framework.ExpectNoError(err, "create DRA socket directory") + pluginOpts := testdriver.Options{ + EnableHealthService: true, + } + for _, option := range options { + option(&pluginOpts) + } + plugin, err := testdriver.StartPlugin( ctx, cdiDir, @@ -857,6 +1023,7 @@ func newKubeletPlugin(ctx context.Context, clientSet kubernetes.Interface, nodeN }, }, }, + pluginOpts, ) framework.ExpectNoError(err) @@ -877,7 +1044,14 @@ func newRegistrar(ctx context.Context, clientSet kubernetes.Interface, nodeName, ginkgo.By("start only Kubelet plugin registrar") logger := klog.LoggerWithValues(klog.LoggerWithName(klog.Background(), "kubelet plugin registrar "+driverName)) ctx = klog.NewContext(ctx, logger) - opts = append(opts, kubeletplugin.DRAService(false)) + + allOpts := []any{ + testdriver.Options{EnableHealthService: false}, + kubeletplugin.DRAService(false), + } + + allOpts = append(allOpts, opts...) + registrar, err := testdriver.StartPlugin( ctx, cdiDir, @@ -885,7 +1059,7 @@ func newRegistrar(ctx context.Context, clientSet kubernetes.Interface, nodeName, clientSet, nodeName, testdriver.FileOperations{}, - opts..., + allOpts..., ) framework.ExpectNoError(err, "start only Kubelet plugin registrar") return registrar @@ -911,16 +1085,23 @@ func newDRAService(ctx context.Context, clientSet kubernetes.Interface, nodeName logger := klog.LoggerWithValues(klog.LoggerWithName(klog.Background(), "kubelet plugin "+driverName), "node", nodeName) ctx = klog.NewContext(ctx, logger) - opts = append(opts, kubeletplugin.RegistrationService(false)) + allOpts := []any{ + testdriver.Options{EnableHealthService: true}, + kubeletplugin.RegistrationService(false), + } + allOpts = append(allOpts, opts...) // Ensure that directories exist, creating them if necessary. We want // to know early if there is a setup problem that would prevent // creating those directories. err := os.MkdirAll(cdiDir, os.FileMode(0750)) framework.ExpectNoError(err, "create CDI directory") + + // If datadir is not provided, set it to the default and ensure it exists. if datadir == "" { datadir = path.Join(kubeletplugin.KubeletPluginsDir, driverName) } + err = os.MkdirAll(datadir, 0750) framework.ExpectNoError(err, "create DRA socket directory") @@ -945,7 +1126,7 @@ func newDRAService(ctx context.Context, clientSet kubernetes.Interface, nodeName }, }, }, - opts..., + allOpts..., ) framework.ExpectNoError(err) @@ -1055,6 +1236,7 @@ func createTestObjects(ctx context.Context, clientSet kubernetes.Interface, node // NOTE: This is usually done by the DRA controller or the scheduler. results := make([]resourceapi.DeviceRequestAllocationResult, len(driverNames)) config := make([]resourceapi.DeviceAllocationConfiguration, len(driverNames)) + for i, driverName := range driverNames { results[i] = resourceapi.DeviceRequestAllocationResult{ Driver: driverName, @@ -1143,6 +1325,145 @@ func matchResourcesByNodeName(nodeName string) types.GomegaMatcher { return gomega.HaveField("Spec.NodeName", gstruct.PointTo(gomega.Equal(nodeName))) } +// This helper function queries the main API server for the pod's status. +func getDeviceHealthFromAPIServer(f *framework.Framework, namespace, podName, driverName, claimName, poolName, deviceName string) (string, error) { + // Get the Pod object from the API server + pod, err := f.ClientSet.CoreV1().Pods(namespace).Get(context.Background(), podName, metav1.GetOptions{}) + if err != nil { + if apierrors.IsNotFound(err) { + return "NotFound", nil + } + return "", fmt.Errorf("failed to get pod %s/%s: %w", namespace, podName, err) + } + + // This is the unique ID for the device based on how Kubelet manager code constructs it. + expectedResourceID := v1.ResourceID(fmt.Sprintf("%s/%s/%s", driverName, poolName, deviceName)) + + expectedResourceStatusNameSimple := v1.ResourceName(fmt.Sprintf("claim:%s", claimName)) + expectedResourceStatusNameWithRequest := v1.ResourceName(fmt.Sprintf("claim:%s/%s", claimName, "my-request")) + + // Loop through container statuses. + for _, containerStatus := range pod.Status.ContainerStatuses { + if containerStatus.AllocatedResourcesStatus != nil { + for _, resourceStatus := range containerStatus.AllocatedResourcesStatus { + if resourceStatus.Name != expectedResourceStatusNameSimple && resourceStatus.Name != expectedResourceStatusNameWithRequest { + continue + } + for _, resourceHealth := range resourceStatus.Resources { + if resourceHealth.ResourceID == expectedResourceID || strings.HasPrefix(string(resourceHealth.ResourceID), driverName) { + return string(resourceHealth.Health), nil + } + } + } + } + } + + return "NotFound", nil +} + +// createHealthTestPodAndClaim is a specialized helper for the Resource Health test. +// It creates all necessary objects (DeviceClass, ResourceClaim, Pod) and ensures +// the pod is long-running and the claim is allocated from the specified pool. +func createHealthTestPodAndClaim(ctx context.Context, f *framework.Framework, driverName, podName, claimName, className, poolName, deviceName string) *v1.Pod { + ginkgo.By(fmt.Sprintf("Creating DeviceClass %q", className)) + dc := &resourceapi.DeviceClass{ + ObjectMeta: metav1.ObjectMeta{ + Name: className, + }, + } + _, err := f.ClientSet.ResourceV1().DeviceClasses().Create(ctx, dc, metav1.CreateOptions{}) + framework.ExpectNoError(err, "failed to create DeviceClass "+className) + ginkgo.DeferCleanup(func() { + err := f.ClientSet.ResourceV1().ResourceClaims(f.Namespace.Name).Delete(context.Background(), claimName, metav1.DeleteOptions{}) + if err != nil && !apierrors.IsNotFound(err) { + framework.Failf("Failed to delete ResourceClaim %s: %v", claimName, err) + } + }) + ginkgo.By(fmt.Sprintf("Creating ResourceClaim %q", claimName)) + claim := &resourceapi.ResourceClaim{ + ObjectMeta: metav1.ObjectMeta{ + Name: claimName, + }, + Spec: resourceapi.ResourceClaimSpec{ + Devices: resourceapi.DeviceClaim{ + Requests: []resourceapi.DeviceRequest{{ + Name: "my-request", + Exactly: &resourceapi.ExactDeviceRequest{ + DeviceClassName: className, + }, + }}, + }, + }, + } + + _, err = f.ClientSet.ResourceV1().ResourceClaims(f.Namespace.Name).Create(ctx, claim, metav1.CreateOptions{}) + framework.ExpectNoError(err, "failed to create ResourceClaim "+claimName) + ginkgo.DeferCleanup(func() { + err := f.ClientSet.ResourceV1().ResourceClaims(f.Namespace.Name).Delete(context.Background(), claimName, metav1.DeleteOptions{}) + if err != nil && !apierrors.IsNotFound(err) { + framework.Failf("Failed to delete ResourceClaim %s: %v", claimName, err) + } + }) + ginkgo.By(fmt.Sprintf("Creating long-running Pod %q (without claim allocation yet)", podName)) + pod := &v1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: podName, + Namespace: f.Namespace.Name, + }, + Spec: v1.PodSpec{ + NodeName: getNodeName(ctx, f), + RestartPolicy: v1.RestartPolicyNever, + ResourceClaims: []v1.PodResourceClaim{ + {Name: claimName, ResourceClaimName: &claimName}, + }, + Containers: []v1.Container{ + { + Name: "testcontainer", + Image: e2epod.GetDefaultTestImage(), + Command: []string{"/bin/sh", "-c", "sleep 600"}, + Resources: v1.ResourceRequirements{ + Claims: []v1.ResourceClaim{{Name: claimName, Request: "my-request"}}, + }, + }, + }, + }, + } + // Create the pod on the API server to assign the real UID. + createdPod, err := f.ClientSet.CoreV1().Pods(f.Namespace.Name).Create(ctx, pod, metav1.CreateOptions{}) + framework.ExpectNoError(err, "failed to create Pod "+podName) + ginkgo.DeferCleanup(func() { + e2epod.DeletePodOrFail(context.Background(), f.ClientSet, createdPod.Namespace, createdPod.Name) + }) + + ginkgo.By(fmt.Sprintf("Allocating claim %q to pod %q with its real UID", claimName, podName)) + // Get the created claim to ensure the latest version before updating. + claimToUpdate, err := f.ClientSet.ResourceV1().ResourceClaims(f.Namespace.Name).Get(ctx, claimName, metav1.GetOptions{}) + framework.ExpectNoError(err, "failed to get latest version of ResourceClaim "+claimName) + + // Update the claims status to reserve it for the *real* pod UID. + claimToUpdate.Status = resourceapi.ResourceClaimStatus{ + ReservedFor: []resourceapi.ResourceClaimConsumerReference{ + {Resource: "pods", Name: createdPod.Name, UID: createdPod.UID}, + }, + Allocation: &resourceapi.AllocationResult{ + Devices: resourceapi.DeviceAllocationResult{ + Results: []resourceapi.DeviceRequestAllocationResult{ + { + Driver: driverName, + Pool: poolName, + Device: deviceName, + Request: "my-request", + }, + }, + }, + }, + } + _, err = f.ClientSet.ResourceV1().ResourceClaims(f.Namespace.Name).UpdateStatus(ctx, claimToUpdate, metav1.UpdateOptions{}) + framework.ExpectNoError(err, "failed to update ResourceClaim status for test") + + return createdPod +} + // errorOnCloseListener is a mock net.Listener that blocks on Accept() // until Close() is called, at which point Accept() returns a predefined error. //