logging and cert handling

This commit is contained in:
Seán C McCord 2023-10-01 20:45:59 -04:00
parent ccb9e7eb51
commit 8f346786f5
Signed by: scm
GPG key ID: 4AF67648FB0336A6
3 changed files with 61 additions and 24 deletions

View file

@ -24,12 +24,15 @@ import (
var ( var (
listenPort int listenPort int
healthPort int healthPort int
enableTLS bool
) )
var allowedDomains []string var allowedDomains []string
func init() { func init() {
flag.IntVar(&listenPort, "smtp", 2525, "port on which to listen for incoming emails") flag.IntVar(&listenPort, "smtp", 2525, "port on which to listen for incoming emails")
flag.IntVar(&listenPort, "smtps", 2526, "port on which to listen for incoming secure emails")
flag.IntVar(&healthPort, "health", 8080, "port on which to listen for health checks") flag.IntVar(&healthPort, "health", 8080, "port on which to listen for health checks")
} }
@ -37,6 +40,8 @@ func main() {
var err error var err error
var log *zap.Logger var log *zap.Logger
flag.Parse()
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer cancel() defer cancel()
@ -60,7 +65,12 @@ func main() {
log.Fatal("no allowed domains") log.Fatal("no allowed domains")
} }
cm, err := localtls.NewCertManager(os.Getenv("CERTIFICATE_NAME"), os.Getenv("CERTIFICATE_NAMESPACE")) var cm *localtls.CertManager
if tlsCertFile, ok := os.LookupEnv("TLS_CERT_FILE"); ok {
enableTLS = true
cm, err = localtls.NewCertManager(tlsCertFile, os.Getenv("TLS_KEY_FILE"))
if err != nil { if err != nil {
log.Fatal("failed to create certificate manager", zap.Error(err)) log.Fatal("failed to create certificate manager", zap.Error(err))
} }
@ -74,12 +84,12 @@ func main() {
cancel() cancel()
} }
}() }()
}
be := &server.Backend{ be := &server.Backend{
AllowedDomains: allowedDomains, AllowedDomains: allowedDomains,
Log: log, Log: log,
LMTPAddress: os.Getenv("LMTP_ADDRESS"), LMTPAddress: os.Getenv("LMTP_ADDRESS"),
RootContext: ctx,
RSpam: milter.NewDefaultClient("tcp", "rspamd:11334"), RSpam: milter.NewDefaultClient("tcp", "rspamd:11334"),
OverallLimiter: rate.NewLimiter(rate.Limit(10), 20), OverallLimiter: rate.NewLimiter(rate.Limit(10), 20),
} }
@ -91,24 +101,37 @@ func main() {
s.Addr = ":2525" s.Addr = ":2525"
s.Domain = "cycore.io" s.Domain = "cycore.io"
s.Debug = os.Stderr
s.WriteTimeout = 10 * time.Second s.WriteTimeout = 10 * time.Second
s.WriteTimeout = 10 * time.Second s.WriteTimeout = 10 * time.Second
s.MaxMessageBytes = 300 * 1024 * 1024 // 300 MiB s.MaxMessageBytes = 300 * 1024 * 1024 // 300 MiB
s.MaxRecipients = 10 s.MaxRecipients = 10
s.TLSConfig = &tls.Config{ s.AuthDisabled = true
GetCertificate: cm.Get,
}
log.Info("starting health service") log.Info("starting health service")
go runHealthService(cm) go runHealthService(cm)
log.Fatal("server exited", zap.Error(s.ListenAndServeTLS())) if enableTLS {
s.TLSConfig = &tls.Config{
GetCertificate: cm.Get,
}
go func() {
defer cancel()
if err := s.ListenAndServeTLS(); err != nil {
log.Error("TLS listener failed", zap.Error(err))
}
}()
}
log.Fatal("server exited", zap.Error(s.ListenAndServe()))
} }
func runHealthService(cm *localtls.CertManager) { func runHealthService(cm *localtls.CertManager) {
http.HandleFunc("/health", func(w http.ResponseWriter, req *http.Request) { http.HandleFunc("/health", func(w http.ResponseWriter, req *http.Request) {
if !cm.Ready() { if cm != nil && !cm.Ready() {
http.Error(w, "no TLS cert", http.StatusInternalServerError) http.Error(w, "no TLS cert", http.StatusInternalServerError)
return return

View file

@ -41,10 +41,16 @@ func NewCertManager(certFile, keyFile string) (*CertManager, error) {
} }
func (cm *CertManager) Get(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { func (cm *CertManager) Get(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
cm.Log.Info("received request for certificate")
if cm.cert == nil { if cm.cert == nil {
cm.Log.Error("no certificate available")
return nil, errors.New("no certificate") return nil, errors.New("no certificate")
} }
cm.Log.Info("returning certificate to requester")
return cm.cert, nil return cm.cert, nil
} }
@ -66,12 +72,16 @@ func (cm *CertManager) Ready() bool {
func (cm *CertManager) Watch(ctx context.Context, log *zap.Logger) error { func (cm *CertManager) Watch(ctx context.Context, log *zap.Logger) error {
cm.Log = log cm.Log = log
log.Info("loading certificate")
if err := cm.Load(); err != nil { if err := cm.Load(); err != nil {
log.Error("failed to load initial certificate", zap.Error(err)) log.Error("failed to load initial certificate", zap.Error(err))
return err return err
} }
log.Info("certificate loaded")
for { for {
select { select {
case <-time.After(time.Hour): case <-time.After(time.Hour):
@ -79,10 +89,14 @@ func (cm *CertManager) Watch(ctx context.Context, log *zap.Logger) error {
return ctx.Err() return ctx.Err()
} }
log.Info("reloading certificate")
if err := cm.Load(); err != nil { if err := cm.Load(); err != nil {
log.Error("failed to reload new certificate", zap.Error(err)) log.Error("failed to reload new certificate", zap.Error(err))
continue continue
} }
log.Info("certificate reloaded")
} }
} }

View file

@ -2,7 +2,6 @@ package server
import ( import (
"bytes" "bytes"
"context"
"crypto/rand" "crypto/rand"
"errors" "errors"
"fmt" "fmt"
@ -47,8 +46,6 @@ type Backend struct {
Log *zap.Logger Log *zap.Logger
RootContext context.Context
RSpam *milter.Client RSpam *milter.Client
OverallLimiter *rate.Limiter OverallLimiter *rate.Limiter
@ -56,8 +53,7 @@ type Backend struct {
// NewSession implements smtp.Backend // NewSession implements smtp.Backend
func (b *Backend) NewSession(conn *smtp.Conn) (smtp.Session, error) { func (b *Backend) NewSession(conn *smtp.Conn) (smtp.Session, error) {
ctx, cancel := context.WithTimeout(b.RootContext, ConnectionTimeout) b.Log.Info("new connection")
defer cancel()
traceID, err := ulid.New(ulid.Now(), rand.Reader) traceID, err := ulid.New(ulid.Now(), rand.Reader)
if err != nil { if err != nil {
@ -68,14 +64,18 @@ func (b *Backend) NewSession(conn *smtp.Conn) (smtp.Session, error) {
log := b.Log.With(zap.String("traceid", traceID.String()), zap.String("remote", conn.Conn().RemoteAddr().String())) log := b.Log.With(zap.String("traceid", traceID.String()), zap.String("remote", conn.Conn().RemoteAddr().String()))
log.Debug("received new connection", zap.String("remote", conn.Conn().RemoteAddr().String())) log.Info("received new connection", zap.String("remote", conn.Conn().RemoteAddr().String()))
if !b.OverallLimiter.Allow() {
log.Info("rate limited connection")
if err := b.OverallLimiter.Wait(ctx); err != nil {
return nil, ErrRateLimit return nil, ErrRateLimit
} }
rs, err := b.RSpam.Session() rs, err := b.RSpam.Session()
if err != nil { if err != nil {
log.Error("failed to create new rspam session", zap.Error(err))
return nil, ErrBadGateway return nil, ErrBadGateway
} }