Interface fuckery

master
eater 5 years ago
parent 3ee5beac97
commit ced1df310f
Signed by: eater
GPG Key ID: 656785D50BE51C0A

@ -7,7 +7,6 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"errors"
"golang.org/x/crypto/blowfish" "golang.org/x/crypto/blowfish"
"io" "io"
) )
@ -16,7 +15,20 @@ type BaseMessage struct {
Signature string Signature string
} }
func DecryptAndVerify(input []byte, key *rsa.PrivateKey, pub *rsa.PublicKey, signature []byte, v interface{}) error { func (it *BaseMessage) GetSignature() string {
return it.Signature
}
func (it *BaseMessage) SetSignature(sig string) {
it.Signature = sig
}
type Message interface {
SetSignature(sig string)
GetSignature() string
}
func DecryptAndVerify(input []byte, key *rsa.PrivateKey, pub *rsa.PublicKey, signature []byte, v Message) error {
all, err := hex.DecodeString(string(input)) all, err := hex.DecodeString(string(input))
if err != nil { if err != nil {
return err return err
@ -49,12 +61,7 @@ func DecryptAndVerify(input []byte, key *rsa.PrivateKey, pub *rsa.PublicKey, sig
return err return err
} }
baseReq, ok := v.(*BaseMessage) recvSignature, err := hex.DecodeString(v.GetSignature())
if !ok {
return errors.New("non-base request")
}
recvSignature, err := hex.DecodeString(baseReq.Signature)
if err != nil { if err != nil {
return err return err
} }
@ -65,7 +72,7 @@ func DecryptAndVerify(input []byte, key *rsa.PrivateKey, pub *rsa.PublicKey, sig
return err return err
} }
func EncryptAndSign(v interface{}, key *rsa.PrivateKey, pub *rsa.PublicKey, signature []byte, writer io.Writer) error { func EncryptAndSign(v Message, key *rsa.PrivateKey, pub *rsa.PublicKey, signature []byte, writer io.Writer) error {
passwordAndIV := make([]byte, 40) passwordAndIV := make([]byte, 40)
_, err := rand.Read(passwordAndIV) _, err := rand.Read(passwordAndIV)
if err != nil { if err != nil {
@ -96,11 +103,7 @@ func EncryptAndSign(v interface{}, key *rsa.PrivateKey, pub *rsa.PublicKey, sign
return err return err
} }
sign, ok := v.(*BaseMessage) v.SetSignature(hex.EncodeToString(plainSignature))
if !ok {
return errors.New("given message can't be signed")
}
sign.Signature = hex.EncodeToString(plainSignature)
body, err := json.Marshal(v) body, err := json.Marshal(v)

@ -28,7 +28,10 @@ func TestSingleRound(t *testing.T) {
keyA, _ := rsa.GenerateKey(rand.Reader, 4096) keyA, _ := rsa.GenerateKey(rand.Reader, 4096)
keyB, _ := rsa.GenerateKey(rand.Reader, 4096) keyB, _ := rsa.GenerateKey(rand.Reader, 4096)
signature := []byte("Hello world!") signature := []byte("Hello world!")
v := &service.BaseMessage{} v := &service.UpdateOpenVPNConfigRequest{
BaseMessage: &service.BaseMessage{""},
Config: "Hello world!",
}
x := make([]byte, 0) x := make([]byte, 0)
b := bytes.NewBuffer(x) b := bytes.NewBuffer(x)
err := service.EncryptAndSign(v, keyA, &keyB.PublicKey, signature, b) err := service.EncryptAndSign(v, keyA, &keyB.PublicKey, signature, b)
@ -37,9 +40,16 @@ func TestSingleRound(t *testing.T) {
t.Errorf("Failed encrypting: %s", err) t.Errorf("Failed encrypting: %s", err)
} }
err = service.DecryptAndVerify(b.Bytes(), keyB, &keyA.PublicKey, signature, &service.BaseMessage{}) ov := &service.UpdateOpenVPNConfigRequest{
BaseMessage: &service.BaseMessage{""},
}
err = service.DecryptAndVerify(b.Bytes(), keyB, &keyA.PublicKey, signature, ov)
if err != nil { if err != nil {
t.Errorf("Failed decrypting: %s", err) t.Errorf("Failed decrypting: %s", err)
} }
if ov.Config != "Hello world!" {
t.Errorf("Config scrambled: %s", ov.Config)
}
} }

@ -21,22 +21,22 @@ type HttpServer struct {
} }
type CreateCSRRequest struct { type CreateCSRRequest struct {
BaseMessage *BaseMessage
Hostname string Hostname string
} }
type CreateCSRResponse struct { type CreateCSRResponse struct {
BaseMessage *BaseMessage
CSR string `json:"csr"` CSR string `json:"csr"`
} }
type UpdateOpenVPNConfigRequest struct { type UpdateOpenVPNConfigRequest struct {
BaseMessage *BaseMessage
Config string Config string
} }
type DeliverCertificateRequest struct { type DeliverCertificateRequest struct {
BaseMessage *BaseMessage
Certificate string Certificate string
} }
@ -53,7 +53,9 @@ func (it *HttpServer) Start() {
http.HandleFunc("/create-csr", func(writer http.ResponseWriter, request *http.Request) { http.HandleFunc("/create-csr", func(writer http.ResponseWriter, request *http.Request) {
log.Printf("%s /create-csr", strings.ToUpper(request.Method)) log.Printf("%s /create-csr", strings.ToUpper(request.Method))
req := &CreateCSRRequest{} req := &CreateCSRRequest{
BaseMessage: &BaseMessage{""},
}
err := it.verifyRequest(request, req) err := it.verifyRequest(request, req)
if err != nil { if err != nil {
@ -69,13 +71,16 @@ func (it *HttpServer) Start() {
} }
it.writeResponse(writer, CreateCSRResponse{ it.writeResponse(writer, CreateCSRResponse{
CSR: string(csr), BaseMessage: &BaseMessage{""},
CSR: string(csr),
}) })
}) })
http.HandleFunc("/deliver-crt", func(writer http.ResponseWriter, request *http.Request) { http.HandleFunc("/deliver-crt", func(writer http.ResponseWriter, request *http.Request) {
log.Printf("%s /deliver-crt", strings.ToUpper(request.Method)) log.Printf("%s /deliver-crt", strings.ToUpper(request.Method))
req := &DeliverCertificateRequest{} req := &DeliverCertificateRequest{
BaseMessage: &BaseMessage{""},
}
err := it.verifyRequest(request, req) err := it.verifyRequest(request, req)
if err != nil { if err != nil {
log.Printf("Error on %s %s: %s", request.Method, request.URL.Path, err) log.Printf("Error on %s %s: %s", request.Method, request.URL.Path, err)
@ -94,7 +99,9 @@ func (it *HttpServer) Start() {
http.HandleFunc("/update-openvpn-config", func(writer http.ResponseWriter, request *http.Request) { http.HandleFunc("/update-openvpn-config", func(writer http.ResponseWriter, request *http.Request) {
log.Printf("%s /update-openvpn-config", strings.ToUpper(request.Method)) log.Printf("%s /update-openvpn-config", strings.ToUpper(request.Method))
req := &UpdateOpenVPNConfigRequest{} req := &UpdateOpenVPNConfigRequest{
BaseMessage: &BaseMessage{""},
}
err := it.verifyRequest(request, req) err := it.verifyRequest(request, req)
if err != nil { if err != nil {
log.Printf("Error on %s %s: %s", request.Method, request.URL.Path, err) log.Printf("Error on %s %s: %s", request.Method, request.URL.Path, err)
@ -114,11 +121,11 @@ func (it *HttpServer) Start() {
http.ListenAndServe(":7864", nil) http.ListenAndServe(":7864", nil)
} }
func (it *HttpServer) writeResponse(writer http.ResponseWriter, v interface{}) error { func (it *HttpServer) writeResponse(writer http.ResponseWriter, v Message) error {
return EncryptAndSign(v, it.manager.privateKey, it.manager.CAPublicKey(), it.manager.GetServerFingerprint(), writer) return EncryptAndSign(v, it.manager.privateKey, it.manager.CAPublicKey(), it.manager.GetServerFingerprint(), writer)
} }
func (it *HttpServer) verifyRequest(r *http.Request, v interface{}) (error) { func (it *HttpServer) verifyRequest(r *http.Request, v Message) (error) {
hexBody, err := ioutil.ReadAll(r.Body) hexBody, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
return err return err

Loading…
Cancel
Save