Golang玩转TensorFlow深度学习模型

  • 文章来源:微博   作者:陈

简介

TensorFlow是目前最流行的深度学习框架,主要支持Python和C++,最近还加入了对Java、Rust和Golang的支持。Golang也是非常流行的服务端编程语言,让Golang应用也能访问深度学习模型,对于服务端编程和智能应用带来很大的想象空间。

但现在使用Python API来构建graph和训练模型更加简单,而且大部分paper的模型也是使用Python API实现的,Golang直接访问这些TensorFlow模型文件比较困难。我们基于TensorFlow Serving和gRPC很好地解决了这个问题,基于一个跨模型的通用访问接口,实现了golang predict client,让Golang开发者可以像Python或C++应用一样“玩转”TensorFlow深度学习模型了。

关注Github的 https://github.com/tensorflow/tensorflow/issues/10 可能知道,TensorFlow已经支持Golang binding可以直接访问graph结构,但不在本文讨论范围内,想了解更多内容欢迎继续关注。


TensorFlow模型

首先我们需要编写应用来生成TensorFlow模型,所谓模型就是在TensorFlow中的graph,还有训练过程中得到的参数。这样的模型可以通过保存checkpoint得到,因为checkpoint就包含graph和参数信息,方便下次运行时加载参数继续训练。

但checkpoint格式的模型文件并不能给TensorFlow Serving直接加载,TensorFlow Serving是Google开源的高性能gRPC服务,使用C++加载TensorFlow模型文件来预测,并提供一个跨模型的通用inference接口。使用TensorFlow Serving好处就是高性能,并且提供跨语言的RPC接口,这样无论是Java、Golang甚至是Ruby都可以直接访问这些模型了。



Golang 客户端

Golang的gRPC客户端代码需要我们自己实现,通过TensorFlow Serving提供的model.proto、predict.proto和prediction_service.proto来生成protobuf和gRPC文件,而它还依赖TensorFlow项目中的TensorProto等proto定义,我们在deep_recommend_system项目中也提供了一个Shell脚本来生成。



用户需要根据不同模型的输入格式,生成对应的TensorProto对象,也就是inference时需要的多维数组格式,目前我们已经支持稠密的Tensor输入和稀疏的SparseTensor输入,基本涵盖了大部分深度学习模型。对于用户自定义的模型,可以通过exporter指定输入和输出op,不需要修改Serving源码就可以直接加载,gRPC接口和数据交换格式都是通用的,只需要在构建inference数据时传入不同的shape和预测数据即可。



目前代码已经开放到 https://github.com/tobegit3hub/deep_recommend_system/tree/master/golang_predict_client ,感谢小米网Golang团队和Github用户 @sparklxb 高质量的pull-request。如果你的业务正在使用Golang,同时也在使用TensorFlow和深度学习模型,对golang predict client有任何反馈欢迎在Github上交流。

Golang binding

TensorFlow对Golang的支持远不止提供gRPC client,如果你想在graph中对某个tensor进行保存和加载,可以参考 https://gist.github.com/helinwang/7782c6b2815c334c77653fc0e52b6069 ,最新的tensorflow package已经可以在godoc中找到 https://godoc.org/github.com/tensorflow/tensorflow/tensorflow/go 。



总结

最后总结回答一个问题,使用Golang能否玩转TensorFlow和深度学习模型?答案是Yes,但目前直接使用Golang binding去训练深度学习模型还有很多不足(构建graph的灵活性和数值计算op数量还远不如Python),但通过我们提供的golang predict client,Golang应用也可以非常灵活并且轻易地访问预训练好的TensorFlow模型了。

deep_recommend_system是一个TensorFlow模板应用,已经支持通用的CSV稠密数据和LIBSVM稀疏数据的模型训练,实现了LR、MLP、CNN和Wide and deep模型,支持checkpoint、tensorboard、dropout、learning rate decay和batch normalization等特性,还有Python、Java、Scala、Spark和Golang等客户端实例,欢迎继续关注和star https://github.com/tobegit3hub/deep_recommend_system 。

来源: 微信

作者: 全球人工智能

文章来自互联网如果对个人或单位有侵犯其著作权行为,请您联系我们:lk_qmail@foxmail.com

在线交流