mirror of
				https://github.com/optim-enterprises-bv/kubernetes.git
				synced 2025-11-03 03:38:15 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			331 lines
		
	
	
		
			7.2 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			331 lines
		
	
	
		
			7.2 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
// Copyright 2013 The Go Authors. All rights reserved.
 | 
						|
// Use of this source code is governed by a BSD-style
 | 
						|
// license that can be found in the LICENSE file.
 | 
						|
 | 
						|
package ssh
 | 
						|
 | 
						|
import (
 | 
						|
	"encoding/binary"
 | 
						|
	"fmt"
 | 
						|
	"io"
 | 
						|
	"log"
 | 
						|
	"sync"
 | 
						|
	"sync/atomic"
 | 
						|
)
 | 
						|
 | 
						|
// debugMux, if set, causes messages in the connection protocol to be
 | 
						|
// logged.
 | 
						|
const debugMux = false
 | 
						|
 | 
						|
// chanList is a thread safe channel list.
 | 
						|
type chanList struct {
 | 
						|
	// protects concurrent access to chans
 | 
						|
	sync.Mutex
 | 
						|
 | 
						|
	// chans are indexed by the local id of the channel, which the
 | 
						|
	// other side should send in the PeersId field.
 | 
						|
	chans []*channel
 | 
						|
 | 
						|
	// This is a debugging aid: it offsets all IDs by this
 | 
						|
	// amount. This helps distinguish otherwise identical
 | 
						|
	// server/client muxes
 | 
						|
	offset uint32
 | 
						|
}
 | 
						|
 | 
						|
// Assigns a channel ID to the given channel.
 | 
						|
func (c *chanList) add(ch *channel) uint32 {
 | 
						|
	c.Lock()
 | 
						|
	defer c.Unlock()
 | 
						|
	for i := range c.chans {
 | 
						|
		if c.chans[i] == nil {
 | 
						|
			c.chans[i] = ch
 | 
						|
			return uint32(i) + c.offset
 | 
						|
		}
 | 
						|
	}
 | 
						|
	c.chans = append(c.chans, ch)
 | 
						|
	return uint32(len(c.chans)-1) + c.offset
 | 
						|
}
 | 
						|
 | 
						|
// getChan returns the channel for the given ID.
 | 
						|
func (c *chanList) getChan(id uint32) *channel {
 | 
						|
	id -= c.offset
 | 
						|
 | 
						|
	c.Lock()
 | 
						|
	defer c.Unlock()
 | 
						|
	if id < uint32(len(c.chans)) {
 | 
						|
		return c.chans[id]
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (c *chanList) remove(id uint32) {
 | 
						|
	id -= c.offset
 | 
						|
	c.Lock()
 | 
						|
	if id < uint32(len(c.chans)) {
 | 
						|
		c.chans[id] = nil
 | 
						|
	}
 | 
						|
	c.Unlock()
 | 
						|
}
 | 
						|
 | 
						|
// dropAll forgets all channels it knows, returning them in a slice.
 | 
						|
func (c *chanList) dropAll() []*channel {
 | 
						|
	c.Lock()
 | 
						|
	defer c.Unlock()
 | 
						|
	var r []*channel
 | 
						|
 | 
						|
	for _, ch := range c.chans {
 | 
						|
		if ch == nil {
 | 
						|
			continue
 | 
						|
		}
 | 
						|
		r = append(r, ch)
 | 
						|
	}
 | 
						|
	c.chans = nil
 | 
						|
	return r
 | 
						|
}
 | 
						|
 | 
						|
// mux represents the state for the SSH connection protocol, which
 | 
						|
// multiplexes many channels onto a single packet transport.
 | 
						|
type mux struct {
 | 
						|
	conn     packetConn
 | 
						|
	chanList chanList
 | 
						|
 | 
						|
	incomingChannels chan NewChannel
 | 
						|
 | 
						|
	globalSentMu     sync.Mutex
 | 
						|
	globalResponses  chan interface{}
 | 
						|
	incomingRequests chan *Request
 | 
						|
 | 
						|
	errCond *sync.Cond
 | 
						|
	err     error
 | 
						|
}
 | 
						|
 | 
						|
// When debugging, each new chanList instantiation has a different
 | 
						|
// offset.
 | 
						|
var globalOff uint32
 | 
						|
 | 
						|
func (m *mux) Wait() error {
 | 
						|
	m.errCond.L.Lock()
 | 
						|
	defer m.errCond.L.Unlock()
 | 
						|
	for m.err == nil {
 | 
						|
		m.errCond.Wait()
 | 
						|
	}
 | 
						|
	return m.err
 | 
						|
}
 | 
						|
 | 
						|
// newMux returns a mux that runs over the given connection.
 | 
						|
func newMux(p packetConn) *mux {
 | 
						|
	m := &mux{
 | 
						|
		conn:             p,
 | 
						|
		incomingChannels: make(chan NewChannel, 16),
 | 
						|
		globalResponses:  make(chan interface{}, 1),
 | 
						|
		incomingRequests: make(chan *Request, 16),
 | 
						|
		errCond:          newCond(),
 | 
						|
	}
 | 
						|
	if debugMux {
 | 
						|
		m.chanList.offset = atomic.AddUint32(&globalOff, 1)
 | 
						|
	}
 | 
						|
 | 
						|
	go m.loop()
 | 
						|
	return m
 | 
						|
}
 | 
						|
 | 
						|
func (m *mux) sendMessage(msg interface{}) error {
 | 
						|
	p := Marshal(msg)
 | 
						|
	if debugMux {
 | 
						|
		log.Printf("send global(%d): %#v", m.chanList.offset, msg)
 | 
						|
	}
 | 
						|
	return m.conn.writePacket(p)
 | 
						|
}
 | 
						|
 | 
						|
func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) {
 | 
						|
	if wantReply {
 | 
						|
		m.globalSentMu.Lock()
 | 
						|
		defer m.globalSentMu.Unlock()
 | 
						|
	}
 | 
						|
 | 
						|
	if err := m.sendMessage(globalRequestMsg{
 | 
						|
		Type:      name,
 | 
						|
		WantReply: wantReply,
 | 
						|
		Data:      payload,
 | 
						|
	}); err != nil {
 | 
						|
		return false, nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	if !wantReply {
 | 
						|
		return false, nil, nil
 | 
						|
	}
 | 
						|
 | 
						|
	msg, ok := <-m.globalResponses
 | 
						|
	if !ok {
 | 
						|
		return false, nil, io.EOF
 | 
						|
	}
 | 
						|
	switch msg := msg.(type) {
 | 
						|
	case *globalRequestFailureMsg:
 | 
						|
		return false, msg.Data, nil
 | 
						|
	case *globalRequestSuccessMsg:
 | 
						|
		return true, msg.Data, nil
 | 
						|
	default:
 | 
						|
		return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// ackRequest must be called after processing a global request that
 | 
						|
// has WantReply set.
 | 
						|
func (m *mux) ackRequest(ok bool, data []byte) error {
 | 
						|
	if ok {
 | 
						|
		return m.sendMessage(globalRequestSuccessMsg{Data: data})
 | 
						|
	}
 | 
						|
	return m.sendMessage(globalRequestFailureMsg{Data: data})
 | 
						|
}
 | 
						|
 | 
						|
func (m *mux) Close() error {
 | 
						|
	return m.conn.Close()
 | 
						|
}
 | 
						|
 | 
						|
// loop runs the connection machine. It will process packets until an
 | 
						|
// error is encountered. To synchronize on loop exit, use mux.Wait.
 | 
						|
func (m *mux) loop() {
 | 
						|
	var err error
 | 
						|
	for err == nil {
 | 
						|
		err = m.onePacket()
 | 
						|
	}
 | 
						|
 | 
						|
	for _, ch := range m.chanList.dropAll() {
 | 
						|
		ch.close()
 | 
						|
	}
 | 
						|
 | 
						|
	close(m.incomingChannels)
 | 
						|
	close(m.incomingRequests)
 | 
						|
	close(m.globalResponses)
 | 
						|
 | 
						|
	m.conn.Close()
 | 
						|
 | 
						|
	m.errCond.L.Lock()
 | 
						|
	m.err = err
 | 
						|
	m.errCond.Broadcast()
 | 
						|
	m.errCond.L.Unlock()
 | 
						|
 | 
						|
	if debugMux {
 | 
						|
		log.Println("loop exit", err)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// onePacket reads and processes one packet.
 | 
						|
func (m *mux) onePacket() error {
 | 
						|
	packet, err := m.conn.readPacket()
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	if debugMux {
 | 
						|
		if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData {
 | 
						|
			log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet))
 | 
						|
		} else {
 | 
						|
			p, _ := decode(packet)
 | 
						|
			log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet))
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	switch packet[0] {
 | 
						|
	case msgChannelOpen:
 | 
						|
		return m.handleChannelOpen(packet)
 | 
						|
	case msgGlobalRequest, msgRequestSuccess, msgRequestFailure:
 | 
						|
		return m.handleGlobalPacket(packet)
 | 
						|
	}
 | 
						|
 | 
						|
	// assume a channel packet.
 | 
						|
	if len(packet) < 5 {
 | 
						|
		return parseError(packet[0])
 | 
						|
	}
 | 
						|
	id := binary.BigEndian.Uint32(packet[1:])
 | 
						|
	ch := m.chanList.getChan(id)
 | 
						|
	if ch == nil {
 | 
						|
		return fmt.Errorf("ssh: invalid channel %d", id)
 | 
						|
	}
 | 
						|
 | 
						|
	return ch.handlePacket(packet)
 | 
						|
}
 | 
						|
 | 
						|
func (m *mux) handleGlobalPacket(packet []byte) error {
 | 
						|
	msg, err := decode(packet)
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	switch msg := msg.(type) {
 | 
						|
	case *globalRequestMsg:
 | 
						|
		m.incomingRequests <- &Request{
 | 
						|
			Type:      msg.Type,
 | 
						|
			WantReply: msg.WantReply,
 | 
						|
			Payload:   msg.Data,
 | 
						|
			mux:       m,
 | 
						|
		}
 | 
						|
	case *globalRequestSuccessMsg, *globalRequestFailureMsg:
 | 
						|
		m.globalResponses <- msg
 | 
						|
	default:
 | 
						|
		panic(fmt.Sprintf("not a global message %#v", msg))
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
// handleChannelOpen schedules a channel to be Accept()ed.
 | 
						|
func (m *mux) handleChannelOpen(packet []byte) error {
 | 
						|
	var msg channelOpenMsg
 | 
						|
	if err := Unmarshal(packet, &msg); err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
 | 
						|
		failMsg := channelOpenFailureMsg{
 | 
						|
			PeersId:  msg.PeersId,
 | 
						|
			Reason:   ConnectionFailed,
 | 
						|
			Message:  "invalid request",
 | 
						|
			Language: "en_US.UTF-8",
 | 
						|
		}
 | 
						|
		return m.sendMessage(failMsg)
 | 
						|
	}
 | 
						|
 | 
						|
	c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData)
 | 
						|
	c.remoteId = msg.PeersId
 | 
						|
	c.maxRemotePayload = msg.MaxPacketSize
 | 
						|
	c.remoteWin.add(msg.PeersWindow)
 | 
						|
	m.incomingChannels <- c
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) {
 | 
						|
	ch, err := m.openChannel(chanType, extra)
 | 
						|
	if err != nil {
 | 
						|
		return nil, nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	return ch, ch.incomingRequests, nil
 | 
						|
}
 | 
						|
 | 
						|
func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) {
 | 
						|
	ch := m.newChannel(chanType, channelOutbound, extra)
 | 
						|
 | 
						|
	ch.maxIncomingPayload = channelMaxPacket
 | 
						|
 | 
						|
	open := channelOpenMsg{
 | 
						|
		ChanType:         chanType,
 | 
						|
		PeersWindow:      ch.myWindow,
 | 
						|
		MaxPacketSize:    ch.maxIncomingPayload,
 | 
						|
		TypeSpecificData: extra,
 | 
						|
		PeersId:          ch.localId,
 | 
						|
	}
 | 
						|
	if err := m.sendMessage(open); err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	switch msg := (<-ch.msg).(type) {
 | 
						|
	case *channelOpenConfirmMsg:
 | 
						|
		return ch, nil
 | 
						|
	case *channelOpenFailureMsg:
 | 
						|
		return nil, &OpenChannelError{msg.Reason, msg.Message}
 | 
						|
	default:
 | 
						|
		return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg)
 | 
						|
	}
 | 
						|
}
 |