Context Manager in Python
Introduction
External resources are often used in programming, and in those cases the programmer must ensure that the descriptor to those resources was closed. If not closed properly, it could run in the background (waste CPU/memory) and cause the program and/or computer to be slow. In this project a SQLite client was built as a context manager to ensure resource is closed when program terminates, even in the case of an error being raised.
Background
One can open a file using file = open(file_path, 'r')
, run some analysis and then close the file file.close()
. This way works, but could be wasteful if an error is raised. If it is just a small file, it might not waste much memory/CPU, but for other tools, such as selenium driver, it could be bad.
Best practice is to utilize with clause:
with open(file_path, 'r') as file: #run some code
This is the context manager functionality which ensures that the connection is closed when program is terminated, including cases when an error is raised. There are different ways to create a context manager that works in with clause, one of them is shown below.
Class Methods
There are two methods that transform a class into context manager: __enter__
and __exit__
. The return value from __enter__
is assigned to the target in the with
statement, and __exit__
method is invoked in the end (that's where connection is closed).
First, lets define a new class called SQLiteClient
:
'''SQLite client''' import sqlite3 from typing import Tuple, Any, List, Optional import re class SQLiteClient(): def __new__(cls, db_name: str = "", timeout:int = 10): """ Ensure parameters are valid prior to memory allocation. Parameters ---------- db_name: str. Name of SQLite database. Can also be a path to database if it is in another directory. timeout: int. Number of seconds before raising OperationalError. """ if not db_name: raise ValueError("Database Name is Missing!") if not isinstance(db_name, str): raise TypeError("Database name must be a string!") if not isinstance(timeout, int): raise TypeError("Timeout must be an integer!") if timeout < 0 : raise ValueError("Timeout must be greater than 0") return super().__new__(cls) def __init__(self, db_name: str = "", timeout: int = 10): """ This method is executed after memory allocation (after execution of __new__ method, and given all parameters are valid), and initialize attributes. Parameters ---------- db_name: str. Name of SQLite database. Can also be a path to database if it is in another directory. timeout: int. Number of seconds before raising OperationalError. """ # Ensure .db file extension exists. self.name = db_name if db_name[-3:] != ".db": self.name += ".db" self.timeout = timeout self.__conn = None
Note that the __new__
method was utilized in order to check wether arguments are valid. If invalid, then an error is raised and memory will not be allocated. Otherwise, memory is allocated for the new object and __init__
method is invoked to initialize attributes.
At this point an object can be created by simply running SQLiteClient(arguments)
. Adding the code below transform the client to a context manager:
def __enter__(self): """ Provides context manager functionality which is triggered by 'with' clause and creates connection to database. For example: `with SQLiteClient(name, 10) as client: #code..... ` """ self.__conn = sqlite3.connect(self.name, timeout=self.timeout) return self def __exit__(self, exc_type, exc_val, exc_tb): """ Part of the context manager functionality. This method ensures that connection to database is always closed at the end of execution, even if an error is raised. """ self.__conn.close() self.__conn = None
Note that __enter__
establish connection to sqlite database using sqlite3.connect()
method and __exit__
ensures connection is closed by executing self.conn.close()
method.
Additional functionality can be added to allow executing CRUD (Create, Read, Update and Delete) operations:
def __exec(self, query: str, parameters: Tuple[Any,...], rows: int = 0) -> List[Tuple[Any,...]]: """ Private method (starts with __) - Only accessible by methods defined withini the SQLiteClient class scope. Execute a sql query containing question marks `?` and replace them with parameters (based on their order). Rows variable is used to distinguish between CIUD (CREATE, INSERT, UPDATE, DELETE) operations and SELECT statements (fetchone, fetchall, fetchmany). Note: Method was set as private to prevent bad actors from running dangerous sql statements (like sql injections) Parameters ---------- query: str. SQL script to execute. Note that values are represented wtih question mark (?). For example: `SELECT * FROM users WHERE name = ? and pass = ?` parameters: Tuple(Any). Values that are safely inserted into sql statement (replacing ?). For example: (username, pass) If there is only one variable, then tuple should look like that: (value,). If there are no variables, then tuple can be empty: (). rows: int. Set to CIUD (CREATE, INSERT, UPDATE and DELETE) when rows = 0, otherwise represent SELECT statement. rows = -1 returns all rows (fetchall), rows = 1 returns only one row (fetchone) and specific number limit the number of rows (fetchmany). """ cur = self.__conn.cursor() cur.execute(query, parameters) # CREATE, INSERT, UPDATE, DELETE operations if rows == 0: self.__conn.commit() return None # SELECT - fetchall if rows == -1: data_ls = cur.fetchall() # SELECT - fetchone if rows == 1: data_ls = [cur.fetchone()] # SELECT - get all rows if rows > 1: data_ls = cur.fetchmany(rows) return data_ls def is_conn(func): def check_connection(self, query, parameters, *args, **kwargs): if not self.__conn: raise ValueError("No connection to database") return func(self, query, parameters, *args, **kwargs) return check_connection def check_args(func): def wrapper(self, query, parameters, *args, **kwargs): # Type checks if not isinstance(query, str): raise TypeError("Query must be a string!") if not isinstance(parameters, tuple): raise TypeError("Parameters variable must be a tuple!") # Number of variables checks # Number of ? in query vs. length of parameters variables = re.findall('\?', query) if len(variables) > len(parameters): raise ValueError("Some variables are missing!") if len(variables) < len(parameters): raise ValueError("Too many values in parameters. This may raise sqlite3.ProgrammingError. Please check your parameters and try again.") # Value Errors unique cases for each function. if func.__name__ == "ciud": if query[:6].lower() not in ["create", "insert", "update", "delete"]: raise ValueError("Query must be CIUD (CREATE, INSERT UPDATE or DELETE)") if func.__name__ == "read": if query[:6].lower() not in ["select"]: raise ValueError("Query must start with `SELECT ....`") rows = None try: rows = args[0] except IndexError: if "rows" in kwargs.keys(): rows = kwargs["rows"] finally: if rows == 0: raise ValueError("Cannot commit SELECT statement. Please set `rows` to a number other than 0") return func(self, query, parameters, *args, **kwargs) return wrapper @is_conn @check_args def ciud(self, query: str, parameters: Tuple[Any, ...]) -> None: """ Execute CREATE, INSERT, UPDATE or DELETE operation. Parameters ---------- query: str. String containing query that starts with CIUD operations, otherwise an error is raised. Values are marked with question mark (?) parameters: Tuple[Any]. Tuple of values to replace question marks in the order they are presented. Note that the length of parameters must match the number of variables (?) within `query`, otherwise an error is raised. """ self.__exec(query, parameters) @is_conn @check_args def read(self, query: str, parameters: Tuple[Any, ...], rows: int) -> Optional[List[str]]: """ Execute CREATE, INSERT, UPDATE or DELETE operation. Parameters ---------- query: str. String containing query that starts with CIUD operations, otherwise an error is raised. Values are marked with question mark (?) parameters: Tuple[Any]. Tuple of values to replace question marks in the order they are presented. Note that the length of parameters must match the number of variables (?) within `query`, otherwise an error is raised. rows: int. Must be an integer number othan than 0. rows = -1 returns all rows (fetchall), rows = 1 returns only one row (fetchone) and specific number limit the number of rows (fetchmany). """ data_ls = self.__exec(query, parameters, rows) return data_ls
Final code appears in the bottom of this page.
Tests
Test 1: Initialize an object
test1 = SQLiteClient("test.db") print(test1.name) test1.ciud("CREATE TABLE IF NOT EXISTS users (id INTEGER PRIMARY KEY, username text, hash text)",())
Output:
test.db Traceback (most recent call last): ... ValueError: No Connection to database
Object was initialied and name was set to test.db
, but there is no connection.
Test 2: Using with clause
with SQLiteClient("another_test") as client: # Create table 'users' with columns: username, hash client.ciud("CREATE TABLE IF NOT EXISTS users (id INTEGER PRIMARY KEY, username text, hash text)",()) # Insert values client.ciud("INSERT INTO users VALUES (NULL,?,?)",("Test 1", "a")) client.ciud("INSERT INTO users VALUES (NULL,?,?)",("Test 2", "b")) client.ciud("INSERT INTO users VALUES (NULL,?,?)",("Test 3", "c")) # Read all columns and all rows print(client.read("SELECT * FROM users", (), rows = -1)) # Read all columns of the first 2 rows. print(client.read("SELECT * FROM users",(), rows= 2))
A database file (another_test.db) was created in the current working directory, and the output of the SELECT
statements appears below
Output:
[(1, 'Test 1', 'a'), (2, 'Test 2', 'b'), (3, 'Test 3', 'c')] [(1, 'Test 1', 'a'), (2, 'Test 2', 'b')]
Final Code
'''SQLite client''' import sqlite3 from typing import Tuple, Any, List, Optional import re class SQLiteClient(): def __new__(cls, db_name: str = "", timeout:int = 10): """ Ensure parameters are valid prior to memory allocation. Parameters ---------- db_name: str. Name of SQLite database. Can also be a path to database if it is in another directory. timeout: int. Number of seconds before raising OperationalError. """ if not db_name: raise ValueError("Database Name is Missing!") if not isinstance(db_name, str): raise TypeError("Database name must be a string!") if not isinstance(timeout, int): raise TypeError("Timeout must be an integer!") if timeout < 0 : raise ValueError("Timeout must be greater than 0") return super().__new__(cls) def __init__(self, db_name: str = "", timeout: int = 10): """ This method is executed after memory allocation (after execution of __new__ method, and given all parameters are valid), and initialize attributes. Parameters ---------- db_name: str. Name of SQLite database. Can also be a path to database if it is in another directory. timeout: int. Number of seconds before raising OperationalError. """ # Ensure .db file extension exists. self.name = db_name if db_name[-3:] != ".db": self.name += ".db" self.timeout = timeout self.__conn = None def __enter__(self): """ Provides context manager functionality which is triggered by 'with' clause and creates connection to database. For example: `with SQLiteClient(name, 10) as client: #code..... ` """ self.__conn = sqlite3.connect(self.name, timeout=self.timeout) return self def __exit__(self, exc_type, exc_val, exc_tb): """ Part of the context manager functionality. This method ensures that connection to database is always closed at the end of execution, even if an error is raised. """ self.__conn.close() self.__conn = None def __exec(self, query: str, parameters: Tuple[Any,...], rows: int = 0) -> List[Tuple[Any,...]]: """ Private method (starts with __) - Only accessible by methods defined withini the SQLiteClient class scope. Execute a sql query containing question marks `?` and replace them with parameters (based on their order). Rows variable is used to distinguish between CIUD (CREATE, INSERT, UPDATE, DELETE) operations and SELECT statements (fetchone, fetchall, fetchmany). Note: Method was set as private to prevent bad actors from running dangerous sql statements (like sql injections) Parameters ---------- query: str. SQL script to execute. Note that values are represented wtih question mark (?). For example: `SELECT * FROM users WHERE name = ? and pass = ?` parameters: Tuple(Any). Values that are safely inserted into sql statement (replacing ?). For example: (username, pass) If there is only one variable, then tuple should look like that: (value,). If there are no variables, then tuple can be empty: (). rows: int. Set to CIUD (CREATE, INSERT, UPDATE and DELETE) when rows = 0, otherwise represent SELECT statement. rows = -1 returns all rows (fetchall), rows = 1 returns only one row (fetchone) and specific number limit the number of rows (fetchmany). """ cur = self.__conn.cursor() cur.execute(query, parameters) # CREATE, INSERT, UPDATE, DELETE operations if rows == 0: self.__conn.commit() return None # SELECT - fetchall if rows == -1: data_ls = cur.fetchall() # SELECT - fetchone if rows == 1: data_ls = [cur.fetchone()] # SELECT - get all rows if rows > 1: data_ls = cur.fetchmany(rows) return data_ls def is_conn(func): def check_connection(self, query, parameters, *args, **kwargs): if not self.__conn: raise ValueError("No connection to database") return func(self, query, parameters, *args, **kwargs) return check_connection def check_args(func): def wrapper(self, query, parameters, *args, **kwargs): # Type checks if not isinstance(query, str): raise TypeError("Query must be a string!") if not isinstance(parameters, tuple): raise TypeError("Parameters variable must be a tuple!") # Number of variables checks # Number of ? in query vs. length of parameters variables = re.findall('\?', query) if len(variables) > len(parameters): raise ValueError("Some variables are missing!") if len(variables) < len(parameters): raise ValueError("Too many values in parameters. This may raise sqlite3.ProgrammingError. Please check your parameters and try again.") # Value Errors unique cases for each function. if func.__name__ == "ciud": if query[:6].lower() not in ["create", "insert", "update", "delete"]: raise ValueError("Query must be CIUD (CREATE, INSERT UPDATE or DELETE)") if func.__name__ == "read": if query[:6].lower() not in ["select"]: raise ValueError("Query must start with `SELECT ....`") rows = None try: rows = args[0] except IndexError: if "rows" in kwargs.keys(): rows = kwargs["rows"] finally: if rows == 0: raise ValueError("Cannot commit SELECT statement. Please set `rows` to a number other than 0") return func(self, query, parameters, *args, **kwargs) return wrapper @is_conn @check_args def ciud(self, query: str, parameters: Tuple[Any, ...]) -> None: """ Execute CREATE, INSERT, UPDATE or DELETE operation. Parameters ---------- query: str. String containing query that starts with CIUD operations, otherwise an error is raised. Values are marked with question mark (?) parameters: Tuple[Any]. Tuple of values to replace question marks in the order they are presented. Note that the length of parameters must match the number of variables (?) within `query`, otherwise an error is raised. """ self.__exec(query, parameters) @is_conn @check_args def read(self, query: str, parameters: Tuple[Any, ...], rows: int) -> Optional[List[str]]: """ Execute CREATE, INSERT, UPDATE or DELETE operation. Parameters ---------- query: str. String containing query that starts with CIUD operations, otherwise an error is raised. Values are marked with question mark (?) parameters: Tuple[Any]. Tuple of values to replace question marks in the order they are presented. Note that the length of parameters must match the number of variables (?) within `query`, otherwise an error is raised. rows: int. Must be an integer number othan than 0. rows = -1 returns all rows (fetchall), rows = 1 returns only one row (fetchone) and specific number limit the number of rows (fetchmany). """ data_ls = self.__exec(query, parameters, rows) return data_ls