You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

247 lines
5.8 KiB

// Package tls provides a TLS/ACME transport for the relay.
package tls
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"golang.org/x/crypto/acme/autocert"
"lol.mleku.dev/chk"
"lol.mleku.dev/log"
)
// Config holds TLS transport configuration.
type Config struct {
// Domains is the list of domains for ACME auto-cert.
Domains []string
// Certs is a list of manual certificate paths (without extension).
// For each path, .pem and .key files are loaded.
Certs []string
// DataDir is the base data directory for the autocert cache.
DataDir string
// Handler is the HTTP handler to serve.
Handler http.Handler
}
// Transport serves HTTPS with automatic or manual TLS certificates.
// It runs two servers: HTTPS on :443 and HTTP on :80 for ACME challenges.
type Transport struct {
cfg *Config
tlsServer *http.Server
httpServer *http.Server
mu sync.Mutex
}
// New creates a new TLS transport.
func New(cfg *Config) *Transport {
return &Transport{cfg: cfg}
}
func (t *Transport) Name() string { return "tls" }
func (t *Transport) Start(ctx context.Context) error {
t.mu.Lock()
defer t.mu.Unlock()
if err := ValidateConfig(t.cfg.Domains, t.cfg.Certs); err != nil {
return fmt.Errorf("invalid TLS configuration: %w", err)
}
// Create cache directory for autocert
cacheDir := filepath.Join(t.cfg.DataDir, "autocert")
if err := os.MkdirAll(cacheDir, 0700); err != nil {
return fmt.Errorf("failed to create autocert cache directory: %w", err)
}
// Set up autocert manager
m := &autocert.Manager{
Prompt: autocert.AcceptTOS,
Cache: autocert.DirCache(cacheDir),
HostPolicy: autocert.HostWhitelist(t.cfg.Domains...),
}
// Create TLS server on port 443
t.tlsServer = &http.Server{
Addr: ":443",
Handler: t.cfg.Handler,
TLSConfig: tlsConfig(m, t.cfg.Certs...),
}
// Create HTTP server for ACME challenges and redirects on port 80
t.httpServer = &http.Server{
Addr: ":80",
Handler: m.HTTPHandler(nil),
}
log.I.F("TLS enabled for domains: %v", t.cfg.Domains)
// Start TLS server
go func() {
log.I.F("starting TLS listener on https://:443")
if err := t.tlsServer.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed {
log.E.F("TLS server error: %v", err)
}
}()
// Start HTTP server for ACME challenges
go func() {
log.I.F("starting HTTP listener on http://:80 for ACME challenges")
if err := t.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.E.F("HTTP server error: %v", err)
}
}()
return nil
}
func (t *Transport) Stop(ctx context.Context) error {
t.mu.Lock()
defer t.mu.Unlock()
var firstErr error
if t.tlsServer != nil {
if err := t.tlsServer.Shutdown(ctx); err != nil {
log.E.F("TLS server shutdown error: %v", err)
firstErr = err
} else {
log.I.F("TLS server shutdown completed")
}
}
if t.httpServer != nil {
if err := t.httpServer.Shutdown(ctx); err != nil {
log.E.F("HTTP server shutdown error: %v", err)
if firstErr == nil {
firstErr = err
}
} else {
log.I.F("HTTP server shutdown completed")
}
}
return firstErr
}
func (t *Transport) Addresses() []string {
var addrs []string
for _, domain := range t.cfg.Domains {
addrs = append(addrs, "wss://"+domain+"/")
}
return addrs
}
// ValidateConfig checks if the TLS configuration is valid.
func ValidateConfig(domains []string, certs []string) error {
if len(domains) == 0 {
return fmt.Errorf("no TLS domains specified")
}
for _, domain := range domains {
if domain == "" {
continue
}
if strings.Contains(domain, " ") || strings.Contains(domain, "\t") {
return fmt.Errorf("invalid domain name: %s", domain)
}
}
return nil
}
// tlsConfig returns a TLS configuration that works with LetsEncrypt automatic
// SSL cert issuer as well as any provided certificate files.
//
// Certs are provided as paths where .pem and .key files exist.
func tlsConfig(m *autocert.Manager, certs ...string) *tls.Config {
certMap := make(map[string]*tls.Certificate)
var mx sync.Mutex
for _, certPath := range certs {
if certPath == "" {
continue
}
var err error
var c tls.Certificate
if c, err = tls.LoadX509KeyPair(
certPath+".pem", certPath+".key",
); chk.E(err) {
log.E.F("failed to load certificate from %s: %v", certPath, err)
continue
}
if len(c.Certificate) > 0 {
if x509Cert, err := x509.ParseCertificate(c.Certificate[0]); err == nil {
if x509Cert.Subject.CommonName != "" {
certMap[x509Cert.Subject.CommonName] = &c
log.I.F("loaded certificate for domain: %s", x509Cert.Subject.CommonName)
}
for _, san := range x509Cert.DNSNames {
if san != "" {
certMap[san] = &c
log.I.F("loaded certificate for SAN domain: %s", san)
}
}
}
}
}
if m == nil {
return &tls.Config{
GetCertificate: func(helo *tls.ClientHelloInfo) (*tls.Certificate, error) {
mx.Lock()
defer mx.Unlock()
if cert, exists := certMap[helo.ServerName]; exists {
return cert, nil
}
for domain, cert := range certMap {
if strings.HasPrefix(domain, "*.") {
baseDomain := domain[2:]
if strings.HasSuffix(helo.ServerName, baseDomain) {
return cert, nil
}
}
}
return nil, fmt.Errorf("no certificate found for %s", helo.ServerName)
},
}
}
tc := m.TLSConfig()
tc.GetCertificate = func(helo *tls.ClientHelloInfo) (*tls.Certificate, error) {
mx.Lock()
if cert, exists := certMap[helo.ServerName]; exists {
mx.Unlock()
return cert, nil
}
for domain, cert := range certMap {
if strings.HasPrefix(domain, "*.") {
baseDomain := domain[2:]
if strings.HasSuffix(helo.ServerName, baseDomain) {
mx.Unlock()
return cert, nil
}
}
}
mx.Unlock()
return m.GetCertificate(helo)
}
return tc
}