Spark ALS实现的步骤是什么

81次阅读
没有评论

共计 4600 个字符,预计需要花费 12 分钟才能阅读完成。

这篇文章主要讲解了“Spark ALS 实现的步骤是什么”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着丸趣 TV 小编的思路慢慢深入,一起来研究和学习“Spark ALS 实现的步骤是什么”吧!

spark ALS 算法是做个性推荐用的,它所需要的数据集是类似用户对商品的打分表之类的数据集。实现步骤主要以下几步:

1、定义输入数据

2、输入数据转换成评分数据格式,如 case class Rating(user: Int, movie: Int, rating: Float)

3、设计 ALS 模型训练数据

4、计算推荐数据,存储起来供业务系统直接使用。

下面看看具体的代码:

package recommend
import org.apache.spark.sql.SparkSession
import java.util.Properties
import org.apache.spark.rdd.RDD
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.recommendation.ALS
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.Row
import org.apache.spark.ml.feature.IndexToString
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.TaskContext
import org.apache.spark.ml.Pipeline
import org.apache.spark.sql.SaveMode
 *  个性化推荐 ALS 算法
 *  用户对资源的点击率作为评分
 *
 */
object Recommend { case class Rating(user: Int, movie: Int, rating: Float)
 
 
 def main(args: Array[String]): Unit = { val spark = SparkSession.builder().appName(Java Spark MYSQL Recommend)
 .master(local)
 .config(es.nodes ,  127.0.0.1)
 .config(es.port ,  9200)
 .config(es.mapping.date.rich ,  false) // 不解析日期类型
 .getOrCreate()
 trainModel(spark)
 spark.close()
 }
 def trainModel(spark: SparkSession): Unit = {
 import spark.implicits._
 val MAX = 3 //  最大推荐数目
 val rank = 10 //  向量大小,默认 10
 val iterations = 10 //  迭代次数,默认 10

 val url =  jdbc:mysql://127.0.0.1:3306/test?useUnicode=true characterEncoding=utf8  val table =  clicks  val user =  root  val pass =  123456  val props = new Properties()  props.setProperty(user , user) //  设置用户名  props.setProperty(password , pass) //  设置密码  val clicks = spark.read.jdbc(url, table, props).repartition(4)  clicks.createOrReplaceGlobalTempView(clicks)  val agg = spark.sql(SELECT userId ,resId ,COUNT(id) AS clicks FROM global_temp.clicks GROUP BY userId,resId )    val userIndexer = new StringIndexer()  .setInputCol(userId)  .setOutputCol(userIndex)  val resIndexer = new StringIndexer()  .setInputCol(resId)  .setOutputCol(resIndex)  val indexed1 = userIndexer.fit(agg).transform(agg)  val indexed2 = resIndexer.fit(indexed1).transform(indexed1)  indexed2.show()  val ratings = indexed2.map(x =  Rating(x.getDouble(3).toInt, x.getDouble(4).toInt, x.getLong(2).toFloat))  ratings.show()  val Array(training, test) = ratings.randomSplit(Array(0.9, 0.1))  println(training:)  training.show()  println(test:)  test.show()  // 隐性反馈和显示反馈  val als = new ALS()  .setMaxIter(iterations)  .setRegParam(0.01)  .setImplicitPrefs(false)  .setUserCol(user)  .setItemCol(movie)  .setRatingCol(rating)  val model = als.fit(ratings)  // Evaluate the model by computing the RMSE on the test data  // Note we set cold start strategy to  drop  to ensure we don t get NaN evaluation metrics  model.setColdStartStrategy(drop)  val predictions = model.transform(test)  val r2 = model.recommendForAllUsers(MAX)  println(r2.schema)  val result = r2.rdd.flatMap(row =  { val userId = row.getInt(0)  val arrayPredict: Seq[Row] = row.getSeq(1)  var result = ArrayBuffer[Rating]()  arrayPredict.foreach(rowPredict =  { val p = rowPredict(0).asInstanceOf[Int]  val score = rowPredict(1).asInstanceOf[Float]  val sql =  insert into recommends(userId,resId,score) values (  +  userId +  ,  +  rowPredict(0) +  ,  +  rowPredict(1) +   )  println(sql:  + sql)  result.append(Rating(userId, p, score))  })  for (i  - result) yield {  i  }  })  println(推荐结果 RDD 已展开)  result.toDF().show()  // 资源 id 隐射  val resInt2Index = new IndexToString()  .setInputCol(movie)  .setOutputCol(resId)  .setLabels(resIndexer.fit(indexed1).labels)  //userId 映射  val userInt2Index = new IndexToString()  .setInputCol(user)  .setOutputCol(userId)  .setLabels(userIndexer.fit(agg).labels)  val rc = userInt2Index.transform(resInt2Index.transform(result.toDF()))  rc.show()  rc.withColumnRenamed(rating , score).select(userId ,  resId , score).write.mode(SaveMode.Overwrite)  .format(jdbc)  .option(url , url)  .option(dbtable ,  recommends)  .option(user , user)  .option(password , pass)  .option(batchsize ,  5000)  .option(truncate ,  true)  .save  println(finished!!!)  } }

DataFrame 写入 mysql 还有另一种写法,就是原生写入:

 // 分区写推荐结果到 mysql
 r2.foreachPartition(p =  {
 @transient val conn = ConnectionPool.getConnection
 p.foreach(row =  { val userId = row.getInt(0)
 val arrayPredict: Seq[Row] = row.getSeq(1)
 arrayPredict.foreach(rowPredict =  { println(rowPredict(0) +  @  + rowPredict(1))
 val sql =  insert into recommends(userId,resId,score) values (  +
 userId+ ,  +
 rowPredict(0)+ , +
 rowPredict(1) +
  ) 
 println(sql: +sql)
 val stmt = conn.createStatement
 stmt.executeUpdate(sql)
 })
 })
 ConnectionPool.returnConnection(conn)
 })

感谢各位的阅读,以上就是“Spark ALS 实现的步骤是什么”的内容了,经过本文的学习后,相信大家对 Spark ALS 实现的步骤是什么这一问题有了更深刻的体会,具体使用情况还需要大家实践验证。这里是丸趣 TV,丸趣 TV 小编将为大家推送更多相关知识点的文章,欢迎关注!

正文完
 
丸趣
版权声明:本站原创文章,由 丸趣 2023-08-25发表,共计4600字。
转载说明:除特殊说明外本站除技术相关以外文章皆由网络搜集发布,转载请注明出处。
评论(没有评论)