logging and cert handling

This commit is contained in:
Seán C McCord 2023-10-01 20:45:59 -04:00
parent ccb9e7eb51
commit 8f346786f5
Signed by: scm
GPG key ID: 4AF67648FB0336A6
3 changed files with 61 additions and 24 deletions

View file

@ -24,12 +24,15 @@ import (
var (
listenPort int
healthPort int
enableTLS bool
)
var allowedDomains []string
func init() {
flag.IntVar(&listenPort, "smtp", 2525, "port on which to listen for incoming emails")
flag.IntVar(&listenPort, "smtps", 2526, "port on which to listen for incoming secure emails")
flag.IntVar(&healthPort, "health", 8080, "port on which to listen for health checks")
}
@ -37,6 +40,8 @@ func main() {
var err error
var log *zap.Logger
flag.Parse()
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer cancel()
@ -60,7 +65,12 @@ func main() {
log.Fatal("no allowed domains")
}
cm, err := localtls.NewCertManager(os.Getenv("CERTIFICATE_NAME"), os.Getenv("CERTIFICATE_NAMESPACE"))
var cm *localtls.CertManager
if tlsCertFile, ok := os.LookupEnv("TLS_CERT_FILE"); ok {
enableTLS = true
cm, err = localtls.NewCertManager(tlsCertFile, os.Getenv("TLS_KEY_FILE"))
if err != nil {
log.Fatal("failed to create certificate manager", zap.Error(err))
}
@ -74,12 +84,12 @@ func main() {
cancel()
}
}()
}
be := &server.Backend{
AllowedDomains: allowedDomains,
Log: log,
LMTPAddress: os.Getenv("LMTP_ADDRESS"),
RootContext: ctx,
RSpam: milter.NewDefaultClient("tcp", "rspamd:11334"),
OverallLimiter: rate.NewLimiter(rate.Limit(10), 20),
}
@ -91,24 +101,37 @@ func main() {
s.Addr = ":2525"
s.Domain = "cycore.io"
s.Debug = os.Stderr
s.WriteTimeout = 10 * time.Second
s.WriteTimeout = 10 * time.Second
s.MaxMessageBytes = 300 * 1024 * 1024 // 300 MiB
s.MaxRecipients = 10
s.TLSConfig = &tls.Config{
GetCertificate: cm.Get,
}
s.AuthDisabled = true
log.Info("starting health service")
go runHealthService(cm)
log.Fatal("server exited", zap.Error(s.ListenAndServeTLS()))
if enableTLS {
s.TLSConfig = &tls.Config{
GetCertificate: cm.Get,
}
go func() {
defer cancel()
if err := s.ListenAndServeTLS(); err != nil {
log.Error("TLS listener failed", zap.Error(err))
}
}()
}
log.Fatal("server exited", zap.Error(s.ListenAndServe()))
}
func runHealthService(cm *localtls.CertManager) {
http.HandleFunc("/health", func(w http.ResponseWriter, req *http.Request) {
if !cm.Ready() {
if cm != nil && !cm.Ready() {
http.Error(w, "no TLS cert", http.StatusInternalServerError)
return

View file

@ -41,10 +41,16 @@ func NewCertManager(certFile, keyFile string) (*CertManager, error) {
}
func (cm *CertManager) Get(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
cm.Log.Info("received request for certificate")
if cm.cert == nil {
cm.Log.Error("no certificate available")
return nil, errors.New("no certificate")
}
cm.Log.Info("returning certificate to requester")
return cm.cert, nil
}
@ -66,12 +72,16 @@ func (cm *CertManager) Ready() bool {
func (cm *CertManager) Watch(ctx context.Context, log *zap.Logger) error {
cm.Log = log
log.Info("loading certificate")
if err := cm.Load(); err != nil {
log.Error("failed to load initial certificate", zap.Error(err))
return err
}
log.Info("certificate loaded")
for {
select {
case <-time.After(time.Hour):
@ -79,10 +89,14 @@ func (cm *CertManager) Watch(ctx context.Context, log *zap.Logger) error {
return ctx.Err()
}
log.Info("reloading certificate")
if err := cm.Load(); err != nil {
log.Error("failed to reload new certificate", zap.Error(err))
continue
}
log.Info("certificate reloaded")
}
}

View file

@ -2,7 +2,6 @@ package server
import (
"bytes"
"context"
"crypto/rand"
"errors"
"fmt"
@ -47,8 +46,6 @@ type Backend struct {
Log *zap.Logger
RootContext context.Context
RSpam *milter.Client
OverallLimiter *rate.Limiter
@ -56,8 +53,7 @@ type Backend struct {
// NewSession implements smtp.Backend
func (b *Backend) NewSession(conn *smtp.Conn) (smtp.Session, error) {
ctx, cancel := context.WithTimeout(b.RootContext, ConnectionTimeout)
defer cancel()
b.Log.Info("new connection")
traceID, err := ulid.New(ulid.Now(), rand.Reader)
if err != nil {
@ -68,14 +64,18 @@ func (b *Backend) NewSession(conn *smtp.Conn) (smtp.Session, error) {
log := b.Log.With(zap.String("traceid", traceID.String()), zap.String("remote", conn.Conn().RemoteAddr().String()))
log.Debug("received new connection", zap.String("remote", conn.Conn().RemoteAddr().String()))
log.Info("received new connection", zap.String("remote", conn.Conn().RemoteAddr().String()))
if !b.OverallLimiter.Allow() {
log.Info("rate limited connection")
if err := b.OverallLimiter.Wait(ctx); err != nil {
return nil, ErrRateLimit
}
rs, err := b.RSpam.Session()
if err != nil {
log.Error("failed to create new rspam session", zap.Error(err))
return nil, ErrBadGateway
}