新注册的用户请输入邮箱并保存,随后登录邮箱激活账号。后续可直接使用邮箱登录!

unsafe_test.go 4.44 KB
package qtls

import (
	"bytes"
	"errors"
	"fmt"
	"reflect"
	"testing"
	"time"

	cmx509 "chainmaker.org/chainmaker/common/v2/crypto/x509"
)

func TestUnsafeConversionIsSafe(t *testing.T) {
	type target struct {
		Name    string
		Version string

		callback func(label string, length int) error
	}

	type renamedField struct {
		NewName string
		Version string

		callback func(label string, length int) error
	}

	type renamedPrivateField struct {
		Name    string
		Version string

		cb func(label string, length int) error
	}

	type additionalField struct {
		Name    string
		Version string

		callback func(label string, length int) error
		secret   []byte
	}

	type interchangedFields struct {
		Version string
		Name    string

		callback func(label string, length int) error
	}

	type renamedCallbackFunctionParams struct { // should be equivalent
		Name    string
		Version string

		callback func(newLabel string, length int) error
	}

	testCases := []struct {
		name string
		from interface{}
		to   interface{}
		safe bool
	}{
		{"same struct", &target{}, &target{}, true},
		{"struct with a renamed field", &target{}, &renamedField{}, false},
		{"struct with a renamed private field", &target{}, &renamedPrivateField{}, false},
		{"struct with an additional field", &target{}, &additionalField{}, false},
		{"struct with interchanged fields", &target{}, &interchangedFields{}, false},
		{"struct with a renamed callback parameter", &target{}, &renamedCallbackFunctionParams{}, true},
	}

	for _, testCase := range testCases {
		t.Run(fmt.Sprintf("unsafe conversion: %s", testCase.name), func(t *testing.T) {
			if structsEqual(testCase.from, testCase.to) != testCase.safe {
				t.Errorf("invalid unsafe conversion")
			}
		})
	}
}

func TestConnectionStateReinterpretCast(t *testing.T) {
	var ekmLabel string
	var ekmContext []byte
	var ekmLength int
	state := connectionState{
		Version:            1234,
		HandshakeComplete:  true,
		DidResume:          true,
		CipherSuite:        4321,
		NegotiatedProtocol: "foobar",
		ServerName:         "server",
		PeerCertificates:   []*cmx509.Certificate{{Raw: []byte("foobar")}},
		OCSPResponse:       []byte("foo"),
		TLSUnique:          []byte("bar"),
		ekm: func(label string, context []byte, length int) ([]byte, error) {
			ekmLabel = label
			ekmContext = append(ekmContext, context...)
			ekmLength = length
			return []byte("ekm"), errors.New("ekm error")
		},
	}
	tlsState := toConnectionState(state)
	if tlsState.Version != 1234 {
		t.Error("Version doesn't match")
	}
	if !tlsState.HandshakeComplete {
		t.Error("HandshakeComplete doesn't match")
	}
	if !tlsState.DidResume {
		t.Error("DidResume doesn't match")
	}
	if tlsState.CipherSuite != 4321 {
		t.Error("CipherSuite doesn't match")
	}
	if tlsState.NegotiatedProtocol != "foobar" {
		t.Error("NegotiatedProtocol doesn't match")
	}
	if tlsState.ServerName != "server" {
		t.Error("ServerName doesn't match")
	}
	if len(tlsState.PeerCertificates) != 1 || !bytes.Equal(tlsState.PeerCertificates[0].Raw, []byte("foobar")) {
		t.Error("PeerCertificates don't match")
	}
	if !bytes.Equal(tlsState.OCSPResponse, []byte("foo")) {
		t.Error("OSCPResponse doesn't match")
	}
	if !bytes.Equal(tlsState.TLSUnique, []byte("bar")) {
		t.Error("TLSUnique doesn't match")
	}

	key, err := tlsState.ExportKeyingMaterial("label", []byte("context"), 42)
	if !bytes.Equal(key, []byte("ekm")) {
		t.Error("exported key doesn't match")
	}
	if err == nil || err.Error() != "ekm error" {
		t.Error("key export error doesn't match")
	}
	if ekmLabel != "label" {
		t.Error("key export label doesn't match")
	}
	if !bytes.Equal(ekmContext, []byte("context")) {
		t.Error("key export context doesn't match")
	}
	if ekmLength != 42 {
		t.Error("key export length doesn't match")
	}
}

func TestClientSessionStateReinterpretCast(t *testing.T) {
	state := &clientSessionState{
		sessionTicket: []byte("foobar"),
		receivedAt:    time.Now(),
		nonce:         []byte("foo"),
		useBy:         time.Now().Add(time.Hour),
		ageAdd:        1234,
	}
	if !reflect.DeepEqual(fromClientSessionState(toClientSessionState(state)), state) {
		t.Fatal("failed")
	}
}

// func TestClientSessionStateReinterpretCast(t *testing.T) {
// 	state := &clientSessionState{
// 		sessionTicket: []byte("foobar"),
// 		receivedAt:    time.Now(),
// 		nonce:         []byte("foo"),
// 		useBy:         time.Now().Add(time.Hour),
// 		ageAdd:        1234,
// 	}
// 	if !reflect.DeepEqual(fromClientSessionState(toClientSessionState(state)), state) {
// 		t.Fatal("failed")
// 	}
// }