src/main/scala/net/tz/lift/util/DB.scala
changeset 5 993582ca8d2e
equal deleted inserted replaced
4:bb4478e1cff7 5:993582ca8d2e
       
     1 /*
       
     2  * Fix for ProtoDBVendor/StandardDBVendor connection leak.
       
     3  *
       
     4  * Original code is:
       
     5  *
       
     6  * Copyright 2006-2011 WorldWide Conferencing, LLC
       
     7  *
       
     8  * Licensed under the Apache License, Version 2.0 (the "License");
       
     9  * you may not use this file except in compliance with the License.
       
    10  * You may obtain a copy of the License at
       
    11  *
       
    12  *     http://www.apache.org/licenses/LICENSE-2.0
       
    13  *
       
    14  * Unless required by applicable law or agreed to in writing, software
       
    15  * distributed under the License is distributed on an "AS IS" BASIS,
       
    16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
       
    17  * See the License for the specific language governing permissions and
       
    18  * limitations under the License.
       
    19  *
       
    20  * All fixes are under the same license.
       
    21  */
       
    22 package net.tz.lift.util
       
    23 
       
    24 import java.sql.{Connection, DriverManager}
       
    25 import net.liftweb.common._
       
    26 import net.liftweb.db.{ConnectionIdentifier, ConnectionManager}
       
    27 import net.liftweb.util._
       
    28 import net.liftweb.util.Helpers._
       
    29 
       
    30 /**
       
    31  * The standard DB vendor.
       
    32  * @param driverName the name of the database driver
       
    33  * @param dbUrl the URL for the JDBC data connection
       
    34  * @param dbUser the optional username
       
    35  * @param dbPassword the optional db password
       
    36  */
       
    37 class StandardDBVendor(driverName: String,
       
    38                        dbUrl: String,
       
    39                        dbUser: Box[String],
       
    40                        dbPassword: Box[String]) extends ProtoDBVendor {
       
    41   protected def createOne: Box[Connection] = try {
       
    42     Class.forName(driverName)
       
    43 
       
    44     val dm = (dbUser, dbPassword) match {
       
    45       case (Full(user), Full(pwd)) =>
       
    46         DriverManager.getConnection(dbUrl, user, pwd)
       
    47 
       
    48       case _ => DriverManager.getConnection(dbUrl)
       
    49     }
       
    50 
       
    51     Full(dm)
       
    52   } catch {
       
    53     case e: Exception => e.printStackTrace; Empty
       
    54   }
       
    55 }
       
    56 
       
    57 trait ProtoDBVendor extends ConnectionManager {
       
    58   private val logger = Logger(classOf[ProtoDBVendor])
       
    59   private var pool: List[Connection] = Nil
       
    60   private var poolSize = 0
       
    61   private var tempMaxSize = maxPoolSize
       
    62 
       
    63   /**
       
    64    * Override and set to false if the maximum pool size can temporarilly be expanded to avoid pool starvation
       
    65    */
       
    66   protected def allowTemporaryPoolExpansion = true
       
    67 
       
    68   /**
       
    69    *  Override this method if you want something other than
       
    70    * 4 connections in the pool
       
    71    */
       
    72   protected def maxPoolSize = 4
       
    73 
       
    74   /**
       
    75    * The absolute maximum that this pool can extend to
       
    76    * The default is 20.  Override this method to change.
       
    77    */
       
    78   protected def doNotExpandBeyond = 20
       
    79 
       
    80   /**
       
    81    * The logic for whether we can expand the pool beyond the current size.  By
       
    82    * default, the logic tests allowTemporaryPoolExpansion && poolSize <= doNotExpandBeyond
       
    83    */
       
    84   protected def canExpand_? : Boolean = allowTemporaryPoolExpansion && poolSize <= doNotExpandBeyond
       
    85 
       
    86   /**
       
    87    *   How is a connection created?
       
    88    */
       
    89   protected def createOne: Box[Connection]
       
    90 
       
    91   /**
       
    92    * Test the connection.  By default, setAutoCommit(false),
       
    93    * but you can do a real query on your RDBMS to see if the connection is alive
       
    94    */
       
    95   protected def testConnection(conn: Connection) {
       
    96     conn.setAutoCommit(false)
       
    97   }
       
    98 
       
    99   def newConnection(name: ConnectionIdentifier): Box[Connection] =
       
   100     synchronized {
       
   101       pool match {
       
   102         case Nil if poolSize < tempMaxSize =>
       
   103           val ret = createOne
       
   104           ret foreach { c =>
       
   105             c.setAutoCommit(false)
       
   106             poolSize = poolSize + 1
       
   107             logger.debug("Created new pool entry. name=%s, poolSize=%d".format(name, poolSize))
       
   108           }
       
   109           ret
       
   110 
       
   111         case Nil =>
       
   112           val curSize = poolSize
       
   113           logger.trace("No connection left in pool, waiting...")
       
   114           wait(50L)
       
   115           // if we've waited 50 ms and the pool is still empty, temporarily expand it
       
   116           if (pool.isEmpty && poolSize == curSize && canExpand_?) {
       
   117             tempMaxSize += 1
       
   118             logger.debug("Temporarily expanding pool. name=%s, tempMaxSize=%d".format(name, tempMaxSize))
       
   119           }
       
   120           newConnection(name)
       
   121 
       
   122         case x :: xs =>
       
   123           logger.trace("Found connection in pool, name=%s".format(name))
       
   124           pool = xs
       
   125           try {
       
   126             this.testConnection(x)
       
   127             Full(x)
       
   128           } catch {
       
   129             case e => try {
       
   130               logger.debug("Test connection failed, removing connection from pool, name=%s".format(name))
       
   131               poolSize = poolSize - 1
       
   132               tryo(x.close)
       
   133               newConnection(name)
       
   134             } catch {
       
   135               case e => newConnection(name)
       
   136             }
       
   137           }
       
   138       }
       
   139     }
       
   140 
       
   141   def releaseConnection(conn: Connection): Unit = synchronized {
       
   142     if (tempMaxSize > maxPoolSize) {
       
   143       tryo {conn.close()}
       
   144       tempMaxSize -= 1
       
   145       poolSize -= 1
       
   146     } else {
       
   147       pool = conn :: pool
       
   148     }
       
   149     logger.debug("Released connection. poolSize=%d".format(poolSize))
       
   150     notifyAll
       
   151   }
       
   152 
       
   153   def closeAllConnections_!(): Unit = synchronized {
       
   154     logger.info("Closing all connections, poolSize=%d".format(poolSize))
       
   155     if (poolSize == 0) ()
       
   156     else {
       
   157       pool.foreach {c => tryo(c.close); poolSize -= 1}
       
   158       pool = Nil
       
   159       
       
   160       if (poolSize > 0) wait(250)
       
   161 
       
   162       closeAllConnections_!()
       
   163     }
       
   164   }
       
   165 }
       
   166 // vim: set ts=2 sw=2 et: