--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/src/main/scala/net/tz/lift/util/DB.scala Thu May 19 14:33:26 2011 +0200
@@ -0,0 +1,166 @@
+/*
+ * Fix for ProtoDBVendor/StandardDBVendor connection leak.
+ *
+ * Original code is:
+ *
+ * Copyright 2006-2011 WorldWide Conferencing, LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * All fixes are under the same license.
+ */
+package net.tz.lift.util
+
+import java.sql.{Connection, DriverManager}
+import net.liftweb.common._
+import net.liftweb.db.{ConnectionIdentifier, ConnectionManager}
+import net.liftweb.util._
+import net.liftweb.util.Helpers._
+
+/**
+ * The standard DB vendor.
+ * @param driverName the name of the database driver
+ * @param dbUrl the URL for the JDBC data connection
+ * @param dbUser the optional username
+ * @param dbPassword the optional db password
+ */
+class StandardDBVendor(driverName: String,
+ dbUrl: String,
+ dbUser: Box[String],
+ dbPassword: Box[String]) extends ProtoDBVendor {
+ protected def createOne: Box[Connection] = try {
+ Class.forName(driverName)
+
+ val dm = (dbUser, dbPassword) match {
+ case (Full(user), Full(pwd)) =>
+ DriverManager.getConnection(dbUrl, user, pwd)
+
+ case _ => DriverManager.getConnection(dbUrl)
+ }
+
+ Full(dm)
+ } catch {
+ case e: Exception => e.printStackTrace; Empty
+ }
+}
+
+trait ProtoDBVendor extends ConnectionManager {
+ private val logger = Logger(classOf[ProtoDBVendor])
+ private var pool: List[Connection] = Nil
+ private var poolSize = 0
+ private var tempMaxSize = maxPoolSize
+
+ /**
+ * Override and set to false if the maximum pool size can temporarilly be expanded to avoid pool starvation
+ */
+ protected def allowTemporaryPoolExpansion = true
+
+ /**
+ * Override this method if you want something other than
+ * 4 connections in the pool
+ */
+ protected def maxPoolSize = 4
+
+ /**
+ * The absolute maximum that this pool can extend to
+ * The default is 20. Override this method to change.
+ */
+ protected def doNotExpandBeyond = 20
+
+ /**
+ * The logic for whether we can expand the pool beyond the current size. By
+ * default, the logic tests allowTemporaryPoolExpansion && poolSize <= doNotExpandBeyond
+ */
+ protected def canExpand_? : Boolean = allowTemporaryPoolExpansion && poolSize <= doNotExpandBeyond
+
+ /**
+ * How is a connection created?
+ */
+ protected def createOne: Box[Connection]
+
+ /**
+ * Test the connection. By default, setAutoCommit(false),
+ * but you can do a real query on your RDBMS to see if the connection is alive
+ */
+ protected def testConnection(conn: Connection) {
+ conn.setAutoCommit(false)
+ }
+
+ def newConnection(name: ConnectionIdentifier): Box[Connection] =
+ synchronized {
+ pool match {
+ case Nil if poolSize < tempMaxSize =>
+ val ret = createOne
+ ret foreach { c =>
+ c.setAutoCommit(false)
+ poolSize = poolSize + 1
+ logger.debug("Created new pool entry. name=%s, poolSize=%d".format(name, poolSize))
+ }
+ ret
+
+ case Nil =>
+ val curSize = poolSize
+ logger.trace("No connection left in pool, waiting...")
+ wait(50L)
+ // if we've waited 50 ms and the pool is still empty, temporarily expand it
+ if (pool.isEmpty && poolSize == curSize && canExpand_?) {
+ tempMaxSize += 1
+ logger.debug("Temporarily expanding pool. name=%s, tempMaxSize=%d".format(name, tempMaxSize))
+ }
+ newConnection(name)
+
+ case x :: xs =>
+ logger.trace("Found connection in pool, name=%s".format(name))
+ pool = xs
+ try {
+ this.testConnection(x)
+ Full(x)
+ } catch {
+ case e => try {
+ logger.debug("Test connection failed, removing connection from pool, name=%s".format(name))
+ poolSize = poolSize - 1
+ tryo(x.close)
+ newConnection(name)
+ } catch {
+ case e => newConnection(name)
+ }
+ }
+ }
+ }
+
+ def releaseConnection(conn: Connection): Unit = synchronized {
+ if (tempMaxSize > maxPoolSize) {
+ tryo {conn.close()}
+ tempMaxSize -= 1
+ poolSize -= 1
+ } else {
+ pool = conn :: pool
+ }
+ logger.debug("Released connection. poolSize=%d".format(poolSize))
+ notifyAll
+ }
+
+ def closeAllConnections_!(): Unit = synchronized {
+ logger.info("Closing all connections, poolSize=%d".format(poolSize))
+ if (poolSize == 0) ()
+ else {
+ pool.foreach {c => tryo(c.close); poolSize -= 1}
+ pool = Nil
+
+ if (poolSize > 0) wait(250)
+
+ closeAllConnections_!()
+ }
+ }
+}
+// vim: set ts=2 sw=2 et: