devroxy

VHost Proxy Server for localhost
git clone git://git.lair.cx/devroxy
Log | Files | Refs | README

commit aa20968d1fc8374c36f08a2093d2fca1db830482
parent 09aa9db87601fbdcf6cdf27068de65e448ab05b3
Author: Yongbin Kim <iam@yongbin.kim>
Date:   Sat, 24 Sep 2022 00:54:58 +0900

feature: Added certificate generator

Now the server generates new tls certificate when client touches
with unknown domain. since client can send https request for
every domain, plain http mode is removed.

Signed-off-by: Yongbin Kim <iam@yongbin.kim>

Diffstat:
M.gitignore | 5-----
MMakefile | 2+-
Ainternal/certificates/ca.go | 125+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Ainternal/certificates/utils.go | 199+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Minternal/devroxy/api.go | 8--------
Minternal/devroxy/devroxy.go | 37++++---------------------------------
Ainternal/devroxy/paths.go | 19+++++++++++++++++++
Ainternal/devroxy/proxy.go | 57+++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Mmain.go | 48++++++++++++++++++++++++------------------------
9 files changed, 429 insertions(+), 71 deletions(-)

diff --git a/.gitignore b/.gitignore @@ -1,8 +1,3 @@ -# TLS related files -*.pem -*.cert -*.key - # Created by https://www.toptal.com/developers/gitignore/api/macos,windows,linux,goland # Edit at https://www.toptal.com/developers/gitignore?templates=macos,windows,linux,goland diff --git a/Makefile b/Makefile @@ -10,4 +10,4 @@ build: go build -o ${BUILD_OUT} . run: - go run . -binds binds.yaml -cert cert.pem -key key.pem + go run . -binds binds.yaml diff --git a/internal/certificates/ca.go b/internal/certificates/ca.go @@ -0,0 +1,125 @@ +package certificates + +import ( + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "fmt" + "os" + "path/filepath" + + "github.com/rs/zerolog/log" +) + +const ( + certFile = "cert.pem" + keyFile = "key.pem" +) + +// CA contains root certificate and key, and generates certificate. +type CA struct { + wd string + rootCert *x509.Certificate + rootKey *rsa.PrivateKey +} + +func NewCA(wd, cert, key string) (*CA, error) { + var ( + rootKey *rsa.PrivateKey + rootCert *x509.Certificate + err error + ) + + if _, err = os.Stat(key); err != nil { + if !os.IsNotExist(err) { + return nil, err + } + + rootKey, err = generateKey(key) + if err != nil { + return nil, fmt.Errorf("failed to generate private key: %w", err) + } + } else { + rootKey, err = readKey(key) + if err != nil { + return nil, fmt.Errorf("failed to load private key: %w", err) + } + } + + if _, err = os.Stat(cert); err != nil { + if !os.IsNotExist(err) { + return nil, err + } + + rootCert, err = generateRootCertificate(cert, rootKey) + if err != nil { + return nil, fmt.Errorf("failed to generate private key: %w", err) + } + } else { + rootCert, err = readCertificate(cert) + if err != nil { + return nil, fmt.Errorf("certificate load error: %w", err) + } + } + + if err := checkCertificate(rootCert, rootKey); err != nil { + return nil, fmt.Errorf("invalid certificate: %w", err) + } + + return &CA{ + wd: wd, + rootCert: rootCert, + rootKey: rootKey, + }, nil +} + +// GenerateCertificate generates new certificate, writes to `(wd)/certs/(host)/` +func (ca *CA) GenerateCertificate(name string) (tls.Certificate, error) { + err := os.MkdirAll(filepath.Join(ca.wd, name), 0755) + if err != nil { + return tls.Certificate{}, err + } + + key, err := generateKey(filepath.Join(ca.wd, name, keyFile)) + if err != nil { + return tls.Certificate{}, err + } + + cert, err := generateCertificate( + filepath.Join(ca.wd, name, certFile), + name, + key, + ca.rootCert, + ca.rootKey, + ) + if err != nil { + return tls.Certificate{}, err + } + + return tls.Certificate{ + Certificate: [][]byte{cert}, + PrivateKey: key, + }, nil +} + +// GetCertificate returns SSL certificate if exists, otherwise generate new one. +func (ca *CA) GetCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + cert, err := tls.LoadX509KeyPair( + filepath.Join(ca.wd, info.ServerName, certFile), + filepath.Join(ca.wd, info.ServerName, keyFile), + ) + if os.IsNotExist(err) { + log.Info(). + Str("host", info.ServerName). + Msg("certificate not found; generating new one...") + cert, err = ca.GenerateCertificate(info.ServerName) + } + if err != nil { + log.Err(err). + Str("host", info.ServerName). + Msg("failed to get certificate") + return nil, err + } + + return &cert, nil +} diff --git a/internal/certificates/utils.go b/internal/certificates/utils.go @@ -0,0 +1,199 @@ +package certificates + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "errors" + "fmt" + "math" + "math/big" + "os" + "time" + + "github.com/valyala/bytebufferpool" +) + +func readCertificate(filename string) (*x509.Certificate, error) { + data, err := os.ReadFile(filename) + if err != nil { + return nil, err + } + + block, _ := pem.Decode(data) + if block == nil || block.Type != "CERTIFICATE" { + return nil, errors.New("invalid certificate") + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, err + } + + return cert, nil +} + +func readKey(file string) (*rsa.PrivateKey, error) { + data, err := os.ReadFile(file) + if err != nil { + return nil, err + } + + block, _ := pem.Decode(data) + if block.Type != "RSA PRIVATE KEY" { + return nil, errors.New("only rsa private key is supported") + } + + return x509.ParsePKCS1PrivateKey(block.Bytes) +} + +func generateCertificateImpl( + filename string, + template *x509.Certificate, + key *rsa.PrivateKey, + parent *x509.Certificate, + rootKey *rsa.PrivateKey, +) ([]byte, error) { + if parent == nil { + parent = template + } + + data, err := x509.CreateCertificate( + rand.Reader, + template, + parent, + key.Public(), + rootKey, + ) + if err != nil { + return nil, err + } + + buf := bytebufferpool.Get() + defer bytebufferpool.Put(buf) + + err = pem.Encode(buf, &pem.Block{ + Type: "CERTIFICATE", + Bytes: data, + }) + if err != nil { + return nil, fmt.Errorf("failed to encode certificate: %w", err) + } + + err = os.WriteFile(filename, buf.B, 0600) + if err != nil { + return nil, err + } + + return data, nil +} + +func generateCertificate( + filename string, + domain string, + key *rsa.PrivateKey, + parent *x509.Certificate, + rootKey *rsa.PrivateKey, +) ([]byte, error) { + serial, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt64)) + if err != nil { + return nil, err + } + + return generateCertificateImpl( + filename, + &x509.Certificate{ + DNSNames: []string{domain}, + Subject: pkix.Name{CommonName: domain}, + SerialNumber: serial, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(2, 0, 30), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + BasicConstraintsValid: true, + }, + key, + parent, + rootKey, + ) +} + +func generateRootCertificate(filename string, key *rsa.PrivateKey) (*x509.Certificate, error) { + serial, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt64)) + if err != nil { + return nil, err + } + + template := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "Devroxy Internal CA ", + }, + SerialNumber: serial, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(2, 0, 30), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + BasicConstraintsValid: true, + IsCA: true, + } + + data, err := generateCertificateImpl( + filename, + template, + key, + nil, + key, + ) + if err != nil { + return nil, err + } + + return x509.ParseCertificate(data) +} + +func generateKey(filename string) (*rsa.PrivateKey, error) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + err = writeKey(filename, key) + if err != nil { + return nil, err + } + return key, nil +} + +func writeKey(file string, key *rsa.PrivateKey) error { + data := x509.MarshalPKCS1PrivateKey(key) + + buf := bytebufferpool.Get() + defer bytebufferpool.Put(buf) + + err := pem.Encode(buf, &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: data, + }) + if err != nil { + return err + } + + return os.WriteFile(file, buf.Bytes(), 0600) +} + +func checkCertificate(cert *x509.Certificate, key *rsa.PrivateKey) error { + if cert.PublicKeyAlgorithm != x509.RSA { + return errors.New("only rsa certificate is supported") + } + + dataA := x509.MarshalPKCS1PublicKey(cert.PublicKey.(*rsa.PublicKey)) + dataB := x509.MarshalPKCS1PublicKey(&key.PublicKey) + + if !bytes.Equal(dataA, dataB) { + return errors.New("public key not match") + } + + return nil +} diff --git a/internal/devroxy/api.go b/internal/devroxy/api.go @@ -45,11 +45,3 @@ func (d *Devroxy) handleSaveBinds(w http.ResponseWriter, _ *http.Request) { func (d *Devroxy) handleNotFound(w http.ResponseWriter, _ *http.Request) { sendMessage(w, http.StatusNotFound, "not found") } - -func (d *Devroxy) handleProxyError(w http.ResponseWriter, r *http.Request, err error) { - log.Err(err). - Str("host", r.Host). - Str("path", r.URL.Path). - Msg("proxy error") - sendError(w, http.StatusBadGateway, err) -} diff --git a/internal/devroxy/devroxy.go b/internal/devroxy/devroxy.go @@ -12,8 +12,8 @@ import ( type Devroxy struct { mutex sync.RWMutex - proxyHandler *httputil.ReverseProxy - mux *http.ServeMux + proxy *httputil.ReverseProxy + mux *http.ServeMux binds map[string]int bindsDest string @@ -37,13 +37,11 @@ func New(addr string) *Devroxy { } func (d *Devroxy) setup() *Devroxy { - d.proxyHandler = &httputil.ReverseProxy{ + d.proxy = &httputil.ReverseProxy{ Director: d.director, ErrorHandler: d.handleProxyError, } - // select * from table where cond=1 order by asc desc; - d.mux = http.NewServeMux() d.mux.HandleFunc("/devroxy/binds", d.handleListBinds) d.mux.HandleFunc("/devroxy/binds/register", d.handleRegisterBind) @@ -51,7 +49,7 @@ func (d *Devroxy) setup() *Devroxy { d.mux.HandleFunc("/devroxy/binds/save", d.handleSaveBinds) // d.mux.HandleFunc("/devroxy/not-found", d.handleNotFound) d.mux.HandleFunc("/devroxy/", d.handleNotFound) - d.mux.Handle("/", d.proxyHandler) + d.mux.HandleFunc("/", d.handleProxy) return d } @@ -66,30 +64,3 @@ func (d *Devroxy) LoadBinds(bindsFile string) error { d.bindsDest = bindsFile return loadBindImpl(d.binds, bindsFile) } - -func (d *Devroxy) director(r *http.Request) { - d.mutex.RLock() - defer d.mutex.RUnlock() - - var dest string - if port, ok := d.binds[r.Host]; ok { - dest = fmt.Sprintf("%s:%d", d.internalIP, port) - } else { - dest = d.internalAddr - r.URL.Path = "/devroxy/not-found" - r.URL.RawPath = "/devroxy/not-found" - r.URL.RawQuery = "" - } - - log.Info(). - Str("host", r.Host). - Str("path", r.URL.Path). - Msg("proxy requested") - - r.URL.Scheme = "http" - r.URL.Host = dest - - if _, ok := r.Header["User-Agent"]; !ok { - r.Header.Set("User-Agent", "") - } -} diff --git a/internal/devroxy/paths.go b/internal/devroxy/paths.go @@ -0,0 +1,19 @@ +package devroxy + +import ( + "os" + "path" + "path/filepath" + + "github.com/rs/zerolog/log" +) + +const ConfDir = ".config/devroxy" + +func GetConfDir(p ...string) string { + homeDir, err := os.UserHomeDir() + if err != nil { + log.Fatal().Err(err).Msg("Failed to get home directory") + } + return filepath.Join(homeDir, ConfDir, path.Join(p...)) +} diff --git a/internal/devroxy/proxy.go b/internal/devroxy/proxy.go @@ -0,0 +1,57 @@ +package devroxy + +import ( + "fmt" + "net/http" + + "github.com/rs/zerolog/log" +) + +func (d *Devroxy) handleProxy(w http.ResponseWriter, r *http.Request) { + d.mutex.RLock() + _, ok := d.binds[r.Host] + d.mutex.RUnlock() + + if !ok { + d.handleNotFound(w, r) + return + } + + d.proxy.ServeHTTP(w, r) +} + +func (d *Devroxy) director(r *http.Request) { + var dest string + + d.mutex.RLock() + if port, ok := d.binds[r.Host]; ok { + dest = fmt.Sprintf("%s:%d", d.internalIP, port) + } else { + dest = d.internalAddr + r.URL.Scheme = "https" + r.URL.Path = "/devroxy/not-found" + r.URL.RawPath = "/devroxy/not-found" + r.URL.RawQuery = "" + } + d.mutex.RUnlock() + + log.Info(). + Str("host", r.Host). + Str("path", r.URL.Path). + Msg("proxy requested") + + r.URL.Scheme = "http" + r.URL.Host = dest + + if _, ok := r.Header["User-Agent"]; !ok { + r.Header.Set("User-Agent", "") + } +} + +func (d *Devroxy) handleProxyError(w http.ResponseWriter, r *http.Request, err error) { + log.Err(err). + Str("host", r.Host). + Str("path", r.URL.Path). + Msg("proxy error") + sendError(w, http.StatusBadGateway, err) +} diff --git a/main.go b/main.go @@ -2,13 +2,16 @@ package main import ( "context" + "crypto/tls" "flag" "net/http" "os" "os/signal" + "path" "syscall" "time" + "devroxy/internal/certificates" "devroxy/internal/devroxy" "github.com/rs/zerolog" @@ -16,10 +19,10 @@ import ( ) var ( - flagAddr = flag.String("addr", "", "Address that server listening for. if omitted, server uses default port (:80 or :443)") - flagBinds = flag.String("binds", "", "Bind file location. if omitted, devroxy runs in-memory mode.") - flagCert = flag.String("cert", "", "SSL certification. if omitted, server will listen for http request.") - flagKey = flag.String("key", "", "SSL key. if omitted, server will listen for http request.") + flagAddr = flag.String("addr", ":443", "Address that server listening for.") + flagBinds = flag.String("binds", "", "Bind file location. if omitted, devroxy runs in-memory mode.") + flagRootCert = flag.String("ca-cert", "root.pem", "Root CA certificate. if omitted, server will listen for http request.") + flagRootKey = flag.String("ca-key", "root-key.pem", "Root CA private key. if omitted, server will listen for http request.") ) const ( @@ -31,17 +34,9 @@ func main() { flag.Parse() zerolog.TimeFieldFormat = zerolog.TimeFormatUnix - useTLS := len(*flagCert) > 0 && len(*flagKey) > 0 - if !useTLS && (len(*flagCert) > 0 && len(*flagKey) > 0) { - - } - if len(*flagAddr) == 0 { - if useTLS { - *flagAddr = ":80" - } else { - *flagAddr = ":443" - } - } + log.Debug(). + Str("addr", *flagAddr). + Msg("address") d := devroxy.New(*flagAddr) @@ -52,27 +47,32 @@ func main() { } } + caRoot := devroxy.GetConfDir("certs") + ca, err := certificates.NewCA( + caRoot, + path.Join(caRoot, *flagRootCert), + path.Join(caRoot, *flagRootKey), + ) + if err != nil { + log.Fatal(). + Err(err). + Msg("failed to initialize ca") + } + server := &http.Server{ Addr: *flagAddr, Handler: d.Handler(), ReadTimeout: ReadTimeout, WriteTimeout: WriteTimeout, + TLSConfig: &tls.Config{GetCertificate: ca.GetCertificate}, } errChan := make(chan error, 1) go func() { log.Info(). Str("addr", *flagAddr). - Bool("tls", useTLS). Msg("server started") - - var err error - if useTLS { - err = server.ListenAndServe() - } else { - err = server.ListenAndServeTLS(*flagCert, *flagKey) - } - + err := server.ListenAndServeTLS("", "") if err != http.ErrServerClosed { errChan <- err }