diff --git a/src/masoniteorm/query/QueryBuilder.py b/src/masoniteorm/query/QueryBuilder.py index fb1a2adc..27564290 100644 --- a/src/masoniteorm/query/QueryBuilder.py +++ b/src/masoniteorm/query/QueryBuilder.py @@ -257,7 +257,7 @@ def begin(self): Returns: self """ - return self.new_connection().begin() + return self.get_connection().begin() def begin_transaction(self, *args, **kwargs): return self.begin(*args, **kwargs) @@ -462,7 +462,7 @@ def add_select(self, alias, callable): def statement(self, query, bindings=None): if bindings is None: bindings = [] - result = self.new_connection().query(query, bindings) + result = self.get_connection().query(query, bindings) return self.prepare_result(result) def select_raw(self, query): @@ -504,7 +504,7 @@ def bulk_create( if model: model = model.hydrate(self._creates) if not self.dry: - connection = self.new_connection() + connection = self.get_connection() query_result = connection.query( self.to_qmark(), self._bindings, results=1 ) @@ -561,7 +561,7 @@ def create( self._creates.update(model.get_dirty_attributes()) if not self.dry: - connection = self.new_connection() + connection = self.get_connection() query_result = connection.query( self.to_qmark(), self._bindings, results=1 @@ -616,7 +616,7 @@ def delete(self, column=None, value=None, query=False): self.where(model.get_primary_key(), model.get_primary_key_value()) self.observe_events(model, "deleting") - connection = self.new_connection() + connection = self.get_connection() connection.query(self.to_qmark(), self._bindings) @@ -973,7 +973,7 @@ def or_where_null(self, column): return self def chunk(self, chunk_amount): - chunk_connection = self.new_connection() + chunk_connection = self.get_connection() for result in chunk_connection.select_many( self.to_sql(), (), chunk_amount ): @@ -1574,7 +1574,7 @@ def update( return self additional.update(updates) - connection = self.new_connection() + connection = self.get_connection() connection.query(self.to_qmark(), self._bindings) if model: @@ -1640,7 +1640,7 @@ def increment(self, column, value=1, dry=False): if dry or self.dry: return self - results = self.new_connection().query(self.to_qmark(), self._bindings) + results = self.get_connection().query(self.to_qmark(), self._bindings) processed_results = self.get_processor().get_column_value( self, column, results, id_key, id_value ) @@ -1684,7 +1684,7 @@ def decrement(self, column, value=1, dry=False): if dry or self.dry: return self - result = self.new_connection().query(self.to_qmark(), self._bindings) + result = self.get_connection().query(self.to_qmark(), self._bindings) processed_results = self.get_processor().get_column_value( self, column, result, id_key, id_value ) @@ -1725,7 +1725,7 @@ def count(self, column=None, dry=False): return self if not column: - result = self.new_connection().query( + result = self.get_connection().query( self.to_qmark(), self._bindings, results=1 ) @@ -1845,7 +1845,7 @@ def first(self, fields=None, query=False): if query: return self - result = self.new_connection().query( + result = self.get_connection().query( self.to_qmark(), self._bindings, results=1 ) @@ -1912,7 +1912,7 @@ def last(self, column=None, query=False): if query: return self - result = self.new_connection().query( + result = self.get_connection().query( self.to_qmark(), self._bindings, results=1, @@ -2132,7 +2132,7 @@ def all(self, selects=[], query=False): return self result = ( - self.new_connection().query(self.to_qmark(), self._bindings) or [] + self.get_connection().query(self.to_qmark(), self._bindings) or [] ) return self.prepare_result(result, collection=True) @@ -2144,24 +2144,26 @@ def get(self, selects=[]): self """ self.select(*selects) - result = self.new_connection().query(self.to_qmark(), self._bindings) + result = self.get_connection().query(self.to_qmark(), self._bindings) return self.prepare_result(result, collection=True) def new_connection(self): - if self._connection: - return self._connection - - self._connection = ( + """Create a new connection""" + return ( self.connection_class( **self.get_connection_information(), name=self.connection ) .set_schema(self._schema) .make_connection() ) - return self._connection def get_connection(self): + """Get the current connection""" + if self._connection: + return self._connection + + self._connection = self.new_connection() return self._connection def without_eager(self): @@ -2386,7 +2388,7 @@ def truncate(self, foreign_keys=False, dry=False): if dry or self.dry: return sql - return self.new_connection().query(sql, ()) + return self.get_connection().query(sql, ()) def exists(self): """Determine if rows exist for the current query. diff --git a/src/masoniteorm/query/processors/MSSQLPostProcessor.py b/src/masoniteorm/query/processors/MSSQLPostProcessor.py index ecc7847d..5f450679 100644 --- a/src/masoniteorm/query/processors/MSSQLPostProcessor.py +++ b/src/masoniteorm/query/processors/MSSQLPostProcessor.py @@ -24,7 +24,7 @@ def process_insert_get_id(self, builder, results, id_key): dictionary: Should return the modified dictionary. """ - last_id = builder.new_connection().query( + last_id = builder.get_connection().query( "SELECT @@Identity as [id]", results=1 ) diff --git a/src/masoniteorm/schema/Schema.py b/src/masoniteorm/schema/Schema.py index eaedabf8..e0b34b6c 100644 --- a/src/masoniteorm/schema/Schema.py +++ b/src/masoniteorm/schema/Schema.py @@ -62,6 +62,7 @@ def __init__( self._dry = dry self.connection = connection self.connection_class = connection_class + self._connection_driver = None self._connection = None self.grammar = grammar self.platform = platform @@ -132,7 +133,7 @@ def create(self, table): self._blueprint = Blueprint( self.grammar, - connection=self.new_connection(), + connection=self.get_connection(), table=Table(table), action="create", platform=self.platform, @@ -148,7 +149,7 @@ def create_table_if_not_exists(self, table): self._blueprint = Blueprint( self.grammar, - connection=self.new_connection(), + connection=self.get_connection(), table=Table(table), action="create_table_if_not_exists", platform=self.platform, @@ -174,7 +175,7 @@ def table(self, table): self._blueprint = Blueprint( self.grammar, - connection=self.new_connection(), + connection=self.get_connection(), table=TableDiff(table), action="alter", platform=self.platform, @@ -212,15 +213,25 @@ def get_connection_information(self): } def new_connection(self): + """Explicitly creates a new connection.""" if self._dry: return - self._connection = ( + return ( self.connection_class(**self.get_connection_information()) .set_schema(self.schema) .make_connection() ) + def get_connection(self): + """Returns the cached connection, creating (and caching) one if needed.""" + if self._dry: + return + + if self._connection: + return self._connection + + self._connection = self.new_connection() return self._connection def has_column(self, table, column, query_only=False): @@ -238,11 +249,11 @@ def has_column(self, table, column, query_only=False): self._sql = sql return sql - return bool(self.new_connection().query(sql, ())) + return bool(self.get_connection().query(sql, ())) def get_columns(self, table, dict=True): table = self.platform().get_current_schema( - self.new_connection(), table, schema=self.get_schema() + self.get_connection(), table, schema=self.get_schema() ) result = {} if dict: @@ -264,7 +275,7 @@ def drop_table(self, table, query_only=False): self._sql = sql return sql - return bool(self.new_connection().query(sql, ())) + return bool(self.get_connection().query(sql, ())) def drop(self, *args, **kwargs): return self.drop_table(*args, **kwargs) @@ -276,7 +287,7 @@ def drop_table_if_exists(self, table, exists=False, query_only=False): self._sql = sql return sql - return bool(self.new_connection().query(sql, ())) + return bool(self.get_connection().query(sql, ())) def rename(self, table, new_name): sql = self.platform().compile_rename_table(table, new_name) @@ -285,7 +296,7 @@ def rename(self, table, new_name): self._sql = sql return sql - return bool(self.new_connection().query(sql, ())) + return bool(self.get_connection().query(sql, ())) def truncate(self, table, foreign_keys=False): sql = self.platform().compile_truncate( @@ -296,7 +307,7 @@ def truncate(self, table, foreign_keys=False): self._sql = sql return sql - return bool(self.new_connection().query(sql, ())) + return bool(self.get_connection().query(sql, ())) def get_schema(self): """Gets the schema set on the migration class""" @@ -315,7 +326,7 @@ def get_all_tables(self): self._sql = sql return sql - result = self.new_connection().query(sql, ()) + result = self.get_connection().query(sql, ()) return ( list(map(lambda t: list(t.values())[0], result)) if result else [] @@ -338,7 +349,7 @@ def has_table(self, table, query_only=False): self._sql = sql return sql - return bool(self.new_connection().query(sql, ())) + return bool(self.get_connection().query(sql, ())) def enable_foreign_key_constraints(self): sql = self.platform().enable_foreign_key_constraints() @@ -347,7 +358,7 @@ def enable_foreign_key_constraints(self): self._sql = sql return sql - return bool(self.new_connection().query(sql, ())) + return bool(self.get_connection().query(sql, ())) def disable_foreign_key_constraints(self): sql = self.platform().disable_foreign_key_constraints() @@ -356,4 +367,17 @@ def disable_foreign_key_constraints(self): self._sql = sql return sql - return bool(self.new_connection().query(sql, ())) + return bool(self.get_connection().query(sql, ())) + + def query_builder(self): + """Get a query builder for the schema connection""" + from ..query import QueryBuilder + + return QueryBuilder( + connection=self.connection, + connection_class=self.connection_class, + connection_driver=self._connection_driver, + connection_details=self.connection_details, + schema=self.schema, + dry=self._dry, + ) diff --git a/tests/query/test_querybuilder.py b/tests/query/test_querybuilder.py new file mode 100644 index 00000000..cba9ca70 --- /dev/null +++ b/tests/query/test_querybuilder.py @@ -0,0 +1,35 @@ +import unittest + +from src.masoniteorm.query import QueryBuilder +from tests.integrations.config.database import DATABASES +from tests.utils import MockSQLiteConnection + + +class TestQueryBuilder(unittest.TestCase): + maxDiff = None + + def get_builder(self): + return QueryBuilder( + connection="dev", + connection_class=MockSQLiteConnection, + connection_details=DATABASES, + ) + + def test_returned_connection_is_same(self): + builder = self.get_builder() + first_connection = builder.get_connection() + self.assertIsNotNone(first_connection) + second_connection = builder.get_connection() + self.assertIsNotNone(second_connection) + self.assertEqual(first_connection, second_connection) + + def test_new_connection_does_not_change_existing_connection(self): + builder = self.get_builder() + first_connection = builder.get_connection() + self.assertIsNotNone(first_connection) + new_connection = builder.new_connection() + self.assertIsNotNone(new_connection) + second_connection = builder.get_connection() + self.assertIsNotNone(second_connection) + self.assertEqual(first_connection, second_connection) + self.assertNotEqual(new_connection, first_connection) diff --git a/tests/schema/test_schema.py b/tests/schema/test_schema.py new file mode 100644 index 00000000..6d3c73a4 --- /dev/null +++ b/tests/schema/test_schema.py @@ -0,0 +1,43 @@ +import unittest + +from src.masoniteorm.query import QueryBuilder +from src.masoniteorm.schema import Schema +from src.masoniteorm.schema.platforms import SQLitePlatform +from tests.integrations.config.database import DATABASES +from tests.utils import MockSQLiteConnection + + +class TestSchema(unittest.TestCase): + maxDiff = None + + def get_schema(self): + return Schema( + connection="dev", + connection_class=MockSQLiteConnection, + connection_details=DATABASES, + platform=SQLitePlatform, + ) + + def test_connection_is_cached(self): + schema = self.get_schema() + first_connection = schema.get_connection() + self.assertIsNotNone(first_connection) + second_connection = schema.get_connection() + self.assertIsNotNone(second_connection) + self.assertEqual(first_connection, second_connection) + + def test_new_connection_is_not_cached(self): + schema = self.get_schema() + first_connection = schema.get_connection() + self.assertIsNotNone(first_connection) + new_connection = schema.new_connection() + self.assertIsNotNone(new_connection) + second_connection = schema.get_connection() + self.assertIsNotNone(second_connection) + self.assertEqual(first_connection, second_connection) + self.assertNotEqual(new_connection, first_connection) + + def test_can_get_query_builder(self): + schema = self.get_schema() + builder = schema.query_builder() + self.assertIsInstance(builder, QueryBuilder)