解读libFM

写在前面

由于之前的比赛中用到了这个工具,所以顺带对矩阵分解以及FM深入学习了下。本文将结合其算法原理和Cpp源码,说说自己的使用心得,另外讲讲如何将Cpp源码分别用Python和Java改写。

问题的本质

对于推荐系统一类问题来说,最核心的就是衡量用户对未接触target的感兴趣程度。CF一类算法的思想是通过相似度计算来直接估计感兴趣程度,矩阵分解一类的思想则是借助隐变量的映射间接得到对target的感兴趣程度。那么,FM呢?

个人理解是,FM和词向量一类的做法有相通之处。通过分别对用户和商品等构建一个向量,训练结束后根据用户的向量和商品向量之间的内积,估计出用户对商品的感兴趣程度。OK,借用论文[Factorization Machines]中的一个例子来说明下:

假设观测到的数据集如下:

(A, TI, 5)
(A, NH, 3)
(A, SW, 1)
(B, SW, 4)
(B, ST, 5)
(C, TI, 1)
(C, SW, 5)

现在我们要估计A对ST的感兴趣程度,显然,如果用传统的分类算法,由于训练集中没有出现(A, ST)的pair,所以得到的感兴趣程度是0.但是FM衡量的是\((v_A, v_{ST})\)之间的相似度,具体来说,是这么做的:由观测数据(B, SW, 4)和(C, SW, 5)可以得到\(v_B\)和\(v_C\)之间的相似度较高,此外,根据(A, TI, 5)和(C, TI, 1)可以推测\(v_A\)和\(v_C\)之间的相似度较低。而根据(B, SW, 4)和(B, ST, 5)可以发现,\(v_{SW}\)和\(v_{ST}\)之间的相似度较高。计算\(v_A\)和\(v_{ST}\)便可得知,A对ST的感兴趣程度较低。从这个角度来看,FM似乎是借用一种向量化表示来融合了基于用户和基于商品的协同过滤。

理解其数学模型

根据前面的描述,给定一个用户商品(U,I)pair后,我们只需计算下式即可得到估计值:

\[\hat{y}(x)=\langle v_u, v_I\rangle\]

但我们这么做的实际上是有几个潜在假设的: 1.默认特征矩阵中只有User和Item两类特征; 2.User和Item维度的特征值为1;

实际问题中除了User和Item,还通常有Time,UserContext和ItemContext等等维度的特征。此外,不同User和Item的权值并不一样,所以并不都是1。将上式改写下便得到了FM的原型:

\[\hat{y}(x)=w_0+\sum_{i=1}^{n}w_i x_i + \sum_{i=1}^{n}\sum_{j=i+1}^{n}\langle v_i, v_j\rangle x_i x_j\]

其中\(w_0\)是全局bias,\(w_i\)是每维特征的权重, \(\langle v_i, v_j \rangle\)可以看做是交互特征的权重。

计算过程优化

显然,上式的计算复杂度为\(O(kn^2)\),利用二次项展开式可以将计算复杂度降低到线性复杂度:

\[ \begin{split} &\sum_{i=1}^{n} \sum_{j=i+1}^{n} \langle v_i, v_j \rangle x_i x_j\\ &=\frac{1}{2}\sum_{i=1}^{n} \sum_{j=1}^{n} \langle v_i, v_j \rangle x_i x_j - \frac{1}{2} \sum_{i=1}^{n}\langle v_i,v_j \rangle x_i x_i \\ &=\frac{1}{2}\left(\sum_{i=1}^{n}\sum_{j=1}^{n}\sum_{f=1}^{k} v_{i,f}v_{j,f} x_i x_j - \sum_{i=1}^{n}\sum_{f=1}^{k}v_{i,f}v_i x_i x_i\right)\\ &=\frac{1}{2}\sum_{f=1}^{k}\left(\left( \sum_{i=1}^{n} v_{i,f} x_i \right) \left( \sum_{j=1}^{n} v_{j,f} x_j \right) - \sum_{i=1}^{n} v_{i,f}^2 x_i^2\right) \\ &=\frac{1}{2}\sum_{f=1}^{k}\left(\left( \sum_{i=1}^{n}v_{i,f} x_i \right)^2 - \sum_{i=1}^{n}v_{i,f}^{2} x_i^2\right) \end{split} \]

当v的维度从2维扩展到d维时,上式也可以做对应的扩展。对上式中的变量(\(w_0, w_i, v_{i,f}\))分别求导可以得到下式:

\[ \frac{\partial}{\partial \theta} \hat{y}(x) = \left\{ \begin{aligned} &1,\hspace{3in}if\ \theta \ is\ w_0\\ &x_i, \hspace{2.9in}if\ \theta \ is\ w_i\\ &x_i\Sigma_{j=1}^{n}v_{j,f}x_j - v_{i,f}x_i^2, \hspace{1.15in} if\ \theta \ is\ v_{i,f}\\ \end{aligned} \right. \]

C+ +源码解析

下面以libfm默认的mcmc方法用于分类任务为例,对其C+ +实现做一个简单分析。

概括地来说,训练的过程分为两步:一是计算误差,二是更新系数(\(w_i\)和\(v_{i,f}\))。具体流程如下:

  1. 读入参数,并初始化。对于mcmc方法,重要的参数只有三个,一是-iter,即迭代次数;二是-dim,用于确定上面的v的维度;三是init_stdev;用于控制v和w的初始值(用正态分布随机初始化,由fm.init()完成)。

  2. 接下来调用_learn(train, test)来完成整个训练过程。其中cache是一个eqterm的对象,该对象中的e用于存储最后的误差项,q可以看做一个缓存,存储一些中间计算结果(不得不说,作者把内存用到了极致!!!各种省内存的做法啊)。把_learn函数的核心部分剥离出来就是下面的内容:

    void _learn(Data& train, Data& test){
     // ...
     predict_data_and_write_to_eterms(data, cache); // 根据v和w分别计算在训练集和测试集上的估计值,并保存到cache中
     //根据训练集的target值,计算训练集上的偏差
     for (uint c = 0; c < train.num_cases; c++) {
     cache[c].e = cache[c].e - train.target(c);
     }
    
     // 迭代更新
     for (uint i = num_complete_iter; i < num_iter; i++) {
     //...
     draw_all(train); //根据偏差(cache中的e部分)和训练集数据更新v和w
     predict_data_and_write_to_eterms(data, cache); // 用更新后的v和w继续计算训练集和测试集上的估计值
     }
  3. 先看一下predict_data_and_write_to_eterms函数,其实就是求解\(\hat{y}\)的一个过程。将其分解为2个部分,即\(\sum_{i=1}^{n}w_i x_i\)和\(\sum_{i=1}^{n}\sum_{j=i+1}^{n}\langle v_i, v_j\rangle x_i x_j\)(对mcmc方法而言\(w_0\)项为0),后者的线性时间简化版可以分成\(\frac{1}{2}\sum_{f=1}^{k}\left(\left( \sum_{i=1}^{n}v_{i,f} x_i \right)^2\right)\) 和\(\frac{1}{2}\sum_{f=1}^{k}\left(\sum_{i=1}^{n}v_{i,f}^{2} x_i^2\right)\)这两块。这样子,predict_data_and_write_to_eterms函数对应的源码就好理解了。需要注意的一点是,源码中的Data数据结构实际上是CSC格式的稀疏矩阵,因此C++源码计算过程中是转置后按行取值计算的。理解这一点后,再用python或java改写该部分时,便可直接使用矩阵计算的方式,避免for循环带来的开销。

  4. draw_all函数稍微复杂点。先假设没有meta文件(特征的分组信息),前面我们已经得到了v,w参数的梯度,mcmc的做法是计算梯度并更新之后,对结果分别加入了正态分布和gamma分布的采样过程,与此同时,作者边更新v和w,根据新旧v和w的变化程度同步更新error(这个部分理解得不是很透彻,一般而言会在v和w完全更新后再更新e)。最后,对训练集上的目标值做一个截断高斯分布采样,利用采样值对error进一步更新。

  5. meta信息的加入,相当于对每类特征内部多了一个正则项,在我使用的过程中,加不加正则项对结果的影响还是蛮大的。主要原因可能是我不同特征组的scale差异有点大,如果采用全局的正则项,难以精确描述不同特征组之间的差异。

改写成python和java版本

由于比赛平台是java的,需要将C++代码改写成java,前期验证功能性的时候,为了快速分析中间变量,先用python写了一遍。整体框架都一致,核心是实现predict_data_and_write_to_etermsdraw_all这两个函数。其中predict_data_and_write_to_eterms的实现比较简单,分别借用python下的scipy中的csc矩阵和java下matrix-toolkits-java包中的FlexCompColMatrix可以将C++源码中的for循环大大简化。但是对于draw_all函数却不太轻松,最核心的问题是前面提到的,C++源码中在更新w和v的同时还在更新error项,导致无法一次通过矩阵运算后再更新error(事实上我一开始这么干过,但是跑出来的效果差太远),无奈,只能和C++一样用循环迭代更新,效率上慢了将近一个数量级。