crypto-sm/sm2.go

128 lines
3.6 KiB
Go
Raw Normal View History

package sm
import (
stdcrypto "crypto"
"crypto/ecdsa"
"crypto/rand"
"encoding/asn1"
"errors"
"math/big"
"apigo.cc/go/crypto"
"apigo.cc/go/safe"
"github.com/emmansun/gmsm/sm2"
"github.com/emmansun/gmsm/smx509"
)
type SM2Algorithm struct{}
var SM2 = &SM2Algorithm{}
func NewSM2(safePrivateKeyBuf, safePublicKeyBuf *safe.SafeBuf) (*crypto.Asymmetric, error) {
return crypto.NewAsymmetric(SM2, safePrivateKeyBuf, safePublicKeyBuf)
}
func NewSM2AndEraseKey(privateKey, publicKey []byte) (*crypto.Asymmetric, error) {
return crypto.NewAsymmetricAndEraseKey(SM2, privateKey, publicKey)
}
func NewSM2WithoutEraseKey(privateKey, publicKey []byte) (*crypto.Asymmetric, error) {
return crypto.NewAsymmetricWithoutEraseKey(SM2, privateKey, publicKey, false)
}
func NewSM2ByPassword(password, salt []byte) (*crypto.Asymmetric, error) {
seed := crypto.DeriveKey(password, salt, 32)
defer safe.ZeroMemory(seed)
curve := sm2.P256()
params := curve.Params()
// Derive D: seed % (n-1) + 1
// SM2 private key d should be in [1, n-1]
d := new(big.Int).SetBytes(seed)
nMinusOne := new(big.Int).Sub(params.N, big.NewInt(1))
d.Mod(d, nMinusOne)
d.Add(d, big.NewInt(1))
priv := new(sm2.PrivateKey)
priv.Curve = curve
priv.D = d
priv.PublicKey.X, priv.PublicKey.Y = curve.ScalarBaseMult(d.Bytes())
privateKey, err := smx509.MarshalPKCS8PrivateKey(priv)
if err != nil {
return nil, err
}
publicKey, err := smx509.MarshalPKIXPublicKey(&priv.PublicKey)
if err != nil {
return nil, err
}
return NewSM2AndEraseKey(privateKey, publicKey)
}
func GenerateSM2KeyPair() ([]byte, []byte, error) {
privKey, err := sm2.GenerateKey(rand.Reader)
if err != nil {
return nil, nil, err
}
privateKey, err := smx509.MarshalPKCS8PrivateKey(privKey)
if err != nil {
return nil, nil, err
}
publicKey, err := smx509.MarshalPKIXPublicKey(&privKey.PublicKey)
if err != nil {
return nil, nil, err
}
return privateKey, publicKey, nil
}
func (a *SM2Algorithm) ParsePrivateKey(der []byte) (any, error) {
return smx509.ParsePKCS8PrivateKey(der)
}
func (a *SM2Algorithm) ParsePublicKey(der []byte) (any, error) {
pubKeyAny, err := smx509.ParsePKIXPublicKey(der)
if err != nil {
return nil, err
}
pubKey, ok := pubKeyAny.(*ecdsa.PublicKey)
if !ok {
return nil, errors.New("not an SM2 public key")
}
return pubKey, nil
}
func (a *SM2Algorithm) Sign(privateKeyObj any, data []byte, hash ...stdcrypto.Hash) ([]byte, error) {
privKey, ok := privateKeyObj.(*sm2.PrivateKey)
if !ok {
return nil, errors.New("invalid SM2 private key")
}
return privKey.SignWithSM2(rand.Reader, nil, data)
}
func (a *SM2Algorithm) Verify(publicKeyObj any, data []byte, signature []byte, hash ...stdcrypto.Hash) (bool, error) {
pubKey, ok := publicKeyObj.(*ecdsa.PublicKey)
if !ok {
return false, errors.New("invalid SM2 public key")
}
var sm2Sig struct{ R, S *big.Int }
if _, err := asn1.Unmarshal(signature, &sm2Sig); err != nil {
return false, err
}
return sm2.VerifyWithSM2(pubKey, nil, data, sm2Sig.R, sm2Sig.S), nil
}
func (a *SM2Algorithm) Encrypt(publicKeyObj any, data []byte) ([]byte, error) {
pubKey, ok := publicKeyObj.(*ecdsa.PublicKey)
if !ok {
return nil, errors.New("invalid SM2 public key")
}
return sm2.Encrypt(rand.Reader, pubKey, data, sm2.NewPlainEncrypterOpts(sm2.MarshalUncompressed, sm2.C1C3C2))
}
func (a *SM2Algorithm) Decrypt(privateKeyObj any, data []byte) ([]byte, error) {
privKey, ok := privateKeyObj.(*sm2.PrivateKey)
if !ok {
return nil, errors.New("invalid SM2 private key")
}
return privKey.Decrypt(nil, data, sm2.NewPlainEncrypterOpts(sm2.MarshalUncompressed, sm2.C1C3C2))
}