Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 22 additions & 20 deletions src/masoniteorm/query/QueryBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def begin(self):
Returns:
self
"""
return self.new_connection().begin()
return self.get_connection().begin()

def begin_transaction(self, *args, **kwargs):
return self.begin(*args, **kwargs)
Expand Down Expand Up @@ -462,7 +462,7 @@ def add_select(self, alias, callable):
def statement(self, query, bindings=None):
if bindings is None:
bindings = []
result = self.new_connection().query(query, bindings)
result = self.get_connection().query(query, bindings)
return self.prepare_result(result)

def select_raw(self, query):
Expand Down Expand Up @@ -504,7 +504,7 @@ def bulk_create(
if model:
model = model.hydrate(self._creates)
if not self.dry:
connection = self.new_connection()
connection = self.get_connection()
query_result = connection.query(
self.to_qmark(), self._bindings, results=1
)
Expand Down Expand Up @@ -561,7 +561,7 @@ def create(
self._creates.update(model.get_dirty_attributes())

if not self.dry:
connection = self.new_connection()
connection = self.get_connection()

query_result = connection.query(
self.to_qmark(), self._bindings, results=1
Expand Down Expand Up @@ -616,7 +616,7 @@ def delete(self, column=None, value=None, query=False):
self.where(model.get_primary_key(), model.get_primary_key_value())
self.observe_events(model, "deleting")

connection = self.new_connection()
connection = self.get_connection()

connection.query(self.to_qmark(), self._bindings)

Expand Down Expand Up @@ -973,7 +973,7 @@ def or_where_null(self, column):
return self

def chunk(self, chunk_amount):
chunk_connection = self.new_connection()
chunk_connection = self.get_connection()
for result in chunk_connection.select_many(
self.to_sql(), (), chunk_amount
):
Expand Down Expand Up @@ -1574,7 +1574,7 @@ def update(
return self

additional.update(updates)
connection = self.new_connection()
connection = self.get_connection()

connection.query(self.to_qmark(), self._bindings)
if model:
Expand Down Expand Up @@ -1640,7 +1640,7 @@ def increment(self, column, value=1, dry=False):
if dry or self.dry:
return self

results = self.new_connection().query(self.to_qmark(), self._bindings)
results = self.get_connection().query(self.to_qmark(), self._bindings)
processed_results = self.get_processor().get_column_value(
self, column, results, id_key, id_value
)
Expand Down Expand Up @@ -1684,7 +1684,7 @@ def decrement(self, column, value=1, dry=False):
if dry or self.dry:
return self

result = self.new_connection().query(self.to_qmark(), self._bindings)
result = self.get_connection().query(self.to_qmark(), self._bindings)
processed_results = self.get_processor().get_column_value(
self, column, result, id_key, id_value
)
Expand Down Expand Up @@ -1725,7 +1725,7 @@ def count(self, column=None, dry=False):
return self

if not column:
result = self.new_connection().query(
result = self.get_connection().query(
self.to_qmark(), self._bindings, results=1
)

Expand Down Expand Up @@ -1845,7 +1845,7 @@ def first(self, fields=None, query=False):
if query:
return self

result = self.new_connection().query(
result = self.get_connection().query(
self.to_qmark(), self._bindings, results=1
)

Expand Down Expand Up @@ -1912,7 +1912,7 @@ def last(self, column=None, query=False):
if query:
return self

result = self.new_connection().query(
result = self.get_connection().query(
self.to_qmark(),
self._bindings,
results=1,
Expand Down Expand Up @@ -2132,7 +2132,7 @@ def all(self, selects=[], query=False):
return self

result = (
self.new_connection().query(self.to_qmark(), self._bindings) or []
self.get_connection().query(self.to_qmark(), self._bindings) or []
)

return self.prepare_result(result, collection=True)
Expand All @@ -2144,24 +2144,26 @@ def get(self, selects=[]):
self
"""
self.select(*selects)
result = self.new_connection().query(self.to_qmark(), self._bindings)
result = self.get_connection().query(self.to_qmark(), self._bindings)

return self.prepare_result(result, collection=True)

def new_connection(self):
if self._connection:
return self._connection

self._connection = (
"""Create a new connection"""
return (
self.connection_class(
**self.get_connection_information(), name=self.connection
)
.set_schema(self._schema)
.make_connection()
)
return self._connection

def get_connection(self):
"""Get the current connection"""
if self._connection:
return self._connection

self._connection = self.new_connection()
return self._connection

def without_eager(self):
Expand Down Expand Up @@ -2386,7 +2388,7 @@ def truncate(self, foreign_keys=False, dry=False):
if dry or self.dry:
return sql

return self.new_connection().query(sql, ())
return self.get_connection().query(sql, ())

def exists(self):
"""Determine if rows exist for the current query.
Expand Down
2 changes: 1 addition & 1 deletion src/masoniteorm/query/processors/MSSQLPostProcessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def process_insert_get_id(self, builder, results, id_key):
dictionary: Should return the modified dictionary.
"""

last_id = builder.new_connection().query(
last_id = builder.get_connection().query(
"SELECT @@Identity as [id]", results=1
)

Expand Down
52 changes: 38 additions & 14 deletions src/masoniteorm/schema/Schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
self._dry = dry
self.connection = connection
self.connection_class = connection_class
self._connection_driver = None
self._connection = None
self.grammar = grammar
self.platform = platform
Expand Down Expand Up @@ -132,7 +133,7 @@ def create(self, table):

self._blueprint = Blueprint(
self.grammar,
connection=self.new_connection(),
connection=self.get_connection(),
table=Table(table),
action="create",
platform=self.platform,
Expand All @@ -148,7 +149,7 @@ def create_table_if_not_exists(self, table):

self._blueprint = Blueprint(
self.grammar,
connection=self.new_connection(),
connection=self.get_connection(),
table=Table(table),
action="create_table_if_not_exists",
platform=self.platform,
Expand All @@ -174,7 +175,7 @@ def table(self, table):

self._blueprint = Blueprint(
self.grammar,
connection=self.new_connection(),
connection=self.get_connection(),
table=TableDiff(table),
action="alter",
platform=self.platform,
Expand Down Expand Up @@ -212,15 +213,25 @@ def get_connection_information(self):
}

def new_connection(self):
"""Explicitly creates a new connection."""
if self._dry:
return

self._connection = (
return (
self.connection_class(**self.get_connection_information())
.set_schema(self.schema)
.make_connection()
)

def get_connection(self):
"""Returns the cached connection, creating (and caching) one if needed."""
if self._dry:
return

if self._connection:
return self._connection

self._connection = self.new_connection()
return self._connection

def has_column(self, table, column, query_only=False):
Expand All @@ -238,11 +249,11 @@ def has_column(self, table, column, query_only=False):
self._sql = sql
return sql

return bool(self.new_connection().query(sql, ()))
return bool(self.get_connection().query(sql, ()))

def get_columns(self, table, dict=True):
table = self.platform().get_current_schema(
self.new_connection(), table, schema=self.get_schema()
self.get_connection(), table, schema=self.get_schema()
)
result = {}
if dict:
Expand All @@ -264,7 +275,7 @@ def drop_table(self, table, query_only=False):
self._sql = sql
return sql

return bool(self.new_connection().query(sql, ()))
return bool(self.get_connection().query(sql, ()))

def drop(self, *args, **kwargs):
return self.drop_table(*args, **kwargs)
Expand All @@ -276,7 +287,7 @@ def drop_table_if_exists(self, table, exists=False, query_only=False):
self._sql = sql
return sql

return bool(self.new_connection().query(sql, ()))
return bool(self.get_connection().query(sql, ()))

def rename(self, table, new_name):
sql = self.platform().compile_rename_table(table, new_name)
Expand All @@ -285,7 +296,7 @@ def rename(self, table, new_name):
self._sql = sql
return sql

return bool(self.new_connection().query(sql, ()))
return bool(self.get_connection().query(sql, ()))

def truncate(self, table, foreign_keys=False):
sql = self.platform().compile_truncate(
Expand All @@ -296,7 +307,7 @@ def truncate(self, table, foreign_keys=False):
self._sql = sql
return sql

return bool(self.new_connection().query(sql, ()))
return bool(self.get_connection().query(sql, ()))

def get_schema(self):
"""Gets the schema set on the migration class"""
Expand All @@ -315,7 +326,7 @@ def get_all_tables(self):
self._sql = sql
return sql

result = self.new_connection().query(sql, ())
result = self.get_connection().query(sql, ())

return (
list(map(lambda t: list(t.values())[0], result)) if result else []
Expand All @@ -338,7 +349,7 @@ def has_table(self, table, query_only=False):
self._sql = sql
return sql

return bool(self.new_connection().query(sql, ()))
return bool(self.get_connection().query(sql, ()))

def enable_foreign_key_constraints(self):
sql = self.platform().enable_foreign_key_constraints()
Expand All @@ -347,7 +358,7 @@ def enable_foreign_key_constraints(self):
self._sql = sql
return sql

return bool(self.new_connection().query(sql, ()))
return bool(self.get_connection().query(sql, ()))

def disable_foreign_key_constraints(self):
sql = self.platform().disable_foreign_key_constraints()
Expand All @@ -356,4 +367,17 @@ def disable_foreign_key_constraints(self):
self._sql = sql
return sql

return bool(self.new_connection().query(sql, ()))
return bool(self.get_connection().query(sql, ()))

def query_builder(self):
"""Get a query builder for the schema connection"""
from ..query import QueryBuilder

return QueryBuilder(
connection=self.connection,
connection_class=self.connection_class,
connection_driver=self._connection_driver,
connection_details=self.connection_details,
schema=self.schema,
dry=self._dry,
)
35 changes: 35 additions & 0 deletions tests/query/test_querybuilder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import unittest

from src.masoniteorm.query import QueryBuilder
from tests.integrations.config.database import DATABASES
from tests.utils import MockSQLiteConnection


class TestQueryBuilder(unittest.TestCase):
maxDiff = None

def get_builder(self):
return QueryBuilder(
connection="dev",
connection_class=MockSQLiteConnection,
connection_details=DATABASES,
)

def test_returned_connection_is_same(self):
builder = self.get_builder()
first_connection = builder.get_connection()
self.assertIsNotNone(first_connection)
second_connection = builder.get_connection()
self.assertIsNotNone(second_connection)
self.assertEqual(first_connection, second_connection)

def test_new_connection_does_not_change_existing_connection(self):
builder = self.get_builder()
first_connection = builder.get_connection()
self.assertIsNotNone(first_connection)
new_connection = builder.new_connection()
self.assertIsNotNone(new_connection)
second_connection = builder.get_connection()
self.assertIsNotNone(second_connection)
self.assertEqual(first_connection, second_connection)
self.assertNotEqual(new_connection, first_connection)
Loading
Loading