diff --git a/crypto.go b/crypto.go index 1a2c590..cc2f92e 100644 --- a/crypto.go +++ b/crypto.go @@ -7,7 +7,6 @@ import ( "crypto/sha256" "encoding/hex" "encoding/json" - "errors" "golang.org/x/crypto/blowfish" "io" ) @@ -16,7 +15,20 @@ type BaseMessage struct { Signature string } -func DecryptAndVerify(input []byte, key *rsa.PrivateKey, pub *rsa.PublicKey, signature []byte, v interface{}) error { +func (it *BaseMessage) GetSignature() string { + return it.Signature +} + +func (it *BaseMessage) SetSignature(sig string) { + it.Signature = sig +} + +type Message interface { + SetSignature(sig string) + GetSignature() string +} + +func DecryptAndVerify(input []byte, key *rsa.PrivateKey, pub *rsa.PublicKey, signature []byte, v Message) error { all, err := hex.DecodeString(string(input)) if err != nil { return err @@ -49,12 +61,7 @@ func DecryptAndVerify(input []byte, key *rsa.PrivateKey, pub *rsa.PublicKey, sig return err } - baseReq, ok := v.(*BaseMessage) - if !ok { - return errors.New("non-base request") - } - - recvSignature, err := hex.DecodeString(baseReq.Signature) + recvSignature, err := hex.DecodeString(v.GetSignature()) if err != nil { return err } @@ -65,7 +72,7 @@ func DecryptAndVerify(input []byte, key *rsa.PrivateKey, pub *rsa.PublicKey, sig return err } -func EncryptAndSign(v interface{}, key *rsa.PrivateKey, pub *rsa.PublicKey, signature []byte, writer io.Writer) error { +func EncryptAndSign(v Message, key *rsa.PrivateKey, pub *rsa.PublicKey, signature []byte, writer io.Writer) error { passwordAndIV := make([]byte, 40) _, err := rand.Read(passwordAndIV) if err != nil { @@ -96,11 +103,7 @@ func EncryptAndSign(v interface{}, key *rsa.PrivateKey, pub *rsa.PublicKey, sign return err } - sign, ok := v.(*BaseMessage) - if !ok { - return errors.New("given message can't be signed") - } - sign.Signature = hex.EncodeToString(plainSignature) + v.SetSignature(hex.EncodeToString(plainSignature)) body, err := json.Marshal(v) diff --git a/crypto_test.go b/crypto_test.go index 71032cf..3dae43e 100644 --- a/crypto_test.go +++ b/crypto_test.go @@ -28,7 +28,10 @@ func TestSingleRound(t *testing.T) { keyA, _ := rsa.GenerateKey(rand.Reader, 4096) keyB, _ := rsa.GenerateKey(rand.Reader, 4096) signature := []byte("Hello world!") - v := &service.BaseMessage{} + v := &service.UpdateOpenVPNConfigRequest{ + BaseMessage: &service.BaseMessage{""}, + Config: "Hello world!", + } x := make([]byte, 0) b := bytes.NewBuffer(x) err := service.EncryptAndSign(v, keyA, &keyB.PublicKey, signature, b) @@ -37,9 +40,16 @@ func TestSingleRound(t *testing.T) { t.Errorf("Failed encrypting: %s", err) } - err = service.DecryptAndVerify(b.Bytes(), keyB, &keyA.PublicKey, signature, &service.BaseMessage{}) + ov := &service.UpdateOpenVPNConfigRequest{ + BaseMessage: &service.BaseMessage{""}, + } + err = service.DecryptAndVerify(b.Bytes(), keyB, &keyA.PublicKey, signature, ov) if err != nil { t.Errorf("Failed decrypting: %s", err) } + + if ov.Config != "Hello world!" { + t.Errorf("Config scrambled: %s", ov.Config) + } } diff --git a/http.go b/http.go index c25a463..057b9e5 100644 --- a/http.go +++ b/http.go @@ -21,22 +21,22 @@ type HttpServer struct { } type CreateCSRRequest struct { - BaseMessage + *BaseMessage Hostname string } type CreateCSRResponse struct { - BaseMessage + *BaseMessage CSR string `json:"csr"` } type UpdateOpenVPNConfigRequest struct { - BaseMessage + *BaseMessage Config string } type DeliverCertificateRequest struct { - BaseMessage + *BaseMessage Certificate string } @@ -53,7 +53,9 @@ func (it *HttpServer) Start() { http.HandleFunc("/create-csr", func(writer http.ResponseWriter, request *http.Request) { log.Printf("%s /create-csr", strings.ToUpper(request.Method)) - req := &CreateCSRRequest{} + req := &CreateCSRRequest{ + BaseMessage: &BaseMessage{""}, + } err := it.verifyRequest(request, req) if err != nil { @@ -69,13 +71,16 @@ func (it *HttpServer) Start() { } it.writeResponse(writer, CreateCSRResponse{ - CSR: string(csr), + BaseMessage: &BaseMessage{""}, + CSR: string(csr), }) }) http.HandleFunc("/deliver-crt", func(writer http.ResponseWriter, request *http.Request) { log.Printf("%s /deliver-crt", strings.ToUpper(request.Method)) - req := &DeliverCertificateRequest{} + req := &DeliverCertificateRequest{ + BaseMessage: &BaseMessage{""}, + } err := it.verifyRequest(request, req) if err != nil { log.Printf("Error on %s %s: %s", request.Method, request.URL.Path, err) @@ -94,7 +99,9 @@ func (it *HttpServer) Start() { http.HandleFunc("/update-openvpn-config", func(writer http.ResponseWriter, request *http.Request) { log.Printf("%s /update-openvpn-config", strings.ToUpper(request.Method)) - req := &UpdateOpenVPNConfigRequest{} + req := &UpdateOpenVPNConfigRequest{ + BaseMessage: &BaseMessage{""}, + } err := it.verifyRequest(request, req) if err != nil { log.Printf("Error on %s %s: %s", request.Method, request.URL.Path, err) @@ -114,11 +121,11 @@ func (it *HttpServer) Start() { http.ListenAndServe(":7864", nil) } -func (it *HttpServer) writeResponse(writer http.ResponseWriter, v interface{}) error { +func (it *HttpServer) writeResponse(writer http.ResponseWriter, v Message) error { return EncryptAndSign(v, it.manager.privateKey, it.manager.CAPublicKey(), it.manager.GetServerFingerprint(), writer) } -func (it *HttpServer) verifyRequest(r *http.Request, v interface{}) (error) { +func (it *HttpServer) verifyRequest(r *http.Request, v Message) (error) { hexBody, err := ioutil.ReadAll(r.Body) if err != nil { return err