You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

340 lines
8.2 KiB
Go

5 years ago
package service
import (
"bufio"
"crypto/md5"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"crypto/x509/pkix"
"encoding/hex"
"encoding/pem"
"io/ioutil"
"log"
"net/http"
"net/url"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
)
type Manager struct {
Config Config
privateKey *rsa.PrivateKey
openVPN *OpenVPN
CA *x509.Certificate
HasCertificate bool
HasOpenVPNConfig bool
}
func NewManager(cfg Config) *Manager {
return &Manager{
Config: cfg,
}
}
func fileExists(s string) bool {
if _, err := os.Stat(s); err == nil {
return true
} else if os.IsNotExist(err) {
return false
} else {
Check(err, "Failed to check %q, err: %s", s, err)
}
return false
}
func (it *Manager) GetOpenVPNConfigLocation() string {
return it.Config.Zerooo.Location + "/server.conf"
}
func (it *Manager) GetCALocation() string {
return it.Config.Zerooo.Location + "/ca.crt"
}
func (it *Manager) GetCertificateLocation() string {
return it.Config.Zerooo.Location + "/server.crt"
}
func (it *Manager) GetKeyLocation() string {
return it.Config.Zerooo.Location + "/server.key"
}
func (it *Manager) GetDHLocation() string {
return it.Config.Zerooo.Location + "/dh2048.pem"
}
func (it *Manager) GetCRLLocation() string {
return it.Config.Zerooo.Location + "/crl.pem"
}
func (it *Manager) EnsureFile(path string, generator func()) bool {
if !fileExists(path) {
generator()
return true
}
return false
}
func (it *Manager) Init() {
if !fileExists(it.Config.Zerooo.Location) {
os.MkdirAll(it.Config.Zerooo.Location, 0755)
}
generated := it.EnsureFile(it.GetKeyLocation(), func() {
it.GenerateKey()
})
if !generated {
it.LoadKey()
}
it.EnsureFile(it.GetCALocation(), func() {
it.DownloadCA()
})
it.LoadCA()
it.EnsureFile(it.GetCRLLocation(), func() {
it.DownloadCRL()
})
it.EnsureFile(it.GetDHLocation(), func() {
it.DownloadDH()
})
}
func (it *Manager) GetServerFingerprint() []byte {
der := it.ExportPublicKeyDER()
fp := it.GetOpenSSLLikeFingerprint(der)
return []byte(fp)
}
func (it *Manager) GetOpenSSLLikeFingerprint(input []byte) string {
output := md5.Sum(input)
items := make([]string, 16)
for i := range output {
items[i] = strings.ToLower(hex.EncodeToString([]byte{output[i]}))
}
return strings.Join(items, ":")
}
func (it *Manager) Register() {
fp := it.GetServerFingerprint()
log.Printf("%s\n", fp)
signed, err := rsa.SignPKCS1v15(rand.Reader, it.privateKey, 0, fp)
Check(err, "Failed signing fingerprint, err: %s", err)
signedHash := sha256.Sum256(signed)
publicKey := string(it.ExportPublicKey())
hexSignature := hex.EncodeToString(signedHash[:])
log.Printf("Registering with\n\nsignature: %s\npublic key:\n\n%s", hexSignature, publicKey)
resp, err := http.PostForm(it.Config.Zerooo.Endpoint+"/server/register", url.Values{
"publicKey": []string{publicKey},
"signature": []string{hexSignature},
})
Check(err, "Failed to sent register post to endpoint (%s/server/register), err: %s", it.Config.Zerooo.Endpoint, err)
if resp.StatusCode != 200 {
Error("Endpoint responded with %d, may not be actual zerooo endpoint", resp.StatusCode)
}
body, err := ioutil.ReadAll(resp.Body)
Check(err, "Failed to read response, err: %s", err)
strBody := string(body)
items := strings.SplitN(strBody, "\n", 2)
code, err := strconv.ParseInt(items[0], 10, 64)
Check(err, "Failed reading response success code, err: %s", err)
if code != 0 {
Error("Register failed, server responded with:\n\n%s", items[1])
}
log.Printf("Register succesful, server responded with: %s", items[1])
}
func (it *Manager) ExportPublicKeyDER() []byte {
if (it.privateKey == nil) {
Error("No private key is loaded, can't export public key")
}
der, err := x509.MarshalPKIXPublicKey(&it.privateKey.PublicKey)
Check(err, "Failed converting public key to DER, err: %s", err)
return der
}
func (it *Manager) ExportPublicKey() []byte {
pemKey := &pem.Block{
Type: "PUBLIC KEY",
Bytes: it.ExportPublicKeyDER(),
}
return pem.EncodeToMemory(pemKey)
}
func (it *Manager) GenerateKey() {
log.Printf("No key exists, generating new 4096 RSA key")
priv, err := rsa.GenerateKey(rand.Reader, 4096)
Check(err, "Couldn't generate key, err: %s", err)
it.privateKey = priv
pemKey := &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(priv),
}
keyFile, err := os.OpenFile(it.GetKeyLocation(), os.O_EXCL|os.O_CREATE|os.O_RDWR, 0600)
Check(err, "Failed creating file at %q for private key, err: %s", it.GetKeyLocation(), err)
defer keyFile.Close()
err = pem.Encode(keyFile, pemKey)
Check(err, "Failed encoding private key, err: %s", err)
}
func (it *Manager) LoadKey() {
log.Printf("Loading generated key")
keyFile, err := os.Open(it.GetKeyLocation())
Check(err, "Failed opening private key at %q for reading, err: %s", it.GetKeyLocation(), err)
defer keyFile.Close()
stat, _ := keyFile.Stat()
var size int64 = stat.Size()
pembytes := make([]byte, size)
buffer := bufio.NewReader(keyFile)
_, err = buffer.Read(pembytes)
Check(err, "Failed to read private key at %q, err: %s", it.GetKeyLocation(), err)
data, _ := pem.Decode([]byte(pembytes))
if data == nil {
Error("No valid key found at %q", it.GetKeyLocation())
}
importedKey, err := x509.ParsePKCS1PrivateKey(data.Bytes)
Check(err, "Failed importing private key at %q, err: %s", it.GetKeyLocation(), err)
it.privateKey = importedKey
}
func (it *Manager) Daemon() {
server := NewHttpServer(it)
go server.Start()
it.openVPN = NewOpenVPN(it)
it.openVPN.Start()
ch := make(chan os.Signal)
signal.Notify(ch, syscall.SIGKILL, syscall.SIGTERM, syscall.SIGSTOP, syscall.SIGINT)
for sig := range ch {
log.Printf("Received signal %s, quitting.", sig)
os.Exit(0)
}
}
func (it *Manager) LoadCA() {
log.Printf("Loading CA certificate")
keyFile, err := os.Open(it.GetCALocation())
Check(err, "Failed opening CA certificate at %q for reading, err: %s", it.GetKeyLocation(), err)
defer keyFile.Close()
stat, _ := keyFile.Stat()
var size int64 = stat.Size()
pemBytes := make([]byte, size)
buffer := bufio.NewReader(keyFile)
_, err = buffer.Read(pemBytes)
Check(err, "Failed to read CA certificate at %q, err: %s", it.GetKeyLocation(), err)
data, _ := pem.Decode([]byte(pemBytes))
if data == nil {
Error("No valid CA certificate found at %q", it.GetKeyLocation())
}
ca, err := x509.ParseCertificate(data.Bytes)
Check(err, "Failed importing private key at %q, err: %s", it.GetKeyLocation(), err)
it.CA = ca
}
func (it *Manager) DownloadCA() {
caUrl := it.Config.Zerooo.Endpoint + "/ca"
download("CA", it.GetCALocation(), caUrl)
}
func (it *Manager) CAPublicKey() (*rsa.PublicKey) {
key, ok := it.CA.PublicKey.(*rsa.PublicKey)
if !ok {
Error("Can't read CA Certificate as public key")
}
return key
}
func (it *Manager) CreateCSR(hostname string) ([]byte, error) {
tpl := x509.CertificateRequest{Subject: pkix.Name{CommonName: hostname}}
csr, err := x509.CreateCertificateRequest(rand.Reader, &tpl, it.privateKey)
if err != nil {
return nil, err
}
pemCSR := &pem.Block{
Type: "CERTIFICATE REQUEST",
Bytes: csr,
}
return pem.EncodeToMemory(pemCSR), nil
}
func (it *Manager) UpdateCertificate(s string) error {
f, err := os.OpenFile(it.GetCertificateLocation(), os.O_RDWR|os.O_CREATE, 0644)
if err != nil {
return err
}
defer f.Close()
_, err = f.WriteString(s)
return err
}
func download(name string, path string, url string) {
resp, err := http.Get(url)
Check(err, "Failed fetching %s at %q, err: %s", name, url, err)
cert, err := ioutil.ReadAll(resp.Body)
Check(err, "Failed fetching %s at %q, err: %s", name, url, err)
f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0644)
Check(err, "Failed open %s for writing at %q, err: %s", name, path, err)
defer f.Close()
_, err = f.Write(cert)
Check(err, "Failed writing %s at %q, err: %s", name, path, err)
}
func (it *Manager) DownloadCRL() {
crlUrl := it.Config.Zerooo.Endpoint + "/crl"
download("CRL", it.GetCRLLocation(), crlUrl)
}
func (it *Manager) DownloadDH() {
dhUrl := it.Config.Zerooo.Endpoint + "/dh.pem"
download("DH", it.GetDHLocation(), dhUrl)
}
func (it *Manager) IsOpenVPNReady() bool {
hasConfig := fileExists(it.GetOpenVPNConfigLocation())
hasCert := fileExists(it.GetCertificateLocation())
return hasConfig && hasCert
}