mirror of
https://github.com/RolandWH/SIPPCompare.git
synced 2025-06-28 07:11:16 +01:00
begin implementation of sqlite database
This commit is contained in:
106
src/db_handler.py
Normal file
106
src/db_handler.py
Normal file
@ -0,0 +1,106 @@
|
||||
import os
|
||||
import sqlite3
|
||||
|
||||
|
||||
class DBHandler:
|
||||
def __init__(self):
|
||||
def create_tables():
|
||||
self.cur.execute("""
|
||||
CREATE TABLE "tblPlatforms" (
|
||||
"PlatformID" INTEGER NOT NULL UNIQUE,
|
||||
"PlatformName" TEXT NOT NULL,
|
||||
PRIMARY KEY("PlatformID")
|
||||
)
|
||||
""")
|
||||
|
||||
self.cur.execute("""
|
||||
CREATE TABLE "tblFlatPlatFees" (
|
||||
"PlatformID" INTEGER NOT NULL,
|
||||
"SharePlatFee" REAL NOT NULL,
|
||||
"SharePlatMaxFee" REAL,
|
||||
PRIMARY KEY("PlatformID"),
|
||||
FOREIGN KEY("PlatformID") REFERENCES "tblPlatforms"("PlatformID")
|
||||
)
|
||||
""")
|
||||
|
||||
self.cur.execute("""
|
||||
CREATE TABLE "tblFlatDealFees" (
|
||||
"PlatformID" INTEGER NOT NULL,
|
||||
"FundDealFee" REAL,
|
||||
"ShareDealFee" REAL NOT NULL,
|
||||
"ShareDealReduceTrades" REAL,
|
||||
"ShareDealReduceAmount" REAL,
|
||||
PRIMARY KEY("PlatformID"),
|
||||
FOREIGN KEY("PlatformID") REFERENCES "tblPlatforms"("PlatformID")
|
||||
)
|
||||
""")
|
||||
|
||||
self.cur.execute("""
|
||||
CREATE TABLE "tblFundPlatFee" (
|
||||
"PlatformID" INTEGER NOT NULL,
|
||||
"Band" REAL NOT NULL,
|
||||
"Fee" REAL NOT NULL,
|
||||
PRIMARY KEY("PlatformID","Band","Fee"),
|
||||
FOREIGN KEY("PlatformID") REFERENCES "tblPlatforms"("PlatformID")
|
||||
)
|
||||
""")
|
||||
|
||||
self.cur.execute("""
|
||||
CREATE TABLE "tblUserDetails" (
|
||||
"PensionValue" REAL,
|
||||
"SliderValue" INTEGER,
|
||||
"ShareTrades" INTEGER,
|
||||
"FundTrades" INTEGER
|
||||
)
|
||||
""")
|
||||
|
||||
if not os.path.exists("SIPPCompare.db"):
|
||||
db_exists = False
|
||||
else:
|
||||
db_exists = True
|
||||
|
||||
self.conn = sqlite3.connect("SIPPCompare.db")
|
||||
self.cur = self.conn.cursor()
|
||||
if not db_exists:
|
||||
create_tables()
|
||||
|
||||
def retrieve_plat_list(self) -> list:
|
||||
res = self.cur.execute("SELECT PlatformName FROM tblPlatforms")
|
||||
res_list = res.fetchall()
|
||||
plat_name_list = []
|
||||
for platform in res_list:
|
||||
plat_name_list.append(platform[0])
|
||||
|
||||
return plat_name_list
|
||||
|
||||
def write_user_details(self, pension_val: float, slider_val: int, share_trades: int, fund_trades: int):
|
||||
user_details_data = (pension_val, slider_val, share_trades, fund_trades)
|
||||
|
||||
res = self.cur.execute("SELECT EXISTS(SELECT 1 FROM tblUserDetails)").fetchone()
|
||||
if res[0] == 0:
|
||||
self.cur.execute("INSERT INTO tblUserDetails VALUES (?, ?, ?, ?)", user_details_data)
|
||||
else:
|
||||
self.cur.execute("""
|
||||
UPDATE tblUserDetails SET
|
||||
PensionValue = ?,
|
||||
SliderValue = ?,
|
||||
ShareTrades = ?,
|
||||
FundTrades = ?
|
||||
""", user_details_data)
|
||||
self.conn.commit()
|
||||
|
||||
def retrieve_user_details(self) -> dict:
|
||||
res = self.cur.execute("SELECT EXISTS(SELECT 1 FROM tblUserDetails)").fetchone()
|
||||
if res[0] == 0:
|
||||
return {"NO_RECORD": None}
|
||||
|
||||
res = self.cur.execute("SELECT * FROM tblUserDetails")
|
||||
res_tuple = res.fetchone()
|
||||
user_details_dict = {
|
||||
"pension_val": res_tuple[0],
|
||||
"slider_val": res_tuple[1],
|
||||
"share_trades": res_tuple[2],
|
||||
"fund_trades": res_tuple[3]
|
||||
}
|
||||
|
||||
return user_details_dict
|
@ -3,7 +3,9 @@ from PyQt6.QtWidgets import QMainWindow, QWidget
|
||||
from PyQt6 import uic
|
||||
|
||||
import output_window
|
||||
import platform_list
|
||||
import resource_finder
|
||||
import db_handler
|
||||
|
||||
|
||||
class SIPPCompare(QMainWindow):
|
||||
@ -33,13 +35,16 @@ class SIPPCompare(QMainWindow):
|
||||
self.share_deal_fees = 0.0
|
||||
|
||||
# Create window objects
|
||||
self.db = db_handler.DBHandler()
|
||||
self.platform_win = plat_edit_win
|
||||
self.platform_list_win = platform_list.PlatformList(self.db)
|
||||
self.output_win = output_window.OutputWindow()
|
||||
|
||||
# Handle events
|
||||
self.calc_but.clicked.connect(self.calculate_fees)
|
||||
# Menu bar entry (File -> Edit Platforms)
|
||||
self.actionEdit_Platforms.triggered.connect(self.show_platform_edit)
|
||||
#self.actionEdit_Platforms.triggered.connect(self.show_platform_edit)
|
||||
self.actionList_Platforms.triggered.connect(self.show_platform_list)
|
||||
# Update percentage mix label when slider moved
|
||||
self.mix_slider.valueChanged.connect(self.update_slider_lab)
|
||||
self.value_input.valueChanged.connect(self.check_valid)
|
||||
@ -50,6 +55,15 @@ class SIPPCompare(QMainWindow):
|
||||
self.share_trades_combo.setValidator(QIntValidator(0, 999))
|
||||
self.fund_trades_combo.setValidator(QIntValidator(0, 99))
|
||||
|
||||
# Restore last session
|
||||
prev_session_data = self.db.retrieve_user_details()
|
||||
if "NO_RECORD" not in prev_session_data:
|
||||
self.value_input.setValue(prev_session_data["pension_val"])
|
||||
self.mix_slider.setValue(prev_session_data["slider_val"])
|
||||
self.share_trades_combo.setCurrentText(str(prev_session_data["share_trades"]))
|
||||
self.fund_trades_combo.setCurrentText(str(prev_session_data["fund_trades"]))
|
||||
self.calc_but.setFocus()
|
||||
|
||||
# Display slider position as mix between two nums (funds/shares)
|
||||
def update_slider_lab(self):
|
||||
slider_val = self.mix_slider.value()
|
||||
@ -100,13 +114,18 @@ class SIPPCompare(QMainWindow):
|
||||
# Calculate fees
|
||||
def calculate_fees(self):
|
||||
self.init_variables()
|
||||
|
||||
# Set to zero each time to avoid persistence
|
||||
self.fund_plat_fees = 0
|
||||
|
||||
# Get user input
|
||||
value_num = float(self.value_input.value())
|
||||
# Funds/shares mix
|
||||
slider_val: int = self.mix_slider.value()
|
||||
funds_value = (slider_val / 100) * value_num
|
||||
fund_trades_num = int(self.fund_trades_combo.currentText())
|
||||
share_trades_num = int(self.share_trades_combo.currentText())
|
||||
|
||||
# Funds/shares mix
|
||||
funds_value = (slider_val / 100) * value_num
|
||||
if self.fund_deal_fee is not None:
|
||||
self.fund_deal_fees = fund_trades_num * self.fund_deal_fee
|
||||
|
||||
@ -130,13 +149,13 @@ class SIPPCompare(QMainWindow):
|
||||
else:
|
||||
self.share_plat_fees = self.share_plat_fee * shares_value
|
||||
|
||||
share_trades_num = int(self.share_trades_combo.currentText())
|
||||
if self.share_deal_reduce_trades is not None:
|
||||
if (share_trades_num / 12) >= self.share_deal_reduce_trades:
|
||||
self.share_deal_fees = self.share_deal_reduce_amount * share_trades_num
|
||||
else:
|
||||
self.share_deal_fees = self.share_deal_fee * share_trades_num
|
||||
|
||||
self.db.write_user_details(value_num, slider_val, share_trades_num, fund_trades_num)
|
||||
self.show_output_win()
|
||||
|
||||
# Show the output window - this func is called from calculate_fee()
|
||||
@ -151,3 +170,6 @@ class SIPPCompare(QMainWindow):
|
||||
# Show the platform editor window (currently run-time only)
|
||||
def show_platform_edit(self):
|
||||
self.platform_win.show()
|
||||
|
||||
def show_platform_list(self):
|
||||
self.platform_list_win.show()
|
||||
|
@ -93,7 +93,7 @@ class PlatformEdit(QWidget):
|
||||
QRegularExpressionValidator(QRegularExpression("\\w*"))
|
||||
)
|
||||
|
||||
def create_plat_fee_struct(self):
|
||||
def create_plat_fee_struct(self) -> list:
|
||||
plat_fee_struct = [[0], [0]]
|
||||
plat_fee_struct[0].append(self.first_tier_box.value())
|
||||
plat_fee_struct[1].append(self.first_tier_fee_box.value())
|
||||
|
44
src/platform_list.py
Normal file
44
src/platform_list.py
Normal file
@ -0,0 +1,44 @@
|
||||
from PyQt6.QtWidgets import QWidget, QListWidgetItem
|
||||
from PyQt6.QtGui import QIcon, QRegularExpressionValidator
|
||||
from PyQt6.QtCore import QRegularExpression
|
||||
from PyQt6 import uic
|
||||
|
||||
import resource_finder
|
||||
|
||||
|
||||
class PlatformRename(QWidget):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Import Qt Designer UI XML file
|
||||
uic.loadUi(resource_finder.get_res_path("gui/dialogs/platform_rename.ui"), self)
|
||||
self.setWindowIcon(QIcon(resource_finder.get_res_path("icon2.ico")))
|
||||
|
||||
# Set validators
|
||||
# Regex accepts any characters that match [a-Z], [0-9] or _
|
||||
self.rename_plat_box.setValidator(
|
||||
QRegularExpressionValidator(QRegularExpression("\\w*"))
|
||||
)
|
||||
|
||||
|
||||
class PlatformList(QWidget):
|
||||
def __init__(self, db):
|
||||
super().__init__()
|
||||
# Import Qt Designer UI XML file
|
||||
uic.loadUi(resource_finder.get_res_path("gui/platform_list.ui"), self)
|
||||
self.setWindowIcon(QIcon(resource_finder.get_res_path("icon2.ico")))
|
||||
|
||||
self.plat_list_dialog = PlatformRename()
|
||||
self.db = db
|
||||
self.plat_name_list = self.db.retrieve_plat_list()
|
||||
print(self.plat_name_list)
|
||||
|
||||
for platform in self.plat_name_list:
|
||||
item = QListWidgetItem()
|
||||
item.setText(platform)
|
||||
self.platListWidget.addItem(item)
|
||||
|
||||
# Handle events
|
||||
self.add_plat_but.clicked.connect(self.add_platform)
|
||||
|
||||
def add_platform(self):
|
||||
self.plat_list_dialog.show()
|
@ -10,4 +10,4 @@ def get_res_path(relative_path):
|
||||
except AttributeError:
|
||||
base_path = os.path.abspath(".")
|
||||
|
||||
return os.path.join(base_path, relative_path)
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
Reference in New Issue
Block a user