mail/server/server.go

307 lines
6.7 KiB
Go

package server
import (
"bytes"
"crypto/rand"
"errors"
"fmt"
"io"
"net"
"strings"
"sync"
"time"
"github.com/emersion/go-message/mail"
"github.com/emersion/go-milter"
"github.com/emersion/go-smtp"
"github.com/oklog/ulid"
"github.com/stretchr/readcaster"
"go.uber.org/zap"
"golang.org/x/time/rate"
)
const ConnectionTimeout = 3 * time.Minute
const defaultLMTPTimeout = 30 * time.Second
// ErrRateLimit indicates that the message is rejected due to a rate limit.
var ErrRateLimit = &smtp.SMTPError{
Code: 450,
EnhancedCode: [3]int{4, 2, 0},
Message: "too many requests",
}
// ErrBadGateway indicates a downstream system failure.
var ErrBadGateway = &smtp.SMTPError{
Code: 451,
EnhancedCode: [3]int{4, 4, 3},
Message: "bad gateway",
}
// Backend provides a backend implementation of an SMTP server.
type Backend struct {
AllowedDomains []string
LMTPAddress string
Log *zap.Logger
RSpam *milter.Client
OverallLimiter *rate.Limiter
}
// NewSession implements smtp.Backend
func (b *Backend) NewSession(conn *smtp.Conn) (smtp.Session, error) {
b.Log.Info("new connection")
traceID, err := ulid.New(ulid.Now(), rand.Reader)
if err != nil {
b.Log.Error("failed to create traceID for new connection", zap.Error(err))
return nil, errors.New("internal error")
}
log := b.Log.With(zap.String("traceid", traceID.String()), zap.String("remote", conn.Conn().RemoteAddr().String()))
log.Info("received new connection", zap.String("remote", conn.Conn().RemoteAddr().String()))
if !b.OverallLimiter.Allow() {
log.Info("rate limited connection")
return nil, ErrRateLimit
}
rs, err := b.RSpam.Session()
if err != nil {
log.Error("failed to create new rspam session", zap.Error(err))
return nil, ErrBadGateway
}
return &Session{
b: b,
conn: conn,
log: log,
rspam: rs,
}, nil
}
// Session provides an implementation of an SMTP session.
type Session struct {
b *Backend
conn *smtp.Conn
rspam *milter.ClientSession
log *zap.Logger
to []string
from string
}
// Discard currently processed message.
func (session *Session) Reset() {
}
// Free all resources associated with session.
func (session *Session) Logout() error {
if session.rspam != nil {
if err := session.rspam.Close(); err != nil {
session.log.Error("failed to close rspam session", zap.Error(err))
}
}
return nil
}
// Authenticate the user using SASL PLAIN.
func (session *Session) AuthPlain(username string, password string) error {
return smtp.ErrAuthUnsupported
}
// Set return path for currently processed message.
func (session *Session) Mail(from string, opts *smtp.MailOptions) error {
session.log = session.log.With(zap.String("from", from))
session.log.Debug("received MAIL FROM")
action, err := session.rspam.Mail(from, nil)
if err != nil {
session.log.Error("failed to process from with rspam", zap.Error(err))
return fmt.Errorf("filter failure: %w", err)
}
if err := parseAction(action); err != nil {
session.log.Info("rspam rejected from", zap.Error(err))
return err
}
session.from = from
session.log.Debug("accepted MAIL FROM")
return nil
}
// Add recipient for currently processed message.
func (session *Session) Rcpt(to string) error {
if session.to == nil {
session.log = session.log.With(zap.String("to", to))
}
session.log.Sugar().Debugf("received RCPT TO: %q", to)
if err := session.validateDomain(to); err != nil {
return &smtp.SMTPError{
Code: 550,
EnhancedCode: [3]int{5, 1, 2},
Message: err.Error(),
}
}
action, err := session.rspam.Rcpt(to, nil)
if err != nil {
session.log.Error("failed to process RCPT TO with rspam", zap.Error(err))
return fmt.Errorf("filter failure: %w", err)
}
if err := parseAction(action); err != nil {
session.log.Info("rspam rejected RCPT TO", zap.Error(err))
return err
}
session.to = append(session.to, to)
session.log.Debug("accepted RCPT TO")
return nil
}
// Set currently processed message contents and send it.
//
// r must be consumed before Data returns.
func (session *Session) Data(r io.Reader) error {
session.log.Debug("receiving DATA")
caster := readcaster.New(r)
wg := new(sync.WaitGroup)
msg := new(bytes.Buffer)
wg.Add(1)
go func() {
defer wg.Done()
n, err := msg.ReadFrom(caster.NewReader())
if err != nil {
session.log.Error("failed to read message data into message", zap.Error(err))
return
}
if n == 0 {
session.log.Warn("received empty message")
}
}()
_, action, err := session.rspam.BodyReadFrom(caster.NewReader())
if err != nil {
session.log.Error("rspamd failed to process message body", zap.Error(err))
return fmt.Errorf("filter failure: %w", err)
}
if err := parseAction(action); err != nil {
session.log.Info("rspamd rejected mail data", zap.Error(err))
return err
}
session.log.Debug("message accepted for delivery")
// ensure the message has been received to msg
wg.Wait()
// Forward to LMTP destination
conn, err := net.DialTimeout("tcp", session.b.LMTPAddress, defaultLMTPTimeout)
if err != nil {
session.log.Error("failed to dial LMTP server",
zap.String("lmtp", session.b.LMTPAddress),
zap.Error(err),
)
return errors.New("failed to forward message")
}
host, _, _ := net.SplitHostPort(session.b.LMTPAddress)
lc, err := smtp.NewClientLMTP(conn, host)
if err != nil {
session.log.Error("failed to construct LMTP client", zap.Error(err))
return errors.New("failed to forward message")
}
if err := lc.SendMail(session.from, session.to, msg); err != nil {
session.log.Error("failed to forward message to LMTP server", zap.Error(err))
return err
}
session.log.Debug("message forwarded to LMTP server")
return nil
}
func (session *Session) validateDomain(to string) error {
toAddr, err := mail.ParseAddress(to)
if err != nil {
return fmt.Errorf("failed to parse address: %w", err)
}
pieces := strings.Split(toAddr.Address, "@")
if len(pieces) != 2 {
return fmt.Errorf("invalid address")
}
for _, d := range session.b.AllowedDomains {
if pieces[2] == d {
return nil
}
}
return fmt.Errorf("invalid domain")
}
func parseAction(action *milter.Action) error {
switch action.Code {
case milter.ActAccept:
return nil
case milter.ActContinue:
return nil
case milter.ActDiscard:
return milterReject(action)
case milter.ActReject:
return milterReject(action)
case milter.ActTempFail:
return milterReject(action)
case milter.ActReplyCode:
return milterReject(action)
case milter.ActSkip:
return milterReject(action)
}
return fmt.Errorf("unhandled action code %q", action.Code)
}
func milterReject(action *milter.Action) error {
return &smtp.SMTPError{
Code: action.SMTPCode,
EnhancedCode: [3]int{4, 2, 0},
Message: action.SMTPText,
}
}