crypto-sm/sm4.go

88 lines
2.4 KiB
Go

package sm
import (
"crypto/cipher"
"errors"
"apigo.cc/go/crypto"
"apigo.cc/go/safe"
"github.com/emmansun/gmsm/sm4"
)
type SM4Cipher struct {
useGCM bool
}
var SM4CBC = &SM4Cipher{useGCM: false}
var SM4GCM = &SM4Cipher{useGCM: true}
// --- Factory functions matching your style ---
func NewSM4CBC(safeKeyBuf, safeIvBuf *safe.SafeBuf) (*crypto.Symmetric, error) {
return crypto.NewSymmetric(SM4CBC, safeKeyBuf, safeIvBuf)
}
func NewSM4CBCAndEraseKey(key, iv []byte) (*crypto.Symmetric, error) {
return crypto.NewSymmetricAndEraseKey(SM4CBC, key, iv)
}
func NewSM4CBCWithOutEraseKey(key, iv []byte) (*crypto.Symmetric, error) {
return crypto.NewSymmetricWithOutEraseKey(SM4CBC, key, iv)
}
func NewSM4GCM(safeKeyBuf, safeIvBuf *safe.SafeBuf) (*crypto.Symmetric, error) {
return crypto.NewSymmetric(SM4GCM, safeKeyBuf, safeIvBuf)
}
func NewSM4GCMAndEraseKey(key, iv []byte) (*crypto.Symmetric, error) {
return crypto.NewSymmetricAndEraseKey(SM4GCM, key, iv)
}
func NewSM4GCMWithOutEraseKey(key, iv []byte) (*crypto.Symmetric, error) {
return crypto.NewSymmetricWithOutEraseKey(SM4GCM, key, iv)
}
func (s *SM4Cipher) Encrypt(data []byte, key []byte, iv []byte) ([]byte, error) {
block, err := sm4.NewCipher(key)
if err != nil {
return nil, err
}
if s.useGCM {
sm4gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
// SM4-GCM nonce 推荐 12 字节
return sm4gcm.Seal(nil, iv[:sm4gcm.NonceSize()], data, nil), nil
} else {
// SM4 块大小固定为 16
blockSize := block.BlockSize()
paddedData := crypto.Pkcs5Padding(data, blockSize)
blockMode := cipher.NewCBCEncrypter(block, iv[:blockSize])
crypted := make([]byte, len(paddedData))
blockMode.CryptBlocks(crypted, paddedData)
return crypted, nil
}
}
func (s *SM4Cipher) Decrypt(data []byte, key []byte, iv []byte) ([]byte, error) {
block, err := sm4.NewCipher(key)
if err != nil {
return nil, err
}
if s.useGCM {
sm4gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
return sm4gcm.Open(nil, iv[:sm4gcm.NonceSize()], data, nil)
} else {
blockSize := block.BlockSize()
if len(data)%blockSize != 0 {
return nil, errors.New("ciphertext is not a multiple of block size")
}
blockMode := cipher.NewCBCDecrypter(block, iv[:blockSize])
plainText := make([]byte, len(data))
blockMode.CryptBlocks(plainText, data)
return crypto.Pkcs5UnPadding(plainText), nil
}
}