package server import ( "bytes" "crypto/rand" "errors" "fmt" "io" "net" "strings" "sync" "time" "github.com/emersion/go-message/mail" "github.com/emersion/go-milter" "github.com/emersion/go-smtp" "github.com/oklog/ulid" "github.com/stretchr/readcaster" "go.uber.org/zap" "golang.org/x/time/rate" ) const ConnectionTimeout = 3 * time.Minute const defaultLMTPTimeout = 30 * time.Second // ErrRateLimit indicates that the message is rejected due to a rate limit. var ErrRateLimit = &smtp.SMTPError{ Code: 450, EnhancedCode: [3]int{4, 2, 0}, Message: "too many requests", } // ErrBadGateway indicates a downstream system failure. var ErrBadGateway = &smtp.SMTPError{ Code: 451, EnhancedCode: [3]int{4, 4, 3}, Message: "bad gateway", } // Backend provides a backend implementation of an SMTP server. type Backend struct { AllowedDomains []string LMTPAddress string Log *zap.Logger RSpam *milter.Client OverallLimiter *rate.Limiter } // NewSession implements smtp.Backend func (b *Backend) NewSession(conn *smtp.Conn) (smtp.Session, error) { b.Log.Info("new connection") traceID, err := ulid.New(ulid.Now(), rand.Reader) if err != nil { b.Log.Error("failed to create traceID for new connection", zap.Error(err)) return nil, errors.New("internal error") } log := b.Log.With(zap.String("traceid", traceID.String()), zap.String("remote", conn.Conn().RemoteAddr().String())) log.Info("received new connection", zap.String("remote", conn.Conn().RemoteAddr().String())) if !b.OverallLimiter.Allow() { log.Info("rate limited connection") return nil, ErrRateLimit } rs, err := b.RSpam.Session() if err != nil { log.Error("failed to create new rspam session", zap.Error(err)) return nil, ErrBadGateway } return &Session{ b: b, conn: conn, log: log, rspam: rs, }, nil } // Session provides an implementation of an SMTP session. type Session struct { b *Backend conn *smtp.Conn rspam *milter.ClientSession log *zap.Logger to []string from string } // Discard currently processed message. func (session *Session) Reset() { } // Free all resources associated with session. func (session *Session) Logout() error { if session.rspam != nil { if err := session.rspam.Close(); err != nil { session.log.Error("failed to close rspam session", zap.Error(err)) } } return nil } // Authenticate the user using SASL PLAIN. func (session *Session) AuthPlain(username string, password string) error { return smtp.ErrAuthUnsupported } // Set return path for currently processed message. func (session *Session) Mail(from string, opts *smtp.MailOptions) error { session.log = session.log.With(zap.String("from", from)) session.log.Debug("received MAIL FROM") action, err := session.rspam.Mail(from, nil) if err != nil { session.log.Error("failed to process from with rspam", zap.Error(err)) return fmt.Errorf("filter failure: %w", err) } if err := parseAction(action); err != nil { session.log.Info("rspam rejected from", zap.Error(err)) return err } session.from = from session.log.Debug("accepted MAIL FROM") return nil } // Add recipient for currently processed message. func (session *Session) Rcpt(to string) error { if session.to == nil { session.log = session.log.With(zap.String("to", to)) } session.log.Sugar().Debugf("received RCPT TO: %q", to) if err := session.validateDomain(to); err != nil { return &smtp.SMTPError{ Code: 550, EnhancedCode: [3]int{5, 1, 2}, Message: err.Error(), } } action, err := session.rspam.Rcpt(to, nil) if err != nil { session.log.Error("failed to process RCPT TO with rspam", zap.Error(err)) return fmt.Errorf("filter failure: %w", err) } if err := parseAction(action); err != nil { session.log.Info("rspam rejected RCPT TO", zap.Error(err)) return err } session.to = append(session.to, to) session.log.Debug("accepted RCPT TO") return nil } // Set currently processed message contents and send it. // // r must be consumed before Data returns. func (session *Session) Data(r io.Reader) error { session.log.Debug("receiving DATA") caster := readcaster.New(r) wg := new(sync.WaitGroup) msg := new(bytes.Buffer) wg.Add(1) go func() { defer wg.Done() n, err := msg.ReadFrom(caster.NewReader()) if err != nil { session.log.Error("failed to read message data into message", zap.Error(err)) return } if n == 0 { session.log.Warn("received empty message") } }() _, action, err := session.rspam.BodyReadFrom(caster.NewReader()) if err != nil { session.log.Error("rspamd failed to process message body", zap.Error(err)) return fmt.Errorf("filter failure: %w", err) } if err := parseAction(action); err != nil { session.log.Info("rspamd rejected mail data", zap.Error(err)) return err } session.log.Debug("message accepted for delivery") // ensure the message has been received to msg wg.Wait() // Forward to LMTP destination conn, err := net.DialTimeout("tcp", session.b.LMTPAddress, defaultLMTPTimeout) if err != nil { session.log.Error("failed to dial LMTP server", zap.String("lmtp", session.b.LMTPAddress), zap.Error(err), ) return errors.New("failed to forward message") } host, _, _ := net.SplitHostPort(session.b.LMTPAddress) lc, err := smtp.NewClientLMTP(conn, host) if err != nil { session.log.Error("failed to construct LMTP client", zap.Error(err)) return errors.New("failed to forward message") } if err := lc.SendMail(session.from, session.to, msg); err != nil { session.log.Error("failed to forward message to LMTP server", zap.Error(err)) return err } session.log.Debug("message forwarded to LMTP server") return nil } func (session *Session) validateDomain(to string) error { toAddr, err := mail.ParseAddress(to) if err != nil { return fmt.Errorf("failed to parse address: %w", err) } pieces := strings.Split(toAddr.Address, "@") if len(pieces) != 2 { return fmt.Errorf("invalid address") } for _, d := range session.b.AllowedDomains { if pieces[2] == d { return nil } } return fmt.Errorf("invalid domain") } func parseAction(action *milter.Action) error { switch action.Code { case milter.ActAccept: return nil case milter.ActContinue: return nil case milter.ActDiscard: return milterReject(action) case milter.ActReject: return milterReject(action) case milter.ActTempFail: return milterReject(action) case milter.ActReplyCode: return milterReject(action) case milter.ActSkip: return milterReject(action) } return fmt.Errorf("unhandled action code %q", action.Code) } func milterReject(action *milter.Action) error { return &smtp.SMTPError{ Code: action.SMTPCode, EnhancedCode: [3]int{4, 2, 0}, Message: action.SMTPText, } }