Post

Go模型服务化: 加载ONNX模型

Go模型服务化: 加载ONNX模型

Go ONNX模型加载教程

环境准备

  • 安装Go环境(建议version < 1.20, 发现1.20运行有问题, 本实例使用v1.17.13)

  • 安装必要的依赖包:

    1
    2
    
    go get github.com/owulveryck/onnx-go
    go get gorgonia.org/gorgonia
    

项目结构

1
2
3
4
5
6
project/
├── main.go
├── model/
   └── model.onnx
└── go.mod

基础代码实现

Python将权重&网络转换为ONNX

1
2
3
4
5
6
7
def save2onnx():
		if distributed and not is_master:
		    return

		model = policy.module if distributed else policy
		dummy_input = torch.randn(1, state_dim).to(device)
		torch.onnx.export(model, dummy_input, path, verbose=True)

使用Golang从onnx加载模型

创建main.go文件,实现基本的模型加载功能:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
package main

import (
	"fmt"
	"log"
	"os"
	"time"

	"github.com/owulveryck/onnx-go"
	"github.com/owulveryck/onnx-go/backend/x/gorgonnx"
	"gorgonia.org/gorgonia"
	"gorgonia.org/tensor"
)

func init() {
	// 确保使用正确的BLAS实现
	gorgonia.UseNonStable()
}

func main() {
	Example_gorgonia()
}

func Example_gorgonia() {
	// Create a backend receiver
	backend := gorgonnx.NewGraph()
	// Create a model and set the execution backend
	model := onnx.NewModel(backend)

	// read the onnx model
	b, err := os.ReadFile("model.onnx")
	if err != nil {
		log.Fatalf("读取模型文件失败: %v", err)
	}

	// Decode it into the model
	err = model.UnmarshalBinary(b)
	if err != nil {
		log.Fatalf("解析模型失败: %v", err)
	}

	// 多轮测试
	for i := 0; i < 10; i++ {
		nt := time.Now()
		// 构造输入数据,这里以9个 float32 数值为例
		inputData := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, float32(i)}

		// 构造一个张量,如果模型要求输入形状为 [1,9](即1个样本,9个特征)
		input := tensor.New(
			tensor.WithShape(1, 9),
			tensor.WithBacking(inputData),
		)

		// Set the first input, the number depends of the model
		model.SetInput(0, input)

		// 运行模型前确保BLAS已正确设置
		err = backend.Run()
		if err != nil {
			log.Fatalf("运行模型失败: %v", err)
		}

		// 获取模型输出
		output, err := model.GetOutputTensors()
		if err != nil {
			log.Fatalf("获取输出失败: %v", err)
		}

		// 打印输出结果
		if len(output) > 0 {
			fmt.Printf("预测结果 [%vs]: %v\n", time.Since(nt).Seconds(), output[0])
		} else {
			fmt.Println("未获取到输出结果")
		}
	}
}

embed

This post is licensed under CC BY 4.0 by the author.