在飞桨第三期黑客松活动中,湖北大学计算机与信息工程学院的研究生韩凡宇(队长)和王勇森组建了“源力觉醒”小队,为飞桨新增了GumbelAPI。
本文将由王勇森分享为飞桨新增GumbelAPI的经验。
01
任务介绍
▎任务背景
耿贝尔(Gumbel)分布是一种极值型分布。Gumbel分布理论认为,最大值分布的潜在适用性与极值理论有关,如果基础样本数据的分布是正态或者指数类型,Gumbel分布就是有用的。Gumbel分布适用于海洋、水文、气象领域,用来计算不同重现期的极端高(低)潮位。而在概率论和统计学中,Gumbel分布常被用来模拟不同分布的样本的最大(或最小)分布。
Gumbel概率密度图像
飞桨框架目前还未集成Gumbel分布,本任务的目标是对飞桨框架中现有概率分布方案进行扩展,新增GumbelAPI。增加此API能够扩大飞桨的应用处理范围,对飞桨来说是非常必要的。
▎设计思路
GumbelAPI的设计需要做多个方面的知识储备:
详细了解Gumbel分布背后的数学原理以及应用场景;
深刻了解飞桨和业界概率分布设计实现的方法和技巧。
注:“充分学习Gumbel背后的数学原理”这一点很重要,避免在开发时违背定理。印象最深的是在开发之初,我们没有了解到分布的mean如何计算,直接使用了一个不清楚的计算方式,导致在最开始的版本中,mean的计算方式错误。最后,通过查询资料得知标准Gumbel分布的mean为负欧拉常数,于是纠正。
接下来,我将详细描述GumbelAPI的设计思路。
命名与参数设计
API的名称直接使用Gumbel分布的名称,参数保持Gumbel分布最原生的参数,包括“位置参数loc”以及“尺度参数scale”。预期GumbelAPI的形式为:
paddle.distribution.gumbel.Gumbel(loc,scale)
Gumbel分布类的初始化方法
类初始化过程中,一方面要严格控制参数loc和scale的形状和数据类型。另一方面还要借助基础分布Uniform以及transforms初始化父类TransformedDistribution。
GumbelAPI的功能
该API部分功能继承于TransformedDistribution,包括mean均值、variance方差、sample随机采样、rsample重参数化采样、prob概率密度、log_prob对数概率密度、entropy熵计算等。除了官方任务要求外,我们还添加了一些其他的方法,比如stddev标准差和cdf累积分布函数等。
▎类初始化方法
■数据类型
首先我们需要判断loc和scale的数据类型是否是飞桨支持的标量数据类型。如果是飞桨支持的标量,需要将其转为飞桨支持的tensor类型。
上述判断实现如下:
ifnotisinstance(loc,(numbers.Real,framework.Variable)):(抛出数据类型错误)ifnotisinstance(scale,(numbers.Real,framework.Variable)):(抛出数据类型错误)ifisinstance(loc,numbers.Real):(转为paddle类型的tensor)ifisinstance(scale,numbers.Real):(转为paddle类型的tensor)
此外,还要统一loc和scale的形状类型,我们选择使用paddle.broadcast_tensors()广播机制来进行统一。
ifloc.shape!=scale.shape:self.loc,self.scale=paddle.broadcast_tensors([loc,scale])else:self.loc,self.scale=loc,scale
■父类调用
因为初始化方法中对基础分布进行一系列transform的操作,我们选择继承父类TransformedDistribution。但在实际开发过程中,由于遇到了一些问题,我们并未在TransformedDistribution类中对Uniform进行变换,而是选择在rsample中进行变换。
▎API伪代码实现
在经过以上准备,确定设计思路后,我们给出paddle.distribution.gumbel.Gumbel中实现的属性、方法的伪代码。
mean:均值
loc+scale*γ
variance:方差
pow(scale,2)*pi*pi/6
stddev:标准差
sqrt(variance)
cdf(value):累积分布函数
exp(-exp(-(value-loc)/scale))
rsample(shape):重参数化采样
ExpTransform()AffineTransform(0,-ones_like(scale))AffineTransform(loc,-scale)chain=ChainTransform(ExpTransform(),AffineTransform(0,-ones_like(scale)),AffineTransform(loc,-scale))chain.forward(base_distribute)
02
代码开发
本节介绍代码开发的过程,着重介绍在开发中遇到困难的两个部分,包括类初始化方法(__init__)和重参数化采样方法(rsample)。最后,再介绍开发过程中遇到的问题以及如何解决该问题。其他属性方法仅是将1.4节中的伪代码使用飞桨框架实现,我们将在本节末尾给出各个方法属性的实现。
▎类初始化方法:__init__
在进行此方法的开发中,我们着重