Go模型服务化: 加载ONNX模型
Go模型服务化: 加载ONNX模型
Go ONNX模型加载教程
环境准备
安装Go环境(建议version < 1.20, 发现1.20运行有问题, 本实例使用v1.17.13)
安装必要的依赖包:
1 2
go get github.com/owulveryck/onnx-go go get gorgonia.org/gorgonia
项目结构
1
2
3
4
5
6
project/
├── main.go
├── model/
│ └── model.onnx
└── go.mod
基础代码实现
Python将权重&网络转换为ONNX
1
2
3
4
5
6
7
def save2onnx():
if distributed and not is_master:
return
model = policy.module if distributed else policy
dummy_input = torch.randn(1, state_dim).to(device)
torch.onnx.export(model, dummy_input, path, verbose=True)
使用Golang从onnx加载模型
创建main.go文件,实现基本的模型加载功能:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
package main
import (
"fmt"
"log"
"os"
"time"
"github.com/owulveryck/onnx-go"
"github.com/owulveryck/onnx-go/backend/x/gorgonnx"
"gorgonia.org/gorgonia"
"gorgonia.org/tensor"
)
func init() {
// 确保使用正确的BLAS实现
gorgonia.UseNonStable()
}
func main() {
Example_gorgonia()
}
func Example_gorgonia() {
// Create a backend receiver
backend := gorgonnx.NewGraph()
// Create a model and set the execution backend
model := onnx.NewModel(backend)
// read the onnx model
b, err := os.ReadFile("model.onnx")
if err != nil {
log.Fatalf("读取模型文件失败: %v", err)
}
// Decode it into the model
err = model.UnmarshalBinary(b)
if err != nil {
log.Fatalf("解析模型失败: %v", err)
}
// 多轮测试
for i := 0; i < 10; i++ {
nt := time.Now()
// 构造输入数据,这里以9个 float32 数值为例
inputData := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, float32(i)}
// 构造一个张量,如果模型要求输入形状为 [1,9](即1个样本,9个特征)
input := tensor.New(
tensor.WithShape(1, 9),
tensor.WithBacking(inputData),
)
// Set the first input, the number depends of the model
model.SetInput(0, input)
// 运行模型前确保BLAS已正确设置
err = backend.Run()
if err != nil {
log.Fatalf("运行模型失败: %v", err)
}
// 获取模型输出
output, err := model.GetOutputTensors()
if err != nil {
log.Fatalf("获取输出失败: %v", err)
}
// 打印输出结果
if len(output) > 0 {
fmt.Printf("预测结果 [%vs]: %v\n", time.Since(nt).Seconds(), output[0])
} else {
fmt.Println("未获取到输出结果")
}
}
}
This post is licensed under CC BY 4.0 by the author.