spark实现item2Vec算法-附scala代码
本文记录了使用spark实现item2vec算法的相关内容,欢迎做相关工作的同学与我联系zhaoliang19960421@outlook.com
/**
* 本代码以做脱敏处理,与原公司、原业务无关,特此声明
/
package *
import *.SparkContextUtils.createSparkSession
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.feature.{Word2Vec, Word2VecModel}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.ml.linalg._
/**
* @author zhaoliang6 on 20220406
* 基于word2vec算法构造item2vec
* 生成item向量,用户侧使用average pooling 构造user向量
*/
object Item2Vec {
def main(args: Array[String]): Unit = {
val Array(locale: String, startDate: String, endDate: String) = args
val sparkSession: SparkSession = createSparkSession(this.getClass.getSimpleName)
val userItemSeqDf = getUserItemSeq(sparkSession, startDate, endDate)
val model = getWord2VecModel(userItemSeqDf, "usage_seq", "vector")
val itemVec = getItemVec(model)
val userVec = getUserVec(sparkSession, userItemSeqDf, itemVec)
}
/**
* 给定的item下的最相似的前topN个结果
*/
def getItemSim(model: Word2VecModel, item: String, topN: Int): Unit = {
try {
println(s"$item 最相似的前${topN}个结果是:")
model.findSynonyms(item, topN).show(truncate = false)
} catch {
case ex: Exception => println(s"$item 不存在")
}
}
/**
* 将用户序列下的item求平均得到用户向量
*/
def getUserVec(sparkSession: SparkSession, orgDf: DataFrame, itemVec: DataFrame): DataFrame = {
val arrayDefaultVec = new Array[Double](200)
def itemVecAagPoolingUDF(map: scala.collection.Map[String, Array[Double]]): UserDefinedFunction = udf((seq: mutable.WrappedArray[String]) => {
val res = ArrayBuffer[Array[Double]]()
res.appendAll(seq.map(map.getOrElse(_, arrayDefaultVec)))
val tmp: (Array[Double], Int) = res.map(e => (e, 1)).reduce((x, y) => {
(x._1.zip(y._1).map(a => a._1 a._2), x._2 y._2)
})
if (tmp._2 > 0) tmp._1.map(e => e / tmp._2)
else arrayDefaultVec
})
val itemVecBC = sparkSession.sparkContext.broadcast(itemVec.rdd.map(r => (r.getString(0), r.getSeq[Double](1).toArray)).collectAsMap())
val userVecDf = orgDf
.withColumn("vector", itemVecAagPoolingUDF(itemVecBC.value)(col("usage_seq")))
.select("gaid", "vector")
userVecDf
}
/**
* 基于w2v 得到item向量
*/
def getItemVec(model: Word2VecModel): DataFrame = {
def vector2ArrayUDF(): UserDefinedFunction = udf((vec: Vector) => {
val norm = Vectors.norm(vec, 2)
vec.toArray.map(e => if (norm != 0) e / norm else 0.0)
})
val itemVec = model.getVectors
.select(col("word").as("pkg"), col("vector").as("org_vector"))
.withColumn("vectorArray", vector2ArrayUDF()(col("vector")))
.selectExpr("word as item", "vectorArray")
itemVec
}
/**
* 获得user-itemSeq
*/
def getUserItemSeq(sparkSession: SparkSession, startDate: String, endDate: String): DataFrame = {
def getSeqUDF(): UserDefinedFunction = udf((seq: mutable.WrappedArray[GenericRowWithSchema]) => {
val listSeq = ArrayBuffer[String]()
seq.sortBy(e => e.getAs[Long]("timestamp"))
var pkg = seq.head.getAs[String]("pkg")
var open = seq.head.getAs[Long]("timestamp")
var dura = seq.head.getAs[Double]("duration")
listSeq.append(pkg)
seq.drop(0).foreach(e => {
val tmp_pkg = e.getAs[String]("pkg")
val tmp_open = e.getAs[Long]("timestamp")
val tmp_dura = e.getAs[Double]("duration")
if (!tmp_pkg.equals(pkg) || (tmp_pkg.equals(pkg) && ((tmp_open - open) / 1000 - dura > 10)))
listSeq.append(tmp_pkg)
pkg = tmp_pkg
open = tmp_open
dura = tmp_dura
})
listSeq
})
val dfAppUsage = sparkSession.read.parquet("hdfs://***")
.where(s"date between $startDate and $endDate")
.groupBy("gaid")
.agg(collect_list(struct("pkg", "timestamp", "duration")).as("seq"))
.withColumn("usage_seq", getSeqUDF()(col("seq")))
.withColumn("seq_len", size(col("usage_seq")))
.where("seq_len > 10") // 最短的路径
.selectExpr("gaid", "usage_seq")
dfAppUsage
}
/**
* 获得word2vec模型
*/
def getWord2VecModel(orgDf: DataFrame, inputCol: String, outputCol: String): Word2VecModel = {
val model: Word2VecModel = new Word2Vec()
.setInputCol(inputCol)
.setOutputCol(outputCol)
.setSeed(1024)
.setMaxIter(10)
.setMinCount(5)
.setVectorSize(200)
.setWindowSize(5)
.setNumPartitions(1000)
.setMaxSentenceLength(100)
.fit(orgDf)
model
}
}
这篇好文章是转载于:学新通技术网
- 版权申明: 本站部分内容来自互联网,仅供学习及演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,请提供相关证据及您的身份证明,我们将在收到邮件后48小时内删除。
- 本站站名: 学新通技术网
- 本文地址: /boutique/detail/tanhfhkchg
系列文章
更多
同类精品
更多
-
photoshop保存的图片太大微信发不了怎么办
PHP中文网 06-15 -
word里面弄一个表格后上面的标题会跑到下面怎么办
PHP中文网 06-20 -
《学习通》视频自动暂停处理方法
HelloWorld317 07-05 -
Android 11 保存文件到外部存储,并分享文件
Luke 10-12 -
photoshop扩展功能面板显示灰色怎么办
PHP中文网 06-14 -
微信公众号没有声音提示怎么办
PHP中文网 03-31 -
excel下划线不显示怎么办
PHP中文网 06-23 -
excel打印预览压线压字怎么办
PHP中文网 06-22 -
TikTok加速器哪个好免费的TK加速器推荐
TK小达人 10-01 -
怎样阻止微信小程序自动打开
PHP中文网 06-13