spark 写 gp/tpg 效率优化 —— 写入 237w 行数据耗时从 77 分钟到 34 秒

摘自内部分享,有删减。

具体到我们这次的场景中,我们用的是 gp,gp 全称是 greenplum,是一个 mpp 版本的 postgresql,可以参考这个简介 http://www.infoq.com/cn/news/2… ,协议上兼容 postgresql,我们可以用普通能连 postgresql 的方式去连 gp,并且把 gp 看成一个黑盒的集群版本的 postgresql 来使用

然后这次的优化的手段也很简单,就是从原来的 jdbc 连接拼 sql 改成用 org.postgresql.copy.CopyManager,类似 postgresql 命令行下的 \copy 命令,所以一句话就能说完,而写这个文章的点主要是分享一下这个过程中的一些思路历程和细节

对比图

1501056179_100_w367_h360

作为对比,我们原先的数据写入方式是 jdbc 连上之后拼 insert 语句,应该说这种方式在 OLTP 场景下是很适用的,但是在 OLAP 场景下效率问题就开始显现出来了,耗时不仅仅产生在写入端拼 query string 的开销上,更重的是在 db server 端去 parse query 的耗时成本,以及附带衍生的事务,回滚日志等开销成本。

那么 gp 作为一个立足于大量数据处理的 RDBMS,肯定要对数据的 IO 有一个解决方案的,官方是怎么来解决这个问题的呢,看到这里,https://gpdb.docs.pivotal.io/4… ,官方主要提供了几种方案:

  • External Tables enable accessing external files as if they are regular database tables.
  • gpload provides an interface to the Greenplum Database parallel loader.
  • COPY is the standard PostgreSQL non-parallel data loading tool.

其中,外部表的方式可以通过以下几种实现达成

  • gpfdist: points to a directory on the file host and serves external data files to all Greenplum Database segments in parallel.
  • gpfdists: the secure version of gpfdist.
  • file:// accesses external data files on a segment host that the Greenplum superuser (gpadmin) can access.
  • gphdfs: accesses files on a Hadoop Distributed File System (HDFS).

gpfdist 可以把一个外部机器上的数据文件让所有 seg 节点能访问到,因而就可以并行的载入数据,gpfdists 是一个安全版本的 gpfdist,但是这种方式存在一些问题考量:

  1. 需要在 seg 节点上额外安装部署程序
  2. 不兼容 tpg(tpg 本身就没有 seg 节点的概念,囧)

所以没有选用,而 gphdfs 这种方式,能够让 gp 连上 hdfs 去读数据,也是并行的,但是出于以下考量也最终没有采用

  1. 我们的 tdw 的 hdfs 带了自定义的鉴权
  2. 我们在 hive 表中的存储格式并不是平坦的二维表,由于指标的值稀疏,我们使用的是类似 postgresql 的 hstore 的存储格式,而这种形式并不利于直接表对表的拷到 gp 中成为一张平坦的表

而 gpload 其实就是一个外部表的载入界面封装

The gpload data loading utility is the interface to Greenplum’s external table parallel loading feature.

所以最终我们的选择就落在了 copy 上,用 copy 的好处主要是

  1. 是 postgresql 的标准工具,无缝兼容 gp 与 tpg,一次干活到处使用
  2. 不需要额外的依赖与安装部署,对目标 server 没有特殊要求

虽然,在官方的介绍中说了 copy 是一个非并行的工具,但是,实测下来,copy 的效率并不低

用 copy 有两种方式,一种是在命令行上用,参考 https://www.postgresql.org/doc… ,另外一种,是引入 jar 包,在代码中用,参考 https://jdbc.postgresql.org/do… ,可以看到他在函数的 Javadoc 上说了

Use COPY FROM STDIN for very fast copying from an InputStream into a database table.

这也说明这个工具的作者是很自信的(笑)

可以看到这个函数有两个重载

public long copyIn(String sql,
                   Reader from)
            throws SQLException,
                   IOException

public long copyIn(String sql,
                   InputStream from)
            throws SQLException,
                   IOException

那么自然的问题就是,从 Reader 中读取和从 InputStream 中读取有什么区别?

参考这个文章, http://blog.sina.com.cn/s/blog… ,其实,二者的主要区别在于:

InputStream提供的是字节流的读取,而非文本读取,这是和Reader类的根本区别。
即用Reader读取出来的是char数组或者String ,使用InputStream读取出来的是byte数组。

也就是说,他们一个是面向字节的,一个是面向字符的,而面向字符的自然就要面临一个问题就是字符的编码方式的选择问题,以及解码和编码的开销成本问题,所以从效率上来说,我们应该是选用面向字节的方式

去看他的源码实现也可以发现

    public long copyIn(final String sql, InputStream from, int bufferSize) throws SQLException, IOException {
        byte[] buf = new byte[bufferSize];
        int len;
        CopyIn cp = copyIn(sql);
        try {
            while( (len = from.read(buf)) > 0 ) {
                cp.writeToCopy(buf, 0, len);
            }
            return cp.endCopy();
        } finally { // see to it that we do not leave the connection locked
            if(cp.isActive())
                cp.cancelCopy();
        }
    }

对于 InputStream 的方式,直接读进来就可以用了

    public long copyIn(final String sql, Reader from, int bufferSize) throws SQLException, IOException {
        char[] cbuf = new char[bufferSize];
        int len;
        CopyIn cp = copyIn(sql);
        try {
            while ( (len = from.read(cbuf)) > 0) {
                byte[] buf = encoding.encode(new String(cbuf, 0, len));
                cp.writeToCopy(buf, 0, buf.length);
            }
            return cp.endCopy();
        } finally { // see to it that we do not leave the connection locked
            if(cp.isActive())
                cp.cancelCopy();
        }
    }

而对于 Reader 的方式,读进来之后还要 encoding 处理一下,进一步验证了我们的想法

接下来,就可以得出一个使用方式的 demo 代码了,我们的根本需求是把计算结果写入,所以应该是 RDD[T] ,但是,由于需要转成一个 InputStream,所以我们需要转而接受一个 Array[Array[String]] 的入参

所以到目前为止可以得出代码如下

  def copyIn(data: Array[Array[String]], tblName: String): Long = {
    var con: Connection = null
    try {      
      Class.forName("org.postgresql.Driver")
      println("connecting to database with url " + url)
      con = DriverManager.getConnection(url, user, password)
      val cm = new CopyManager(con.asInstanceOf[BaseConnection])
      val COPY_CMD = s"COPY $tblName from STDIN"
      val start = System.currentTimeMillis()
      val affectedRowCount = cm.copyIn(COPY_CMD, genInputStream(data))
      val finish = System.currentTimeMillis()
      println("copy operation completed successfully in " + (finish-start)/1000.0 + " seconds, affectedRowCount " + affectedRowCount)
      con.close()
      affectedRowCount
    } catch {
      case ex: SQLException => println("Failed to copy data: " + ex.getMessage()); 0
    } finally {
      try {
        if (con != null) con.close()
      } catch {
        case ex: SQLException => println(ex.getMessage())
      }
    }
  }
  
  def genInputStream(arr: Array[Array[String]]): InputStream = {    
    val stringBuilder = new StringBuilder
    println("input data has " + arr.length + " rows")
    if (arr.length != 0) {
      val rowcount = arr.length;
      val columncount = arr(0).length
      for (i <- 0 to rowcount-1; j <- 0 to columncount-1) {
        stringBuilder.append(arr(i)(j) + (if (j == columncount-1) "\r\n" else "\t"))
      }
    }
    val str = stringBuilder.toString
    println("input data " + arr.length + " rows total " + str.length + " bytes")
    new ByteArrayInputStream(str.getBytes(StandardCharsets.UTF_8))
  }

以上这个代码能够正常跑出文章开头的性能测试的结果,但是很显然,埋了一个大坑:会爆内存

且不说为了能够满足 API 的要求,我们需要把输入数据组织成一个 Array[Array[String]] 并拉到 driver 节点(这一步骤本身就违背了 driver 节点不干具体活的宗旨),就说我们把 RDD collect 到 driver 之后,还需要转变出一个 InputStream 这种形式,中途还需要通过 StringBuilder 去 Build 一个大 string,这何止是奢侈,简直就是奢侈,Array[Array[String]] 占一份内存,toString 之后又占一份内存

于是尝试使用 PipedOutputStream 和 PipedInputStream 来解决,这是一个基于管道的流式读写,我们可以起一个单独的线程,来往这个 PipedOutputStream 写入数据,由于缓冲区大小有限,他就会阻塞在缓冲区满的状态下,然后读取端从 PipedInputStream 去读,一边读一边写入到网络上去,jvm 顿时轻松很多,但是,动手之前,有一个问题是,怎么来确认我们的这些改动是真的有效呢,由此需要引入 java 运行时的内存监控工具

我们可以发现,在安装 jdk 的时候,附送了这么一个东西

1501063231_39_w214_h48

他其实是一个运行时的监控,简单的 CPU 内存监控可以不需要说明书就能上手

1501063311_9_w929_h882

他可以绘制出 jvm 运行时的 cpu 和内存曲线图,并带有仪表盘

另外,我们也通过 runtime 来获取使用的内存,参考这里, http://viralpatel.net/blogs/ge…  可以加入打印函数如下

  def printMem(currentMoment: String) {
    println(s"=====$currentMoment=========")
    val mb = 1024*1024
    val runtime = Runtime.getRuntime()
    println("Used Memory:" + (runtime.totalMemory() - runtime.freeMemory()) / mb)
    println("Free Memory:" + runtime.freeMemory() / mb)
    println("Total Memory:" + runtime.totalMemory() / mb)
    println("===============")
  }

然后在原有的函数上打点

  def genInputStream(arr: Array[Array[String]]): InputStream = {    
    printMem("before gen string")
    val stringBuilder = new StringBuilder
    println("input data has " + arr.length + " rows")
    if (arr.length != 0) {
      val rowcount = arr.length;
      val columncount = arr(0).length
      for (i <- 0 to rowcount-1; j <- 0 to columncount-1) {
        stringBuilder.append(arr(i)(j) + (if (j == columncount-1) "\r\n" else "\t"))
      }
    }
    val str = stringBuilder.toString
    printMem("after gen string")
    println("input data " + arr.length + " rows total " + str.length + " bytes")
    new ByteArrayInputStream(str.getBytes(StandardCharsets.UTF_8))
  }

并生成 1kw 行测试数据

  def main(args: Array[String]): Unit = {
    //var data = Array(Array("P1","PenDrive","50","US"), Array("P1","PenDrive","300","US"))
    printMem("before gen array")
    val data = Array.fill(100*10000*10)(Array("P1","PenDrive","50","US"))
    printMem("after gen array")
    copyIn(data, "test.product")
  }

我们使用一个测试用的表,直接在 eclipse 中跑一下,可以得到输出如下

=====before gen array=========
Used Memory:3
Free Memory:241
Total Memory:245
===============
=====after gen array=========
Used Memory:345
Free Memory:290
Total Memory:635
===============
=====before gen string=========
Used Memory:352
Free Memory:305
Total Memory:658
===============
input data has 10000000 rows
=====after gen string=========
Used Memory:1989
Free Memory:479
Total Memory:2469
===============
input data 10000000 rows total 190000000 bytes
copy operation completed successfully in 69.951 seconds, affectedRowCount 10000000

可以看到,通过 stringbuilder 来生成 inputstream 的方式,耗用的内存,远比一倍要多

那么接下来就可以尝试改成 PipedOutputStream 和 PipedInputStream 的方式了

把生成 InputStream 的类改成如下

  def genPipedInputStream(arr: Array[Array[String]]): InputStream = {
    printMem("before gen inputstream")
    val out = new PipedOutputStream
    (new Thread(){
      override def run {
        println("input data has " + arr.length + " rows")
        if (arr.length != 0) {
          val rowcount = arr.length;
          val columncount = arr(0).length
          for (i <- 0 to rowcount-1; j <- 0 to columncount-1) {
            out.write((arr(i)(j) + (if (j == columncount-1) "\r\n" else "\t")).getBytes(StandardCharsets.UTF_8))
          }
        }        
        out.close()
        println("PipedOutputStream closed")
      }
    }).start()
    val in = new PipedInputStream
    in.connect(out)
    printMem("after gen inputstream")
    in    
  }

(这里其实隐含了一个问题就是是否需要 CountDownLatch)

可以看到输出如下,耗时有所增加,不过内存控制住了

=====before gen array=========
Used Memory:3
Free Memory:241
Total Memory:245
===============
=====after gen array=========
Used Memory:345
Free Memory:295
Total Memory:641
===============
=====before gen inputstream=========
Used Memory:352
Free Memory:288
Total Memory:641
===============
=====after gen inputstream=========
Used Memory:352
Free Memory:286
input data has 10000000 rows
Total Memory:641
===============
PipedOutputStream closed
copy operation completed successfully in 97.917 seconds, affectedRowCount 10000000

并且过程中的内存曲线基本平稳

1501069503_38_w929_h882

接下来,浮现出来的一个问题就是:是否真的需要把 RDD[T] collect 到 driver 上来?

答案其实是可以不需要,我们有 mapPartitions 这个算子,可以写成如下

val start = System.currentTimeMillis()
dataGpFlatten.mapPartitions(x => {
  GPCopyMgr.copyIn(x.toArray, "xxxxx")
  x
}).count
val finish = System.currentTimeMillis()
println("operation completed successfully in " + (finish-start)/1000.0 + " seconds")

需要注意的是,mapPartitions 并不是 action,而是一个 transform,所以我们需要在后面给他跟上一个 action,例如 count,来触发执行

主节点再无写入数据的动作,并且总的耗时比文章开头的耗时还要下降了 5s,不过基本在一个量级,可以认为是实验误差范围内

通过这种 mapPartitions 的方式,需要注意的问题有

  1. partition 数量的选择,过多容易造成同时连接 db 的连接数过多,而且每个分区小了,其实吞吐性能不利
  2. 如果需要 re-partition,需要意识到 re-partition 也是有开销成本的
  3. 最后别忘了跟一个 action

至此,基本就完结了,剩下就是一些工程化方面的工作,例如

  1. 在写入数据之前删除分区,以避免脏数据
  2. 在写入数据之后校验写入行数是否相符,以免某个 partition 写的过程中出异常了(这里其实引申出来一个问题,如果某个 executor 在写到一半的时候挂了,怎么办,是否只能整个 lz 任务重跑来清理现场?)
  3. 加强日志的可读性

以上动作都是工程化方面的工作,其实还是避免自己给自己挖坑,哈哈

Leave a Reply

Your email address will not be published. Required fields are marked *