lib/connections, vendor: Change KCP mux to SMUX

Closes #4032
This commit is contained in:
Audrius Butkevicius
2017-03-09 14:01:07 +01:00
committed by Jakob Borg
parent f35e1ac0c5
commit ceea5ebeb3
22 changed files with 908 additions and 1889 deletions

View File

@@ -7,17 +7,17 @@ import (
var defaultEmitter Emitter
const emitQueue = 8192
func init() {
defaultEmitter.init()
}
type (
// packet emit request
emitPacket struct {
conn net.PacketConn
to net.Addr
data []byte
conn net.PacketConn
to net.Addr
data []byte
// mark this packet should recycle to global xmitBuf
recycle bool
}
@@ -28,7 +28,7 @@ type (
)
func (e *Emitter) init() {
e.ch = make(chan emitPacket, emitQueue)
e.ch = make(chan emitPacket)
go e.emitTask()
}
@@ -36,7 +36,7 @@ func (e *Emitter) init() {
func (e *Emitter) emitTask() {
for p := range e.ch {
if n, err := p.conn.WriteTo(p.data, p.to); err == nil {
atomic.AddUint64(&DefaultSnmp.OutSegs, 1)
atomic.AddUint64(&DefaultSnmp.OutPkts, 1)
atomic.AddUint64(&DefaultSnmp.OutBytes, uint64(n))
}
if p.recycle {

View File

@@ -117,6 +117,7 @@ func (seg *Segment) encode(ptr []byte) []byte {
ptr = ikcp_encode32u(ptr, seg.sn)
ptr = ikcp_encode32u(ptr, seg.una)
ptr = ikcp_encode32u(ptr, uint32(len(seg.data)))
atomic.AddUint64(&DefaultSnmp.OutSegs, 1)
return ptr
}
@@ -484,9 +485,10 @@ func (kcp *KCP) Input(data []byte, regular, ackNoDelay bool) int {
}
var maxack uint32
var lastackts uint32
var flag int
var inSegs uint64
current := currentMs()
for {
var ts, sn, length, una, conv uint32
var wnd uint16
@@ -525,10 +527,6 @@ func (kcp *KCP) Input(data []byte, regular, ackNoDelay bool) int {
kcp.shrink_buf()
if cmd == IKCP_CMD_ACK {
if _itimediff(current, ts) >= 0 {
kcp.update_ack(_itimediff(current, ts))
}
kcp.parse_ack(sn)
kcp.shrink_buf()
if flag == 0 {
@@ -537,6 +535,7 @@ func (kcp *KCP) Input(data []byte, regular, ackNoDelay bool) int {
} else if _itimediff(sn, maxack) > 0 {
maxack = sn
}
lastackts = ts
} else if cmd == IKCP_CMD_PUSH {
if _itimediff(sn, kcp.rcv_nxt+kcp.rcv_wnd) < 0 {
kcp.ack_push(sn, ts)
@@ -567,11 +566,17 @@ func (kcp *KCP) Input(data []byte, regular, ackNoDelay bool) int {
return -3
}
inSegs++
data = data[length:]
}
atomic.AddUint64(&DefaultSnmp.InSegs, inSegs)
if flag != 0 && regular {
kcp.parse_fastack(maxack)
current := currentMs()
if _itimediff(current, lastackts) >= 0 {
kcp.update_ack(_itimediff(current, lastackts))
}
}
if _itimediff(kcp.snd_una, una) > 0 {
@@ -660,9 +665,9 @@ func (kcp *KCP) flush(ackOnly bool) {
return
}
current := currentMs()
// probe window size (if remote window size equals zero)
if kcp.rmt_wnd == 0 {
current := currentMs()
if kcp.probe_wait == 0 {
kcp.probe_wait = IKCP_PROBE_INIT
kcp.ts_probe = current + kcp.probe_wait
@@ -742,6 +747,7 @@ func (kcp *KCP) flush(ackOnly bool) {
// send new segments
for k := len(kcp.snd_buf) - newSegsCount; k < len(kcp.snd_buf); k++ {
current := currentMs()
segment := &kcp.snd_buf[k]
segment.xmit++
segment.rto = kcp.rx_rto
@@ -765,6 +771,7 @@ func (kcp *KCP) flush(ackOnly bool) {
// check for retransmissions
for k := 0; k < len(kcp.snd_buf)-newSegsCount; k++ {
current := currentMs()
segment := &kcp.snd_buf[k]
needsend := false
if _itimediff(current, segment.resendts) >= 0 { // RTO

View File

@@ -23,13 +23,23 @@ func (errTimeout) Temporary() bool { return true }
func (errTimeout) Error() string { return "i/o timeout" }
const (
defaultWndSize = 128 // default window size, in packet
nonceSize = 16 // magic number
crcSize = 4 // 4bytes packet checksum
// 16-bytes magic number for each packet
nonceSize = 16
// 4-bytes packet checksum
crcSize = 4
// overall crypto header size
cryptHeaderSize = nonceSize + crcSize
mtuLimit = 2048
rxQueueLimit = 8192
rxFECMulti = 3 // FEC keeps rxFECMulti* (dataShard+parityShard) ordered packets in memory
// maximum packet size
mtuLimit = 2048
// packet receiving channel limit
rxQueueLimit = 2048
// FEC keeps rxFECMulti* (dataShard+parityShard) ordered packets in memory
rxFECMulti = 3
)
const (
@@ -38,8 +48,12 @@ const (
)
var (
// global packet buffer
// shared among sending/receiving/FEC
xmitBuf sync.Pool
sid uint32
// monotonic session id
sid uint32
)
func init() {
@@ -51,36 +65,39 @@ func init() {
type (
// UDPSession defines a KCP session implemented by UDP
UDPSession struct {
// core
sid uint32
conn net.PacketConn // the underlying packet socket
kcp *KCP // the core ARQ
l *Listener // point to server listener if it's a server socket
block BlockCrypt // encryption
sockbuff []byte // kcp receiving is based on packet, I turn it into stream
sid uint32 // session id(monotonic)
conn net.PacketConn // the underlying packet connection
kcp *KCP // KCP ARQ protocol
l *Listener // point to the Listener if it's accepted by Listener
block BlockCrypt // block encryption
// forward error correction
fec *FEC
fecDataShards [][]byte
fecHeaderOffset int
fecPayloadOffset int
fecCnt int // count datashard
fecMaxSize int // record maximum data length in datashard
// kcp receiving is based on packets
// sockbuff turns packets into stream
sockbuff []byte
fec *FEC // forward error correction
fecDataShards [][]byte // data shards cache
fecShardCount int // count the number of datashards collected
fecMaxSize int // record maximum data length in datashard
fecHeaderOffset int // FEC header offset in packet
fecPayloadOffset int // FEC payload offset in packet
// settings
remote net.Addr
remote net.Addr // remote peer address
rd time.Time // read deadline
wd time.Time // write deadline
headerSize int
updateInterval int32
ackNoDelay bool
headerSize int // the overall header size added before KCP frame
updateInterval int32 // interval in seconds to call kcp.flush()
ackNoDelay bool // send ack immediately for each incoming packet
// notifications
die chan struct{}
chReadEvent chan struct{}
chWriteEvent chan struct{}
isClosed bool
mu sync.Mutex
die chan struct{} // notify session has Closed
chReadEvent chan struct{} // notify Read() can be called without blocking
chWriteEvent chan struct{} // notify Write() can be called without blocking
isClosed bool // flag the session has Closed
mu sync.Mutex
}
setReadBuffer interface {
@@ -132,7 +149,6 @@ func newUDPSession(conv uint32, dataShards, parityShards int, l *Listener, conn
sess.output(buf[:size])
}
})
sess.kcp.WndSize(defaultWndSize, defaultWndSize)
sess.kcp.SetMtu(IKCP_MTU_DEF - sess.headerSize)
sess.kcp.setFEC(dataShards, parityShards)
@@ -324,11 +340,16 @@ func (s *UDPSession) SetWindowSize(sndwnd, rcvwnd int) {
s.kcp.WndSize(sndwnd, rcvwnd)
}
// SetMtu sets the maximum transmission unit
func (s *UDPSession) SetMtu(mtu int) {
// SetMtu sets the maximum transmission unit(not including UDP header)
func (s *UDPSession) SetMtu(mtu int) bool {
if mtu > mtuLimit {
return false
}
s.mu.Lock()
defer s.mu.Unlock()
s.kcp.SetMtu(mtu - s.headerSize)
return true
}
// SetStreamMode toggles the stream mode on/off
@@ -416,9 +437,9 @@ func (s *UDPSession) output(buf []byte) {
// copy data to fec datashards
sz := len(ext)
s.fecDataShards[s.fecCnt] = s.fecDataShards[s.fecCnt][:sz]
copy(s.fecDataShards[s.fecCnt], ext)
s.fecCnt++
s.fecDataShards[s.fecShardCount] = s.fecDataShards[s.fecShardCount][:sz]
copy(s.fecDataShards[s.fecShardCount], ext)
s.fecShardCount++
// record max datashard length
if sz > s.fecMaxSize {
@@ -426,7 +447,7 @@ func (s *UDPSession) output(buf []byte) {
}
// calculate Reed-Solomon Erasure Code
if s.fecCnt == s.fec.dataShards {
if s.fecShardCount == s.fec.dataShards {
// bzero each datashard's tail
for i := 0; i < s.fec.dataShards; i++ {
shard := s.fecDataShards[i]
@@ -442,7 +463,7 @@ func (s *UDPSession) output(buf []byte) {
}
// reset counters to zero
s.fecCnt = 0
s.fecShardCount = 0
s.fecMaxSize = 0
}
}
@@ -557,7 +578,7 @@ func (s *UDPSession) kcpInput(data []byte) {
s.mu.Unlock()
}
atomic.AddUint64(&DefaultSnmp.InSegs, 1)
atomic.AddUint64(&DefaultSnmp.InPkts, 1)
atomic.AddUint64(&DefaultSnmp.InBytes, uint64(len(data)))
if fecParityShards > 0 {
atomic.AddUint64(&DefaultSnmp.FECParityShards, fecParityShards)
@@ -626,21 +647,23 @@ func (s *UDPSession) readLoop() {
type (
// Listener defines a server listening for connections
Listener struct {
block BlockCrypt
dataShards, parityShards int
fec *FEC // for fec init test
conn net.PacketConn
sessions map[string]*UDPSession
chAccepts chan *UDPSession
chDeadlinks chan net.Addr
headerSize int
die chan struct{}
rxbuf sync.Pool
rd atomic.Value
wd atomic.Value
block BlockCrypt // block encryption
dataShards int // FEC data shard
parityShards int // FEC parity shard
fec *FEC // FEC mock initialization
conn net.PacketConn // the underlying packet connection
sessions map[string]*UDPSession // all sessions accepted by this Listener
chAccepts chan *UDPSession // Listen() backlog
chDeadlinks chan net.Addr // session close queue
headerSize int // the overall header size added before KCP frame
die chan struct{} // notify the listener has closed
rd atomic.Value // read deadline for Accept()
wd atomic.Value
}
packet struct {
// incoming packet
inPacket struct {
from net.Addr
data []byte
}
@@ -648,7 +671,7 @@ type (
// monitor incoming data for all connections of server
func (l *Listener) monitor() {
chPacket := make(chan packet, rxQueueLimit)
chPacket := make(chan inPacket, rxQueueLimit)
go l.receiver(chPacket)
for {
select {
@@ -699,7 +722,7 @@ func (l *Listener) monitor() {
}
}
l.rxbuf.Put(raw)
xmitBuf.Put(raw)
case deadlink := <-l.chDeadlinks:
delete(l.sessions, deadlink.String())
case <-l.die:
@@ -708,11 +731,11 @@ func (l *Listener) monitor() {
}
}
func (l *Listener) receiver(ch chan packet) {
func (l *Listener) receiver(ch chan inPacket) {
for {
data := l.rxbuf.Get().([]byte)[:mtuLimit]
data := xmitBuf.Get().([]byte)[:mtuLimit]
if n, from, err := l.conn.ReadFrom(data); err == nil && n >= l.headerSize+IKCP_OVERHEAD {
ch <- packet{from, data[:n]}
ch <- inPacket{from, data[:n]}
} else if err != nil {
return
} else {
@@ -829,9 +852,6 @@ func ServeConn(block BlockCrypt, dataShards, parityShards int, conn net.PacketCo
l.parityShards = parityShards
l.block = block
l.fec = newFEC(rxFECMulti*(dataShards+parityShards), dataShards, parityShards)
l.rxbuf.New = func() interface{} {
return make([]byte, mtuLimit)
}
// calculate header size
if l.block != nil {

View File

@@ -7,22 +7,24 @@ import (
// Snmp defines network statistics indicator
type Snmp struct {
BytesSent uint64 // raw bytes sent
BytesReceived uint64
MaxConn uint64
ActiveOpens uint64
PassiveOpens uint64
CurrEstab uint64 // count of connections for now
InErrs uint64 // udp read errors
BytesSent uint64 // bytes sent from upper level
BytesReceived uint64 // bytes received to upper level
MaxConn uint64 // max number of connections ever reached
ActiveOpens uint64 // accumulated active open connections
PassiveOpens uint64 // accumulated passive open connections
CurrEstab uint64 // current number of established connections
InErrs uint64 // UDP read errors reported from net.PacketConn
InCsumErrors uint64 // checksum errors from CRC32
KCPInErrors uint64 // packet iput errors from kcp
InSegs uint64
OutSegs uint64
InBytes uint64 // udp bytes received
OutBytes uint64 // udp bytes sent
RetransSegs uint64
FastRetransSegs uint64
EarlyRetransSegs uint64
KCPInErrors uint64 // packet iput errors reported from KCP
InPkts uint64 // incoming packets count
OutPkts uint64 // outgoing packets count
InSegs uint64 // incoming KCP segments
OutSegs uint64 // outgoing KCP segments
InBytes uint64 // UDP bytes received
OutBytes uint64 // UDP bytes sent
RetransSegs uint64 // accmulated retransmited segments
FastRetransSegs uint64 // accmulated fast retransmitted segments
EarlyRetransSegs uint64 // accmulated early retransmitted segments
LostSegs uint64 // number of segs infered as lost
RepeatSegs uint64 // number of segs duplicated
FECRecovered uint64 // correct packets recovered from FEC
@@ -47,6 +49,8 @@ func (s *Snmp) Header() []string {
"InErrs",
"InCsumErrors",
"KCPInErrors",
"InPkts",
"OutPkts",
"InSegs",
"OutSegs",
"InBytes",
@@ -76,6 +80,8 @@ func (s *Snmp) ToSlice() []string {
fmt.Sprint(snmp.InErrs),
fmt.Sprint(snmp.InCsumErrors),
fmt.Sprint(snmp.KCPInErrors),
fmt.Sprint(snmp.InPkts),
fmt.Sprint(snmp.OutPkts),
fmt.Sprint(snmp.InSegs),
fmt.Sprint(snmp.OutSegs),
fmt.Sprint(snmp.InBytes),
@@ -104,6 +110,8 @@ func (s *Snmp) Copy() *Snmp {
d.InErrs = atomic.LoadUint64(&s.InErrs)
d.InCsumErrors = atomic.LoadUint64(&s.InCsumErrors)
d.KCPInErrors = atomic.LoadUint64(&s.KCPInErrors)
d.InPkts = atomic.LoadUint64(&s.InPkts)
d.OutPkts = atomic.LoadUint64(&s.OutPkts)
d.InSegs = atomic.LoadUint64(&s.InSegs)
d.OutSegs = atomic.LoadUint64(&s.OutSegs)
d.InBytes = atomic.LoadUint64(&s.InBytes)
@@ -131,6 +139,8 @@ func (s *Snmp) Reset() {
atomic.StoreUint64(&s.InErrs, 0)
atomic.StoreUint64(&s.InCsumErrors, 0)
atomic.StoreUint64(&s.KCPInErrors, 0)
atomic.StoreUint64(&s.InPkts, 0)
atomic.StoreUint64(&s.OutPkts, 0)
atomic.StoreUint64(&s.InSegs, 0)
atomic.StoreUint64(&s.OutSegs, 0)
atomic.StoreUint64(&s.InBytes, 0)

View File

@@ -13,12 +13,14 @@ func init() {
go updater.updateTask()
}
// entry contains a session update info
type entry struct {
sid uint32
ts time.Time
s *UDPSession
}
// a global heap managed kcp.flush() caller
type updateHeap struct {
entries []entry
indices map[uint32]int

21
vendor/github.com/xtaci/smux/LICENSE generated vendored Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2016-2017 Daniel Fu
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

60
vendor/github.com/xtaci/smux/frame.go generated vendored Normal file
View File

@@ -0,0 +1,60 @@
package smux
import (
"encoding/binary"
"fmt"
)
const (
version = 1
)
const ( // cmds
cmdSYN byte = iota // stream open
cmdFIN // stream close, a.k.a EOF mark
cmdPSH // data push
cmdNOP // no operation
)
const (
sizeOfVer = 1
sizeOfCmd = 1
sizeOfLength = 2
sizeOfSid = 4
headerSize = sizeOfVer + sizeOfCmd + sizeOfSid + sizeOfLength
)
// Frame defines a packet from or to be multiplexed into a single connection
type Frame struct {
ver byte
cmd byte
sid uint32
data []byte
}
func newFrame(cmd byte, sid uint32) Frame {
return Frame{ver: version, cmd: cmd, sid: sid}
}
type rawHeader []byte
func (h rawHeader) Version() byte {
return h[0]
}
func (h rawHeader) Cmd() byte {
return h[1]
}
func (h rawHeader) Length() uint16 {
return binary.LittleEndian.Uint16(h[2:])
}
func (h rawHeader) StreamID() uint32 {
return binary.LittleEndian.Uint32(h[4:])
}
func (h rawHeader) String() string {
return fmt.Sprintf("Version:%d Cmd:%d StreamID:%d Length:%d",
h.Version(), h.Cmd(), h.StreamID(), h.Length())
}

80
vendor/github.com/xtaci/smux/mux.go generated vendored Normal file
View File

@@ -0,0 +1,80 @@
package smux
import (
"fmt"
"io"
"time"
"github.com/pkg/errors"
)
// Config is used to tune the Smux session
type Config struct {
// KeepAliveInterval is how often to send a NOP command to the remote
KeepAliveInterval time.Duration
// KeepAliveTimeout is how long the session
// will be closed if no data has arrived
KeepAliveTimeout time.Duration
// MaxFrameSize is used to control the maximum
// frame size to sent to the remote
MaxFrameSize int
// MaxReceiveBuffer is used to control the maximum
// number of data in the buffer pool
MaxReceiveBuffer int
}
// DefaultConfig is used to return a default configuration
func DefaultConfig() *Config {
return &Config{
KeepAliveInterval: 10 * time.Second,
KeepAliveTimeout: 30 * time.Second,
MaxFrameSize: 4096,
MaxReceiveBuffer: 4194304,
}
}
// VerifyConfig is used to verify the sanity of configuration
func VerifyConfig(config *Config) error {
if config.KeepAliveInterval == 0 {
return errors.New("keep-alive interval must be positive")
}
if config.KeepAliveTimeout < config.KeepAliveInterval {
return fmt.Errorf("keep-alive timeout must be larger than keep-alive interval")
}
if config.MaxFrameSize <= 0 {
return errors.New("max frame size must be positive")
}
if config.MaxFrameSize > 65535 {
return errors.New("max frame size must not be larger than 65535")
}
if config.MaxReceiveBuffer <= 0 {
return errors.New("max receive buffer must be positive")
}
return nil
}
// Server is used to initialize a new server-side connection.
func Server(conn io.ReadWriteCloser, config *Config) (*Session, error) {
if config == nil {
config = DefaultConfig()
}
if err := VerifyConfig(config); err != nil {
return nil, err
}
return newSession(config, conn, false), nil
}
// Client is used to initialize a new client-side connection.
func Client(conn io.ReadWriteCloser, config *Config) (*Session, error) {
if config == nil {
config = DefaultConfig()
}
if err := VerifyConfig(config); err != nil {
return nil, err
}
return newSession(config, conn, true), nil
}

335
vendor/github.com/xtaci/smux/session.go generated vendored Normal file
View File

@@ -0,0 +1,335 @@
package smux
import (
"encoding/binary"
"io"
"sync"
"sync/atomic"
"time"
"github.com/pkg/errors"
)
const (
defaultAcceptBacklog = 1024
)
const (
errBrokenPipe = "broken pipe"
errInvalidProtocol = "invalid protocol version"
)
type writeRequest struct {
frame Frame
result chan writeResult
}
type writeResult struct {
n int
err error
}
// Session defines a multiplexed connection for streams
type Session struct {
conn io.ReadWriteCloser
config *Config
nextStreamID uint32 // next stream identifier
bucket int32 // token bucket
bucketCond *sync.Cond // used for waiting for tokens
streams map[uint32]*Stream // all streams in this session
streamLock sync.Mutex // locks streams
die chan struct{} // flag session has died
dieLock sync.Mutex
chAccepts chan *Stream
dataReady int32 // flag data has arrived
deadline atomic.Value
writes chan writeRequest
}
func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
s := new(Session)
s.die = make(chan struct{})
s.conn = conn
s.config = config
s.streams = make(map[uint32]*Stream)
s.chAccepts = make(chan *Stream, defaultAcceptBacklog)
s.bucket = int32(config.MaxReceiveBuffer)
s.bucketCond = sync.NewCond(&sync.Mutex{})
s.writes = make(chan writeRequest)
if client {
s.nextStreamID = 1
} else {
s.nextStreamID = 2
}
go s.recvLoop()
go s.sendLoop()
go s.keepalive()
return s
}
// OpenStream is used to create a new stream
func (s *Session) OpenStream() (*Stream, error) {
if s.IsClosed() {
return nil, errors.New(errBrokenPipe)
}
sid := atomic.AddUint32(&s.nextStreamID, 2)
stream := newStream(sid, s.config.MaxFrameSize, s)
if _, err := s.writeFrame(newFrame(cmdSYN, sid)); err != nil {
return nil, errors.Wrap(err, "writeFrame")
}
s.streamLock.Lock()
s.streams[sid] = stream
s.streamLock.Unlock()
return stream, nil
}
// AcceptStream is used to block until the next available stream
// is ready to be accepted.
func (s *Session) AcceptStream() (*Stream, error) {
var deadline <-chan time.Time
if d, ok := s.deadline.Load().(time.Time); ok && !d.IsZero() {
timer := time.NewTimer(d.Sub(time.Now()))
defer timer.Stop()
deadline = timer.C
}
select {
case stream := <-s.chAccepts:
return stream, nil
case <-deadline:
return nil, errTimeout
case <-s.die:
return nil, errors.New(errBrokenPipe)
}
}
// Close is used to close the session and all streams.
func (s *Session) Close() (err error) {
s.dieLock.Lock()
select {
case <-s.die:
s.dieLock.Unlock()
return errors.New(errBrokenPipe)
default:
close(s.die)
s.dieLock.Unlock()
s.streamLock.Lock()
for k := range s.streams {
s.streams[k].sessionClose()
}
s.streamLock.Unlock()
s.bucketCond.Signal()
return s.conn.Close()
}
}
// IsClosed does a safe check to see if we have shutdown
func (s *Session) IsClosed() bool {
select {
case <-s.die:
return true
default:
return false
}
}
// NumStreams returns the number of currently open streams
func (s *Session) NumStreams() int {
if s.IsClosed() {
return 0
}
s.streamLock.Lock()
defer s.streamLock.Unlock()
return len(s.streams)
}
// SetDeadline sets a deadline used by Accept* calls.
// A zero time value disables the deadline.
func (s *Session) SetDeadline(t time.Time) error {
s.deadline.Store(t)
return nil
}
// notify the session that a stream has closed
func (s *Session) streamClosed(sid uint32) {
s.streamLock.Lock()
if n := s.streams[sid].recycleTokens(); n > 0 { // return remaining tokens to the bucket
if atomic.AddInt32(&s.bucket, int32(n)) > 0 {
s.bucketCond.Signal()
}
}
delete(s.streams, sid)
s.streamLock.Unlock()
}
// returnTokens is called by stream to return token after read
func (s *Session) returnTokens(n int) {
oldvalue := atomic.LoadInt32(&s.bucket)
newvalue := atomic.AddInt32(&s.bucket, int32(n))
if oldvalue <= 0 && newvalue > 0 {
s.bucketCond.Signal()
}
}
// session read a frame from underlying connection
// it's data is pointed to the input buffer
func (s *Session) readFrame(buffer []byte) (f Frame, err error) {
if _, err := io.ReadFull(s.conn, buffer[:headerSize]); err != nil {
return f, errors.Wrap(err, "readFrame")
}
dec := rawHeader(buffer)
if dec.Version() != version {
return f, errors.New(errInvalidProtocol)
}
f.ver = dec.Version()
f.cmd = dec.Cmd()
f.sid = dec.StreamID()
if length := dec.Length(); length > 0 {
if _, err := io.ReadFull(s.conn, buffer[headerSize:headerSize+length]); err != nil {
return f, errors.Wrap(err, "readFrame")
}
f.data = buffer[headerSize : headerSize+length]
}
return f, nil
}
// recvLoop keeps on reading from underlying connection if tokens are available
func (s *Session) recvLoop() {
buffer := make([]byte, (1<<16)+headerSize)
for {
s.bucketCond.L.Lock()
for atomic.LoadInt32(&s.bucket) <= 0 && !s.IsClosed() {
s.bucketCond.Wait()
}
s.bucketCond.L.Unlock()
if s.IsClosed() {
return
}
if f, err := s.readFrame(buffer); err == nil {
atomic.StoreInt32(&s.dataReady, 1)
switch f.cmd {
case cmdNOP:
case cmdSYN:
s.streamLock.Lock()
if _, ok := s.streams[f.sid]; !ok {
stream := newStream(f.sid, s.config.MaxFrameSize, s)
s.streams[f.sid] = stream
select {
case s.chAccepts <- stream:
case <-s.die:
}
}
s.streamLock.Unlock()
case cmdFIN:
s.streamLock.Lock()
if stream, ok := s.streams[f.sid]; ok {
stream.markRST()
stream.notifyReadEvent()
}
s.streamLock.Unlock()
case cmdPSH:
s.streamLock.Lock()
if stream, ok := s.streams[f.sid]; ok {
atomic.AddInt32(&s.bucket, -int32(len(f.data)))
stream.pushBytes(f.data)
stream.notifyReadEvent()
}
s.streamLock.Unlock()
default:
s.Close()
return
}
} else {
s.Close()
return
}
}
}
func (s *Session) keepalive() {
tickerPing := time.NewTicker(s.config.KeepAliveInterval)
tickerTimeout := time.NewTicker(s.config.KeepAliveTimeout)
defer tickerPing.Stop()
defer tickerTimeout.Stop()
for {
select {
case <-tickerPing.C:
s.writeFrame(newFrame(cmdNOP, 0))
s.bucketCond.Signal() // force a signal to the recvLoop
case <-tickerTimeout.C:
if !atomic.CompareAndSwapInt32(&s.dataReady, 1, 0) {
s.Close()
return
}
case <-s.die:
return
}
}
}
func (s *Session) sendLoop() {
buf := make([]byte, (1<<16)+headerSize)
for {
select {
case <-s.die:
return
case request, ok := <-s.writes:
if !ok {
continue
}
buf[0] = request.frame.ver
buf[1] = request.frame.cmd
binary.LittleEndian.PutUint16(buf[2:], uint16(len(request.frame.data)))
binary.LittleEndian.PutUint32(buf[4:], request.frame.sid)
copy(buf[headerSize:], request.frame.data)
n, err := s.conn.Write(buf[:headerSize+len(request.frame.data)])
n -= headerSize
if n < 0 {
n = 0
}
result := writeResult{
n: n,
err: err,
}
request.result <- result
close(request.result)
}
}
}
// writeFrame writes the frame to the underlying connection
// and returns the number of bytes written if successful
func (s *Session) writeFrame(f Frame) (n int, err error) {
req := writeRequest{
frame: f,
result: make(chan writeResult, 1),
}
select {
case <-s.die:
return 0, errors.New(errBrokenPipe)
case s.writes <- req:
}
result := <-req.result
return result.n, result.err
}

261
vendor/github.com/xtaci/smux/stream.go generated vendored Normal file
View File

@@ -0,0 +1,261 @@
package smux
import (
"bytes"
"io"
"net"
"sync"
"sync/atomic"
"time"
"github.com/pkg/errors"
)
// Stream implements net.Conn
type Stream struct {
id uint32
rstflag int32
sess *Session
buffer bytes.Buffer
bufferLock sync.Mutex
frameSize int
chReadEvent chan struct{} // notify a read event
die chan struct{} // flag the stream has closed
dieLock sync.Mutex
readDeadline atomic.Value
writeDeadline atomic.Value
}
// newStream initiates a Stream struct
func newStream(id uint32, frameSize int, sess *Session) *Stream {
s := new(Stream)
s.id = id
s.chReadEvent = make(chan struct{}, 1)
s.frameSize = frameSize
s.sess = sess
s.die = make(chan struct{})
return s
}
// ID returns the unique stream ID.
func (s *Stream) ID() uint32 {
return s.id
}
// Read implements net.Conn
func (s *Stream) Read(b []byte) (n int, err error) {
var deadline <-chan time.Time
if d, ok := s.readDeadline.Load().(time.Time); ok && !d.IsZero() {
timer := time.NewTimer(d.Sub(time.Now()))
defer timer.Stop()
deadline = timer.C
}
READ:
select {
case <-s.die:
return 0, errors.New(errBrokenPipe)
case <-deadline:
return n, errTimeout
default:
}
s.bufferLock.Lock()
n, err = s.buffer.Read(b)
s.bufferLock.Unlock()
if n > 0 {
s.sess.returnTokens(n)
return n, nil
} else if atomic.LoadInt32(&s.rstflag) == 1 {
_ = s.Close()
return 0, io.EOF
}
select {
case <-s.chReadEvent:
goto READ
case <-deadline:
return n, errTimeout
case <-s.die:
return 0, errors.New(errBrokenPipe)
}
}
// Write implements net.Conn
func (s *Stream) Write(b []byte) (n int, err error) {
var deadline <-chan time.Time
if d, ok := s.writeDeadline.Load().(time.Time); ok && !d.IsZero() {
timer := time.NewTimer(d.Sub(time.Now()))
defer timer.Stop()
deadline = timer.C
}
select {
case <-s.die:
return 0, errors.New(errBrokenPipe)
default:
}
frames := s.split(b, cmdPSH, s.id)
sent := 0
for k := range frames {
req := writeRequest{
frame: frames[k],
result: make(chan writeResult, 1),
}
select {
case s.sess.writes <- req:
case <-s.die:
return sent, errors.New(errBrokenPipe)
case <-deadline:
return sent, errTimeout
}
select {
case result := <-req.result:
sent += result.n
if result.err != nil {
return sent, result.err
}
case <-s.die:
return sent, errors.New(errBrokenPipe)
case <-deadline:
return sent, errTimeout
}
}
return sent, nil
}
// Close implements net.Conn
func (s *Stream) Close() error {
s.dieLock.Lock()
select {
case <-s.die:
s.dieLock.Unlock()
return errors.New(errBrokenPipe)
default:
close(s.die)
s.dieLock.Unlock()
s.sess.streamClosed(s.id)
_, err := s.sess.writeFrame(newFrame(cmdFIN, s.id))
return err
}
}
// SetReadDeadline sets the read deadline as defined by
// net.Conn.SetReadDeadline.
// A zero time value disables the deadline.
func (s *Stream) SetReadDeadline(t time.Time) error {
s.readDeadline.Store(t)
return nil
}
// SetWriteDeadline sets the write deadline as defined by
// net.Conn.SetWriteDeadline.
// A zero time value disables the deadline.
func (s *Stream) SetWriteDeadline(t time.Time) error {
s.writeDeadline.Store(t)
return nil
}
// SetDeadline sets both read and write deadlines as defined by
// net.Conn.SetDeadline.
// A zero time value disables the deadlines.
func (s *Stream) SetDeadline(t time.Time) error {
if err := s.SetReadDeadline(t); err != nil {
return err
}
if err := s.SetWriteDeadline(t); err != nil {
return err
}
return nil
}
// session closes the stream
func (s *Stream) sessionClose() {
s.dieLock.Lock()
defer s.dieLock.Unlock()
select {
case <-s.die:
default:
close(s.die)
}
}
// LocalAddr satisfies net.Conn interface
func (s *Stream) LocalAddr() net.Addr {
if ts, ok := s.sess.conn.(interface {
LocalAddr() net.Addr
}); ok {
return ts.LocalAddr()
}
return nil
}
// RemoteAddr satisfies net.Conn interface
func (s *Stream) RemoteAddr() net.Addr {
if ts, ok := s.sess.conn.(interface {
RemoteAddr() net.Addr
}); ok {
return ts.RemoteAddr()
}
return nil
}
// pushBytes a slice into buffer
func (s *Stream) pushBytes(p []byte) {
s.bufferLock.Lock()
s.buffer.Write(p)
s.bufferLock.Unlock()
}
// recycleTokens transform remaining bytes to tokens(will truncate buffer)
func (s *Stream) recycleTokens() (n int) {
s.bufferLock.Lock()
n = s.buffer.Len()
s.buffer.Reset()
s.bufferLock.Unlock()
return
}
// split large byte buffer into smaller frames, reference only
func (s *Stream) split(bts []byte, cmd byte, sid uint32) []Frame {
var frames []Frame
for len(bts) > s.frameSize {
frame := newFrame(cmd, sid)
frame.data = bts[:s.frameSize]
bts = bts[s.frameSize:]
frames = append(frames, frame)
}
if len(bts) > 0 {
frame := newFrame(cmd, sid)
frame.data = bts
frames = append(frames, frame)
}
return frames
}
// notify read event
func (s *Stream) notifyReadEvent() {
select {
case s.chReadEvent <- struct{}{}:
default:
}
}
// mark this stream has been reset
func (s *Stream) markRST() {
atomic.StoreInt32(&s.rstflag, 1)
}
var errTimeout error = &timeoutError{}
type timeoutError struct{}
func (e *timeoutError) Error() string { return "i/o timeout" }
func (e *timeoutError) Timeout() bool { return true }
func (e *timeoutError) Temporary() bool { return true }