在 Go 语言中 Patch 非导出函数
TLDR; 使用 https://github.com/cch123/supermonkey 可以 patch 任意导出/非导出函数。
Update: 目前在 Linux 直接运行 go test 生成的二进制文件没有符号表,所以如果在 test 中使用需要先用 go test -c 生成带符号表的二进制文件,然后再运行,略麻烦。
目前在 Go 语言里写测试还是比较麻烦的。
除了传统的 test double,也可以通过把一个现成的对象的成员方法 Patch 掉,以达成测试执行时的特殊目的。
举个例子,我的业务逻辑是从远端获取一段数据,在测试环节没有网络,所以我需要把和网络交互的环节 mock 掉:
func LoadConfig() string {
jsonBytes, err := redis.Get("xxxx")
return string(jsonBytes)
}
这里的 redis.Get 中有网络操作,写测试时,我们的目的是为了验证 Get 之后的逻辑是否正常,所以我们可以把这个 Get 替换为直接返回内容,不走网络,社区中有 monkey patch 来达成这个目的:
monkey.Patch(redis.Get, func(input string) ([]byte, error) {
return []byte("{"key" : 12345}"), nil
})
Patch 之后,redis.Get 就会按照我们替换之后的函数来执行了,还是比较方便的。
monkey patch 的基本原理不复杂,就是把进程中 .text 段中的代码(你可以理解成 byte 数组)替换为用户提供的替换函数。
读取 target 的地址使用了 reflect.ValueOf(funcVal).Pointer() 获取函数的虚拟地址,然后把替换函数的内容以 []byte 的形式覆盖进去。
一方面是因为 reflect 本身没有办法读取非导出函数,一方面是从 Go 的语法上来讲,我们没法在包外部以字面量对非导出函数进行引用。所以目前开源的 monkey patch 是没有办法 patch 那些非导出函数的。
如果我们想要 patch 那些非导出函数,理论上并不需要对这个函数进行引用,只要能找到这个函数的虚拟地址就可以了,在这里提供一个思路,可以使用 nm 来找到我们想要 patch 的函数地址:
NM(1) GNU Development Tools NM(1)
NAME
nm - list symbols from object files
nm 可以查看一个二进制文件中的所有符号的名字、虚拟地址、大小。还是举个例子:
$cat hello.go
package main
func say() {
println("yyyy")
}
func main() {
say()
}
build 需要带 -l 的 gcflags,防止内联优化:
go build -gcflags="-l" hello.go
用 nm 找找这个 say 的地址
$nm hello | grep main
000000000044e3f0 T main
0000000000401070 T main.init
00000000004d5620 B main.initdone.
0000000000401050 T main.main
0000000000401000 T main.say ------> 这里
0000000000423620 T runtime.main
0000000000488c78 R runtime.main.f
0000000000442740 T runtime.main.func1
0000000000488c60 R runtime.main.func1.f
0000000000442780 T runtime.main.func2
0000000000488c68 R runtime.main.func2.f
00000000004b1e70 B runtime.main_init_done
0000000000488c70 R runtime.mainPC
有了虚拟地址,也就有了拷贝的 target。
在 monkey 代码的基础上,再结合 nm 命令得到的符号地址,组合一下就是下面这样的 demo:
package main
import (
"os"
"os/exec"
"reflect"
"strconv"
"strings"
"syscall"
"unsafe"
)
//go:noinline
func HeiHeiHei() {
println("hei")
}
//go:noinline
func heiheiPrivate() {
println("oh no")
}
func Replace() {
println("fake")
}
func generateFuncName2PtrDict() map[string]uintptr {
fileFullPath := os.Args[0]
cmd := exec.Command("nm", fileFullPath)
contentBytes, err := cmd.Output()
if err != nil {
println(err)
return nil
}
var result = map[string]uintptr{}
content := string(contentBytes)
lines := strings.Split(content, "\n")
for _, line := range lines {
arr := strings.Split(line, " ")
if len(arr) < 3 {
continue
}
funcSymbol, addr := arr[2], arr[0]
addrUint, _ := strconv.ParseUint(addr, 16, 64)
result[funcSymbol] = uintptr(addrUint)
}
return result
}
func main() {
m := generateFuncName2PtrDict()
heiheiPrivate()
replaceFunction(m["_main.heiheiPrivate"], (uintptr)(getPtr(reflect.ValueOf(Replace))))
heiheiPrivate()
}
type value struct {
_ uintptr
ptr unsafe.Pointer
}
func getPtr(v reflect.Value) unsafe.Pointer {
return (*value)(unsafe.Pointer(&v)).ptr
}
// from is a pointer to the actual function
// to is a pointer to a go funcvalue
func replaceFunction(from, to uintptr) (original []byte) {
jumpData := jmpToFunctionValue(to)
f := rawMemoryAccess(from, len(jumpData))
original = make([]byte, len(f))
copy(original, f)
copyToLocation(from, jumpData)
return
}
// Assembles a jump to a function value
func jmpToFunctionValue(to uintptr) []byte {
return []byte{
0x48, 0xBA,
byte(to),
byte(to >> 8),
byte(to >> 16),
byte(to >> 24),
byte(to >> 32),
byte(to >> 40),
byte(to >> 48),
byte(to >> 56), // movabs rdx,to
0xFF, 0x22, // jmp QWORD PTR [rdx]
}
}
func rawMemoryAccess(p uintptr, length int) []byte {
return *(*[]byte)(unsafe.Pointer(&reflect.SliceHeader{
Data: p,
Len: length,
Cap: length,
}))
}
func mprotectCrossPage(addr uintptr, length int, prot int) {
pageSize := syscall.Getpagesize()
for p := pageStart(addr); p < addr+uintptr(length); p += uintptr(pageSize) {
page := rawMemoryAccess(p, pageSize)
err := syscall.Mprotect(page, prot)
if err != nil {
panic(err)
}
}
}
// this function is super unsafe
// aww yeah
// It copies a slice to a raw memory location, disabling all memory protection before doing so.
func copyToLocation(location uintptr, data []byte) {
f := rawMemoryAccess(location, len(data))
mprotectCrossPage(location, len(data), syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC)
copy(f, data[:])
mprotectCrossPage(location, len(data), syscall.PROT_READ|syscall.PROT_EXEC)
}
func pageStart(ptr uintptr) uintptr {
return ptr & ^(uintptr(syscall.Getpagesize() - 1))
}
go run -gcflags="-l" yourfile.go
该思路已被封装至 https://github.com/cch123/supermonkey 中。
评论链接: