Files
kubernetes/test/e2e_node/testdeviceplugin/device-plugin.go
2024-07-10 20:14:59 +00:00

238 lines
6.6 KiB
Go

/*
Copyright 2024 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package testdeviceplugin
import (
"context"
"fmt"
"net"
"os"
"sync"
"time"
"github.com/onsi/ginkgo/v2"
"github.com/onsi/gomega"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
kubeletdevicepluginv1beta1 "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
)
type DevicePlugin struct {
server *grpc.Server
uniqueName string
devices []kubeletdevicepluginv1beta1.Device
devicesSync sync.Mutex
devicesUpdateCh chan struct{}
calls []string
callsSync sync.Mutex
errorInjector func(string) error
}
func NewDevicePlugin(errorInjector func(string) error) *DevicePlugin {
return &DevicePlugin{
calls: []string{},
devicesUpdateCh: make(chan struct{}),
errorInjector: errorInjector,
}
}
func (dp *DevicePlugin) GetDevicePluginOptions(context.Context, *kubeletdevicepluginv1beta1.Empty) (*kubeletdevicepluginv1beta1.DevicePluginOptions, error) {
// lock the mutex and add to a list of calls
dp.callsSync.Lock()
dp.calls = append(dp.calls, "GetDevicePluginOptions")
dp.callsSync.Unlock()
if dp.errorInjector != nil {
return &kubeletdevicepluginv1beta1.DevicePluginOptions{}, dp.errorInjector("GetDevicePluginOptions")
}
return &kubeletdevicepluginv1beta1.DevicePluginOptions{}, nil
}
func (dp *DevicePlugin) sendDevices(stream kubeletdevicepluginv1beta1.DevicePlugin_ListAndWatchServer) error {
resp := new(kubeletdevicepluginv1beta1.ListAndWatchResponse)
dp.devicesSync.Lock()
for _, d := range dp.devices {
resp.Devices = append(resp.Devices, &d)
}
dp.devicesSync.Unlock()
return stream.Send(resp)
}
func (dp *DevicePlugin) ListAndWatch(empty *kubeletdevicepluginv1beta1.Empty, stream kubeletdevicepluginv1beta1.DevicePlugin_ListAndWatchServer) error {
dp.callsSync.Lock()
dp.calls = append(dp.calls, "ListAndWatch")
dp.callsSync.Unlock()
if dp.errorInjector != nil {
if err := dp.errorInjector("ListAndWatch"); err != nil {
return err
}
}
if err := dp.sendDevices(stream); err != nil {
return err
}
// when the devices are updated, send the new devices to the kubelet
// also start a timer and every second call into the dp.errorInjector
// to simulate a device plugin failure
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for {
select {
case <-dp.devicesUpdateCh:
if err := dp.sendDevices(stream); err != nil {
return err
}
case <-ticker.C:
if dp.errorInjector != nil {
if err := dp.errorInjector("ListAndWatch"); err != nil {
return err
}
}
}
}
}
func (dp *DevicePlugin) Allocate(ctx context.Context, request *kubeletdevicepluginv1beta1.AllocateRequest) (*kubeletdevicepluginv1beta1.AllocateResponse, error) {
result := new(kubeletdevicepluginv1beta1.AllocateResponse)
dp.callsSync.Lock()
dp.calls = append(dp.calls, "Allocate")
dp.callsSync.Unlock()
for _, r := range request.ContainerRequests {
response := &kubeletdevicepluginv1beta1.ContainerAllocateResponse{}
for _, id := range r.DevicesIDs {
fpath, err := os.CreateTemp("/tmp", fmt.Sprintf("%s-%s", dp.uniqueName, id))
gomega.Expect(err).To(gomega.Succeed())
response.Mounts = append(response.Mounts, &kubeletdevicepluginv1beta1.Mount{
ContainerPath: fpath.Name(),
HostPath: fpath.Name(),
})
}
result.ContainerResponses = append(result.ContainerResponses, response)
}
return result, nil
}
func (dp *DevicePlugin) PreStartContainer(ctx context.Context, request *kubeletdevicepluginv1beta1.PreStartContainerRequest) (*kubeletdevicepluginv1beta1.PreStartContainerResponse, error) {
return nil, nil
}
func (dp *DevicePlugin) GetPreferredAllocation(ctx context.Context, request *kubeletdevicepluginv1beta1.PreferredAllocationRequest) (*kubeletdevicepluginv1beta1.PreferredAllocationResponse, error) {
return nil, nil
}
func (dp *DevicePlugin) RegisterDevicePlugin(ctx context.Context, uniqueName, resourceName string, devices []kubeletdevicepluginv1beta1.Device) error {
ginkgo.GinkgoHelper()
dp.devicesSync.Lock()
dp.devices = devices
dp.devicesSync.Unlock()
devicePluginEndpoint := fmt.Sprintf("%s-%s.sock", "test-device-plugin", uniqueName)
dp.uniqueName = uniqueName
// Implement the logic to register the device plugin with the kubelet
// Create a new gRPC server
dp.server = grpc.NewServer()
// Register the device plugin with the server
kubeletdevicepluginv1beta1.RegisterDevicePluginServer(dp.server, dp)
// Create a listener on a specific port
lis, err := net.Listen("unix", kubeletdevicepluginv1beta1.DevicePluginPath+devicePluginEndpoint)
if err != nil {
return err
}
// Start the gRPC server
go func() {
err := dp.server.Serve(lis)
gomega.Expect(err).To(gomega.Succeed())
}()
// Create a connection to the kubelet
conn, err := grpc.NewClient("unix://"+kubeletdevicepluginv1beta1.KubeletSocket,
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
if err != nil {
return err
}
defer func() {
err := conn.Close()
gomega.Expect(err).To(gomega.Succeed())
}()
// Create a client for the kubelet
client := kubeletdevicepluginv1beta1.NewRegistrationClient(conn)
// Register the device plugin with the kubelet
_, err = client.Register(ctx, &kubeletdevicepluginv1beta1.RegisterRequest{
Version: kubeletdevicepluginv1beta1.Version,
Endpoint: devicePluginEndpoint,
ResourceName: resourceName,
})
if err != nil {
return err
}
return nil
}
func (dp *DevicePlugin) Stop() {
if dp.server != nil {
dp.server.Stop()
dp.server = nil
}
}
func (dp *DevicePlugin) WasCalled(method string) bool {
// lock mutex and then search if the method was called
dp.callsSync.Lock()
defer dp.callsSync.Unlock()
for _, call := range dp.calls {
if call == method {
return true
}
}
return false
}
func (dp *DevicePlugin) Calls() []string {
// lock the mutex and return the calls
dp.callsSync.Lock()
defer dp.callsSync.Unlock()
// return a copy of the calls
calls := make([]string, len(dp.calls))
copy(calls, dp.calls)
return calls
}
func (dp *DevicePlugin) UpdateDevices(devices []kubeletdevicepluginv1beta1.Device) {
// lock the mutex and update the devices
dp.devicesSync.Lock()
defer dp.devicesSync.Unlock()
dp.devices = devices
dp.devicesUpdateCh <- struct{}{}
}