diff --git a/13.go b/13.go index 30ac200..b333d6a 100644 --- a/13.go +++ b/13.go @@ -191,7 +191,7 @@ func (hs *serverHandshakeState) readClientFinished13() error { // happens, even if the Read call might do more work. c.confirmMutex.Unlock() - return hs.sendSessionTicket13() + return hs.sendSessionTicket13() // TODO: do in a goroutine } func (hs *serverHandshakeState) sendCertificate13() error { @@ -460,12 +460,20 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, alert alert) { hashSize := hash.Size() for i := range hs.clientHello.psks { sessionTicket := append([]uint8{}, hs.clientHello.psks[i].identity...) - serializedTicket, _ := hs.c.decryptTicket(sessionTicket) - if serializedTicket == nil { - continue + if hs.c.config.SessionTicketSealer != nil { + var ok bool + sessionTicket, ok = hs.c.config.SessionTicketSealer.Unseal(hs.clientHelloInfo(), sessionTicket) + if !ok { + continue + } + } else { + sessionTicket, _ = hs.c.decryptTicket(sessionTicket) + if sessionTicket == nil { + continue + } } s := &sessionState13{} - if s.unmarshal(serializedTicket) != alertSuccess { + if s.unmarshal(sessionTicket) != alertSuccess { continue } if s.vers != hs.c.vers { @@ -565,11 +573,21 @@ func (hs *serverHandshakeState) sendSessionTicket13() error { } for i := 0; i < numSessionTickets; i++ { - ticket, err := c.encryptTicket(sessionState.marshal()) + ticket := sessionState.marshal() + var err error + if c.config.SessionTicketSealer != nil { + cs := c.ConnectionState() + ticket, err = c.config.SessionTicketSealer.Seal(&cs, ticket) + } else { + ticket, err = c.encryptTicket(ticket) + } if err != nil { c.sendAlert(alertInternalError) return err } + if ticket == nil { + continue + } ticketMsg := &newSessionTicketMsg13{ lifetime: 24 * 3600, // TODO(filippo) maxEarlyDataLength: c.config.Max0RTTDataSize, diff --git a/common.go b/common.go index 96d3af9..b421296 100644 --- a/common.go +++ b/common.go @@ -592,6 +592,10 @@ type Config struct { // See https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-2.3. Accept0RTTData bool + // SessionTicketSealer, if not nil, is used to wrap and unwrap + // session tickets, instead of SessionTicketKey. + SessionTicketSealer SessionTicketSealer + serverInitOnce sync.Once // guards calling (*Config).serverInit // mutex protects sessionTicketKeys. @@ -668,6 +672,7 @@ func (c *Config) Clone() *Config { KeyLogWriter: c.KeyLogWriter, Accept0RTTData: c.Accept0RTTData, Max0RTTDataSize: c.Max0RTTDataSize, + SessionTicketSealer: c.SessionTicketSealer, sessionTicketKeys: sessionTicketKeys, } } @@ -676,7 +681,7 @@ func (c *Config) Clone() *Config { // returned by a GetConfigForClient callback then the argument should be the // Config that was passed to Server, otherwise it should be nil. func (c *Config) serverInit(originalConfig *Config) { - if c.SessionTicketsDisabled || len(c.ticketKeys()) != 0 { + if c.SessionTicketsDisabled || len(c.ticketKeys()) != 0 || c.SessionTicketSealer != nil { return } diff --git a/ticket.go b/ticket.go index eac9416..b8de132 100644 --- a/ticket.go +++ b/ticket.go @@ -15,6 +15,23 @@ import ( "io" ) +// A SessionTicketSealer provides a way to securely encapsulate +// session state for storage on the client. All methods are safe for +// concurrent use. +type SessionTicketSealer interface { + // Seal returns a session ticket value that can be later passed to Unseal + // to recover the content, usually by encrypting it. The ticket will be sent + // to the client to be stored, and will be sent back in plaintext, so it can + // be read and modified by an attacker. + Seal(cs *ConnectionState, content []byte) (ticket []byte, err error) + + // Unseal returns a session ticket contents. The ticket can't be safely + // assumed to have been generated by Seal. + // If unable to unseal the ticket, the connection will proceed with a + // complete handshake. + Unseal(chi *ClientHelloInfo, ticket []byte) (content []byte, success bool) +} + // sessionState contains the information that is serialized into a session // ticket in order to later resume a connection. type sessionState struct { diff --git a/tls_test.go b/tls_test.go index d352be8..b9b3bf9 100644 --- a/tls_test.go +++ b/tls_test.go @@ -656,6 +656,8 @@ func TestCloneNonFuncFields(t *testing.T) { f.Set(reflect.ValueOf(RenegotiateOnceAsClient)) case "Max0RTTDataSize": f.Set(reflect.ValueOf(uint32(0))) + case "SessionTicketSealer": + // TODO default: t.Errorf("all fields must be accounted for, but saw unknown field %q", fn) }