VerifactuMidAPI/internal/crypto/crypto.go

140 lines
3.2 KiB
Go
Raw Permalink Normal View History

package crypto
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"fmt"
"os"
"path/filepath"
)
const (
DefaultKeyBits = 2048
DefaultKeyDir = "./keys"
)
type KeyPair struct {
PublicKey *rsa.PublicKey
PrivateKey *rsa.PrivateKey
}
func GenerateKeyPair(bits int) (*KeyPair, error) {
priv, err := rsa.GenerateKey(rand.Reader, bits)
if err != nil {
return nil, fmt.Errorf("generating RSA key: %w", err)
}
return &KeyPair{
PublicKey: &priv.PublicKey,
PrivateKey: priv,
}, nil
}
func (k *KeyPair) PublicKeyPEM() ([]byte, error) {
pubBytes, err := x509.MarshalPKIXPublicKey(k.PublicKey)
if err != nil {
return nil, fmt.Errorf("marshaling public key: %w", err)
}
block := &pem.Block{Type: "PUBLIC KEY", Bytes: pubBytes}
return pem.EncodeToMemory(block), nil
}
func (k *KeyPair) PrivateKeyPEM() ([]byte, error) {
block := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(k.PrivateKey)}
return pem.EncodeToMemory(block), nil
}
func LoadKeyPair(pubPath, privPath string) (*KeyPair, error) {
pubData, err := os.ReadFile(pubPath)
if err != nil {
return nil, fmt.Errorf("reading public key: %w", err)
}
privData, err := os.ReadFile(privPath)
if err != nil {
return nil, fmt.Errorf("reading private key: %w", err)
}
block, _ := pem.Decode(pubData)
if block == nil {
return nil, fmt.Errorf("invalid public key PEM")
}
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("parsing public key: %w", err)
}
rsaPub, ok := pub.(*rsa.PublicKey)
if !ok {
return nil, fmt.Errorf("not an RSA public key")
}
block, _ = pem.Decode(privData)
if block == nil {
return nil, fmt.Errorf("invalid private key PEM")
}
priv, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("parsing private key: %w", err)
}
return &KeyPair{
PublicKey: rsaPub,
PrivateKey: priv,
}, nil
}
func LoadOrCreateKeyPair(keyDir string) (*KeyPair, error) {
if keyDir == "" {
keyDir = DefaultKeyDir
}
pubPath := filepath.Join(keyDir, "public.pem")
privPath := filepath.Join(keyDir, "private.pem")
if _, err := os.Stat(pubPath); err == nil {
if _, err := os.Stat(privPath); err == nil {
return LoadKeyPair(pubPath, privPath)
}
}
if err := os.MkdirAll(keyDir, 0700); err != nil {
return nil, fmt.Errorf("creating key directory: %w", err)
}
keyPair, err := GenerateKeyPair(DefaultKeyBits)
if err != nil {
return nil, err
}
pubPEM, err := keyPair.PublicKeyPEM()
if err != nil {
return nil, err
}
if err := os.WriteFile(pubPath, pubPEM, 0644); err != nil {
return nil, fmt.Errorf("saving public key: %w", err)
}
privPEM, err := keyPair.PrivateKeyPEM()
if err != nil {
return nil, err
}
if err := os.WriteFile(privPath, privPEM, 0600); err != nil {
return nil, fmt.Errorf("saving private key: %w", err)
}
return keyPair, nil
}
func Encrypt(plain []byte, pub *rsa.PublicKey) ([]byte, error) {
return rsa.EncryptPKCS1v15(rand.Reader, pub, plain)
}
func Decrypt(cipher []byte, priv *rsa.PrivateKey) ([]byte, error) {
return rsa.DecryptPKCS1v15(rand.Reader, priv, cipher)
}
func (k *KeyPair) Decrypt(cipher []byte) ([]byte, error) {
return rsa.DecryptPKCS1v15(rand.Reader, k.PrivateKey, cipher)
}