package crypto import ( "crypto/rand" "crypto/rsa" "crypto/x509" "encoding/pem" "fmt" "os" "path/filepath" ) const ( DefaultKeyBits = 2048 DefaultKeyDir = "./keys" ) type KeyPair struct { PublicKey *rsa.PublicKey PrivateKey *rsa.PrivateKey } func GenerateKeyPair(bits int) (*KeyPair, error) { priv, err := rsa.GenerateKey(rand.Reader, bits) if err != nil { return nil, fmt.Errorf("generating RSA key: %w", err) } return &KeyPair{ PublicKey: &priv.PublicKey, PrivateKey: priv, }, nil } func (k *KeyPair) PublicKeyPEM() ([]byte, error) { pubBytes, err := x509.MarshalPKIXPublicKey(k.PublicKey) if err != nil { return nil, fmt.Errorf("marshaling public key: %w", err) } block := &pem.Block{Type: "PUBLIC KEY", Bytes: pubBytes} return pem.EncodeToMemory(block), nil } func (k *KeyPair) PrivateKeyPEM() ([]byte, error) { block := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(k.PrivateKey)} return pem.EncodeToMemory(block), nil } func LoadKeyPair(pubPath, privPath string) (*KeyPair, error) { pubData, err := os.ReadFile(pubPath) if err != nil { return nil, fmt.Errorf("reading public key: %w", err) } privData, err := os.ReadFile(privPath) if err != nil { return nil, fmt.Errorf("reading private key: %w", err) } block, _ := pem.Decode(pubData) if block == nil { return nil, fmt.Errorf("invalid public key PEM") } pub, err := x509.ParsePKIXPublicKey(block.Bytes) if err != nil { return nil, fmt.Errorf("parsing public key: %w", err) } rsaPub, ok := pub.(*rsa.PublicKey) if !ok { return nil, fmt.Errorf("not an RSA public key") } block, _ = pem.Decode(privData) if block == nil { return nil, fmt.Errorf("invalid private key PEM") } priv, err := x509.ParsePKCS1PrivateKey(block.Bytes) if err != nil { return nil, fmt.Errorf("parsing private key: %w", err) } return &KeyPair{ PublicKey: rsaPub, PrivateKey: priv, }, nil } func LoadOrCreateKeyPair(keyDir string) (*KeyPair, error) { if keyDir == "" { keyDir = DefaultKeyDir } pubPath := filepath.Join(keyDir, "public.pem") privPath := filepath.Join(keyDir, "private.pem") if _, err := os.Stat(pubPath); err == nil { if _, err := os.Stat(privPath); err == nil { return LoadKeyPair(pubPath, privPath) } } if err := os.MkdirAll(keyDir, 0700); err != nil { return nil, fmt.Errorf("creating key directory: %w", err) } keyPair, err := GenerateKeyPair(DefaultKeyBits) if err != nil { return nil, err } pubPEM, err := keyPair.PublicKeyPEM() if err != nil { return nil, err } if err := os.WriteFile(pubPath, pubPEM, 0644); err != nil { return nil, fmt.Errorf("saving public key: %w", err) } privPEM, err := keyPair.PrivateKeyPEM() if err != nil { return nil, err } if err := os.WriteFile(privPath, privPEM, 0600); err != nil { return nil, fmt.Errorf("saving private key: %w", err) } return keyPair, nil } func Encrypt(plain []byte, pub *rsa.PublicKey) ([]byte, error) { return rsa.EncryptPKCS1v15(rand.Reader, pub, plain) } func Decrypt(cipher []byte, priv *rsa.PrivateKey) ([]byte, error) { return rsa.DecryptPKCS1v15(rand.Reader, priv, cipher) }