Compare commits
12 Commits
5008becd15
...
option_pri
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cfb8b9d3b9 | ||
|
|
9e9eef21a5 | ||
|
|
b2427eaf9d | ||
|
|
1dd5d8657a | ||
|
|
73641b7e5b | ||
|
|
23a28c6776 | ||
|
|
3dacc0a418 | ||
|
|
b3663258e4 | ||
|
|
e9b3a4aac3 | ||
|
|
087a2f0d74 | ||
|
|
61df0b425d | ||
|
|
ff30a3e1ce |
14
.env.example
Normal file
14
.env.example
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
DB_HOST=localhost
|
||||||
|
DB_PORT=5432
|
||||||
|
DB_NAME=options_db
|
||||||
|
DB_USER=quant_user
|
||||||
|
DB_PASSWORD=change_me
|
||||||
|
PIPELINE_SYMBOLS=SPY
|
||||||
|
|
||||||
|
# For scripts/setup_postgres.py when creating role/database:
|
||||||
|
# Use a superuser/admin account that can CREATE ROLE and CREATE DATABASE.
|
||||||
|
POSTGRES_ADMIN_USER=postgres
|
||||||
|
POSTGRES_ADMIN_PASSWORD=postgres
|
||||||
|
POSTGRES_ADMIN_HOST=localhost
|
||||||
|
POSTGRES_ADMIN_PORT=5432
|
||||||
|
POSTGRES_ADMIN_DB=postgres
|
||||||
31
.gitignore
vendored
Normal file
31
.gitignore
vendored
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
.venv/
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*.pyo
|
||||||
|
|
||||||
|
.ipynb_checkpoints/
|
||||||
|
|
||||||
|
# Built Python extension dropped next to qengine/__init__.py for local dev
|
||||||
|
/qengine/*.so
|
||||||
|
/qengine/*.dylib
|
||||||
|
/qengine/__pycache__/
|
||||||
|
|
||||||
|
/skbuild-build/
|
||||||
|
|
||||||
|
/build/
|
||||||
|
/.idea/
|
||||||
|
**/__pycache__/
|
||||||
|
/docs/html/
|
||||||
|
/docs/latex/
|
||||||
|
|
||||||
|
# Local reference tree (optional clone)
|
||||||
|
/CPP-design-pattern-derivatives-pricing/
|
||||||
|
|
||||||
|
# Local environment and secrets
|
||||||
|
.env
|
||||||
|
.env.*
|
||||||
|
!.env.example
|
||||||
|
|
||||||
|
# Local tooling caches
|
||||||
|
/.pycache/
|
||||||
|
/.mplconfig/
|
||||||
@@ -4,28 +4,53 @@ project(QuantEngine)
|
|||||||
set(CMAKE_CXX_STANDARD 20)
|
set(CMAKE_CXX_STANDARD 20)
|
||||||
set(CMAKE_CXX_FLAGS "-O3 -march=native")
|
set(CMAKE_CXX_FLAGS "-O3 -march=native")
|
||||||
|
|
||||||
|
option(BUILD_TESTING "Build GoogleTest target and tests" ON)
|
||||||
|
|
||||||
|
set(PYBIND11_FINDPYTHON ON)
|
||||||
|
find_package(Python3 REQUIRED COMPONENTS Interpreter Development.Module)
|
||||||
find_package(Eigen3 REQUIRED)
|
find_package(Eigen3 REQUIRED)
|
||||||
|
find_package(pybind11 CONFIG REQUIRED)
|
||||||
|
#find_package(PostgreSQL REQUIRED)
|
||||||
|
#find_package(PkgConfig REQUIRED)
|
||||||
|
#pkg_check_modules(PQXX REQUIRED IMPORTED_TARGET libpqxx)
|
||||||
|
|
||||||
add_subdirectory(src)
|
add_subdirectory(cpp)
|
||||||
|
|
||||||
# Testing
|
find_package(Doxygen OPTIONAL_COMPONENTS dot)
|
||||||
enable_testing()
|
if(DOXYGEN_FOUND)
|
||||||
|
add_custom_target(
|
||||||
|
docs
|
||||||
|
COMMAND ${DOXYGEN_EXECUTABLE} ${CMAKE_SOURCE_DIR}/docs/Doxyfile
|
||||||
|
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
|
||||||
|
COMMENT "Generating API documentation (HTML in docs/html)"
|
||||||
|
VERBATIM)
|
||||||
|
endif()
|
||||||
|
|
||||||
include(FetchContent)
|
install(FILES "${CMAKE_SOURCE_DIR}/qengine/__init__.py" DESTINATION qengine)
|
||||||
|
install(TARGETS qengine_cpp
|
||||||
|
LIBRARY DESTINATION qengine
|
||||||
|
RUNTIME DESTINATION qengine)
|
||||||
|
|
||||||
FetchContent_Declare(
|
if(BUILD_TESTING)
|
||||||
googletest
|
enable_testing()
|
||||||
URL https://github.com/google/googletest/archive/refs/tags/v1.14.0.zip
|
|
||||||
DOWNLOAD_EXTRACT_TIMESTAMP TRUE
|
|
||||||
)
|
|
||||||
|
|
||||||
FetchContent_MakeAvailable(googletest)
|
include(FetchContent)
|
||||||
|
|
||||||
add_executable(qengine_tests
|
FetchContent_Declare(
|
||||||
tests/test_black_scholes.cpp
|
googletest
|
||||||
tests/stubs/FlatYieldCurve.cpp
|
URL https://github.com/google/googletest/archive/refs/tags/v1.14.0.zip
|
||||||
tests/stubs/FlatVolatilitySurface.cpp)
|
DOWNLOAD_EXTRACT_TIMESTAMP TRUE
|
||||||
|
)
|
||||||
|
|
||||||
target_link_libraries(qengine_tests qengine GTest::gtest_main)
|
FetchContent_MakeAvailable(googletest)
|
||||||
include(GoogleTest)
|
|
||||||
gtest_discover_tests(qengine_tests)
|
add_executable(qengine_tests
|
||||||
|
tests/test_black_scholes.cpp
|
||||||
|
tests/stubs/FlatYieldCurve.cpp
|
||||||
|
tests/stubs/FlatVolatilitySurface.cpp)
|
||||||
|
|
||||||
|
target_include_directories(qengine_tests PRIVATE ${CMAKE_SOURCE_DIR}/tests)
|
||||||
|
target_link_libraries(qengine_tests qengine_core GTest::gtest_main)
|
||||||
|
include(GoogleTest)
|
||||||
|
gtest_discover_tests(qengine_tests)
|
||||||
|
endif()
|
||||||
|
|||||||
80
README.md
80
README.md
@@ -1,5 +1,79 @@
|
|||||||
# pricing
|
# option_pricing
|
||||||
|
|
||||||
Monte Carlo pricing of European options under Black–Scholes
|
C++/Python quantitative finance engine for option pricing, implied-volatility analysis, and market-data ingestion.
|
||||||
|
|
||||||
### Project structure
|
## What is included
|
||||||
|
|
||||||
|
- `cpp/`: core C++ pricing library (Monte Carlo + Black-Scholes closed form), DB ingestion hooks, and pybind bindings.
|
||||||
|
- `qengine/`: Python package exposing the native extension (`import qengine`).
|
||||||
|
- `src/ImpliedVolatility/`: SVI calibration and implied-volatility tooling.
|
||||||
|
- `src/data/`: data ingestion, SQL schema, and analytics helpers.
|
||||||
|
- `tests/`: C++ unit tests (GoogleTest).
|
||||||
|
- `scripts/`: operational scripts, including PostgreSQL setup.
|
||||||
|
- `docs/`: Doxygen configuration and generated API docs (ignored in git for publication).
|
||||||
|
|
||||||
|
## Quickstart
|
||||||
|
|
||||||
|
### 1) Clone and create a Python environment
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 -m venv .venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
pip install --upgrade pip
|
||||||
|
pip install -e .
|
||||||
|
pip install pandas yfinance sqlalchemy psycopg2-binary matplotlib scipy
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2) Configure environment variables
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cp .env.example .env
|
||||||
|
```
|
||||||
|
|
||||||
|
Then edit `.env` with your local database credentials.
|
||||||
|
|
||||||
|
### 3) Create database and schema
|
||||||
|
|
||||||
|
Use the idempotent setup script:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
source .env
|
||||||
|
python scripts/setup_postgres.py
|
||||||
|
```
|
||||||
|
|
||||||
|
This script creates/updates:
|
||||||
|
- database role (`DB_USER`)
|
||||||
|
- database (`DB_NAME`)
|
||||||
|
- tables/indexes from `src/data/sql/schema.sql`
|
||||||
|
|
||||||
|
### 4) Build C++ extension and run tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cmake -S . -B build
|
||||||
|
cmake --build build -j
|
||||||
|
ctest --test-dir build --output-on-failure
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5) Run Yahoo options ingestion
|
||||||
|
|
||||||
|
```bash
|
||||||
|
source .env
|
||||||
|
python src/data/ingestion/ingest_yahoo_options.py
|
||||||
|
```
|
||||||
|
|
||||||
|
`PIPELINE_SYMBOLS` in `.env` controls which symbols are ingested (comma-separated, e.g. `SPY,AAPL,QQQ`).
|
||||||
|
|
||||||
|
## Security and publication notes
|
||||||
|
|
||||||
|
- No credentials are stored in source code.
|
||||||
|
- `.env` files are git-ignored; only `.env.example` is committed.
|
||||||
|
- Before publishing, rotate any credentials that were ever committed in the past.
|
||||||
|
- Prefer least-privilege DB users for runtime ingestion jobs.
|
||||||
|
|
||||||
|
## Generating C++ API docs
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cmake --build build --target docs
|
||||||
|
```
|
||||||
|
|
||||||
|
Generated output goes to `docs/html/` and is ignored in version control.
|
||||||
|
|||||||
0
__init__.py
Normal file
0
__init__.py
Normal file
49
cpp/BSWrapper.cpp
Normal file
49
cpp/BSWrapper.cpp
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
//
|
||||||
|
// Created by David Doebel on 27.03.2026.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "BSWrapper.hpp"
|
||||||
|
|
||||||
|
#include "BlackScholesClosedFormEngine.hpp"
|
||||||
|
#include "BlackScholesProcess.hpp"
|
||||||
|
#include "Instrument.hpp"
|
||||||
|
#include "Option.hpp"
|
||||||
|
#include "FlatVolatilitySurface.hpp"
|
||||||
|
#include "FlatYieldCurve.hpp"
|
||||||
|
#include <cassert>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
class FlatYieldCurve;
|
||||||
|
|
||||||
|
double BSWrapper::bs_price_wrapper(double S, double K, double T, double r, double sigma,
|
||||||
|
bool is_call) {
|
||||||
|
std::shared_ptr<FlatYieldCurve> flat_curve = std::make_shared<FlatYieldCurve>(r);
|
||||||
|
auto flat_vol_surface = std::make_shared<FlatVolatilitySurface>(sigma);
|
||||||
|
MarketData data(S,flat_curve, flat_vol_surface);
|
||||||
|
std::unique_ptr<BlackScholesProcess> process = std::make_unique<BlackScholesProcess>(data);
|
||||||
|
std::unique_ptr<BlackScholesClosedFormEngine> pricing_engine =
|
||||||
|
std::make_unique<BlackScholesClosedFormEngine>(std::move(process));
|
||||||
|
std::unique_ptr<Payoff> payoff;
|
||||||
|
if (is_call)
|
||||||
|
payoff = std::make_unique<CallPayoff>(K);
|
||||||
|
else payoff = std::make_unique<PutPayoff>(K);
|
||||||
|
EuropeanExercise exercise(T);
|
||||||
|
VanillaOption option(T,std::make_unique<EuropeanExercise>(exercise),
|
||||||
|
std::move(payoff),std::move(pricing_engine));
|
||||||
|
return option.price();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<double> BSWrapper::batch_bs_price_wrapper(const std::vector<double> &S, const std::vector<double> &K,
|
||||||
|
const std::vector<double> &T, const std::vector<double> &r, const std::vector<double> &sigma,
|
||||||
|
const std::vector<bool> &is_call) {
|
||||||
|
assert(K.size() == S.size() && K.size() == T.size() && K.size() == r.size() && K.size() ==
|
||||||
|
sigma.size() && K.size() == is_call.size());
|
||||||
|
std::size_t n = K.size();
|
||||||
|
std::vector<double> result(n);
|
||||||
|
for (std::size_t i = 0; i < n; ++i) {
|
||||||
|
result[i] = bs_price_wrapper(S[i], K[i], T[i], r[i], sigma[i], is_call[i]);
|
||||||
|
if (i % 100 == 0)
|
||||||
|
std::cout << "i = " << i << " result = " << result[i] << std::endl; // ( i % 1000 == 0)
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
24
cpp/BSWrapper.hpp
Normal file
24
cpp/BSWrapper.hpp
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
/**
|
||||||
|
* @file BSWrapper.hpp
|
||||||
|
* @brief Black–Scholes vanilla price (closed form; used from Python / implied vol).
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef QUANTENGINE_BSWRAPPER_HPP
|
||||||
|
#define QUANTENGINE_BSWRAPPER_HPP
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Static helpers wrapping scalar and batch pricing.
|
||||||
|
*/
|
||||||
|
class BSWrapper {
|
||||||
|
public:
|
||||||
|
BSWrapper() = delete;
|
||||||
|
static double bs_price_wrapper(double S, double K, double T, double r, double sigma, bool is_call);
|
||||||
|
static std::vector<double> batch_bs_price_wrapper(const std::vector<double>& S, const std::vector<double>& K,
|
||||||
|
const std::vector<double>& T, const std::vector<double>& r, const std::vector<double>& sigma,
|
||||||
|
const std::vector<bool>& is_call);
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
#endif //QUANTENGINE_BSWRAPPER_HPP
|
||||||
69
cpp/BlackScholesClosedFormEngine.cpp
Normal file
69
cpp/BlackScholesClosedFormEngine.cpp
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
/**
|
||||||
|
* @file BlackScholesClosedFormEngine.cpp
|
||||||
|
* @brief Black–Scholes closed-form pricing (calls, puts, cash-or-nothing digital).
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "BlackScholesClosedFormEngine.hpp"
|
||||||
|
#include "Instrument.hpp"
|
||||||
|
#include "Payoff.hpp"
|
||||||
|
#include <cmath>
|
||||||
|
#include <stdexcept>
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
double norm_cdf(double x) {
|
||||||
|
return 0.5 * (1.0 + std::erf(x / std::sqrt(2.0)));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
double BlackScholesClosedFormEngine::calculate(const Instrument &instrument) const {
|
||||||
|
if (instrument.exerciseType() != Exercise::Type::European) {
|
||||||
|
throw std::invalid_argument("BlackScholesClosedFormEngine: European exercise only");
|
||||||
|
}
|
||||||
|
|
||||||
|
const double T = instrument.maturity();
|
||||||
|
const MarketData &md = process_->data();
|
||||||
|
const double S = md.spot();
|
||||||
|
double K = instrument.payoff().strike();
|
||||||
|
const PayoffKind pk = instrument.payoff().kind();
|
||||||
|
|
||||||
|
if (T <= 0.0) {
|
||||||
|
return instrument.payoff()(S);
|
||||||
|
}
|
||||||
|
|
||||||
|
const double r = md.yield_curve().zeroRate(T);
|
||||||
|
const double sigma = md.volatility_surface().sigma(K, T);
|
||||||
|
if (sigma <= 0.0) {
|
||||||
|
throw std::invalid_argument("BlackScholesClosedFormEngine: volatility must be positive");
|
||||||
|
}
|
||||||
|
|
||||||
|
const double disc = md.yield_curve().discount(T);
|
||||||
|
const double sqrtT = std::sqrt(T);
|
||||||
|
const double sig_sqrtT = sigma * sqrtT;
|
||||||
|
|
||||||
|
if (sig_sqrtT < 1e-15) {
|
||||||
|
const double forward = S * std::exp(r * T);
|
||||||
|
switch (pk) {
|
||||||
|
case PayoffKind::Call:
|
||||||
|
return disc * std::max(0.0, forward - K);
|
||||||
|
case PayoffKind::Put:
|
||||||
|
return disc * std::max(0.0, K - forward);
|
||||||
|
case PayoffKind::Digital:
|
||||||
|
return (forward > K) ? disc : 0.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const double d1 = (std::log(S / K) + (r + 0.5 * sigma * sigma) * T) / sig_sqrtT;
|
||||||
|
const double d2 = d1 - sig_sqrtT;
|
||||||
|
|
||||||
|
switch (pk) {
|
||||||
|
case PayoffKind::Call:
|
||||||
|
return S * norm_cdf(d1) - K * disc * norm_cdf(d2);
|
||||||
|
case PayoffKind::Put:
|
||||||
|
return K * disc * norm_cdf(-d2) - S * norm_cdf(-d1);
|
||||||
|
case PayoffKind::Digital:
|
||||||
|
return disc * norm_cdf(d2);
|
||||||
|
}
|
||||||
|
throw std::logic_error("BlackScholesClosedFormEngine: unhandled PayoffKind");
|
||||||
|
}
|
||||||
22
cpp/BlackScholesClosedFormEngine.hpp
Normal file
22
cpp/BlackScholesClosedFormEngine.hpp
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
/**
|
||||||
|
* @file BlackScholesClosedFormEngine.hpp
|
||||||
|
* @brief Risk-neutral Black–Scholes formula for European payoffs under GBM (flat or surface inputs via @ref MarketData).
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef QUANTENGINE_BLACKSCHOLESCLOSEDFORMENGINE_HPP
|
||||||
|
#define QUANTENGINE_BLACKSCHOLESCLOSEDFORMENGINE_HPP
|
||||||
|
|
||||||
|
#include "PricingEngine.hpp"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Analytic European vanilla / digital prices using @f$r@f$ and @f$\sigma(K,T)@f$ from the embedded process’s @ref MarketData.
|
||||||
|
*/
|
||||||
|
class BlackScholesClosedFormEngine : public PricingEngine {
|
||||||
|
public:
|
||||||
|
explicit BlackScholesClosedFormEngine(std::unique_ptr<StochasticProcess> process)
|
||||||
|
: PricingEngine(std::move(process)) {}
|
||||||
|
|
||||||
|
double calculate(const Instrument &instrument) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // QUANTENGINE_BLACKSCHOLESCLOSEDFORMENGINE_HPP
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
//
|
/**
|
||||||
// Created by David Doebel on 06.03.2026.
|
* @file BlackScholesProcess.cpp
|
||||||
//
|
* @brief Black–Scholes GBM drift, diffusion, and step.
|
||||||
|
*/
|
||||||
|
|
||||||
#include "BlackScholesProcess.hpp"
|
#include "BlackScholesProcess.hpp"
|
||||||
|
|
||||||
@@ -1,12 +1,15 @@
|
|||||||
//
|
/**
|
||||||
// Created by David Doebel on 06.03.2026.
|
* @file BlackScholesProcess.hpp
|
||||||
//
|
* @brief Geometric Brownian motion with yield and volatility surfaces.
|
||||||
|
*/
|
||||||
|
|
||||||
#ifndef QUANTENGINE_BLACKSCHOLESPROCESS_HPP
|
#ifndef QUANTENGINE_BLACKSCHOLESPROCESS_HPP
|
||||||
#define QUANTENGINE_BLACKSCHOLESPROCESS_HPP
|
#define QUANTENGINE_BLACKSCHOLESPROCESS_HPP
|
||||||
#include "StochasticProcess.hpp"
|
#include "StochasticProcess.hpp"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief GBM: drift @f$r_t S@f$, diffusion @f$\sigma(S,t) S@f$, exact log-step.
|
||||||
|
*/
|
||||||
class BlackScholesProcess : public StochasticProcess{
|
class BlackScholesProcess : public StochasticProcess{
|
||||||
public:
|
public:
|
||||||
explicit BlackScholesProcess(MarketData data) : StochasticProcess(std::move(data)){}
|
explicit BlackScholesProcess(MarketData data) : StochasticProcess(std::move(data)){}
|
||||||
65
cpp/CMakeLists.txt
Normal file
65
cpp/CMakeLists.txt
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
add_library(qengine_core
|
||||||
|
Instrument.cpp
|
||||||
|
Instrument.hpp
|
||||||
|
Payoff.cpp
|
||||||
|
Payoff.hpp
|
||||||
|
Option.cpp
|
||||||
|
Option.hpp
|
||||||
|
PricingEngine.cpp
|
||||||
|
PricingEngine.hpp
|
||||||
|
MonteCarloEngine.cpp
|
||||||
|
MonteCarloEngine.hpp
|
||||||
|
StochasticProcess.cpp
|
||||||
|
StochasticProcess.hpp
|
||||||
|
Exercise.cpp
|
||||||
|
Exercise.hpp
|
||||||
|
MarketData.cpp
|
||||||
|
MarketData.hpp
|
||||||
|
YieldCurve.cpp
|
||||||
|
YieldCurve.hpp
|
||||||
|
VolatilitySurface.cpp
|
||||||
|
VolatilitySurface.hpp
|
||||||
|
RandomGenerator.cpp
|
||||||
|
RandomGenerator.hpp
|
||||||
|
Statistics.cpp
|
||||||
|
Statistics.hpp
|
||||||
|
BlackScholesClosedFormEngine.cpp
|
||||||
|
BlackScholesClosedFormEngine.hpp
|
||||||
|
BlackScholesProcess.cpp
|
||||||
|
BlackScholesProcess.hpp
|
||||||
|
DBIngest.cpp
|
||||||
|
DBIngest.hpp
|
||||||
|
BSWrapper.cpp
|
||||||
|
BSWrapper.hpp
|
||||||
|
NewtonSolver.cpp
|
||||||
|
NewtonSolver.hpp
|
||||||
|
)
|
||||||
|
|
||||||
|
target_include_directories(qengine_core PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
|
||||||
|
target_include_directories(qengine_core PRIVATE
|
||||||
|
/opt/homebrew/include
|
||||||
|
)
|
||||||
|
|
||||||
|
find_library(PQXX_LIB pqxx PATHS /opt/homebrew/lib /usr/local/lib /usr/lib)
|
||||||
|
find_library(PQ_LIB pq PATHS /opt/homebrew/opt/libpq/lib /opt/homebrew/lib /usr/local/lib /usr/lib)
|
||||||
|
if(NOT PQXX_LIB OR NOT PQ_LIB)
|
||||||
|
message(FATAL_ERROR "Could not find libpqxx and/or libpq (install via Homebrew: brew install libpqxx libpq)")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
target_link_libraries(qengine_core Eigen3::Eigen)
|
||||||
|
target_link_libraries(qengine_core ${PQXX_LIB} ${PQ_LIB})
|
||||||
|
|
||||||
|
# Python import path: package qengine, extension submodule qengine (file qengine/qengine*.so)
|
||||||
|
pybind11_add_module(qengine_cpp MODULE ImpliedVolatility/Pybind.cpp)
|
||||||
|
set_target_properties(qengine_cpp PROPERTIES OUTPUT_NAME qengine)
|
||||||
|
target_link_libraries(qengine_cpp PRIVATE qengine_core)
|
||||||
|
|
||||||
|
# Place the module next to qengine/__init__.py so `import qengine` works from the repo root
|
||||||
|
set(_qengine_py_pkg "${CMAKE_SOURCE_DIR}/qengine")
|
||||||
|
set_target_properties(qengine_cpp PROPERTIES
|
||||||
|
LIBRARY_OUTPUT_DIRECTORY "${_qengine_py_pkg}"
|
||||||
|
LIBRARY_OUTPUT_DIRECTORY_RELEASE "${_qengine_py_pkg}"
|
||||||
|
LIBRARY_OUTPUT_DIRECTORY_DEBUG "${_qengine_py_pkg}"
|
||||||
|
RUNTIME_OUTPUT_DIRECTORY "${_qengine_py_pkg}"
|
||||||
|
RUNTIME_OUTPUT_DIRECTORY_RELEASE "${_qengine_py_pkg}"
|
||||||
|
RUNTIME_OUTPUT_DIRECTORY_DEBUG "${_qengine_py_pkg}")
|
||||||
64
cpp/DBIngest.cpp
Normal file
64
cpp/DBIngest.cpp
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
/**
|
||||||
|
* @file DBIngest.cpp
|
||||||
|
* @brief Database connection and placeholder update routines.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "DBIngest.hpp"
|
||||||
|
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <iostream>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
bool DBIngest::connect() {
|
||||||
|
const char* db_name = std::getenv("DB_NAME");
|
||||||
|
const char* db_user = std::getenv("DB_USER");
|
||||||
|
const char* db_password = std::getenv("DB_PASSWORD");
|
||||||
|
const char* db_host = std::getenv("DB_HOST");
|
||||||
|
const char* db_port = std::getenv("DB_PORT");
|
||||||
|
|
||||||
|
std::ostringstream conn_str;
|
||||||
|
conn_str
|
||||||
|
<< "dbname=" << (db_name ? db_name : "options_db")
|
||||||
|
<< " user=" << (db_user ? db_user : "quant_user")
|
||||||
|
<< " host=" << (db_host ? db_host : "localhost")
|
||||||
|
<< " port=" << (db_port ? db_port : "5432")
|
||||||
|
<< " password=" << (db_password ? db_password : "");
|
||||||
|
|
||||||
|
connection_ = pqxx::connection(conn_str.str());
|
||||||
|
|
||||||
|
if(connection_.is_open()) {
|
||||||
|
std::cout << "Connected\n";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
std::cout << "Not connected\n";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool DBIngest::disconnect() {
|
||||||
|
connection_.close();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool DBIngest::update(VolatilitySurface &surface) {
|
||||||
|
std::string vol_surface_query = "SELECT c.strike, c.expiration_date, q.mid, u.price "
|
||||||
|
"FROM option_quotes q"
|
||||||
|
"JOIN option_contracts c "
|
||||||
|
"ON q.contract_id = c.id "
|
||||||
|
"JOIN underlying_prices u"
|
||||||
|
"ON u.underlying_id = c.underlying_id"
|
||||||
|
"WHERE q.timestamp = ("
|
||||||
|
"SELECT MAX(timestamp) FROM option_quotes"
|
||||||
|
")";
|
||||||
|
pqxx::work work(connection_);
|
||||||
|
pqxx::result result = work.exec(vol_surface_query);
|
||||||
|
for (auto row : result) {
|
||||||
|
std::cout << row[0] << " " << row[1] << " " << row[2] << " " << row[3] << std::endl;
|
||||||
|
}
|
||||||
|
(void)surface;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool DBIngest::update(YieldCurve &yield_curve) {
|
||||||
|
(void)yield_curve;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
28
cpp/DBIngest.hpp
Normal file
28
cpp/DBIngest.hpp
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
/**
|
||||||
|
* @file DBIngest.hpp
|
||||||
|
* @brief PostgreSQL helpers to load market objects (work in progress).
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef QUANTENGINE_DBINGEST_HPP
|
||||||
|
#define QUANTENGINE_DBINGEST_HPP
|
||||||
|
|
||||||
|
#include <pqxx/pqxx>
|
||||||
|
|
||||||
|
#include "VolatilitySurface.hpp"
|
||||||
|
#include "YieldCurve.hpp"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Connects to Postgres via libpqxx and queries quotes for surface building.
|
||||||
|
*/
|
||||||
|
class DBIngest {
|
||||||
|
|
||||||
|
bool connect();
|
||||||
|
bool disconnect();
|
||||||
|
bool update(VolatilitySurface& surface);
|
||||||
|
bool update(YieldCurve& yield_curve);
|
||||||
|
private:
|
||||||
|
pqxx::connection connection_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
#endif //QUANTENGINE_DBINGEST_HPP
|
||||||
6
cpp/Exercise.cpp
Normal file
6
cpp/Exercise.cpp
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
/**
|
||||||
|
* @file Exercise.cpp
|
||||||
|
* @brief @ref Exercise translation unit (interface only).
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "Exercise.hpp"
|
||||||
@@ -1,11 +1,15 @@
|
|||||||
//
|
/**
|
||||||
// Created by David Doebel on 05.03.2026.
|
* @file Exercise.hpp
|
||||||
//
|
* @brief Exercise style (European, American, Bermudan) and exercise times.
|
||||||
|
*/
|
||||||
|
|
||||||
#ifndef QUANTENGINE_EXERCISE_HPP
|
#ifndef QUANTENGINE_EXERCISE_HPP
|
||||||
#define QUANTENGINE_EXERCISE_HPP
|
#define QUANTENGINE_EXERCISE_HPP
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Describes when the holder may exercise (metadata for pricing engines).
|
||||||
|
*/
|
||||||
class Exercise {
|
class Exercise {
|
||||||
public:
|
public:
|
||||||
Exercise() = default;
|
Exercise() = default;
|
||||||
@@ -22,7 +26,9 @@ protected:
|
|||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/** @brief Single exercise at maturity. */
|
||||||
class EuropeanExercise : public Exercise {
|
class EuropeanExercise : public Exercise {
|
||||||
|
public:
|
||||||
EuropeanExercise() : type_(Type::European) {};
|
EuropeanExercise() : type_(Type::European) {};
|
||||||
EuropeanExercise(double maturity) : type_(Type::European){
|
EuropeanExercise(double maturity) : type_(Type::European){
|
||||||
exercise_times_.push_back(maturity);
|
exercise_times_.push_back(maturity);
|
||||||
@@ -35,7 +41,9 @@ private:
|
|||||||
Type type_;
|
Type type_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/** @brief Continuous American exercise from @f$t=0@f$ to maturity (placeholder grid). */
|
||||||
class AmericanExercise : public Exercise{
|
class AmericanExercise : public Exercise{
|
||||||
|
public:
|
||||||
AmericanExercise() : type_(Type::American) {};
|
AmericanExercise() : type_(Type::American) {};
|
||||||
AmericanExercise(double maturity) : type_(Type::American) {
|
AmericanExercise(double maturity) : type_(Type::American) {
|
||||||
exercise_times_.push_back(0);
|
exercise_times_.push_back(0);
|
||||||
5
cpp/FlatVolatilitySurface.cpp
Normal file
5
cpp/FlatVolatilitySurface.cpp
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
/**
|
||||||
|
* @file FlatVolatilitySurface.cpp
|
||||||
|
* @brief Ensures link visibility for @ref FlatVolatilitySurface.
|
||||||
|
*/
|
||||||
|
#include "FlatVolatilitySurface.hpp"
|
||||||
21
cpp/FlatVolatilitySurface.hpp
Normal file
21
cpp/FlatVolatilitySurface.hpp
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
/**
|
||||||
|
* @file FlatVolatilitySurface.hpp
|
||||||
|
* @brief Constant implied volatility surface.
|
||||||
|
*/
|
||||||
|
#ifndef QUANTENGINE_FLATVOLATILITYSURFACE_HPP
|
||||||
|
#define QUANTENGINE_FLATVOLATILITYSURFACE_HPP
|
||||||
|
#include "VolatilitySurface.hpp"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief @f$\sigma(K,T)\equiv\sigma_0@f$.
|
||||||
|
*/
|
||||||
|
class FlatVolatilitySurface : public VolatilitySurface {
|
||||||
|
public:
|
||||||
|
explicit FlatVolatilitySurface(double sigma = 0.2) : sigma_(sigma) {}
|
||||||
|
|
||||||
|
double sigma(double K, double T) const override {return sigma_;}
|
||||||
|
|
||||||
|
private:
|
||||||
|
double sigma_;
|
||||||
|
};
|
||||||
|
#endif
|
||||||
5
cpp/FlatYieldCurve.cpp
Normal file
5
cpp/FlatYieldCurve.cpp
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
/**
|
||||||
|
* @file FlatYieldCurve.cpp
|
||||||
|
* @brief Ensures link visibility for @ref FlatYieldCurve (inline methods in header).
|
||||||
|
*/
|
||||||
|
#include "FlatYieldCurve.hpp"
|
||||||
22
cpp/FlatYieldCurve.hpp
Normal file
22
cpp/FlatYieldCurve.hpp
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
/**
|
||||||
|
* @file FlatYieldCurve.hpp
|
||||||
|
* @brief Constant zero rate yield curve.
|
||||||
|
*/
|
||||||
|
#ifndef QUANTENGINE_FLATYIELDCURVE_HPP
|
||||||
|
#define QUANTENGINE_FLATYIELDCURVE_HPP
|
||||||
|
#include "YieldCurve.hpp"
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief @f$P(t)=e^{-r t}@f$, @f$f(t)\equiv r@f$.
|
||||||
|
*/
|
||||||
|
class FlatYieldCurve : public YieldCurve{
|
||||||
|
public:
|
||||||
|
explicit FlatYieldCurve(double rate = 0.01) : rate_(rate) {}
|
||||||
|
|
||||||
|
double discount(double t) const override {return std::exp(-rate_ * t); };
|
||||||
|
double zeroRate(double t) const override {return rate_; }
|
||||||
|
private:
|
||||||
|
double rate_ = 0.01;
|
||||||
|
};
|
||||||
|
#endif
|
||||||
93
cpp/ImpliedVolatility/Pybind.cpp
Normal file
93
cpp/ImpliedVolatility/Pybind.cpp
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
/**
|
||||||
|
* @file Pybind.cpp
|
||||||
|
* @brief pybind11 module @c qengine exposing @ref BSWrapper::bs_price_wrapper overloads.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <pybind11/numpy.h>
|
||||||
|
#include <pybind11/pybind11.h>
|
||||||
|
#include <pybind11/stl.h>
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "BSWrapper.hpp"
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
std::vector<double> to_vector_double(const py::array_t<double> &a) {
|
||||||
|
py::buffer_info info = a.request();
|
||||||
|
if (info.ndim != 1) {
|
||||||
|
throw std::runtime_error("expected 1-D ndarray for S, K, T, r, sigma");
|
||||||
|
}
|
||||||
|
const auto *p = static_cast<const double *>(info.ptr);
|
||||||
|
const ssize_t n = info.shape[0];
|
||||||
|
return std::vector<double>(p, p + n);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<bool> to_vector_bool_1d(const py::array_t<bool> &a) {
|
||||||
|
py::buffer_info info = a.request();
|
||||||
|
if (info.ndim != 1) {
|
||||||
|
throw std::runtime_error("expected 1-D ndarray for is_call");
|
||||||
|
}
|
||||||
|
if (info.itemsize != 1) {
|
||||||
|
throw std::runtime_error("is_call: expected a boolean ndarray (dtype=bool)");
|
||||||
|
}
|
||||||
|
const ssize_t n = info.shape[0];
|
||||||
|
const auto *p = static_cast<const std::uint8_t *>(info.ptr);
|
||||||
|
std::vector<bool> out(static_cast<size_t>(n));
|
||||||
|
for (ssize_t i = 0; i < n; ++i) {
|
||||||
|
out[static_cast<size_t>(i)] = (p[i] != 0);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
void check_same_length(size_t n, size_t k, const char *name) {
|
||||||
|
if (n != k) {
|
||||||
|
throw std::runtime_error(std::string("length mismatch for ") + name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
PYBIND11_MODULE(qengine, m) {
|
||||||
|
m.doc() = "Binding for the Black Scholes model";
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"bs_price",
|
||||||
|
[](double S, double K, double T, double r, double sigma, bool is_call) {
|
||||||
|
return BSWrapper::bs_price_wrapper(S, K, T, r, sigma, is_call);
|
||||||
|
},
|
||||||
|
py::arg("S"), py::arg("K"), py::arg("T"), py::arg("r"), py::arg("sigma"), py::arg("is_call"));
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"bs_price",
|
||||||
|
[](py::array_t<double> S, py::array_t<double> K, py::array_t<double> T, py::array_t<double> r,
|
||||||
|
py::array_t<double> sigma, py::array_t<bool> is_call) {
|
||||||
|
std::vector<double> vS = to_vector_double(S);
|
||||||
|
std::vector<double> vK = to_vector_double(K);
|
||||||
|
std::vector<double> vT = to_vector_double(T);
|
||||||
|
std::vector<double> vr = to_vector_double(r);
|
||||||
|
std::vector<double> vsig = to_vector_double(sigma);
|
||||||
|
std::vector<bool> vC = to_vector_bool_1d(is_call);
|
||||||
|
const size_t n = vS.size();
|
||||||
|
check_same_length(n, vK.size(), "K");
|
||||||
|
check_same_length(n, vT.size(), "T");
|
||||||
|
check_same_length(n, vr.size(), "r");
|
||||||
|
check_same_length(n, vsig.size(), "sigma");
|
||||||
|
check_same_length(n, vC.size(), "is_call");
|
||||||
|
return BSWrapper::batch_bs_price_wrapper(vS, vK, vT, vr, vsig, vC);
|
||||||
|
},
|
||||||
|
py::arg("S"), py::arg("K"), py::arg("T"), py::arg("r"), py::arg("sigma"), py::arg("is_call"));
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"bs_price",
|
||||||
|
[](const std::vector<double> &S, const std::vector<double> &K, const std::vector<double> &T,
|
||||||
|
const std::vector<double> &r, const std::vector<double> &sigma, const std::vector<bool> &is_call) {
|
||||||
|
return BSWrapper::batch_bs_price_wrapper(S, K, T, r, sigma, is_call);
|
||||||
|
},
|
||||||
|
py::arg("S"), py::arg("K"), py::arg("T"), py::arg("r"), py::arg("sigma"), py::arg("is_call"));
|
||||||
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
//
|
/**
|
||||||
// Created by David Doebel on 05.03.2026.
|
* @file Instrument.cpp
|
||||||
//
|
* @brief @ref Instrument implementation.
|
||||||
|
*/
|
||||||
|
|
||||||
#include "Instrument.hpp"
|
#include "Instrument.hpp"
|
||||||
|
|
||||||
@@ -1,15 +1,20 @@
|
|||||||
//
|
/**
|
||||||
// Created by David Doebel on 05.03.2026.
|
* @file Instrument.hpp
|
||||||
//
|
* @brief Generic derivative instrument: payoff plus pricing engine.
|
||||||
|
*/
|
||||||
|
|
||||||
#ifndef QUANTENGINE_INSTRUMENT_HPP
|
#ifndef QUANTENGINE_INSTRUMENT_HPP
|
||||||
#define QUANTENGINE_INSTRUMENT_HPP
|
#define QUANTENGINE_INSTRUMENT_HPP
|
||||||
|
#include "Exercise.hpp"
|
||||||
#include "Payoff.hpp"
|
#include "Payoff.hpp"
|
||||||
#include "PricingEngine.hpp"
|
#include "PricingEngine.hpp"
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
class PricingEngine;
|
class PricingEngine;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Represents a tradeable claim priced via a @ref PricingEngine.
|
||||||
|
*/
|
||||||
class Instrument {
|
class Instrument {
|
||||||
public:
|
public:
|
||||||
Instrument() = default;
|
Instrument() = default;
|
||||||
@@ -24,6 +29,9 @@ public:
|
|||||||
return *payoff_;
|
return *payoff_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** @brief Base @ref Instrument is treated as European unless overridden by @ref Option. */
|
||||||
|
[[nodiscard]] virtual Exercise::Type exerciseType() const { return Exercise::Type::European; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
double maturity_;
|
double maturity_;
|
||||||
std::unique_ptr<Payoff> payoff_;
|
std::unique_ptr<Payoff> payoff_;
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
//
|
/**
|
||||||
// Created by David Doebel on 06.03.2026.
|
* @file MarketData.cpp
|
||||||
//
|
* @brief @ref MarketData accessors.
|
||||||
|
*/
|
||||||
|
|
||||||
#include "MarketData.hpp"
|
#include "MarketData.hpp"
|
||||||
|
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
//
|
/**
|
||||||
// Created by David Doebel on 06.03.2026.
|
* @file MarketData.hpp
|
||||||
//
|
* @brief Spot, discount curve, and volatility surface bundle.
|
||||||
|
*/
|
||||||
|
|
||||||
#ifndef QUANTENGINE_MARKETDATA_HPP
|
#ifndef QUANTENGINE_MARKETDATA_HPP
|
||||||
#define QUANTENGINE_MARKETDATA_HPP
|
#define QUANTENGINE_MARKETDATA_HPP
|
||||||
@@ -8,6 +9,9 @@
|
|||||||
#include "VolatilitySurface.hpp"
|
#include "VolatilitySurface.hpp"
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Immutable snapshot of inputs needed to simulate or price.
|
||||||
|
*/
|
||||||
class MarketData {
|
class MarketData {
|
||||||
public:
|
public:
|
||||||
MarketData() = delete;
|
MarketData() = delete;
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
//
|
/**
|
||||||
// Created by David Doebel on 05.03.2026.
|
* @file MonteCarloEngine.cpp
|
||||||
//
|
* @brief Monte Carlo mean estimator with discounting.
|
||||||
|
*/
|
||||||
|
|
||||||
#include "MonteCarloEngine.hpp"
|
#include "MonteCarloEngine.hpp"
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
@@ -1,13 +1,16 @@
|
|||||||
//
|
/**
|
||||||
// Created by David Doebel on 05.03.2026.
|
* @file MonteCarloEngine.hpp
|
||||||
//
|
* @brief Monte Carlo pricing using a @ref StochasticProcess and @ref RandomGenerator.
|
||||||
|
*/
|
||||||
|
|
||||||
#ifndef QUANTENGINE_MONTECARLOENGINE_HPP
|
#ifndef QUANTENGINE_MONTECARLOENGINE_HPP
|
||||||
#define QUANTENGINE_MONTECARLOENGINE_HPP
|
#define QUANTENGINE_MONTECARLOENGINE_HPP
|
||||||
#include "PricingEngine.hpp"
|
#include "PricingEngine.hpp"
|
||||||
#include "RandomGenerator.hpp"
|
#include "RandomGenerator.hpp"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Simple path simulation: one Euler/exact step to horizon, average discounted payoff.
|
||||||
|
*/
|
||||||
class MonteCarloEngine : public PricingEngine{
|
class MonteCarloEngine : public PricingEngine{
|
||||||
public:
|
public:
|
||||||
MonteCarloEngine() = default;
|
MonteCarloEngine() = default;
|
||||||
8
cpp/NewtonSolver.cpp
Normal file
8
cpp/NewtonSolver.cpp
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
/**
|
||||||
|
* @file NewtonSolver.cpp
|
||||||
|
* @brief Placeholder translation unit for @ref NewtonSolver.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "NewtonSolver.hpp"
|
||||||
|
|
||||||
|
|
||||||
30
cpp/NewtonSolver.hpp
Normal file
30
cpp/NewtonSolver.hpp
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
/**
|
||||||
|
* @file NewtonSolver.hpp
|
||||||
|
* @brief Generic Newton iteration helper (incomplete / reserved for solvers).
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef QUANTENGINE_GAUSSSOLVER_HPP
|
||||||
|
#define QUANTENGINE_GAUSSSOLVER_HPP
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Template Newton step loop with relative/absolute tolerances.
|
||||||
|
*/
|
||||||
|
class NewtonSolver {
|
||||||
|
template<typename F, typename DFinv, typename T>
|
||||||
|
bool solve(F&& func, DFinv&& dfinv,T x0 , double rtol, double atol) {
|
||||||
|
T x = x0;
|
||||||
|
int i = 0;
|
||||||
|
T increment;
|
||||||
|
do {
|
||||||
|
increment = dfinv(x) * func(x);
|
||||||
|
x -= increment;
|
||||||
|
++i;
|
||||||
|
} while (i < 1000 && std::abs(increment)/ std::abs(x) > rtol && std::abs(increment) > atol);
|
||||||
|
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
#endif //QUANTENGINE_GAUSSSOLVER_HPP
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
//
|
/**
|
||||||
// Created by David Doebel on 05.03.2026.
|
* @file Option.cpp
|
||||||
//
|
* @brief @ref Option implementation.
|
||||||
|
*/
|
||||||
|
|
||||||
#include "Option.hpp"
|
#include "Option.hpp"
|
||||||
|
|
||||||
@@ -1,12 +1,16 @@
|
|||||||
//
|
/**
|
||||||
// Created by David Doebel on 05.03.2026.
|
* @file Option.hpp
|
||||||
//
|
* @brief Option instrument with exercise style (@ref Exercise).
|
||||||
|
*/
|
||||||
|
|
||||||
#ifndef QUANTENGINE_OPTION_HPP
|
#ifndef QUANTENGINE_OPTION_HPP
|
||||||
#define QUANTENGINE_OPTION_HPP
|
#define QUANTENGINE_OPTION_HPP
|
||||||
#include "Instrument.hpp"
|
#include "Instrument.hpp"
|
||||||
#include "Exercise.hpp"
|
#include "Exercise.hpp"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Extends @ref Instrument with exercise schedule / style metadata.
|
||||||
|
*/
|
||||||
class Option : public Instrument{
|
class Option : public Instrument{
|
||||||
public:
|
public:
|
||||||
Option() = default;
|
Option() = default;
|
||||||
@@ -17,10 +21,13 @@ public:
|
|||||||
return *exercise_;
|
return *exercise_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] Exercise::Type exerciseType() const override { return exercise_->type(); }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::unique_ptr<Exercise> exercise_;
|
std::unique_ptr<Exercise> exercise_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/** @brief Plain-vanilla option using the base @ref Option constructor. */
|
||||||
class VanillaOption : public Option {
|
class VanillaOption : public Option {
|
||||||
public:
|
public:
|
||||||
using Option::Option;
|
using Option::Option;
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
//
|
/**
|
||||||
// Created by David Doebel on 05.03.2026.
|
* @file Payoff.cpp
|
||||||
//
|
* @brief Payoff function implementations.
|
||||||
|
*/
|
||||||
|
|
||||||
#include "Payoff.hpp"
|
#include "Payoff.hpp"
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
@@ -1,11 +1,19 @@
|
|||||||
//
|
/**
|
||||||
// Created by David Doebel on 05.03.2026.
|
* @file Payoff.hpp
|
||||||
//
|
* @brief Payoff interface and standard European payoffs (call, put, digital).
|
||||||
|
*/
|
||||||
|
|
||||||
#ifndef QUANTENGINE_PAYOFF_HPP
|
#ifndef QUANTENGINE_PAYOFF_HPP
|
||||||
#define QUANTENGINE_PAYOFF_HPP
|
#define QUANTENGINE_PAYOFF_HPP
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Standard payoff shapes for routing (e.g. analytic vs Monte Carlo).
|
||||||
|
*/
|
||||||
|
enum class PayoffKind { Call, Put, Digital };
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Terminal payoff as a function of spot @f$S_T@f$.
|
||||||
|
*/
|
||||||
class Payoff {
|
class Payoff {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
@@ -14,35 +22,42 @@ public:
|
|||||||
virtual ~Payoff() = default;
|
virtual ~Payoff() = default;
|
||||||
virtual double operator()(double S) = 0;
|
virtual double operator()(double S) = 0;
|
||||||
virtual double strike() = 0;
|
virtual double strike() = 0;
|
||||||
|
[[nodiscard]] virtual PayoffKind kind() const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/** @brief Standard European call @f$\max(S-K,0)@f$. */
|
||||||
class CallPayoff : public Payoff {
|
class CallPayoff : public Payoff {
|
||||||
public:
|
public:
|
||||||
CallPayoff() = default;
|
CallPayoff() = default;
|
||||||
CallPayoff(double strike) : strike_(strike) {}
|
CallPayoff(double strike) : strike_(strike) {}
|
||||||
double operator()(double S) override;
|
double operator()(double S) override;
|
||||||
double strike() override {return strike_;}
|
double strike() override {return strike_;}
|
||||||
|
[[nodiscard]] PayoffKind kind() const override { return PayoffKind::Call; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
double strike_;
|
double strike_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/** @brief Standard European put @f$\max(K-S,0)@f$. */
|
||||||
class PutPayoff : public Payoff {
|
class PutPayoff : public Payoff {
|
||||||
public:
|
public:
|
||||||
PutPayoff() = default;
|
PutPayoff() = default;
|
||||||
PutPayoff(double strike) : strike_(strike) {}
|
PutPayoff(double strike) : strike_(strike) {}
|
||||||
double operator()(double S) override;
|
double operator()(double S) override;
|
||||||
double strike() override {return strike_;}
|
double strike() override {return strike_;}
|
||||||
|
[[nodiscard]] PayoffKind kind() const override { return PayoffKind::Put; }
|
||||||
private:
|
private:
|
||||||
double strike_;
|
double strike_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/** @brief Digital (cash-or-nothing style) payoff @f$1_{S>K}@f$. */
|
||||||
class DigitalPayoff : public Payoff {
|
class DigitalPayoff : public Payoff {
|
||||||
public:
|
public:
|
||||||
DigitalPayoff() = default;
|
DigitalPayoff() = default;
|
||||||
DigitalPayoff(double strike) : strike_(strike) {}
|
DigitalPayoff(double strike) : strike_(strike) {}
|
||||||
double operator()(double S) override;
|
double operator()(double S) override;
|
||||||
double strike() override {return strike_;}
|
double strike() override {return strike_;}
|
||||||
|
[[nodiscard]] PayoffKind kind() const override { return PayoffKind::Digital; }
|
||||||
private:
|
private:
|
||||||
double strike_;
|
double strike_;
|
||||||
};
|
};
|
||||||
6
cpp/PricingEngine.cpp
Normal file
6
cpp/PricingEngine.cpp
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
/**
|
||||||
|
* @file PricingEngine.cpp
|
||||||
|
* @brief @ref PricingEngine translation unit (interface only).
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "PricingEngine.hpp"
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
//
|
/**
|
||||||
// Created by David Doebel on 05.03.2026.
|
* @file PricingEngine.hpp
|
||||||
//
|
* @brief Abstract pricer for @ref Instrument given a stochastic model.
|
||||||
|
*/
|
||||||
|
|
||||||
#ifndef QUANTENGINE_PRICINGENGINE_HPP
|
#ifndef QUANTENGINE_PRICINGENGINE_HPP
|
||||||
#define QUANTENGINE_PRICINGENGINE_HPP
|
#define QUANTENGINE_PRICINGENGINE_HPP
|
||||||
@@ -10,6 +11,9 @@
|
|||||||
|
|
||||||
class Instrument;
|
class Instrument;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Computes model price of an instrument (e.g. Monte Carlo, PDE, closed form).
|
||||||
|
*/
|
||||||
class PricingEngine {
|
class PricingEngine {
|
||||||
public:
|
public:
|
||||||
PricingEngine() = default;
|
PricingEngine() = default;
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
//
|
/**
|
||||||
// Created by David Doebel on 06.03.2026.
|
* @file RandomGenerator.cpp
|
||||||
//
|
* @brief @ref MersenneTwister implementation.
|
||||||
|
*/
|
||||||
|
|
||||||
#include "RandomGenerator.hpp"
|
#include "RandomGenerator.hpp"
|
||||||
|
|
||||||
@@ -1,11 +1,13 @@
|
|||||||
//
|
/**
|
||||||
// Created by David Doebel on 06.03.2026.
|
* @file RandomGenerator.hpp
|
||||||
//
|
* @brief Random numbers for Monte Carlo (Gaussian draws).
|
||||||
|
*/
|
||||||
|
|
||||||
#ifndef QUANTENGINE_RANDOMGENERATOR_HPP
|
#ifndef QUANTENGINE_RANDOMGENERATOR_HPP
|
||||||
#define QUANTENGINE_RANDOMGENERATOR_HPP
|
#define QUANTENGINE_RANDOMGENERATOR_HPP
|
||||||
#include <random>
|
#include <random>
|
||||||
|
|
||||||
|
/** @brief Interface for standard normal variates. */
|
||||||
class RandomGenerator {
|
class RandomGenerator {
|
||||||
public:
|
public:
|
||||||
RandomGenerator() = default;
|
RandomGenerator() = default;
|
||||||
@@ -14,6 +16,7 @@ public:
|
|||||||
virtual std::vector<double> nextGaussianVector(std::size_t n) = 0;
|
virtual std::vector<double> nextGaussianVector(std::size_t n) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/** @brief @c std::mt19937 with normal distribution. */
|
||||||
class MersenneTwister : public RandomGenerator {
|
class MersenneTwister : public RandomGenerator {
|
||||||
public:
|
public:
|
||||||
MersenneTwister() = default;
|
MersenneTwister() = default;
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
//
|
/**
|
||||||
// Created by David Doebel on 06.03.2026.
|
* @file Statistics.cpp
|
||||||
//
|
* @brief Streaming moment and extrema updates.
|
||||||
|
*/
|
||||||
|
|
||||||
#include "Statistics.hpp"
|
#include "Statistics.hpp"
|
||||||
|
|
||||||
@@ -1,12 +1,15 @@
|
|||||||
//
|
/**
|
||||||
// Created by David Doebel on 06.03.2026.
|
* @file Statistics.hpp
|
||||||
//
|
* @brief Online sample moments for Monte Carlo diagnostics.
|
||||||
|
*/
|
||||||
|
|
||||||
#ifndef QUANTENGINE_STATISTICS_HPP
|
#ifndef QUANTENGINE_STATISTICS_HPP
|
||||||
#define QUANTENGINE_STATISTICS_HPP
|
#define QUANTENGINE_STATISTICS_HPP
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Accumulates count, mean/variance-related sums, and running min/max.
|
||||||
|
*/
|
||||||
class Statistics {
|
class Statistics {
|
||||||
public:
|
public:
|
||||||
Statistics() : moments_({0., 0., 0.}), n(0), max_(0.), min_(0.) {}
|
Statistics() : moments_({0., 0., 0.}), n(0), max_(0.), min_(0.) {}
|
||||||
6
cpp/StochasticProcess.cpp
Normal file
6
cpp/StochasticProcess.cpp
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
/**
|
||||||
|
* @file StochasticProcess.cpp
|
||||||
|
* @brief @ref StochasticProcess translation unit (interface only).
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "StochasticProcess.hpp"
|
||||||
@@ -1,12 +1,16 @@
|
|||||||
//
|
/**
|
||||||
// Created by David Doebel on 05.03.2026.
|
* @file StochasticProcess.hpp
|
||||||
//
|
* @brief Interface for SDE drift, diffusion, and time stepping.
|
||||||
|
*/
|
||||||
|
|
||||||
#ifndef QUANTENGINE_STOCHASTICPROCESS_HPP
|
#ifndef QUANTENGINE_STOCHASTICPROCESS_HPP
|
||||||
#define QUANTENGINE_STOCHASTICPROCESS_HPP
|
#define QUANTENGINE_STOCHASTICPROCESS_HPP
|
||||||
#include "MarketData.hpp"
|
#include "MarketData.hpp"
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Stochastic model for the underlying, driven by @ref MarketData.
|
||||||
|
*/
|
||||||
class StochasticProcess {
|
class StochasticProcess {
|
||||||
public:
|
public:
|
||||||
StochasticProcess() = delete;
|
StochasticProcess() = delete;
|
||||||
6
cpp/VolatilitySurface.cpp
Normal file
6
cpp/VolatilitySurface.cpp
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
/**
|
||||||
|
* @file VolatilitySurface.cpp
|
||||||
|
* @brief @ref VolatilitySurface translation unit (interface only).
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "VolatilitySurface.hpp"
|
||||||
28
cpp/VolatilitySurface.hpp
Normal file
28
cpp/VolatilitySurface.hpp
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
/**
|
||||||
|
* @file VolatilitySurface.hpp
|
||||||
|
* @brief Implied volatility as a function of strike and expiry.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef QUANTENGINE_VOLATILITYSURFACE_HPP
|
||||||
|
#define QUANTENGINE_VOLATILITYSURFACE_HPP
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Local/vol surface @f$\sigma(K,T)@f$ used by simulation.
|
||||||
|
*/
|
||||||
|
class VolatilitySurface {
|
||||||
|
public:
|
||||||
|
virtual ~VolatilitySurface() = default;
|
||||||
|
virtual double sigma(double K, double T) const = 0;
|
||||||
|
private:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
class SVI : public VolatilitySurface {
|
||||||
|
public:
|
||||||
|
SVI() = default;
|
||||||
|
SVI(std::vector<double> K, std::vector<double> rho, std::vector<double> S, std::vector<double> T);
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
#endif //QUANTENGINE_VOLATILITYSURFACE_HPP
|
||||||
6
cpp/YieldCurve.cpp
Normal file
6
cpp/YieldCurve.cpp
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
/**
|
||||||
|
* @file YieldCurve.cpp
|
||||||
|
* @brief @ref YieldCurve translation unit (interface only).
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "YieldCurve.hpp"
|
||||||
@@ -1,11 +1,14 @@
|
|||||||
//
|
/**
|
||||||
// Created by David Doebel on 06.03.2026.
|
* @file YieldCurve.hpp
|
||||||
//
|
* @brief Abstract yield curve: discount factors and zero rates.
|
||||||
|
*/
|
||||||
|
|
||||||
#ifndef QUANTENGINE_YIELDCURVE_HPP
|
#ifndef QUANTENGINE_YIELDCURVE_HPP
|
||||||
#define QUANTENGINE_YIELDCURVE_HPP
|
#define QUANTENGINE_YIELDCURVE_HPP
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Risk-free rate term structure for discounting and risk-neutral drift.
|
||||||
|
*/
|
||||||
class YieldCurve {
|
class YieldCurve {
|
||||||
public:
|
public:
|
||||||
YieldCurve() = default;
|
YieldCurve() = default;
|
||||||
50
docs/Doxyfile
Normal file
50
docs/Doxyfile
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
# Doxygen configuration for QuantEngine (option_pricing).
|
||||||
|
# Run from repo root: doxygen docs/Doxyfile
|
||||||
|
# Or: cmake --build build --target docs
|
||||||
|
|
||||||
|
PROJECT_NAME = QuantEngine
|
||||||
|
PROJECT_BRIEF = "Monte Carlo option pricing, market data abstractions, and Python bindings"
|
||||||
|
|
||||||
|
OUTPUT_DIRECTORY = docs/html
|
||||||
|
CREATE_SUBDIRS = NO
|
||||||
|
ALLOW_UNICODE_NAMES = YES
|
||||||
|
|
||||||
|
JAVADOC_AUTOBRIEF = YES
|
||||||
|
QT_AUTOBRIEF = NO
|
||||||
|
OPTIMIZE_OUTPUT_FOR_CPLUSPLUS = YES
|
||||||
|
|
||||||
|
FULL_PATH_NAMES = YES
|
||||||
|
STRIP_FROM_PATH =
|
||||||
|
|
||||||
|
QUIET = NO
|
||||||
|
WARNINGS = YES
|
||||||
|
WARN_IF_UNDOCUMENTED = NO
|
||||||
|
WARN_NO_PARAMDOC = NO
|
||||||
|
|
||||||
|
INPUT = cpp
|
||||||
|
INPUT_ENCODING = UTF-8
|
||||||
|
FILE_PATTERNS = *.cpp *.hpp *.h
|
||||||
|
RECURSIVE = YES
|
||||||
|
|
||||||
|
EXCLUDE_PATTERNS =
|
||||||
|
EXCLUDE_SYMBOLS =
|
||||||
|
|
||||||
|
GENERATE_HTML = YES
|
||||||
|
HTML_OUTPUT = .
|
||||||
|
HTML_COLORSTYLE_HUE = 220
|
||||||
|
GENERATE_LATEX = NO
|
||||||
|
|
||||||
|
SEARCHENGINE = YES
|
||||||
|
|
||||||
|
SOURCE_BROWSER = YES
|
||||||
|
REFERENCED_BY_RELATION = YES
|
||||||
|
REFERENCES_RELATION = YES
|
||||||
|
|
||||||
|
ALPHABETICAL_INDEX = YES
|
||||||
|
ENABLE_PREPROCESSING = YES
|
||||||
|
MACRO_EXPANSION = NO
|
||||||
|
|
||||||
|
CLASS_DIAGRAMS = YES
|
||||||
|
HAVE_DOT = NO
|
||||||
|
|
||||||
|
PREDEFINED = DOXYGEN
|
||||||
27
docs/SECURITY.md
Normal file
27
docs/SECURITY.md
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
# Security Checklist
|
||||||
|
|
||||||
|
## Secrets handling
|
||||||
|
|
||||||
|
- Never commit `.env` or any file containing credentials.
|
||||||
|
- Use `.env.example` for non-sensitive defaults only.
|
||||||
|
- Set DB credentials through environment variables.
|
||||||
|
- Rotate credentials if they have ever appeared in git history.
|
||||||
|
|
||||||
|
## Database hardening
|
||||||
|
|
||||||
|
- Use a dedicated runtime user with least required privileges.
|
||||||
|
- Keep administrative users separate from ingestion users.
|
||||||
|
- Restrict DB network access to trusted hosts/VPC/private network.
|
||||||
|
- Enable SSL/TLS for non-local database connections.
|
||||||
|
|
||||||
|
## Publication readiness
|
||||||
|
|
||||||
|
Before making the repository public:
|
||||||
|
|
||||||
|
1. Confirm `git status` has no secret files staged.
|
||||||
|
2. Search for potential secret patterns:
|
||||||
|
- passwords
|
||||||
|
- API keys
|
||||||
|
- tokens
|
||||||
|
3. Verify `.gitignore` includes local secret files (`.env*`).
|
||||||
|
4. Regenerate credentials used during development.
|
||||||
60
docs/SETUP.md
Normal file
60
docs/SETUP.md
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
# Setup Guide
|
||||||
|
|
||||||
|
This guide describes a clean local setup for development and reproducible runs.
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
- Python 3.10+
|
||||||
|
- CMake 3.16+
|
||||||
|
- A C++20 compiler
|
||||||
|
- PostgreSQL 14+ (or Docker)
|
||||||
|
- On macOS, Homebrew packages for C++ DB support:
|
||||||
|
- `libpq`
|
||||||
|
- `libpqxx`
|
||||||
|
- `eigen`
|
||||||
|
- `pybind11`
|
||||||
|
|
||||||
|
## Python dependencies
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 -m venv .venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
pip install --upgrade pip
|
||||||
|
pip install -e .
|
||||||
|
pip install pandas yfinance sqlalchemy psycopg2-binary matplotlib scipy
|
||||||
|
```
|
||||||
|
|
||||||
|
## Environment configuration
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cp .env.example .env
|
||||||
|
```
|
||||||
|
|
||||||
|
Edit `.env` and set:
|
||||||
|
|
||||||
|
- `DB_HOST`, `DB_PORT`, `DB_NAME`, `DB_USER`, `DB_PASSWORD`
|
||||||
|
- `PIPELINE_SYMBOLS`
|
||||||
|
- admin credentials used only by setup script (`POSTGRES_ADMIN_*`)
|
||||||
|
|
||||||
|
## Database bootstrap
|
||||||
|
|
||||||
|
```bash
|
||||||
|
source .env
|
||||||
|
python scripts/setup_postgres.py
|
||||||
|
```
|
||||||
|
|
||||||
|
The script is idempotent and safe to rerun.
|
||||||
|
|
||||||
|
## Build and test C++
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cmake -S . -B build
|
||||||
|
cmake --build build -j
|
||||||
|
ctest --test-dir build --output-on-failure
|
||||||
|
```
|
||||||
|
|
||||||
|
## Generate Doxygen docs
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cmake --build build --target docs
|
||||||
|
```
|
||||||
164
electricity_price_predictor/README.md
Normal file
164
electricity_price_predictor/README.md
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
# Electricity Price Predictor
|
||||||
|
|
||||||
|
Standalone module for ENTSO-E ingestion and feature-store creation in `quant_db`.
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
|
||||||
|
- High-level architecture: `docs/architecture.md`
|
||||||
|
- Detailed developer documentation (file-by-file + UML): `docs/developer_guide.md`
|
||||||
|
|
||||||
|
## What it builds
|
||||||
|
|
||||||
|
- Input columns:
|
||||||
|
- `day_ahead_price`
|
||||||
|
- `load_forecast`
|
||||||
|
- `wind_forecast`
|
||||||
|
- `solar_forecast`
|
||||||
|
- Derived columns:
|
||||||
|
- `residual_load`
|
||||||
|
- `lagged_price` (`t-1` ... `t-24`, stored as array length 24)
|
||||||
|
- `lagged_residual_load` (`t-1` ... `t-24`, stored as array length 24)
|
||||||
|
- `hour_of_day_sin`, `hour_of_day_cos`
|
||||||
|
- `weekday_sin`, `weekday_cos`
|
||||||
|
- `month_sin`, `month_cos`
|
||||||
|
|
||||||
|
## Column units (data dictionary)
|
||||||
|
|
||||||
|
All timestamps are hourly and stored in UTC (`delivery_start`).
|
||||||
|
|
||||||
|
- `day_ahead_price`
|
||||||
|
- Unit: `EUR/MWh` (euros per megawatt-hour).
|
||||||
|
- `load_forecast`
|
||||||
|
- Unit: `MW` (megawatts).
|
||||||
|
- `wind_forecast`
|
||||||
|
- Unit: `MW` (megawatts).
|
||||||
|
- `solar_forecast`
|
||||||
|
- Unit: `MW` (megawatts).
|
||||||
|
- `residual_load = load_forecast - wind_forecast - solar_forecast`
|
||||||
|
- Unit: `MW` (megawatts).
|
||||||
|
- `lagged_price` (array of 24 values for `t-1..t-24`)
|
||||||
|
- Unit: `EUR/MWh`.
|
||||||
|
- `lagged_residual_load` (array of 24 values for `t-1..t-24`)
|
||||||
|
- Unit: `MW`.
|
||||||
|
- `hour_of_day_sin`, `hour_of_day_cos`
|
||||||
|
- Unit: dimensionless in `[-1, 1]`.
|
||||||
|
- `weekday_sin`, `weekday_cos`
|
||||||
|
- Unit: dimensionless in `[-1, 1]`.
|
||||||
|
- `month_sin`, `month_cos`
|
||||||
|
- Unit: dimensionless in `[-1, 1]`.
|
||||||
|
|
||||||
|
## Missing-data semantics
|
||||||
|
|
||||||
|
The pipeline intentionally distinguishes **missing** from **measured zero**:
|
||||||
|
|
||||||
|
- Forecast columns (`load_forecast`, `wind_forecast`, `solar_forecast`) remain `NaN` when ENTSO-E has no value.
|
||||||
|
- `residual_load` is computed directly from source columns and remains `NaN` when any source component is missing.
|
||||||
|
- Feature engineering drops rows only for lag warmup requirements (`day_ahead_price` and `lagged_price_1..24`), not for every nullable forecast column.
|
||||||
|
- During DB persistence to `electricity_price_features`, rows with nulls in NOT NULL core columns are skipped (to satisfy schema constraints) while still being available in the returned in-memory DataFrame.
|
||||||
|
|
||||||
|
## Data contracts
|
||||||
|
|
||||||
|
### In-memory contract (`run_feature_pipeline(...)` return value)
|
||||||
|
|
||||||
|
- Index:
|
||||||
|
- Type: timezone-aware `DatetimeIndex`
|
||||||
|
- Granularity: hourly
|
||||||
|
- Timezone: UTC
|
||||||
|
- Uniqueness: unique timestamps expected
|
||||||
|
- Columns:
|
||||||
|
- Base signals: `day_ahead_price`, `load_forecast`, `wind_forecast`, `solar_forecast`
|
||||||
|
- Derived: `residual_load`
|
||||||
|
- Lag vectors (expanded): `lagged_price_1..24`, `lagged_residual_load_1..24`
|
||||||
|
- Cyclical: `hour_of_day_sin`, `hour_of_day_cos`, `weekday_sin`, `weekday_cos`, `month_sin`, `month_cos`
|
||||||
|
- Nullability:
|
||||||
|
- `day_ahead_price`: required for returned rows.
|
||||||
|
- `lagged_price_1..24`: required for returned rows.
|
||||||
|
- `load_forecast`, `wind_forecast`, `solar_forecast`, `residual_load`, and `lagged_residual_load_*`: nullable (`NaN` allowed).
|
||||||
|
|
||||||
|
### Persistence contract (`electricity_price_features`)
|
||||||
|
|
||||||
|
- Persisted row key:
|
||||||
|
- (`country_code`, `delivery_start`, `feature_version`)
|
||||||
|
- NOT NULL core columns in schema:
|
||||||
|
- `day_ahead_price`, `load_forecast`, `wind_forecast`, `solar_forecast`, `residual_load`
|
||||||
|
- `lagged_price`, `lagged_residual_load`
|
||||||
|
- cyclical columns (`hour_of_day_*`, `weekday_*`, `month_*`)
|
||||||
|
- Write behavior:
|
||||||
|
- The persistence layer filters out non-conforming rows (nulls in NOT NULL core columns) before UPSERT.
|
||||||
|
- Persisted lag arrays are fixed length 24 and map to `t-1..t-24`.
|
||||||
|
|
||||||
|
### Raw observations contract (`electricity_market_observations`)
|
||||||
|
|
||||||
|
- Indexing/key semantics:
|
||||||
|
- one row per (`country_code`, `delivery_start`)
|
||||||
|
- Update semantics:
|
||||||
|
- partial refreshes do not overwrite existing non-null values with nulls (`COALESCE` merge policy)
|
||||||
|
|
||||||
|
## Country-code and API behavior notes
|
||||||
|
|
||||||
|
- Use ENTSO-E bidding-zone identifiers (for example `DE_LU` rather than `DE`) when querying across all endpoints.
|
||||||
|
- The pipeline includes a bidding-zone resolver for common country aliases:
|
||||||
|
- `DE -> DE_LU`
|
||||||
|
- `IT -> IT_NORD`
|
||||||
|
- Persistence uses the resolved bidding-zone code so DB keys match the queried market zone.
|
||||||
|
- Some ENTSO-E endpoints may return no matches for specific windows/countries. These are handled as empty hourly frames for that endpoint rather than hard-failing the whole fetch.
|
||||||
|
- Wind/solar responses can include duplicate semantic columns after normalization; the service coalesces duplicates by taking the first non-null value per timestamp.
|
||||||
|
|
||||||
|
## Database objects
|
||||||
|
|
||||||
|
`sql/001_electricity_price_schema.sql` creates:
|
||||||
|
|
||||||
|
- `entsoe_api_cache`: generic decorator cache table (pickled payloads, TTL support)
|
||||||
|
- `electricity_market_observations`: raw hourly ENTSO-E observations
|
||||||
|
- `electricity_price_features`: model-ready feature store
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd electricity_price_predictor
|
||||||
|
python -m venv .venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
Set env vars:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export ENTSOE_API_KEY="your_entsoe_key"
|
||||||
|
export QUANT_DB_HOST="localhost"
|
||||||
|
export QUANT_DB_PORT="5432"
|
||||||
|
export QUANT_DB_NAME="quant_db"
|
||||||
|
export QUANT_DB_USER="quant_user"
|
||||||
|
export QUANT_DB_PASSWORD="strong_password"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Initialize schema
|
||||||
|
|
||||||
|
```bash
|
||||||
|
PYTHONPATH=src python scripts/init_db.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Build feature store
|
||||||
|
|
||||||
|
```bash
|
||||||
|
PYTHONPATH=src python scripts/build_feature_store.py \
|
||||||
|
--country-code DE_LU \
|
||||||
|
--start 2026-01-01T00:00:00Z \
|
||||||
|
--end 2026-02-01T00:00:00Z \
|
||||||
|
--cache-ttl-hours 24
|
||||||
|
```
|
||||||
|
|
||||||
|
## Decorator caching behavior
|
||||||
|
|
||||||
|
The `cache_to_db` decorator:
|
||||||
|
|
||||||
|
- hashes function name + arguments into deterministic `cache_key`
|
||||||
|
- checks `entsoe_api_cache` first
|
||||||
|
- returns cached payload if key exists and not expired
|
||||||
|
- otherwise uses `electricity_market_observations` as secondary cache at timestamp level
|
||||||
|
- only calls ENTSO-E for missing hourly intervals, then upserts those rows
|
||||||
|
- stores final returned object in `entsoe_api_cache`
|
||||||
|
|
||||||
|
This gives a two-layer cache:
|
||||||
|
1) fast function-result cache (`entsoe_api_cache`) and
|
||||||
|
2) canonical timestamp cache (`electricity_market_observations`).
|
||||||
73
electricity_price_predictor/docs/architecture.md
Normal file
73
electricity_price_predictor/docs/architecture.md
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
# Architecture Notes
|
||||||
|
|
||||||
|
This document is the quick architecture map. For full file-by-file implementation details, see `docs/developer_guide.md`.
|
||||||
|
|
||||||
|
## End-to-end data flow
|
||||||
|
|
||||||
|
1. `scripts/build_feature_store.py` parses CLI arguments and validates env vars.
|
||||||
|
2. It calls `pipeline.run_feature_pipeline(...)`.
|
||||||
|
3. `EntsoeDataService.fetch_inputs(...)` loads:
|
||||||
|
- `day_ahead_price`
|
||||||
|
- `load_forecast`
|
||||||
|
- `wind_forecast`
|
||||||
|
- `solar_forecast`
|
||||||
|
4. Each ENTSO-E call is wrapped by `cache_to_db(...)` and either:
|
||||||
|
- serves a hit from `entsoe_api_cache`, or
|
||||||
|
- falls back to `electricity_market_observations` for already-known timestamps,
|
||||||
|
- and performs API calls only for missing hourly intervals.
|
||||||
|
5. Missing intervals returned from API are upserted into `electricity_market_observations`.
|
||||||
|
6. The final merged result is cached in `entsoe_api_cache`.
|
||||||
|
7. Raw merged series are upserted to `electricity_market_observations`.
|
||||||
|
8. `features.build_feature_frame(...)` computes:
|
||||||
|
- `residual_load`
|
||||||
|
- lagged arrays (24 values each)
|
||||||
|
- cyclical encodings for hour/weekday/month
|
||||||
|
- preserves `NaN` for missing forecast-derived values.
|
||||||
|
9. `pipeline.persist_feature_frame(...)` upserts model-ready rows to `electricity_price_features`.
|
||||||
|
- filters out rows that violate feature-table NOT NULL constraints.
|
||||||
|
|
||||||
|
## Process diagram
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
flowchart TD
|
||||||
|
A[build_feature_store.py CLI] --> B[run_feature_pipeline]
|
||||||
|
B --> C[EntsoeDataService.fetch_inputs]
|
||||||
|
C --> D{Hit in entsoe_api_cache?}
|
||||||
|
D -->|Yes| E[Load payload from entsoe_api_cache]
|
||||||
|
D -->|No| F[Read electricity_market_observations]
|
||||||
|
F --> G{Missing hourly timestamps?}
|
||||||
|
G -->|No| H[Reuse DB observation rows]
|
||||||
|
G -->|Yes| I[Call ENTSO-E only for missing ranges]
|
||||||
|
I --> I2{NoMatchingDataError?}
|
||||||
|
I2 -->|Yes| I3[Use empty hourly frame for endpoint]
|
||||||
|
I2 -->|No| I4[Normalize payload]
|
||||||
|
I4 --> I5[Coalesce duplicate columns by first non-null]
|
||||||
|
I3 --> J[Upsert missing rows to electricity_market_observations]
|
||||||
|
I5 --> J
|
||||||
|
H --> K[Build merged input DataFrame]
|
||||||
|
J --> K
|
||||||
|
K --> L[Store payload in entsoe_api_cache]
|
||||||
|
E --> M[Use cached input DataFrame]
|
||||||
|
L --> N[Upsert electricity_market_observations]
|
||||||
|
M --> N
|
||||||
|
N --> O[build_feature_frame]
|
||||||
|
O --> P[Create lags + cyclical features]
|
||||||
|
P --> P2[Preserve NaN in forecast-derived columns]
|
||||||
|
P2 --> P3[Drop rows missing day_ahead_price or lagged_price_1..24]
|
||||||
|
P3 --> Q[Upsert persistable subset into electricity_price_features]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Key design reasons
|
||||||
|
|
||||||
|
- DB cache avoids repeated ENTSO-E calls during iterative model work.
|
||||||
|
- Observation-table fallback avoids re-fetching timestamps already persisted once.
|
||||||
|
- Pickled payloads preserve exact pandas object shape and index information.
|
||||||
|
- Feature table stores fixed-size lag arrays so one row corresponds to one prediction timestamp.
|
||||||
|
- Missing forecasts are kept as `NaN` in analysis outputs, avoiding misleading zero-imputation.
|
||||||
|
- Persistence layer enforces schema compatibility by skipping rows with nulls in NOT NULL feature columns.
|
||||||
|
|
||||||
|
## Extension points
|
||||||
|
|
||||||
|
- Add label/target tables (`t+1`, `t+24`, etc.).
|
||||||
|
- Add training metadata + model registry tables.
|
||||||
|
- Add partitioning strategy for multi-year production-scale data.
|
||||||
362
electricity_price_predictor/docs/developer_guide.md
Normal file
362
electricity_price_predictor/docs/developer_guide.md
Normal file
@@ -0,0 +1,362 @@
|
|||||||
|
# Developer Guide (Deep Dive)
|
||||||
|
|
||||||
|
This guide explains each file in the module, execution order, control flow, and data/state transitions so you can reason about behavior without reading source code.
|
||||||
|
|
||||||
|
## 1) Directory map and responsibilities
|
||||||
|
|
||||||
|
### Top-level
|
||||||
|
|
||||||
|
- `requirements.txt`
|
||||||
|
- Python dependencies for ingestion and DB persistence.
|
||||||
|
- `README.md`
|
||||||
|
- Operator-focused setup and run commands.
|
||||||
|
- `sql/001_electricity_price_schema.sql`
|
||||||
|
- DDL for cache, raw observations, and feature store.
|
||||||
|
- `scripts/init_db.py`
|
||||||
|
- Applies the SQL schema to `quant_db`.
|
||||||
|
- `scripts/build_feature_store.py`
|
||||||
|
- CLI entrypoint for data fetch + feature persistence.
|
||||||
|
- `docs/architecture.md`
|
||||||
|
- High-level architecture summary.
|
||||||
|
- `docs/developer_guide.md`
|
||||||
|
- This detailed developer-facing explanation.
|
||||||
|
|
||||||
|
### Python package (`src/electricity_price_predictor`)
|
||||||
|
|
||||||
|
- `__init__.py`
|
||||||
|
- Public package exports (`get_engine`, `EntsoeDataService`, `build_feature_frame`).
|
||||||
|
- `db.py`
|
||||||
|
- Builds DB URL from env vars and creates SQLAlchemy `Engine`.
|
||||||
|
- `cache.py`
|
||||||
|
- Implements decorator-based DB cache with deterministic keying.
|
||||||
|
- `entsoe_api.py`
|
||||||
|
- Wraps ENTSO-E API calls, normalizes data, and writes raw observations.
|
||||||
|
- `features.py`
|
||||||
|
- Pure feature engineering logic (residual load, lags, cyclical encoding).
|
||||||
|
- `pipeline.py`
|
||||||
|
- Orchestration layer for end-to-end fetch -> raw persist -> feature build -> feature persist.
|
||||||
|
|
||||||
|
## 2) Runtime execution path (step-by-step)
|
||||||
|
|
||||||
|
When you run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
PYTHONPATH=src python3 scripts/build_feature_store.py --country-code ... --start ... --end ...
|
||||||
|
```
|
||||||
|
|
||||||
|
Execution sequence:
|
||||||
|
|
||||||
|
1. **Argument parsing**
|
||||||
|
- `build_feature_store.py` reads country code/time range/TTL.
|
||||||
|
2. **Credential/connection bootstrap**
|
||||||
|
- checks `ENTSOE_API_KEY`.
|
||||||
|
- calls `get_engine()` from `db.py`.
|
||||||
|
3. **Pipeline orchestration**
|
||||||
|
- `run_feature_pipeline(...)` in `pipeline.py` starts.
|
||||||
|
4. **API service creation**
|
||||||
|
- initializes `EntsoePandasClient`.
|
||||||
|
- creates `EntsoeDataService(client, engine, cache_ttl_hours)`.
|
||||||
|
5. **Decorator wrapping**
|
||||||
|
- in `EntsoeDataService.__post_init__`, API methods are wrapped by `cache_to_db(...)`.
|
||||||
|
6. **Data retrieval**
|
||||||
|
- `fetch_inputs(...)` calls:
|
||||||
|
- `get_day_ahead_prices(...)`
|
||||||
|
- `get_load_forecast(...)`
|
||||||
|
- `get_wind_solar_forecast(...)`
|
||||||
|
- country aliases are normalized to bidding zones before queries (currently `DE -> DE_LU`, `IT -> IT_NORD`).
|
||||||
|
7. **Cache check/compute loop (per call)**
|
||||||
|
- decorator computes hash key from function + args.
|
||||||
|
- if non-expired row exists in `entsoe_api_cache`: returns payload.
|
||||||
|
- else: reads `electricity_market_observations` for requested timestamps.
|
||||||
|
- if timestamps are missing there, only missing hourly ranges are requested from ENTSO-E.
|
||||||
|
- `NoMatchingDataError` from ENTSO-E is converted to an empty hourly frame for that endpoint/range.
|
||||||
|
- normalized responses coalesce duplicate semantic columns (for example multiple wind/solar columns) via first non-null-per-row.
|
||||||
|
- missing rows are upserted into `electricity_market_observations`.
|
||||||
|
- final merged dataset is stored in `entsoe_api_cache` and returned.
|
||||||
|
8. **Raw persistence**
|
||||||
|
- merged inputs are upserted to `electricity_market_observations`.
|
||||||
|
9. **Feature engineering**
|
||||||
|
- `build_feature_frame(...)` computes:
|
||||||
|
- `residual_load = load - wind - solar`
|
||||||
|
- `lagged_price_1..24`
|
||||||
|
- `lagged_residual_load_1..24`
|
||||||
|
- `hour_of_day_sin/cos`, `weekday_sin/cos`, `month_sin/cos`
|
||||||
|
- preserves source missingness as `NaN` (no 0.0 imputation).
|
||||||
|
- drops rows only when `day_ahead_price` / `lagged_price_1..24` are missing (lag warmup requirement).
|
||||||
|
10. **Feature-store persistence**
|
||||||
|
- lags are materialized into PostgreSQL arrays (`DOUBLE PRECISION[]`, length 24).
|
||||||
|
- rows violating NOT NULL core feature constraints are filtered out before upsert.
|
||||||
|
- persistable rows are upserted to `electricity_price_features`.
|
||||||
|
11. **CLI completion**
|
||||||
|
- prints persisted row count.
|
||||||
|
|
||||||
|
## 3) UML diagrams
|
||||||
|
|
||||||
|
## 3.1 Component diagram
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
flowchart LR
|
||||||
|
CLI[scripts/build_feature_store.py] --> PIPE[pipeline.run_feature_pipeline]
|
||||||
|
PIPE --> DBMOD[db.get_engine]
|
||||||
|
PIPE --> SERVICE[EntsoeDataService]
|
||||||
|
SERVICE --> CACHEDEC[cache_to_db decorator]
|
||||||
|
SERVICE --> ENTSOE[EntsoePandasClient]
|
||||||
|
SERVICE --> SECONDARY[electricity_market_observations secondary cache]
|
||||||
|
PIPE --> FEAT[features.build_feature_frame NaN-preserving]
|
||||||
|
FEAT --> PERSIST[pipeline.persist_feature_frame null-filtered]
|
||||||
|
CACHEDEC --> DB[(quant_db.entsoe_api_cache)]
|
||||||
|
SECONDARY --> RAW[(quant_db.electricity_market_observations)]
|
||||||
|
PERSIST --> STORE[(quant_db.electricity_price_features)]
|
||||||
|
```
|
||||||
|
|
||||||
|
## 3.2 Class diagram (logical)
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
classDiagram
|
||||||
|
class EntsoeDataService {
|
||||||
|
+client: EntsoePandasClient
|
||||||
|
+engine: Engine
|
||||||
|
+cache_ttl_hours: Optional[int]
|
||||||
|
+fetch_inputs(country_code, start, end) DataFrame
|
||||||
|
+upsert_raw_data(country_code, frame) None
|
||||||
|
-_get_day_ahead_prices_impl(country_code, start, end) Series
|
||||||
|
-_get_load_forecast_impl(country_code, start, end) Series
|
||||||
|
-_get_wind_solar_forecast_impl(country_code, start, end) DataFrame
|
||||||
|
}
|
||||||
|
|
||||||
|
class CacheDecorator {
|
||||||
|
+cache_to_db(engine, namespace, ttl_hours) decorator
|
||||||
|
-_build_cache_key(function_name, args, kwargs) str
|
||||||
|
}
|
||||||
|
|
||||||
|
class FeatureBuilder {
|
||||||
|
+build_feature_frame(inputs, max_lag=24) DataFrame
|
||||||
|
-_cyclical_encode(values, period, prefix) DataFrame
|
||||||
|
}
|
||||||
|
|
||||||
|
class Pipeline {
|
||||||
|
+run_feature_pipeline(engine, entsoe_api_key, country_code, start, end, cache_ttl_hours) DataFrame
|
||||||
|
+persist_feature_frame(engine, country_code, feature_frame) None
|
||||||
|
}
|
||||||
|
|
||||||
|
Pipeline --> EntsoeDataService : uses
|
||||||
|
Pipeline --> FeatureBuilder : uses
|
||||||
|
EntsoeDataService --> CacheDecorator : wraps methods
|
||||||
|
```
|
||||||
|
|
||||||
|
## 3.3 Sequence diagram (single API method with cache)
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
sequenceDiagram
|
||||||
|
participant Caller as fetch_inputs()
|
||||||
|
participant Decorator as cache_to_db wrapper
|
||||||
|
participant CacheTable as entsoe_api_cache (L1)
|
||||||
|
participant ObsTable as electricity_market_observations (L2)
|
||||||
|
participant API as ENTSO-E API
|
||||||
|
|
||||||
|
Caller->>Decorator: get_day_ahead_prices(country, start, end)
|
||||||
|
Decorator->>CacheTable: SELECT by cache_key and expires_at
|
||||||
|
alt L1 cache hit
|
||||||
|
CacheTable-->>Decorator: payload
|
||||||
|
Decorator-->>Caller: unpickled pandas object
|
||||||
|
else L1 cache miss/expired
|
||||||
|
Decorator->>ObsTable: SELECT existing timestamps
|
||||||
|
alt L2 fully covers range
|
||||||
|
ObsTable-->>Decorator: pandas-compatible rows
|
||||||
|
else L2 has gaps
|
||||||
|
Decorator->>API: query only missing ranges
|
||||||
|
alt API returns data
|
||||||
|
API-->>Decorator: missing rows
|
||||||
|
Decorator->>Decorator: normalize columns + coalesce duplicates
|
||||||
|
Decorator->>ObsTable: UPSERT missing rows
|
||||||
|
else NoMatchingDataError
|
||||||
|
Decorator->>Decorator: synthesize empty hourly frame
|
||||||
|
end
|
||||||
|
end
|
||||||
|
Decorator->>CacheTable: INSERT/UPSERT merged payload
|
||||||
|
Decorator-->>Caller: fresh result
|
||||||
|
end
|
||||||
|
```
|
||||||
|
|
||||||
|
## 3.4 State diagram (cache entry lifecycle)
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
stateDiagram-v2
|
||||||
|
[*] --> L1Missing
|
||||||
|
L1Missing --> L2Check: cache miss/expiry
|
||||||
|
L2Check --> Fresh: observation table fully covers range
|
||||||
|
L2Check --> Partial: observation table has gaps
|
||||||
|
Partial --> Fresh: fetch missing ranges, upsert L2, upsert L1
|
||||||
|
Fresh --> Fresh: reused before expiry
|
||||||
|
Fresh --> Expired: TTL passes for L1 entry
|
||||||
|
Expired --> L2Check: next call
|
||||||
|
Fresh --> Overwritten: Same key, new payload upsert
|
||||||
|
Overwritten --> Fresh
|
||||||
|
```
|
||||||
|
|
||||||
|
## 3.5 ER diagram (database schema)
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
erDiagram
|
||||||
|
entsoe_api_cache {
|
||||||
|
text cache_key PK
|
||||||
|
text namespace
|
||||||
|
text function_name
|
||||||
|
jsonb args_json
|
||||||
|
bytea payload
|
||||||
|
timestamptz created_at
|
||||||
|
timestamptz expires_at
|
||||||
|
}
|
||||||
|
|
||||||
|
electricity_market_observations {
|
||||||
|
text country_code PK
|
||||||
|
timestamptz delivery_start PK
|
||||||
|
float day_ahead_price
|
||||||
|
float load_forecast
|
||||||
|
float wind_forecast
|
||||||
|
float solar_forecast
|
||||||
|
timestamptz ingested_at
|
||||||
|
}
|
||||||
|
|
||||||
|
electricity_price_features {
|
||||||
|
text country_code PK
|
||||||
|
timestamptz delivery_start PK
|
||||||
|
text feature_version PK
|
||||||
|
float day_ahead_price
|
||||||
|
float load_forecast
|
||||||
|
float wind_forecast
|
||||||
|
float solar_forecast
|
||||||
|
float residual_load
|
||||||
|
float[] lagged_price
|
||||||
|
float[] lagged_residual_load
|
||||||
|
float hour_of_day_sin
|
||||||
|
float hour_of_day_cos
|
||||||
|
float weekday_sin
|
||||||
|
float weekday_cos
|
||||||
|
float month_sin
|
||||||
|
float month_cos
|
||||||
|
timestamptz created_at
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 4) How files collaborate
|
||||||
|
|
||||||
|
## 4.1 `db.py` + scripts
|
||||||
|
|
||||||
|
- Scripts never hardcode DB URI; they call `get_engine()`.
|
||||||
|
- `get_engine()` centralizes environment-driven connectivity.
|
||||||
|
|
||||||
|
## 4.2 `cache.py` + `entsoe_api.py`
|
||||||
|
|
||||||
|
- `cache_to_db()` is generic and independent of ENTSO-E specifics.
|
||||||
|
- `EntsoeDataService.__post_init__` binds that generic decorator to each API-fetch method.
|
||||||
|
- Result: all expensive API calls automatically become cache-aware without changing call sites.
|
||||||
|
|
||||||
|
## 4.3 `entsoe_api.py` + `features.py`
|
||||||
|
|
||||||
|
- `entsoe_api.py` guarantees normalized timestamp index and expected source columns.
|
||||||
|
- `features.py` assumes these columns and transforms them to model features only (no DB side effects).
|
||||||
|
|
||||||
|
## 4.4 `features.py` + `pipeline.py`
|
||||||
|
|
||||||
|
- `build_feature_frame()` returns wide DataFrame with `lagged_*_1..24`.
|
||||||
|
- `persist_feature_frame()` converts those to PostgreSQL arrays so table rows stay compact and versioned.
|
||||||
|
|
||||||
|
## 5) Important implementation details
|
||||||
|
|
||||||
|
- **Cache keys are deterministic**
|
||||||
|
- Built from JSON of function name + args + kwargs with stable sorting.
|
||||||
|
- **Cache payload type**
|
||||||
|
- `pickle` stored in `BYTEA` to preserve pandas objects.
|
||||||
|
- **TTL logic**
|
||||||
|
- `expires_at IS NULL` means never expires.
|
||||||
|
- Otherwise must be greater than current UTC time to be considered valid.
|
||||||
|
- **Two-layer cache order**
|
||||||
|
- Layer 1: `entsoe_api_cache` (function-result cache).
|
||||||
|
- Layer 2: `electricity_market_observations` (timestamp-level raw cache).
|
||||||
|
- API calls happen only for Layer-2 gaps.
|
||||||
|
- **Upsert strategy**
|
||||||
|
- Raw and feature tables use `ON CONFLICT ... DO UPDATE` for idempotent reruns.
|
||||||
|
- Raw upsert uses `COALESCE(EXCLUDED.col, existing.col)` to avoid null-overwriting previously stored values during partial refreshes.
|
||||||
|
- Feature upsert operates on a filtered persistable subset where core NOT NULL columns are present.
|
||||||
|
- **Missingness semantics**
|
||||||
|
- Forecast and derived residual columns preserve `NaN` in memory.
|
||||||
|
- No zero-imputation is performed for missing forecast values.
|
||||||
|
- **Bidding-zone normalization**
|
||||||
|
- `resolve_bidding_zone_code(...)` maps common country aliases to ENTSO-E zone codes.
|
||||||
|
- Pipeline persistence uses the resolved code, ensuring DB keys match actual queried zones.
|
||||||
|
- **Timezone handling**
|
||||||
|
- API index is normalized to UTC to avoid DST ambiguity in lag features.
|
||||||
|
- **Feature warmup**
|
||||||
|
- Rows missing `day_ahead_price` or any `lagged_price_1..24` are dropped because lag history is incomplete.
|
||||||
|
|
||||||
|
## 6) Failure modes and expected behavior
|
||||||
|
|
||||||
|
- Missing `ENTSOE_API_KEY` -> CLI raises early runtime error.
|
||||||
|
- Missing required input columns -> feature builder raises `ValueError`.
|
||||||
|
- Duplicate normalized columns from ENTSO-E payloads -> coalesced before reindexing to avoid pandas duplicate-label reindex errors.
|
||||||
|
- ENTSO-E no-data responses for an endpoint/range -> transformed to empty hourly frames and merged safely.
|
||||||
|
- Empty data frame -> raw/feature persistence functions no-op safely.
|
||||||
|
- Repeated identical request -> cache hit (no API roundtrip).
|
||||||
|
- Expired L1 cache row + full L2 coverage -> no API call required.
|
||||||
|
- Expired L1 cache row + partial L2 coverage -> API called only for missing ranges.
|
||||||
|
|
||||||
|
## 7) Data contracts
|
||||||
|
|
||||||
|
### 7.1 In-memory features contract
|
||||||
|
|
||||||
|
Producer: `run_feature_pipeline(...)` return value (`pd.DataFrame`).
|
||||||
|
|
||||||
|
- **Index contract**
|
||||||
|
- hourly UTC `DatetimeIndex`, sorted ascending.
|
||||||
|
- unique timestamps expected after deduplication.
|
||||||
|
- **Column contract**
|
||||||
|
- base: `day_ahead_price`, `load_forecast`, `wind_forecast`, `solar_forecast`
|
||||||
|
- derived: `residual_load`
|
||||||
|
- lag columns: `lagged_price_1..24`, `lagged_residual_load_1..24`
|
||||||
|
- cyclical: `hour_of_day_sin/cos`, `weekday_sin/cos`, `month_sin/cos`
|
||||||
|
- **Nullability contract**
|
||||||
|
- required non-null in returned rows: `day_ahead_price`, `lagged_price_1..24`
|
||||||
|
- nullable: `load_forecast`, `wind_forecast`, `solar_forecast`, `residual_load`, and `lagged_residual_load_*`
|
||||||
|
- rationale: preserve upstream missingness semantics for analysis and QC.
|
||||||
|
|
||||||
|
### 7.2 Feature-store persistence contract
|
||||||
|
|
||||||
|
Consumer: `electricity_price_features` table.
|
||||||
|
|
||||||
|
- **Primary key contract**
|
||||||
|
- (`country_code`, `delivery_start`, `feature_version`)
|
||||||
|
- **Schema constraint contract**
|
||||||
|
- core numeric columns are `NOT NULL`.
|
||||||
|
- lag arrays are `DOUBLE PRECISION[]` and expected length 24.
|
||||||
|
- **Write-time contract**
|
||||||
|
- `persist_feature_frame(...)` filters rows that violate NOT NULL core columns before UPSERT.
|
||||||
|
- retained rows are idempotently upserted via `ON CONFLICT ... DO UPDATE`.
|
||||||
|
|
||||||
|
### 7.3 Raw-observation contract
|
||||||
|
|
||||||
|
Consumer: `electricity_market_observations` table.
|
||||||
|
|
||||||
|
- **Primary key contract**
|
||||||
|
- (`country_code`, `delivery_start`)
|
||||||
|
- **Merge contract**
|
||||||
|
- upsert uses `COALESCE(EXCLUDED.col, existing.col)` to avoid null-overwriting prior known values.
|
||||||
|
- **Coverage contract**
|
||||||
|
- secondary cache guarantees fetched payloads are aligned to expected hourly index for the requested `[start, end)` range.
|
||||||
|
|
||||||
|
## 8) Practical debugging checklist
|
||||||
|
|
||||||
|
1. Run `scripts/init_db.py` and ensure tables exist.
|
||||||
|
2. Run one short-range fetch window (1-2 days) first.
|
||||||
|
3. Verify cache growth:
|
||||||
|
- `SELECT namespace, function_name, COUNT(*) FROM entsoe_api_cache GROUP BY 1,2;`
|
||||||
|
4. Verify raw persistence:
|
||||||
|
- `SELECT COUNT(*) FROM electricity_market_observations WHERE country_code = '...';`
|
||||||
|
5. Verify feature persistence:
|
||||||
|
- check lag array sizes are 24 and row count is lower than raw by about 24.
|
||||||
|
|
||||||
|
## 9) Suggested next developer docs to add
|
||||||
|
|
||||||
|
- Data quality rules (acceptable missingness, clipping policy, anomaly handling).
|
||||||
|
- Training-set contract (target definition, split strategy, leakage constraints).
|
||||||
|
- Backfill/replay policy for reprocessing historical periods.
|
||||||
4
electricity_price_predictor/requirements.txt
Normal file
4
electricity_price_predictor/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
entsoe-py
|
||||||
|
pandas
|
||||||
|
sqlalchemy
|
||||||
|
psycopg2-binary
|
||||||
41
electricity_price_predictor/scripts/build_feature_store.py
Normal file
41
electricity_price_predictor/scripts/build_feature_store.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from electricity_price_predictor.db import get_engine
|
||||||
|
from electricity_price_predictor.pipeline import run_feature_pipeline
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(description="Fetch ENTSO-E inputs and build feature store.")
|
||||||
|
parser.add_argument("--country-code", required=True, help="ENTSO-E bidding zone code, e.g. DE_LU")
|
||||||
|
parser.add_argument("--start", required=True, help="Inclusive start datetime, e.g. 2026-01-01T00:00:00Z")
|
||||||
|
parser.add_argument("--end", required=True, help="Exclusive end datetime, e.g. 2026-02-01T00:00:00Z")
|
||||||
|
parser.add_argument("--cache-ttl-hours", type=int, default=24, help="Decorator cache TTL in hours")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = parse_args()
|
||||||
|
api_key = os.getenv("ENTSOE_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
raise RuntimeError("ENTSOE_API_KEY environment variable is required.")
|
||||||
|
|
||||||
|
engine = get_engine()
|
||||||
|
start = pd.Timestamp(args.start, tz="UTC")
|
||||||
|
end = pd.Timestamp(args.end, tz="UTC")
|
||||||
|
|
||||||
|
features = run_feature_pipeline(
|
||||||
|
engine=engine,
|
||||||
|
entsoe_api_key=api_key,
|
||||||
|
country_code=args.country_code,
|
||||||
|
start=start,
|
||||||
|
end=end,
|
||||||
|
cache_ttl_hours=args.cache_ttl_hours,
|
||||||
|
)
|
||||||
|
print(f"Persisted {len(features)} feature rows for {args.country_code}.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
23
electricity_price_predictor/scripts/init_db.py
Normal file
23
electricity_price_predictor/scripts/init_db.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from electricity_price_predictor.db import get_engine
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
engine = get_engine()
|
||||||
|
schema_path = Path(__file__).resolve().parents[1] / "sql" / "001_electricity_price_schema.sql"
|
||||||
|
sql = schema_path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
with engine.begin() as conn:
|
||||||
|
for statement in sql.split(";"):
|
||||||
|
stmt = statement.strip()
|
||||||
|
if stmt:
|
||||||
|
conn.execute(text(stmt))
|
||||||
|
|
||||||
|
print("Schema initialized for electricity price predictor.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,55 @@
|
|||||||
|
CREATE TABLE IF NOT EXISTS entsoe_api_cache (
|
||||||
|
cache_key TEXT PRIMARY KEY,
|
||||||
|
namespace TEXT NOT NULL,
|
||||||
|
function_name TEXT NOT NULL,
|
||||||
|
args_json JSONB NOT NULL,
|
||||||
|
payload BYTEA NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
expires_at TIMESTAMPTZ
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_entsoe_api_cache_namespace_fn
|
||||||
|
ON entsoe_api_cache(namespace, function_name);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_entsoe_api_cache_expires_at
|
||||||
|
ON entsoe_api_cache(expires_at);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS electricity_market_observations (
|
||||||
|
country_code TEXT NOT NULL,
|
||||||
|
delivery_start TIMESTAMPTZ NOT NULL,
|
||||||
|
day_ahead_price DOUBLE PRECISION,
|
||||||
|
load_forecast DOUBLE PRECISION,
|
||||||
|
wind_forecast DOUBLE PRECISION,
|
||||||
|
solar_forecast DOUBLE PRECISION,
|
||||||
|
ingested_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
PRIMARY KEY (country_code, delivery_start)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_electricity_market_observations_delivery
|
||||||
|
ON electricity_market_observations(delivery_start);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS electricity_price_features (
|
||||||
|
country_code TEXT NOT NULL,
|
||||||
|
delivery_start TIMESTAMPTZ NOT NULL,
|
||||||
|
day_ahead_price DOUBLE PRECISION NOT NULL,
|
||||||
|
load_forecast DOUBLE PRECISION NOT NULL,
|
||||||
|
wind_forecast DOUBLE PRECISION NOT NULL,
|
||||||
|
solar_forecast DOUBLE PRECISION NOT NULL,
|
||||||
|
residual_load DOUBLE PRECISION NOT NULL,
|
||||||
|
lagged_price DOUBLE PRECISION[] NOT NULL,
|
||||||
|
lagged_residual_load DOUBLE PRECISION[] NOT NULL,
|
||||||
|
hour_of_day_sin DOUBLE PRECISION NOT NULL,
|
||||||
|
hour_of_day_cos DOUBLE PRECISION NOT NULL,
|
||||||
|
weekday_sin DOUBLE PRECISION NOT NULL,
|
||||||
|
weekday_cos DOUBLE PRECISION NOT NULL,
|
||||||
|
month_sin DOUBLE PRECISION NOT NULL,
|
||||||
|
month_cos DOUBLE PRECISION NOT NULL,
|
||||||
|
feature_version TEXT NOT NULL DEFAULT 'v1',
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
PRIMARY KEY (country_code, delivery_start, feature_version),
|
||||||
|
CONSTRAINT chk_lagged_price_len CHECK (CARDINALITY(lagged_price) = 24),
|
||||||
|
CONSTRAINT chk_lagged_residual_load_len CHECK (CARDINALITY(lagged_residual_load) = 24)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_electricity_price_features_delivery
|
||||||
|
ON electricity_price_features(delivery_start);
|
||||||
320
electricity_price_predictor/src/data_analysis/analyze_data.ipynb
Normal file
320
electricity_price_predictor/src/data_analysis/analyze_data.ipynb
Normal file
File diff suppressed because one or more lines are too long
@@ -0,0 +1,7 @@
|
|||||||
|
"""Electricity price forecasting data pipeline package."""
|
||||||
|
|
||||||
|
from .db import get_engine
|
||||||
|
from .entsoe_api import EntsoeDataService
|
||||||
|
from .features import build_feature_frame
|
||||||
|
|
||||||
|
__all__ = ["EntsoeDataService", "build_feature_frame", "get_engine"]
|
||||||
@@ -0,0 +1,112 @@
|
|||||||
|
import functools
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import pickle
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
|
||||||
|
|
||||||
|
def _json_fallback_serializer(value: Any) -> str:
|
||||||
|
"""Serializer for values that aren't directly JSON serializable."""
|
||||||
|
if hasattr(value, "isoformat"):
|
||||||
|
return value.isoformat()
|
||||||
|
return repr(value)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_cache_key(function_name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> str:
|
||||||
|
payload = json.dumps(
|
||||||
|
{"function_name": function_name, "args": args, "kwargs": kwargs},
|
||||||
|
sort_keys=True,
|
||||||
|
default=_json_fallback_serializer,
|
||||||
|
)
|
||||||
|
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def cache_to_db(
|
||||||
|
engine: Engine,
|
||||||
|
namespace: str,
|
||||||
|
ttl_hours: Optional[int] = None,
|
||||||
|
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||||
|
"""
|
||||||
|
Cache function output in quant_db.entsoe_api_cache table.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- TTL is optional. If omitted, cached values do not expire.
|
||||||
|
- Cached payload uses pickle so pandas objects can be restored losslessly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
cache_key = _build_cache_key(f"{namespace}.{func.__name__}", args, kwargs)
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
with engine.begin() as conn:
|
||||||
|
result = conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
SELECT payload
|
||||||
|
FROM entsoe_api_cache
|
||||||
|
WHERE cache_key = :cache_key
|
||||||
|
AND (expires_at IS NULL OR expires_at > :now_utc)
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"cache_key": cache_key, "now_utc": now},
|
||||||
|
).fetchone()
|
||||||
|
|
||||||
|
if result:
|
||||||
|
return pickle.loads(result[0])
|
||||||
|
|
||||||
|
data = func(*args, **kwargs)
|
||||||
|
expires_at = None
|
||||||
|
if ttl_hours is not None:
|
||||||
|
expires_at = now + timedelta(hours=ttl_hours)
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO entsoe_api_cache (
|
||||||
|
cache_key,
|
||||||
|
namespace,
|
||||||
|
function_name,
|
||||||
|
args_json,
|
||||||
|
payload,
|
||||||
|
created_at,
|
||||||
|
expires_at
|
||||||
|
) VALUES (
|
||||||
|
:cache_key,
|
||||||
|
:namespace,
|
||||||
|
:function_name,
|
||||||
|
CAST(:args_json AS JSONB),
|
||||||
|
:payload,
|
||||||
|
:created_at,
|
||||||
|
:expires_at
|
||||||
|
)
|
||||||
|
ON CONFLICT (cache_key) DO UPDATE
|
||||||
|
SET payload = EXCLUDED.payload,
|
||||||
|
created_at = EXCLUDED.created_at,
|
||||||
|
expires_at = EXCLUDED.expires_at,
|
||||||
|
args_json = EXCLUDED.args_json
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"cache_key": cache_key,
|
||||||
|
"namespace": namespace,
|
||||||
|
"function_name": func.__name__,
|
||||||
|
"args_json": json.dumps(
|
||||||
|
{"args": args, "kwargs": kwargs},
|
||||||
|
default=_json_fallback_serializer,
|
||||||
|
),
|
||||||
|
"payload": pickle.dumps(data),
|
||||||
|
"created_at": now,
|
||||||
|
"expires_at": expires_at,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
|
||||||
|
|
||||||
|
def get_database_url() -> str:
|
||||||
|
"""Build database URL from env or fallback defaults."""
|
||||||
|
explicit_url = os.getenv("QUANT_DB_URL")
|
||||||
|
if explicit_url:
|
||||||
|
return explicit_url
|
||||||
|
|
||||||
|
host = os.getenv("QUANT_DB_HOST", "localhost")
|
||||||
|
port = os.getenv("QUANT_DB_PORT", "5432")
|
||||||
|
database = os.getenv("QUANT_DB_NAME", "quant_db")
|
||||||
|
user = os.getenv("QUANT_DB_USER", "quant_user")
|
||||||
|
password = os.getenv("QUANT_DB_PASSWORD", "strong_password")
|
||||||
|
|
||||||
|
return f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_engine(echo: bool = False) -> Engine:
|
||||||
|
"""Create SQLAlchemy engine for quant_db."""
|
||||||
|
return create_engine(get_database_url(), future=True, echo=echo)
|
||||||
@@ -0,0 +1,393 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from entsoe import EntsoePandasClient
|
||||||
|
from entsoe.exceptions import NoMatchingDataError
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
|
||||||
|
from .cache import cache_to_db
|
||||||
|
|
||||||
|
OBSERVATION_COLUMNS = {
|
||||||
|
"day_ahead_price",
|
||||||
|
"load_forecast",
|
||||||
|
"wind_forecast",
|
||||||
|
"solar_forecast",
|
||||||
|
}
|
||||||
|
|
||||||
|
BIDDING_ZONE_ALIASES = {
|
||||||
|
# ENTSO-E often expects bidding-zone EIC aliases instead of plain country codes.
|
||||||
|
"DE": "DE_LU",
|
||||||
|
"IT": "IT_NORD",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _as_utc_index(series_or_df: pd.Series | pd.DataFrame) -> pd.Series | pd.DataFrame:
|
||||||
|
if series_or_df.index.tz is None:
|
||||||
|
series_or_df.index = series_or_df.index.tz_localize("UTC")
|
||||||
|
else:
|
||||||
|
series_or_df.index = series_or_df.index.tz_convert("UTC")
|
||||||
|
return series_or_df.sort_index()
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_float(value) -> Optional[float]:
|
||||||
|
if pd.isna(value):
|
||||||
|
return None
|
||||||
|
return float(value)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_bidding_zone_code(country_code: str) -> str:
|
||||||
|
code = str(country_code).strip().upper()
|
||||||
|
return BIDDING_ZONE_ALIASES.get(code, code)
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_single_column_frame(
|
||||||
|
data: pd.Series | pd.DataFrame,
|
||||||
|
target_column: str,
|
||||||
|
preferred_tokens: tuple[str, ...] = (),
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Normalize ENTSO-E responses that can be either a Series or DataFrame.
|
||||||
|
"""
|
||||||
|
if isinstance(data, pd.Series):
|
||||||
|
series = _as_utc_index(data)
|
||||||
|
series.name = target_column
|
||||||
|
return series.to_frame()
|
||||||
|
|
||||||
|
frame = _as_utc_index(data.copy())
|
||||||
|
if target_column in frame.columns:
|
||||||
|
return frame[[target_column]]
|
||||||
|
|
||||||
|
if len(frame.columns) == 1:
|
||||||
|
return frame.rename(columns={frame.columns[0]: target_column})[[target_column]]
|
||||||
|
|
||||||
|
def _best_column(candidates: list) -> Optional[str]:
|
||||||
|
if not candidates:
|
||||||
|
return None
|
||||||
|
# Prefer the candidate with most available points.
|
||||||
|
return max(candidates, key=lambda col: int(frame[col].notna().sum()))
|
||||||
|
|
||||||
|
lowered = {col: str(col).lower() for col in frame.columns}
|
||||||
|
preferred_candidates = []
|
||||||
|
for token in preferred_tokens:
|
||||||
|
preferred_candidates.extend([col for col, col_lc in lowered.items() if token in col_lc])
|
||||||
|
|
||||||
|
preferred_col = _best_column(preferred_candidates)
|
||||||
|
if preferred_col is not None:
|
||||||
|
return frame.rename(columns={preferred_col: target_column})[[target_column]]
|
||||||
|
|
||||||
|
any_col = _best_column(list(frame.columns))
|
||||||
|
if any_col is not None:
|
||||||
|
return frame.rename(columns={any_col: target_column})[[target_column]]
|
||||||
|
|
||||||
|
first_col = frame.columns[0]
|
||||||
|
return frame.rename(columns={first_col: target_column})[[target_column]]
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_utc_bounds(start: pd.Timestamp, end: pd.Timestamp) -> tuple[pd.Timestamp, pd.Timestamp]:
|
||||||
|
start_ts = pd.Timestamp(start)
|
||||||
|
end_ts = pd.Timestamp(end)
|
||||||
|
if start_ts.tz is None:
|
||||||
|
start_ts = start_ts.tz_localize("UTC")
|
||||||
|
else:
|
||||||
|
start_ts = start_ts.tz_convert("UTC")
|
||||||
|
if end_ts.tz is None:
|
||||||
|
end_ts = end_ts.tz_localize("UTC")
|
||||||
|
else:
|
||||||
|
end_ts = end_ts.tz_convert("UTC")
|
||||||
|
return start_ts, end_ts
|
||||||
|
|
||||||
|
|
||||||
|
def _empty_hourly_frame(
|
||||||
|
columns: list[str], start: pd.Timestamp, end: pd.Timestamp
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
start_ts, end_ts = _normalize_utc_bounds(start, end)
|
||||||
|
idx = pd.date_range(start=start_ts, end=end_ts, freq="h", inclusive="left")
|
||||||
|
return pd.DataFrame(index=idx, columns=columns)
|
||||||
|
|
||||||
|
|
||||||
|
def _coalesce_duplicate_columns(frame: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Collapse duplicate column labels by taking the first non-null per row.
|
||||||
|
"""
|
||||||
|
if frame.columns.is_unique:
|
||||||
|
return frame
|
||||||
|
|
||||||
|
merged: dict[str, pd.Series] = {}
|
||||||
|
for col in frame.columns.unique():
|
||||||
|
same_name = frame.loc[:, frame.columns == col]
|
||||||
|
if isinstance(same_name, pd.Series):
|
||||||
|
merged[str(col)] = same_name
|
||||||
|
else:
|
||||||
|
merged[str(col)] = same_name.bfill(axis=1).iloc[:, 0]
|
||||||
|
return pd.DataFrame(merged, index=frame.index)
|
||||||
|
|
||||||
|
|
||||||
|
def _missing_ranges(missing_index: pd.DatetimeIndex) -> list[tuple[pd.Timestamp, pd.Timestamp]]:
|
||||||
|
if missing_index.empty:
|
||||||
|
return []
|
||||||
|
missing_index = missing_index.sort_values().unique()
|
||||||
|
ranges: list[tuple[pd.Timestamp, pd.Timestamp]] = []
|
||||||
|
current_start = missing_index[0]
|
||||||
|
prev = missing_index[0]
|
||||||
|
step = pd.Timedelta(hours=1)
|
||||||
|
|
||||||
|
for ts in missing_index[1:]:
|
||||||
|
if ts - prev != step:
|
||||||
|
ranges.append((current_start, prev + step))
|
||||||
|
current_start = ts
|
||||||
|
prev = ts
|
||||||
|
ranges.append((current_start, prev + step))
|
||||||
|
return ranges
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EntsoeDataService:
|
||||||
|
client: EntsoePandasClient
|
||||||
|
engine: Engine
|
||||||
|
cache_ttl_hours: Optional[int] = 24
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
self.get_day_ahead_prices = cache_to_db(
|
||||||
|
self.engine, "entsoe", ttl_hours=self.cache_ttl_hours
|
||||||
|
)(self._get_day_ahead_prices_impl)
|
||||||
|
self.get_load_forecast = cache_to_db(
|
||||||
|
self.engine, "entsoe", ttl_hours=self.cache_ttl_hours
|
||||||
|
)(self._get_load_forecast_impl)
|
||||||
|
self.get_wind_solar_forecast = cache_to_db(
|
||||||
|
self.engine, "entsoe", ttl_hours=self.cache_ttl_hours
|
||||||
|
)(self._get_wind_solar_forecast_impl)
|
||||||
|
|
||||||
|
def resolve_country_code(self, country_code: str) -> str:
|
||||||
|
return resolve_bidding_zone_code(country_code)
|
||||||
|
|
||||||
|
def _get_day_ahead_prices_impl(
|
||||||
|
self, country_code: str, start: pd.Timestamp, end: pd.Timestamp
|
||||||
|
) -> pd.Series:
|
||||||
|
df = self._fetch_inputs_with_secondary_cache(
|
||||||
|
country_code=country_code,
|
||||||
|
start=start,
|
||||||
|
end=end,
|
||||||
|
required_columns=["day_ahead_price"],
|
||||||
|
api_fetcher=self._query_day_ahead_prices,
|
||||||
|
)
|
||||||
|
return df["day_ahead_price"]
|
||||||
|
|
||||||
|
def _get_load_forecast_impl(
|
||||||
|
self, country_code: str, start: pd.Timestamp, end: pd.Timestamp
|
||||||
|
) -> pd.Series:
|
||||||
|
df = self._fetch_inputs_with_secondary_cache(
|
||||||
|
country_code=country_code,
|
||||||
|
start=start,
|
||||||
|
end=end,
|
||||||
|
required_columns=["load_forecast"],
|
||||||
|
api_fetcher=self._query_load_forecast,
|
||||||
|
)
|
||||||
|
return df["load_forecast"]
|
||||||
|
|
||||||
|
def _get_wind_solar_forecast_impl(
|
||||||
|
self, country_code: str, start: pd.Timestamp, end: pd.Timestamp
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
return self._fetch_inputs_with_secondary_cache(
|
||||||
|
country_code=country_code,
|
||||||
|
start=start,
|
||||||
|
end=end,
|
||||||
|
required_columns=["wind_forecast", "solar_forecast"],
|
||||||
|
api_fetcher=self._query_wind_solar_forecast,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _query_day_ahead_prices(
|
||||||
|
self, country_code: str, start: pd.Timestamp, end: pd.Timestamp
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
try:
|
||||||
|
raw = self.client.query_day_ahead_prices(country_code, start=start, end=end)
|
||||||
|
except NoMatchingDataError:
|
||||||
|
return _empty_hourly_frame(["day_ahead_price"], start, end)
|
||||||
|
return _coerce_single_column_frame(
|
||||||
|
raw,
|
||||||
|
target_column="day_ahead_price",
|
||||||
|
preferred_tokens=("price", "ahead"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _query_load_forecast(
|
||||||
|
self, country_code: str, start: pd.Timestamp, end: pd.Timestamp
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
try:
|
||||||
|
raw = self.client.query_load_forecast(country_code, start=start, end=end)
|
||||||
|
except NoMatchingDataError:
|
||||||
|
return _empty_hourly_frame(["load_forecast"], start, end)
|
||||||
|
return _coerce_single_column_frame(
|
||||||
|
raw,
|
||||||
|
target_column="load_forecast",
|
||||||
|
preferred_tokens=("load",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _query_wind_solar_forecast(
|
||||||
|
self, country_code: str, start: pd.Timestamp, end: pd.Timestamp
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
try:
|
||||||
|
df = self.client.query_wind_and_solar_forecast(country_code, start=start, end=end)
|
||||||
|
except NoMatchingDataError:
|
||||||
|
return _empty_hourly_frame(["wind_forecast", "solar_forecast"], start, end)
|
||||||
|
df = _as_utc_index(df)
|
||||||
|
if isinstance(df, pd.Series):
|
||||||
|
df = df.to_frame()
|
||||||
|
|
||||||
|
renamed = {}
|
||||||
|
for column in df.columns:
|
||||||
|
lc = str(column).lower()
|
||||||
|
if "wind" in lc:
|
||||||
|
renamed[column] = "wind_forecast"
|
||||||
|
elif "solar" in lc:
|
||||||
|
renamed[column] = "solar_forecast"
|
||||||
|
df = df.rename(columns=renamed)
|
||||||
|
df = _coalesce_duplicate_columns(df)
|
||||||
|
|
||||||
|
if "wind_forecast" not in df.columns:
|
||||||
|
df["wind_forecast"] = None
|
||||||
|
if "solar_forecast" not in df.columns:
|
||||||
|
df["solar_forecast"] = None
|
||||||
|
|
||||||
|
return df[["wind_forecast", "solar_forecast"]]
|
||||||
|
|
||||||
|
def _load_observations(
|
||||||
|
self,
|
||||||
|
country_code: str,
|
||||||
|
start: pd.Timestamp,
|
||||||
|
end: pd.Timestamp,
|
||||||
|
columns: list[str],
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
invalid = [col for col in columns if col not in OBSERVATION_COLUMNS]
|
||||||
|
if invalid:
|
||||||
|
raise ValueError(f"Unsupported observation columns requested: {invalid}")
|
||||||
|
|
||||||
|
sql_columns = ", ".join(columns)
|
||||||
|
query = text(
|
||||||
|
f"""
|
||||||
|
SELECT delivery_start, {sql_columns}
|
||||||
|
FROM electricity_market_observations
|
||||||
|
WHERE country_code = :country_code
|
||||||
|
AND delivery_start >= :start_ts
|
||||||
|
AND delivery_start < :end_ts
|
||||||
|
ORDER BY delivery_start
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.engine.begin() as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
query,
|
||||||
|
{"country_code": country_code, "start_ts": start, "end_ts": end},
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
return pd.DataFrame(columns=columns, index=pd.DatetimeIndex([], tz="UTC"))
|
||||||
|
|
||||||
|
db_frame = pd.DataFrame(rows, columns=["delivery_start", *columns])
|
||||||
|
db_frame["delivery_start"] = pd.to_datetime(db_frame["delivery_start"], utc=True)
|
||||||
|
db_frame = db_frame.set_index("delivery_start").sort_index()
|
||||||
|
return db_frame
|
||||||
|
|
||||||
|
def _fetch_inputs_with_secondary_cache(
|
||||||
|
self,
|
||||||
|
country_code: str,
|
||||||
|
start: pd.Timestamp,
|
||||||
|
end: pd.Timestamp,
|
||||||
|
required_columns: list[str],
|
||||||
|
api_fetcher,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
start_ts, end_ts = _normalize_utc_bounds(start, end)
|
||||||
|
expected_index = pd.date_range(start=start_ts, end=end_ts, freq="h", inclusive="left")
|
||||||
|
from_observations = self._load_observations(country_code, start_ts, end_ts, required_columns)
|
||||||
|
|
||||||
|
if from_observations.empty:
|
||||||
|
complete_index = pd.DatetimeIndex([], tz="UTC")
|
||||||
|
else:
|
||||||
|
complete_mask = from_observations[required_columns].notna().all(axis=1)
|
||||||
|
complete_index = from_observations.index[complete_mask]
|
||||||
|
|
||||||
|
missing_index = expected_index.difference(complete_index)
|
||||||
|
|
||||||
|
fetched_parts: list[pd.DataFrame] = []
|
||||||
|
for missing_start, missing_end in _missing_ranges(missing_index):
|
||||||
|
fetched = api_fetcher(country_code, missing_start, missing_end)
|
||||||
|
fetched = fetched.reindex(columns=required_columns)
|
||||||
|
fetched_parts.append(fetched)
|
||||||
|
|
||||||
|
if fetched_parts:
|
||||||
|
fetched_frame = pd.concat(fetched_parts).sort_index()
|
||||||
|
self.upsert_raw_data(country_code=country_code, frame=fetched_frame)
|
||||||
|
else:
|
||||||
|
fetched_frame = pd.DataFrame(columns=required_columns, index=pd.DatetimeIndex([], tz="UTC"))
|
||||||
|
|
||||||
|
combined = pd.concat([from_observations, fetched_frame]).sort_index()
|
||||||
|
combined = combined[~combined.index.duplicated(keep="last")]
|
||||||
|
combined = combined.reindex(expected_index)
|
||||||
|
return combined[required_columns]
|
||||||
|
|
||||||
|
def fetch_inputs(
|
||||||
|
self, country_code: str, start: pd.Timestamp, end: pd.Timestamp
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
resolved_country_code = self.resolve_country_code(country_code)
|
||||||
|
price = _coerce_single_column_frame(
|
||||||
|
self.get_day_ahead_prices(resolved_country_code, start, end),
|
||||||
|
target_column="day_ahead_price",
|
||||||
|
preferred_tokens=("price", "ahead"),
|
||||||
|
)
|
||||||
|
load = _coerce_single_column_frame(
|
||||||
|
self.get_load_forecast(resolved_country_code, start, end),
|
||||||
|
target_column="load_forecast",
|
||||||
|
preferred_tokens=("load",),
|
||||||
|
)
|
||||||
|
wind_solar = self.get_wind_solar_forecast(resolved_country_code, start, end)
|
||||||
|
df = price.join(load, how="outer").join(wind_solar, how="outer").sort_index()
|
||||||
|
return df
|
||||||
|
|
||||||
|
def upsert_raw_data(self, country_code: str, frame: pd.DataFrame) -> None:
|
||||||
|
rows = []
|
||||||
|
for ts, row in frame.iterrows():
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"country_code": country_code,
|
||||||
|
"delivery_start": ts.to_pydatetime(),
|
||||||
|
"day_ahead_price": _safe_float(row.get("day_ahead_price")),
|
||||||
|
"load_forecast": _safe_float(row.get("load_forecast")),
|
||||||
|
"wind_forecast": _safe_float(row.get("wind_forecast")),
|
||||||
|
"solar_forecast": _safe_float(row.get("solar_forecast")),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
return
|
||||||
|
|
||||||
|
with self.engine.begin() as conn:
|
||||||
|
conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO electricity_market_observations (
|
||||||
|
country_code,
|
||||||
|
delivery_start,
|
||||||
|
day_ahead_price,
|
||||||
|
load_forecast,
|
||||||
|
wind_forecast,
|
||||||
|
solar_forecast
|
||||||
|
) VALUES (
|
||||||
|
:country_code,
|
||||||
|
:delivery_start,
|
||||||
|
:day_ahead_price,
|
||||||
|
:load_forecast,
|
||||||
|
:wind_forecast,
|
||||||
|
:solar_forecast
|
||||||
|
)
|
||||||
|
ON CONFLICT (country_code, delivery_start) DO UPDATE
|
||||||
|
SET day_ahead_price = COALESCE(EXCLUDED.day_ahead_price, electricity_market_observations.day_ahead_price),
|
||||||
|
load_forecast = COALESCE(EXCLUDED.load_forecast, electricity_market_observations.load_forecast),
|
||||||
|
wind_forecast = COALESCE(EXCLUDED.wind_forecast, electricity_market_observations.wind_forecast),
|
||||||
|
solar_forecast = COALESCE(EXCLUDED.solar_forecast, electricity_market_observations.solar_forecast),
|
||||||
|
ingested_at = NOW()
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
rows,
|
||||||
|
)
|
||||||
@@ -0,0 +1,58 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
def _cyclical_encode(values: pd.Series, period: int, prefix: str) -> pd.DataFrame:
|
||||||
|
angle = 2.0 * math.pi * values / period
|
||||||
|
return pd.DataFrame(
|
||||||
|
{f"{prefix}_sin": angle.apply(math.sin), f"{prefix}_cos": angle.apply(math.cos)},
|
||||||
|
index=values.index,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_feature_frame(inputs: pd.DataFrame, max_lag: int = 24) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Build feature set for electricity price forecasting.
|
||||||
|
|
||||||
|
Included:
|
||||||
|
- day_ahead_price, load_forecast, wind_forecast, solar_forecast
|
||||||
|
- residual_load
|
||||||
|
- lagged_price(t-1..t-24)
|
||||||
|
- lagged_residual_load(t-1..t-24)
|
||||||
|
- hour/week_day/month cyclical encodings
|
||||||
|
"""
|
||||||
|
df = inputs.copy().sort_index()
|
||||||
|
required = {"day_ahead_price", "load_forecast", "wind_forecast", "solar_forecast"}
|
||||||
|
missing = required.difference(df.columns)
|
||||||
|
if missing:
|
||||||
|
raise ValueError(f"Missing required input columns: {sorted(missing)}")
|
||||||
|
|
||||||
|
# Preserve source missingness semantics for downstream users.
|
||||||
|
# We keep NaNs instead of imputing to 0.0 so missing data is explicit.
|
||||||
|
df["residual_load"] = df["load_forecast"] - df["wind_forecast"] - df["solar_forecast"]
|
||||||
|
|
||||||
|
for lag in range(1, max_lag + 1):
|
||||||
|
df[f"lagged_price_{lag}"] = df["day_ahead_price"].shift(lag)
|
||||||
|
df[f"lagged_residual_load_{lag}"] = df["residual_load"].shift(lag)
|
||||||
|
|
||||||
|
time_index = df.index
|
||||||
|
if time_index.tz is None:
|
||||||
|
time_index = time_index.tz_localize("UTC")
|
||||||
|
else:
|
||||||
|
time_index = time_index.tz_convert("UTC")
|
||||||
|
|
||||||
|
cyclical = [
|
||||||
|
_cyclical_encode(pd.Series(time_index.hour, index=df.index), 24, "hour_of_day"),
|
||||||
|
_cyclical_encode(pd.Series(time_index.weekday, index=df.index), 7, "weekday"),
|
||||||
|
_cyclical_encode(pd.Series(time_index.month - 1, index=df.index), 12, "month"),
|
||||||
|
]
|
||||||
|
for cyc in cyclical:
|
||||||
|
df = df.join(cyc)
|
||||||
|
|
||||||
|
# Only enforce warmup/history constraints for price lags.
|
||||||
|
# Other feature columns can remain NaN when source data is missing.
|
||||||
|
required_for_row = ["day_ahead_price", *[f"lagged_price_{lag}" for lag in range(1, max_lag + 1)]]
|
||||||
|
return df.dropna(subset=required_for_row).sort_index()
|
||||||
@@ -0,0 +1,146 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from entsoe import EntsoePandasClient
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
|
||||||
|
from .entsoe_api import EntsoeDataService
|
||||||
|
from .features import build_feature_frame
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_float(value) -> float | None:
|
||||||
|
if pd.isna(value):
|
||||||
|
return None
|
||||||
|
return float(value)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_inputs_have_signal(inputs: pd.DataFrame, country_code: str) -> None:
|
||||||
|
required = ["day_ahead_price", "load_forecast", "wind_forecast", "solar_forecast"]
|
||||||
|
non_null_counts = {col: int(inputs[col].notna().sum()) for col in required if col in inputs.columns}
|
||||||
|
if non_null_counts and all(count == 0 for count in non_null_counts.values()):
|
||||||
|
raise ValueError(
|
||||||
|
"No ENTSO-E data available for "
|
||||||
|
f"'{country_code}' in the requested time window. "
|
||||||
|
"Try another bidding zone/time range, and ensure the API key has access."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def persist_feature_frame(engine: Engine, country_code: str, feature_frame: pd.DataFrame) -> None:
|
||||||
|
# Feature-store schema expects core numeric fields to be non-null.
|
||||||
|
# Keep NaNs in the returned DataFrame, but skip incomplete rows on DB write.
|
||||||
|
persistable = feature_frame.dropna(
|
||||||
|
subset=["day_ahead_price", "load_forecast", "wind_forecast", "solar_forecast", "residual_load"]
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = []
|
||||||
|
for ts, row in persistable.iterrows():
|
||||||
|
lag_price = [float(row[f"lagged_price_{lag}"]) for lag in range(1, 25)]
|
||||||
|
lag_residual = [float(row[f"lagged_residual_load_{lag}"]) for lag in range(1, 25)]
|
||||||
|
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"country_code": country_code,
|
||||||
|
"delivery_start": ts.to_pydatetime(),
|
||||||
|
"day_ahead_price": float(row["day_ahead_price"]),
|
||||||
|
"load_forecast": _safe_float(row["load_forecast"]),
|
||||||
|
"wind_forecast": _safe_float(row["wind_forecast"]),
|
||||||
|
"solar_forecast": _safe_float(row["solar_forecast"]),
|
||||||
|
"residual_load": _safe_float(row["residual_load"]),
|
||||||
|
"lagged_price": lag_price,
|
||||||
|
"lagged_residual_load": lag_residual,
|
||||||
|
"hour_of_day_sin": float(row["hour_of_day_sin"]),
|
||||||
|
"hour_of_day_cos": float(row["hour_of_day_cos"]),
|
||||||
|
"weekday_sin": float(row["weekday_sin"]),
|
||||||
|
"weekday_cos": float(row["weekday_cos"]),
|
||||||
|
"month_sin": float(row["month_sin"]),
|
||||||
|
"month_cos": float(row["month_cos"]),
|
||||||
|
"feature_version": "v1",
|
||||||
|
"created_at": datetime.now(timezone.utc),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
return
|
||||||
|
|
||||||
|
with engine.begin() as conn:
|
||||||
|
conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO electricity_price_features (
|
||||||
|
country_code,
|
||||||
|
delivery_start,
|
||||||
|
day_ahead_price,
|
||||||
|
load_forecast,
|
||||||
|
wind_forecast,
|
||||||
|
solar_forecast,
|
||||||
|
residual_load,
|
||||||
|
lagged_price,
|
||||||
|
lagged_residual_load,
|
||||||
|
hour_of_day_sin,
|
||||||
|
hour_of_day_cos,
|
||||||
|
weekday_sin,
|
||||||
|
weekday_cos,
|
||||||
|
month_sin,
|
||||||
|
month_cos,
|
||||||
|
feature_version,
|
||||||
|
created_at
|
||||||
|
) VALUES (
|
||||||
|
:country_code,
|
||||||
|
:delivery_start,
|
||||||
|
:day_ahead_price,
|
||||||
|
:load_forecast,
|
||||||
|
:wind_forecast,
|
||||||
|
:solar_forecast,
|
||||||
|
:residual_load,
|
||||||
|
:lagged_price,
|
||||||
|
:lagged_residual_load,
|
||||||
|
:hour_of_day_sin,
|
||||||
|
:hour_of_day_cos,
|
||||||
|
:weekday_sin,
|
||||||
|
:weekday_cos,
|
||||||
|
:month_sin,
|
||||||
|
:month_cos,
|
||||||
|
:feature_version,
|
||||||
|
:created_at
|
||||||
|
)
|
||||||
|
ON CONFLICT (country_code, delivery_start, feature_version) DO UPDATE
|
||||||
|
SET day_ahead_price = EXCLUDED.day_ahead_price,
|
||||||
|
load_forecast = EXCLUDED.load_forecast,
|
||||||
|
wind_forecast = EXCLUDED.wind_forecast,
|
||||||
|
solar_forecast = EXCLUDED.solar_forecast,
|
||||||
|
residual_load = EXCLUDED.residual_load,
|
||||||
|
lagged_price = EXCLUDED.lagged_price,
|
||||||
|
lagged_residual_load = EXCLUDED.lagged_residual_load,
|
||||||
|
hour_of_day_sin = EXCLUDED.hour_of_day_sin,
|
||||||
|
hour_of_day_cos = EXCLUDED.hour_of_day_cos,
|
||||||
|
weekday_sin = EXCLUDED.weekday_sin,
|
||||||
|
weekday_cos = EXCLUDED.weekday_cos,
|
||||||
|
month_sin = EXCLUDED.month_sin,
|
||||||
|
month_cos = EXCLUDED.month_cos,
|
||||||
|
created_at = EXCLUDED.created_at
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
rows,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_feature_pipeline(
|
||||||
|
engine: Engine,
|
||||||
|
entsoe_api_key: str,
|
||||||
|
country_code: str,
|
||||||
|
start: pd.Timestamp,
|
||||||
|
end: pd.Timestamp,
|
||||||
|
cache_ttl_hours: int = 24,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
client = EntsoePandasClient(api_key=entsoe_api_key)
|
||||||
|
service = EntsoeDataService(client=client, engine=engine, cache_ttl_hours=cache_ttl_hours)
|
||||||
|
resolved_country_code = service.resolve_country_code(country_code)
|
||||||
|
inputs = service.fetch_inputs(country_code=resolved_country_code, start=start, end=end)
|
||||||
|
_validate_inputs_have_signal(inputs, resolved_country_code)
|
||||||
|
service.upsert_raw_data(country_code=resolved_country_code, frame=inputs)
|
||||||
|
features = build_feature_frame(inputs)
|
||||||
|
persist_feature_frame(engine, country_code=resolved_country_code, feature_frame=features)
|
||||||
|
return features
|
||||||
24
pyproject.toml
Normal file
24
pyproject.toml
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["scikit-build-core>=0.5", "pybind11"]
|
||||||
|
build-backend = "scikit_build_core.build"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "qengine"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Quant engine with C++ backend"
|
||||||
|
authors = [{name = "David"}]
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
dependencies = [
|
||||||
|
"numpy",
|
||||||
|
"pandas",
|
||||||
|
"sqlalchemy",
|
||||||
|
"psycopg2-binary",
|
||||||
|
"yfinance",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.scikit-build]
|
||||||
|
# Keep separate from a local `cmake -B build` tree (different generators: Ninja vs Makefiles).
|
||||||
|
build-dir = "skbuild-build"
|
||||||
|
cmake.version = ">=3.16"
|
||||||
|
cmake.build-type = "Release"
|
||||||
|
cmake.define.BUILD_TESTING = "OFF"
|
||||||
5
qengine/__init__.py
Normal file
5
qengine/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""Qengine: quant pricing backend (native extension in qengine.qengine)."""
|
||||||
|
|
||||||
|
from .qengine import bs_price
|
||||||
|
|
||||||
|
__all__ = ["bs_price"]
|
||||||
108
scripts/setup_postgres.py
Normal file
108
scripts/setup_postgres.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Idempotent PostgreSQL bootstrap script for the option_pricing project.
|
||||||
|
|
||||||
|
What it does:
|
||||||
|
1) Creates the project role if it does not exist.
|
||||||
|
2) Creates the project database if it does not exist.
|
||||||
|
3) Grants ownership/privileges.
|
||||||
|
4) Applies src/data/sql/schema.sql to the project database.
|
||||||
|
|
||||||
|
Configuration comes from environment variables (see .env.example).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import psycopg2
|
||||||
|
from psycopg2 import sql
|
||||||
|
|
||||||
|
|
||||||
|
ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
SCHEMA_PATH = ROOT / "src" / "data" / "sql" / "schema.sql"
|
||||||
|
|
||||||
|
|
||||||
|
def _env(name: str, default: str | None = None) -> str:
|
||||||
|
value = os.getenv(name, default)
|
||||||
|
if value is None:
|
||||||
|
raise RuntimeError(f"Missing required environment variable: {name}")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def admin_connect(dbname: str):
|
||||||
|
return psycopg2.connect(
|
||||||
|
dbname=dbname,
|
||||||
|
user=_env("POSTGRES_ADMIN_USER", "postgres"),
|
||||||
|
password=_env("POSTGRES_ADMIN_PASSWORD", "postgres"),
|
||||||
|
host=_env("POSTGRES_ADMIN_HOST", "localhost"),
|
||||||
|
port=_env("POSTGRES_ADMIN_PORT", "5432"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_role_and_database() -> None:
|
||||||
|
db_user = _env("DB_USER", "quant_user")
|
||||||
|
db_password = _env("DB_PASSWORD", "")
|
||||||
|
db_name = _env("DB_NAME", "options_db")
|
||||||
|
|
||||||
|
admin_db = _env("POSTGRES_ADMIN_DB", "postgres")
|
||||||
|
with admin_connect(admin_db) as conn:
|
||||||
|
conn.autocommit = True
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
cur.execute("SELECT 1 FROM pg_roles WHERE rolname = %s", (db_user,))
|
||||||
|
role_exists = cur.fetchone() is not None
|
||||||
|
if not role_exists:
|
||||||
|
cur.execute(
|
||||||
|
sql.SQL("CREATE ROLE {} WITH LOGIN PASSWORD %s").format(
|
||||||
|
sql.Identifier(db_user)
|
||||||
|
),
|
||||||
|
(db_password,),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cur.execute(
|
||||||
|
sql.SQL("ALTER ROLE {} WITH LOGIN PASSWORD %s").format(
|
||||||
|
sql.Identifier(db_user)
|
||||||
|
),
|
||||||
|
(db_password,),
|
||||||
|
)
|
||||||
|
|
||||||
|
cur.execute("SELECT 1 FROM pg_database WHERE datname = %s", (db_name,))
|
||||||
|
db_exists = cur.fetchone() is not None
|
||||||
|
if not db_exists:
|
||||||
|
cur.execute(
|
||||||
|
sql.SQL("CREATE DATABASE {} OWNER {}").format(
|
||||||
|
sql.Identifier(db_name),
|
||||||
|
sql.Identifier(db_user),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cur.execute(
|
||||||
|
sql.SQL("ALTER DATABASE {} OWNER TO {}").format(
|
||||||
|
sql.Identifier(db_name),
|
||||||
|
sql.Identifier(db_user),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_schema() -> None:
|
||||||
|
if not SCHEMA_PATH.exists():
|
||||||
|
raise FileNotFoundError(f"Schema file not found: {SCHEMA_PATH}")
|
||||||
|
|
||||||
|
schema_sql = SCHEMA_PATH.read_text(encoding="utf-8")
|
||||||
|
with admin_connect(_env("DB_NAME", "options_db")) as conn:
|
||||||
|
conn.autocommit = True
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
cur.execute(schema_sql)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
print("Ensuring role/database exist...")
|
||||||
|
ensure_role_and_database()
|
||||||
|
print("Applying schema...")
|
||||||
|
apply_schema()
|
||||||
|
print("Database setup complete.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
65
scripts/test_qengine_bindings.py
Normal file
65
scripts/test_qengine_bindings.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Smoke test: use an installed `qengine` package (pip install .) or a dev build (cmake -> qengine/*.so)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Running `python scripts/this.py` puts `scripts/` on sys.path, not the repo root
|
||||||
|
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
if str(_REPO_ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(_REPO_ROOT))
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
try:
|
||||||
|
import qengine
|
||||||
|
except ImportError as e:
|
||||||
|
print(
|
||||||
|
f"Import failed ({e}). Install the package (pip install .) or build with CMake so "
|
||||||
|
"qengine/qengine.*.so exists next to qengine/__init__.py.",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
call = qengine.bs_price(100.0, 100.0, 1.0, 0.05, 0.2, True)
|
||||||
|
put = qengine.bs_price(100.0, 100.0, 1.0, 0.05, 0.2, False)
|
||||||
|
batch_list = qengine.bs_price(
|
||||||
|
[100.0, 100.0],
|
||||||
|
[100.0, 110.0],
|
||||||
|
[1.0, 1.0],
|
||||||
|
[0.05, 0.05],
|
||||||
|
[0.2, 0.2],
|
||||||
|
[True, False],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert math.isfinite(call) and math.isfinite(put)
|
||||||
|
assert len(batch_list) == 2 and all(math.isfinite(x) for x in batch_list)
|
||||||
|
|
||||||
|
print("qengine.bs_price (call):", call)
|
||||||
|
print("qengine.bs_price (put):", put)
|
||||||
|
print("qengine.bs_price (list batch):", list(batch_list))
|
||||||
|
|
||||||
|
try:
|
||||||
|
import numpy as np
|
||||||
|
except ImportError:
|
||||||
|
print("ok: overloads callable (NumPy not installed; skipped ndarray batch test).")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
s = np.array([100.0, 100.0], dtype=np.float64)
|
||||||
|
k = np.array([100.0, 110.0], dtype=np.float64)
|
||||||
|
t = np.array([1.0, 1.0], dtype=np.float64)
|
||||||
|
r = np.array([0.05, 0.05], dtype=np.float64)
|
||||||
|
sig = np.array([0.2, 0.2], dtype=np.float64)
|
||||||
|
opt = np.array([True, False], dtype=bool)
|
||||||
|
batch_np = qengine.bs_price(s, k, t, r, sig, opt)
|
||||||
|
assert len(batch_np) == 2 and all(math.isfinite(float(x)) for x in batch_np)
|
||||||
|
print("qengine.bs_price (ndarray batch):", [float(x) for x in batch_np])
|
||||||
|
print("ok: overloads callable.")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
add_library(qengine
|
|
||||||
Instrument.cpp
|
|
||||||
Instrument.hpp
|
|
||||||
Payoff.cpp
|
|
||||||
Payoff.hpp
|
|
||||||
Option.cpp
|
|
||||||
Option.hpp
|
|
||||||
PricingEngine.cpp
|
|
||||||
PricingEngine.hpp
|
|
||||||
MonteCarloEngine.cpp
|
|
||||||
MonteCarloEngine.hpp
|
|
||||||
StochasticProcess.cpp
|
|
||||||
StochasticProcess.hpp
|
|
||||||
Exercise.cpp
|
|
||||||
Exercise.hpp
|
|
||||||
MarketData.cpp
|
|
||||||
MarketData.hpp
|
|
||||||
YieldCurve.cpp
|
|
||||||
YieldCurve.hpp
|
|
||||||
VolatilitySurface.cpp
|
|
||||||
VolatilitySurface.hpp
|
|
||||||
RandomGenerator.cpp
|
|
||||||
RandomGenerator.hpp
|
|
||||||
Statistics.cpp
|
|
||||||
Statistics.hpp
|
|
||||||
BlackScholesProcess.cpp
|
|
||||||
BlackScholesProcess.hpp
|
|
||||||
|
|
||||||
|
|
||||||
)
|
|
||||||
|
|
||||||
target_include_directories(qengine PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
|
|
||||||
target_link_libraries(qengine Eigen3::Eigen)
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
//
|
|
||||||
// Created by David Doebel on 05.03.2026.
|
|
||||||
//
|
|
||||||
|
|
||||||
#include "Exercise.hpp"
|
|
||||||
0
src/ImpliedVolatility/__init__.py
Normal file
0
src/ImpliedVolatility/__init__.py
Normal file
49
src/ImpliedVolatility/compute_vls.py
Normal file
49
src/ImpliedVolatility/compute_vls.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import numpy as np
|
||||||
|
import qengine
|
||||||
|
from scipy.optimize import brentq
|
||||||
|
|
||||||
|
|
||||||
|
def implied_vol(price, S, K, T, r, call):
|
||||||
|
"""
|
||||||
|
Implied vol for each row. Arguments may be scalars or 1-D arrays-like (same length).
|
||||||
|
"""
|
||||||
|
price = np.asarray(price, dtype=np.float64)
|
||||||
|
S = np.asarray(S, dtype=np.float64)
|
||||||
|
K = np.asarray(K, dtype=np.float64)
|
||||||
|
T = np.asarray(T, dtype=np.float64)
|
||||||
|
call = np.asarray(call, dtype=bool)
|
||||||
|
r = float(r)
|
||||||
|
|
||||||
|
scalar_in = price.ndim == 0
|
||||||
|
if scalar_in:
|
||||||
|
price = np.atleast_1d(price)
|
||||||
|
S = np.atleast_1d(S)
|
||||||
|
K = np.atleast_1d(K)
|
||||||
|
T = np.atleast_1d(T)
|
||||||
|
call = np.atleast_1d(call)
|
||||||
|
|
||||||
|
n = price.shape[0]
|
||||||
|
if (S.shape[0] != n or K.shape[0] != n or T.shape[0] != n or call.shape[0] != n):
|
||||||
|
raise ValueError(
|
||||||
|
f"implied_vol: length mismatch price={n}, S={S.shape[0]}, K={K.shape[0]}, "
|
||||||
|
f"T={T.shape[0]}, call={call.shape[0]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
out = np.full(n, np.nan, dtype=np.float64)
|
||||||
|
for i in range(n):
|
||||||
|
p, s, k, t, c = float(price[i]), float(S[i]), float(K[i]), float(T[i]), bool(call[i])
|
||||||
|
if not np.isfinite(p) or not np.isfinite(s) or not np.isfinite(k) or not np.isfinite(t):
|
||||||
|
continue
|
||||||
|
if s <= 0 or k <= 0 or t <= 0:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
def f(sig: float) -> float:
|
||||||
|
return qengine.bs_price(s, k, t, r, sig, c) - p
|
||||||
|
|
||||||
|
out[i] = brentq(f, 1e-6, 5.0)
|
||||||
|
except (ValueError, RuntimeError):
|
||||||
|
out[i] = np.nan
|
||||||
|
|
||||||
|
if scalar_in:
|
||||||
|
return float(out[0])
|
||||||
|
return out
|
||||||
0
src/ImpliedVolatility/setup.py
Normal file
0
src/ImpliedVolatility/setup.py
Normal file
1023
src/ImpliedVolatility/svi.py
Normal file
1023
src/ImpliedVolatility/svi.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,5 +0,0 @@
|
|||||||
//
|
|
||||||
// Created by David Doebel on 05.03.2026.
|
|
||||||
//
|
|
||||||
|
|
||||||
#include "PricingEngine.hpp"
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
//
|
|
||||||
// Created by David Doebel on 05.03.2026.
|
|
||||||
//
|
|
||||||
|
|
||||||
#include "StochasticProcess.hpp"
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
//
|
|
||||||
// Created by David Doebel on 06.03.2026.
|
|
||||||
//
|
|
||||||
|
|
||||||
#include "VolatilitySurface.hpp"
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
//
|
|
||||||
// Created by David Doebel on 06.03.2026.
|
|
||||||
//
|
|
||||||
|
|
||||||
#ifndef QUANTENGINE_VOLATILITYSURFACE_HPP
|
|
||||||
#define QUANTENGINE_VOLATILITYSURFACE_HPP
|
|
||||||
|
|
||||||
|
|
||||||
class VolatilitySurface {
|
|
||||||
public:
|
|
||||||
virtual ~VolatilitySurface() = default;
|
|
||||||
virtual double sigma(double K, double T) const = 0;
|
|
||||||
private:
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
#endif //QUANTENGINE_VOLATILITYSURFACE_HPP
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
//
|
|
||||||
// Created by David Doebel on 06.03.2026.
|
|
||||||
//
|
|
||||||
|
|
||||||
#include "YieldCurve.hpp"
|
|
||||||
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
0
src/data/__init__.py
Normal file
0
src/data/__init__.py
Normal file
@@ -1,14 +0,0 @@
|
|||||||
DB_CONFIG = {
|
|
||||||
"host": "localhost",
|
|
||||||
"port": 5432,
|
|
||||||
"database": "options_db",
|
|
||||||
"user": "quant_user",
|
|
||||||
"password": "strong_password",
|
|
||||||
}
|
|
||||||
|
|
||||||
PIPELINE_CONFIG = {
|
|
||||||
"symbols": [
|
|
||||||
"SPY"
|
|
||||||
# Example: "SPY"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,13 +1,15 @@
|
|||||||
import psycopg2
|
import pandas as pd
|
||||||
|
|
||||||
conn = psycopg2.connect(
|
from option_pricing.src.data.ingestion.db_connect import db_engine
|
||||||
dbname="options_db",
|
|
||||||
user="quant_user",
|
|
||||||
password="strong_password",
|
|
||||||
host="144.91.73.49",
|
|
||||||
port="5432"
|
|
||||||
)
|
|
||||||
|
|
||||||
cursor = conn.cursor()
|
|
||||||
cursor.execute("SELECT * FROM underlyings;")
|
def fetch_underlyings() -> pd.DataFrame:
|
||||||
print(cursor.fetchall())
|
"""
|
||||||
|
Fetch all entries from the underlyings table using configured DB credentials.
|
||||||
|
"""
|
||||||
|
engine = db_engine()
|
||||||
|
return pd.read_sql("SELECT * FROM underlyings;", engine)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print(fetch_underlyings())
|
||||||
0
src/data/ingestion/__init__.py
Normal file
0
src/data/ingestion/__init__.py
Normal file
3
src/data/ingestion/config/__init__.py
Normal file
3
src/data/ingestion/config/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .settings import DB_CONFIG, PIPELINE_CONFIG
|
||||||
|
|
||||||
|
__all__ = ["DB_CONFIG", "PIPELINE_CONFIG"]
|
||||||
31
src/data/ingestion/config/settings.py
Normal file
31
src/data/ingestion/config/settings.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def _get_env_int(name: str, default: int) -> int:
|
||||||
|
raw = os.getenv(name)
|
||||||
|
if raw is None:
|
||||||
|
return default
|
||||||
|
try:
|
||||||
|
return int(raw)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise ValueError(f"Environment variable {name} must be an integer, got '{raw}'") from exc
|
||||||
|
|
||||||
|
|
||||||
|
def _get_env_list(name: str, default: list[str]) -> list[str]:
|
||||||
|
raw = os.getenv(name)
|
||||||
|
if not raw:
|
||||||
|
return default
|
||||||
|
return [x.strip() for x in raw.split(",") if x.strip()]
|
||||||
|
|
||||||
|
|
||||||
|
DB_CONFIG = {
|
||||||
|
"host": os.getenv("DB_HOST", "localhost"),
|
||||||
|
"port": _get_env_int("DB_PORT", 5432),
|
||||||
|
"database": os.getenv("DB_NAME", "options_db"),
|
||||||
|
"user": os.getenv("DB_USER", "quant_user"),
|
||||||
|
"password": os.getenv("DB_PASSWORD", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
PIPELINE_CONFIG = {
|
||||||
|
"symbols": _get_env_list("PIPELINE_SYMBOLS", ["SPY"]),
|
||||||
|
}
|
||||||
13
src/data/ingestion/db_connect.py
Normal file
13
src/data/ingestion/db_connect.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
from sqlalchemy import create_engine
|
||||||
|
from option_pricing.src.data.ingestion.config.settings import DB_CONFIG
|
||||||
|
|
||||||
|
def build_db_url() -> str:
|
||||||
|
return (
|
||||||
|
f"postgresql+psycopg2://{DB_CONFIG['user']}:{DB_CONFIG['password']}"
|
||||||
|
f"@{DB_CONFIG['host']}:{DB_CONFIG['port']}/{DB_CONFIG['database']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def db_engine():
|
||||||
|
db_url = build_db_url()
|
||||||
|
engine = create_engine(db_url, future=True)
|
||||||
|
return engine
|
||||||
4
src/data/ingestion/fred_data_ingestion.py
Normal file
4
src/data/ingestion/fred_data_ingestion.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from fredapi import Fred
|
||||||
|
fred = Fred(api_key='471be0178bfc20ce10bb93e3fcceee3b')
|
||||||
|
data = fred.get_series_latest_release('DTB3')
|
||||||
|
print(data.tail())
|
||||||
100
src/data/ingestion/ingest_ubs_comparison.py
Normal file
100
src/data/ingestion/ingest_ubs_comparison.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
from datetime import datetime, timedelta
|
||||||
|
import pandas as pd
|
||||||
|
import yfinance as yf
|
||||||
|
|
||||||
|
from db_connect import db_engine
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.dialects.postgresql import insert
|
||||||
|
from sqlalchemy import MetaData, Table
|
||||||
|
|
||||||
|
# --- CONFIG ---
|
||||||
|
TICKERS = ["UBS", "^GSPC"]
|
||||||
|
DAYS_BACK = 31 # ~3 weeks
|
||||||
|
TABLE_NAME = "prices"
|
||||||
|
|
||||||
|
def fetch_data(tickers, start_date, end_date):
|
||||||
|
data = yf.download(
|
||||||
|
tickers,
|
||||||
|
start=start_date,
|
||||||
|
end=end_date,
|
||||||
|
group_by="ticker",
|
||||||
|
auto_adjust=True,
|
||||||
|
progress=False
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def get_asset_map(engine):
|
||||||
|
query = "SELECT id, ticker FROM assets"
|
||||||
|
df = pd.read_sql(query, engine)
|
||||||
|
return dict(zip(df["ticker"], df["id"]))
|
||||||
|
|
||||||
|
|
||||||
|
def transform_data(raw_data):
|
||||||
|
frames = []
|
||||||
|
|
||||||
|
for ticker in raw_data.columns.levels[0]:
|
||||||
|
df = raw_data[ticker].copy()
|
||||||
|
df["ticker"] = ticker
|
||||||
|
df = df.reset_index()
|
||||||
|
|
||||||
|
# Keep only what we need
|
||||||
|
df = df[["Date", "ticker", "Close", "Volume"]]
|
||||||
|
|
||||||
|
df.rename(columns={
|
||||||
|
"Date": "date",
|
||||||
|
"Close": "close",
|
||||||
|
"Volume": "volume"
|
||||||
|
}, inplace=True)
|
||||||
|
|
||||||
|
# Compute daily returns
|
||||||
|
df["return"] = df["close"].pct_change()
|
||||||
|
|
||||||
|
frames.append(df)
|
||||||
|
return pd.concat(frames, ignore_index=True)
|
||||||
|
|
||||||
|
|
||||||
|
def load_to_postgres(df, engine):
|
||||||
|
asset_map = get_asset_map(engine)
|
||||||
|
df["asset_id"] = df["ticker"].map(asset_map)
|
||||||
|
df = df.drop(columns=["ticker"])
|
||||||
|
|
||||||
|
metadata = MetaData()
|
||||||
|
prices = Table(TABLE_NAME, metadata, autoload_with=engine)
|
||||||
|
|
||||||
|
with engine.begin() as conn:
|
||||||
|
for _, row in df.iterrows():
|
||||||
|
stmt = insert(prices).values({
|
||||||
|
"asset_id": row["asset_id"],
|
||||||
|
"date": row["date"],
|
||||||
|
"close": row["close"],
|
||||||
|
"volume": row["volume"],
|
||||||
|
"return": row["return"]
|
||||||
|
})
|
||||||
|
|
||||||
|
stmt = stmt.on_conflict_do_update(
|
||||||
|
index_elements=["asset_id", "date"],
|
||||||
|
set_={
|
||||||
|
"close": stmt.excluded.close,
|
||||||
|
"volume": stmt.excluded.volume,
|
||||||
|
"return": stmt.excluded["return"], # important change
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
conn.execute(stmt)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
end_date = datetime.utcnow()
|
||||||
|
start_date = end_date - timedelta(days=DAYS_BACK)
|
||||||
|
|
||||||
|
raw = fetch_data(TICKERS, start_date, end_date)
|
||||||
|
df = transform_data(raw)
|
||||||
|
|
||||||
|
engine = db_engine()
|
||||||
|
load_to_postgres(df, engine)
|
||||||
|
|
||||||
|
print("Ingestion complete.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from decimal import Decimal, InvalidOperation
|
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import yfinance as yf
|
import yfinance as yf
|
||||||
from sqlalchemy import create_engine, text
|
from sqlalchemy import text
|
||||||
|
|
||||||
from config.settings import DB_CONFIG, PIPELINE_CONFIG
|
from option_pricing.src.data.ingestion.config import DB_CONFIG, PIPELINE_CONFIG
|
||||||
|
from db_connect import db_engine
|
||||||
|
|
||||||
|
|
||||||
def build_db_url() -> str:
|
def build_db_url() -> str:
|
||||||
@@ -113,15 +113,15 @@ def get_or_create_contract(
|
|||||||
return result[0]
|
return result[0]
|
||||||
|
|
||||||
|
|
||||||
def insert_underlying_price(conn, underlying_id: int, price_timestamp: datetime, price: float):
|
def insert_underlying_price(conn, underlying_id: int, timestamp: datetime, price: float):
|
||||||
query = text("""
|
query = text("""
|
||||||
INSERT INTO underlying_prices (underlying_id, price_timestamp, price)
|
INSERT INTO underlying_prices (underlying_id, timestamp, price)
|
||||||
VALUES (:underlying_id, :price_timestamp, :price)
|
VALUES (:underlying_id, :timestamp, :price)
|
||||||
ON CONFLICT (underlying_id, price_timestamp) DO NOTHING
|
ON CONFLICT (underlying_id, timestamp) DO NOTHING
|
||||||
""")
|
""")
|
||||||
conn.execute(query, {
|
conn.execute(query, {
|
||||||
"underlying_id": underlying_id,
|
"underlying_id": underlying_id,
|
||||||
"price_timestamp": price_timestamp,
|
"timestamp": timestamp,
|
||||||
"price": price,
|
"price": price,
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -129,7 +129,7 @@ def insert_underlying_price(conn, underlying_id: int, price_timestamp: datetime,
|
|||||||
def insert_option_quote(
|
def insert_option_quote(
|
||||||
conn,
|
conn,
|
||||||
contract_id: int,
|
contract_id: int,
|
||||||
quote_timestamp: datetime,
|
timestamp: datetime,
|
||||||
bid,
|
bid,
|
||||||
ask,
|
ask,
|
||||||
mid,
|
mid,
|
||||||
@@ -140,19 +140,19 @@ def insert_option_quote(
|
|||||||
):
|
):
|
||||||
query = text("""
|
query = text("""
|
||||||
INSERT INTO option_quotes (
|
INSERT INTO option_quotes (
|
||||||
contract_id, quote_timestamp, bid, ask, mid,
|
contract_id, timestamp, bid, ask, mid,
|
||||||
last_price, implied_vol, volume, open_interest
|
last_price, implied_vol, volume, open_interest
|
||||||
)
|
)
|
||||||
VALUES (
|
VALUES (
|
||||||
:contract_id, :quote_timestamp, :bid, :ask, :mid,
|
:contract_id, :timestamp, :bid, :ask, :mid,
|
||||||
:last_price, :implied_vol, :volume, :open_interest
|
:last_price, :implied_vol, :volume, :open_interest
|
||||||
)
|
)
|
||||||
ON CONFLICT (contract_id, quote_timestamp) DO NOTHING
|
ON CONFLICT (contract_id, timestamp) DO NOTHING
|
||||||
""")
|
""")
|
||||||
|
|
||||||
conn.execute(query, {
|
conn.execute(query, {
|
||||||
"contract_id": contract_id,
|
"contract_id": contract_id,
|
||||||
"quote_timestamp": quote_timestamp,
|
"timestamp": timestamp,
|
||||||
"bid": bid,
|
"bid": bid,
|
||||||
"ask": ask,
|
"ask": ask,
|
||||||
"mid": mid,
|
"mid": mid,
|
||||||
@@ -163,7 +163,7 @@ def insert_option_quote(
|
|||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def process_option_dataframe(conn, df: pd.DataFrame, underlying_id: int, option_type: str, symbol: str, expiration_date, quote_timestamp: datetime):
|
def process_option_dataframe(conn, df: pd.DataFrame, underlying_id: int, option_type: str, symbol: str, expiration_date, timestamp: datetime):
|
||||||
style = infer_option_style(symbol)
|
style = infer_option_style(symbol)
|
||||||
|
|
||||||
for _, row in df.iterrows():
|
for _, row in df.iterrows():
|
||||||
@@ -194,7 +194,7 @@ def process_option_dataframe(conn, df: pd.DataFrame, underlying_id: int, option_
|
|||||||
insert_option_quote(
|
insert_option_quote(
|
||||||
conn=conn,
|
conn=conn,
|
||||||
contract_id=contract_id,
|
contract_id=contract_id,
|
||||||
quote_timestamp=quote_timestamp,
|
timestamp=timestamp,
|
||||||
bid=bid,
|
bid=bid,
|
||||||
ask=ask,
|
ask=ask,
|
||||||
mid=mid,
|
mid=mid,
|
||||||
@@ -215,7 +215,7 @@ def ingest_symbol(symbol: str, engine):
|
|||||||
print(f"No options found for {symbol}")
|
print(f"No options found for {symbol}")
|
||||||
return
|
return
|
||||||
|
|
||||||
quote_timestamp = datetime.now(timezone.utc)
|
timestamp = datetime.now(timezone.utc)
|
||||||
|
|
||||||
# Try to get spot price
|
# Try to get spot price
|
||||||
info = {}
|
info = {}
|
||||||
@@ -235,7 +235,7 @@ def ingest_symbol(symbol: str, engine):
|
|||||||
insert_underlying_price(
|
insert_underlying_price(
|
||||||
conn=conn,
|
conn=conn,
|
||||||
underlying_id=underlying_id,
|
underlying_id=underlying_id,
|
||||||
price_timestamp=quote_timestamp,
|
timestamp=timestamp,
|
||||||
price=float(spot_price),
|
price=float(spot_price),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -252,7 +252,7 @@ def ingest_symbol(symbol: str, engine):
|
|||||||
option_type="call",
|
option_type="call",
|
||||||
symbol=symbol,
|
symbol=symbol,
|
||||||
expiration_date=expiration_date,
|
expiration_date=expiration_date,
|
||||||
quote_timestamp=quote_timestamp,
|
timestamp=timestamp,
|
||||||
)
|
)
|
||||||
|
|
||||||
process_option_dataframe(
|
process_option_dataframe(
|
||||||
@@ -262,15 +262,14 @@ def ingest_symbol(symbol: str, engine):
|
|||||||
option_type="put",
|
option_type="put",
|
||||||
symbol=symbol,
|
symbol=symbol,
|
||||||
expiration_date=expiration_date,
|
expiration_date=expiration_date,
|
||||||
quote_timestamp=quote_timestamp,
|
timestamp=timestamp,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Finished ingestion for {symbol}.")
|
print(f"Finished ingestion for {symbol}.")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
db_url = build_db_url()
|
engine = db_engine()
|
||||||
engine = create_engine(db_url, future=True)
|
|
||||||
|
|
||||||
for symbol in PIPELINE_CONFIG["symbols"]:
|
for symbol in PIPELINE_CONFIG["symbols"]:
|
||||||
ingest_symbol(symbol, engine)
|
ingest_symbol(symbol, engine)
|
||||||
|
|||||||
436
src/data/load_data.py
Normal file
436
src/data/load_data.py
Normal file
@@ -0,0 +1,436 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
from option_pricing.src.data.ingestion.db_connect import db_engine
|
||||||
|
from option_pricing.src.ImpliedVolatility.compute_vls import implied_vol
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_quote_timestamp(df: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
if "timestamp" not in df.columns and "quote_timestamp" in df.columns:
|
||||||
|
return df.rename(columns={"quote_timestamp": "timestamp"})
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_price_timestamp(df: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
if "timestamp" not in df.columns and "price_timestamp" in df.columns:
|
||||||
|
return df.rename(columns={"price_timestamp": "timestamp"})
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def load_data():
|
||||||
|
engine = db_engine()
|
||||||
|
underlyings = pd.read_sql("SELECT * FROM underlyings;", engine)
|
||||||
|
underlying_prices = _normalize_price_timestamp(
|
||||||
|
pd.read_sql("SELECT * FROM underlying_prices;", engine)
|
||||||
|
)
|
||||||
|
option_quotes = _normalize_quote_timestamp(pd.read_sql("SELECT * FROM option_quotes;", engine))
|
||||||
|
option_contracts = pd.read_sql("SELECT * FROM option_contracts;", engine)
|
||||||
|
return underlyings, underlying_prices, option_quotes, option_contracts
|
||||||
|
|
||||||
|
|
||||||
|
def clean_data(data: pd.DataFrame):
|
||||||
|
data.dropna(inplace=True)
|
||||||
|
data = data[data["volume"] > 0]
|
||||||
|
data = data[data["open_interest"] > 10]
|
||||||
|
data["spread"] = data["ask"] - data["bid"]
|
||||||
|
#data = data[data["spread"] / data["mid"] < 1]
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def merge_quotes_contracts(option_quotes: pd.DataFrame, option_contracts: pd.DataFrame):
|
||||||
|
if "timestamp" not in option_quotes.columns:
|
||||||
|
raise KeyError("option_quotes needs a quote time column ('timestamp' or 'quote_timestamp')")
|
||||||
|
|
||||||
|
option_quotes = option_quotes.groupby(["contract_id", "timestamp"], as_index=False).agg(
|
||||||
|
{
|
||||||
|
"bid": "mean",
|
||||||
|
"ask": "mean",
|
||||||
|
"mid": "mean",
|
||||||
|
"last_price": "mean",
|
||||||
|
"implied_vol": "mean",
|
||||||
|
"volume": "sum",
|
||||||
|
"open_interest": "sum",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
option_quotes = option_quotes.merge(
|
||||||
|
option_contracts, left_on="contract_id", right_on="id", how="left"
|
||||||
|
)
|
||||||
|
option_quotes["timestamp"] = pd.to_datetime(option_quotes["timestamp"])
|
||||||
|
option_quotes["expiration_date"] = pd.to_datetime(option_quotes["expiration_date"])
|
||||||
|
option_quotes["T"] = (
|
||||||
|
option_quotes["expiration_date"] - option_quotes["timestamp"]
|
||||||
|
).dt.total_seconds() / (365 * 24 * 3600)
|
||||||
|
return option_quotes
|
||||||
|
|
||||||
|
|
||||||
|
def compute_iv(option_quotes_contracts, underlying_prices):
|
||||||
|
df = option_quotes_contracts.copy()
|
||||||
|
up = _normalize_price_timestamp(underlying_prices.copy())
|
||||||
|
|
||||||
|
up["timestamp"] = pd.to_datetime(up["timestamp"])
|
||||||
|
up = up.sort_values("timestamp").drop_duplicates(
|
||||||
|
["underlying_id", "timestamp"], keep="last"
|
||||||
|
)
|
||||||
|
|
||||||
|
mask = df["T"] > 0
|
||||||
|
if not mask.any():
|
||||||
|
df["iv"] = np.nan
|
||||||
|
return df
|
||||||
|
|
||||||
|
sub = df.loc[mask].copy()
|
||||||
|
sub["_idx"] = sub.index
|
||||||
|
|
||||||
|
merged = sub.merge(
|
||||||
|
up[["underlying_id", "timestamp", "price"]],
|
||||||
|
on=["underlying_id", "timestamp"],
|
||||||
|
how="left",
|
||||||
|
validate="many_to_one",
|
||||||
|
)
|
||||||
|
|
||||||
|
# assign back using explicit index
|
||||||
|
df["spot"] = np.nan
|
||||||
|
df.loc[merged["_idx"], "spot"] = merged["price"].to_numpy()
|
||||||
|
|
||||||
|
price = merged["mid"].to_numpy(dtype=np.float64)
|
||||||
|
S = merged["price"].to_numpy(dtype=np.float64)
|
||||||
|
K = merged["strike"].to_numpy(dtype=np.float64)
|
||||||
|
T = merged["T"].to_numpy(dtype=np.float64)
|
||||||
|
call = (merged["option_type"] == "call").to_numpy()
|
||||||
|
|
||||||
|
|
||||||
|
df["iv"] = np.nan
|
||||||
|
df.loc[sub.index, "iv"] = implied_vol(price, S, K, T, 0.05, call)
|
||||||
|
return df
|
||||||
|
|
||||||
|
def fit_ivsimle(option_quotes_contracts):
|
||||||
|
from scipy.interpolate import UnivariateSpline
|
||||||
|
sort = option_quotes_contracts.sort_values("log_moneyness").dropna()
|
||||||
|
x = sort["log_moneyness"]
|
||||||
|
y = sort["iv"]
|
||||||
|
y_yahoo = sort["implied_vol"]
|
||||||
|
print(x,y,y_yahoo)
|
||||||
|
f = UnivariateSpline(x, y, s=None)
|
||||||
|
f_yahoo = UnivariateSpline(x, y_yahoo, s=None)
|
||||||
|
# plot the smile
|
||||||
|
x_lin = np.linspace(x.min(), x.max(), 200)
|
||||||
|
plt.plot(x_lin, f(x_lin), label="iv smile")
|
||||||
|
plt.plot(x_lin, f_yahoo(x_lin), label="yahoo iv smile")
|
||||||
|
plt.legend()
|
||||||
|
plt.savefig("iv_smile_fit.pdf")
|
||||||
|
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
|
def calibrate_svi_surface(option_quotes_contracts: pd.DataFrame, r: float = 0.05, **kwargs):
|
||||||
|
"""
|
||||||
|
Fit SVI per expiry on ``iv`` from :func:`compute_iv` and plot diagnostics.
|
||||||
|
|
||||||
|
See :func:`option_pricing.src.ImpliedVolatility.svi.calibrate_from_option_frame`.
|
||||||
|
"""
|
||||||
|
from option_pricing.src.ImpliedVolatility.svi import calibrate_from_option_frame
|
||||||
|
|
||||||
|
return calibrate_from_option_frame(option_quotes_contracts, r=r, **kwargs)
|
||||||
|
|
||||||
|
def clean_before_svi(option_quotes_contracts: pd.DataFrame):
|
||||||
|
option_quotes_contracts = option_quotes_contracts[option_quotes_contracts["T"] > 0.05]
|
||||||
|
return option_quotes_contracts
|
||||||
|
|
||||||
|
|
||||||
|
def plot_smoothed_svi_surface(prep: pd.DataFrame, params: pd.DataFrame, r: float = 0.05):
|
||||||
|
"""
|
||||||
|
Plot independent slice fits after maturity smoothing.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
- svi_smoothed_surface.pdf
|
||||||
|
- svi_calendar_violation_heatmap.pdf
|
||||||
|
"""
|
||||||
|
from option_pricing.src.ImpliedVolatility.svi import (
|
||||||
|
calendar_violation_matrix,
|
||||||
|
smooth_svi_parameters,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build smoothed maturity-parameter curves from calibrated slice parameters
|
||||||
|
curves = smooth_svi_parameters(
|
||||||
|
params,
|
||||||
|
T_col="T_mean",
|
||||||
|
smooth_factor_a=1e-4,
|
||||||
|
smooth_factor_m=1e-4,
|
||||||
|
smooth_factor_others=0.0,
|
||||||
|
min_T=0.05,
|
||||||
|
weight_col="n_points",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Overlay market points and smoothed model by maturity
|
||||||
|
plot_df = prep.copy()
|
||||||
|
if "T" not in plot_df.columns or "total_var" not in plot_df.columns:
|
||||||
|
raise KeyError("prep must include columns 'T' and 'total_var'")
|
||||||
|
|
||||||
|
T_grid = np.sort(params.loc[params["success"], "T_mean"].to_numpy(dtype=np.float64))
|
||||||
|
if T_grid.size < 2:
|
||||||
|
return
|
||||||
|
k_grid = np.linspace(
|
||||||
|
float(plot_df["log_moneyness"].quantile(0.02)),
|
||||||
|
float(plot_df["log_moneyness"].quantile(0.98)),
|
||||||
|
180,
|
||||||
|
)
|
||||||
|
|
||||||
|
plt.figure(figsize=(11, 7))
|
||||||
|
cmap = plt.colormaps["viridis"]
|
||||||
|
nT = max(len(T_grid), 1)
|
||||||
|
for i, Ti in enumerate(T_grid):
|
||||||
|
color = cmap(i / max(nT - 1, 1)) if nT > 1 else cmap(0.5)
|
||||||
|
near = np.isclose(plot_df["T"].to_numpy(dtype=np.float64), Ti, rtol=0.03, atol=2e-3)
|
||||||
|
sub = plot_df.loc[near]
|
||||||
|
if sub.empty:
|
||||||
|
continue
|
||||||
|
# market IV points
|
||||||
|
iv_mkt = np.sqrt(
|
||||||
|
np.maximum(sub["total_var"].to_numpy(dtype=np.float64), 0.0)
|
||||||
|
/ np.maximum(Ti, 1e-12)
|
||||||
|
)
|
||||||
|
plt.scatter(
|
||||||
|
sub["log_moneyness"].to_numpy(dtype=np.float64),
|
||||||
|
iv_mkt,
|
||||||
|
s=10,
|
||||||
|
alpha=0.35,
|
||||||
|
color=color,
|
||||||
|
)
|
||||||
|
# smoothed curve IV
|
||||||
|
w_model = curves.total_var(k_grid, np.array([Ti], dtype=np.float64))[0]
|
||||||
|
iv_model = np.sqrt(np.maximum(w_model, 0.0) / np.maximum(Ti, 1e-12))
|
||||||
|
plt.plot(k_grid, iv_model, color=color, lw=2, label=f"T={Ti:.3f}")
|
||||||
|
|
||||||
|
plt.xlabel("log moneyness log(K/F)")
|
||||||
|
plt.ylabel("implied vol")
|
||||||
|
plt.title("SVI surface: market points vs smoothed maturity curves")
|
||||||
|
plt.grid(alpha=0.3)
|
||||||
|
plt.legend(fontsize=8, ncol=2)
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig("svi_smoothed_surface.pdf", bbox_inches="tight")
|
||||||
|
plt.clf()
|
||||||
|
|
||||||
|
# Calendar diagnostics from smoothed surface
|
||||||
|
cal_diff = calendar_violation_matrix(curves, T_grid, k_grid)
|
||||||
|
# diff shape: (len(T_grid)-1, len(k_grid)) where negative is violation
|
||||||
|
plt.figure(figsize=(11, 4))
|
||||||
|
im = plt.imshow(
|
||||||
|
cal_diff,
|
||||||
|
aspect="auto",
|
||||||
|
origin="lower",
|
||||||
|
cmap="coolwarm",
|
||||||
|
vmin=-0.02,
|
||||||
|
vmax=0.02,
|
||||||
|
extent=[k_grid.min(), k_grid.max(), 0, cal_diff.shape[0]],
|
||||||
|
)
|
||||||
|
plt.colorbar(im, label="w(T_{j+1},k)-w(T_j,k)")
|
||||||
|
plt.xlabel("log moneyness")
|
||||||
|
plt.ylabel("maturity step j")
|
||||||
|
plt.title("Calendar diagnostic heatmap (negative = violation)")
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig("svi_calendar_violation_heatmap.pdf", bbox_inches="tight")
|
||||||
|
plt.clf()
|
||||||
|
|
||||||
|
|
||||||
|
def _fit_slice_with_svi_py_model(
|
||||||
|
model: object,
|
||||||
|
model_name: str,
|
||||||
|
k: np.ndarray,
|
||||||
|
w: np.ndarray,
|
||||||
|
T: float,
|
||||||
|
*,
|
||||||
|
theta_ref: float,
|
||||||
|
prev_params: dict | None,
|
||||||
|
k_eval: np.ndarray,
|
||||||
|
) -> tuple[np.ndarray, dict]:
|
||||||
|
"""Fit one slice with a specific pysvi model and evaluate total variance on k_eval."""
|
||||||
|
T = float(T)
|
||||||
|
k = np.asarray(k, dtype=np.float64)
|
||||||
|
w = np.asarray(w, dtype=np.float64)
|
||||||
|
k_eval = np.asarray(k_eval, dtype=np.float64)
|
||||||
|
|
||||||
|
# ATM total variance proxy for models requiring theta
|
||||||
|
theta = float(np.interp(0.0, np.sort(k), w[np.argsort(k)]))
|
||||||
|
theta = max(theta, 1e-8)
|
||||||
|
|
||||||
|
kwargs: dict = {}
|
||||||
|
if model_name == "ssvi":
|
||||||
|
kwargs["theta"] = theta
|
||||||
|
elif model_name == "essvi":
|
||||||
|
kwargs["theta"] = theta
|
||||||
|
kwargs["theta_ref"] = max(float(theta_ref), 1e-8)
|
||||||
|
elif model_name in {"jumpwings", "jw"}:
|
||||||
|
kwargs["T"] = max(T, 1e-8)
|
||||||
|
|
||||||
|
# Option B: calendar penalty uses pysvi internal 200-point grid per current slice.
|
||||||
|
# Build w_prev on that exact grid to avoid shape mismatch.
|
||||||
|
if prev_params is not None:
|
||||||
|
k_cal = np.linspace(float(k.min()) - 0.5, float(k.max()) + 0.5, 200)
|
||||||
|
kwargs["w_prev"] = np.asarray(model.total_variance(k_cal, prev_params), dtype=np.float64)
|
||||||
|
|
||||||
|
params = model.calibrate(k, w, **kwargs)
|
||||||
|
if params is None:
|
||||||
|
raise RuntimeError(f"pysvi {model_name} calibration failed")
|
||||||
|
w_eval = model.total_variance(k_eval, params)
|
||||||
|
return np.asarray(w_eval, dtype=np.float64), params
|
||||||
|
|
||||||
|
|
||||||
|
def compare_vs_svi_py(prep: pd.DataFrame, params: pd.DataFrame):
|
||||||
|
"""
|
||||||
|
Compare in-house SVI fit against pysvi models with explicit no-arbitrage flags.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
- svi_vs_pysvi_<model>_comparison.pdf for model in {svi, ssvi, essvi, jumpwings}
|
||||||
|
- svi_vs_pysvi_metrics.csv
|
||||||
|
"""
|
||||||
|
from option_pricing.src.ImpliedVolatility.svi import SVIParams
|
||||||
|
from pysvi import ArbitrageFreedom, get_model
|
||||||
|
|
||||||
|
ok_params = params[params["success"]].copy()
|
||||||
|
if ok_params.empty:
|
||||||
|
print("compare_vs_svi_py: no successful in-house slices; skipping.")
|
||||||
|
return
|
||||||
|
|
||||||
|
k_min = float(prep["log_moneyness"].quantile(0.02))
|
||||||
|
k_max = float(prep["log_moneyness"].quantile(0.98))
|
||||||
|
k_grid = np.linspace(k_min, k_max, 180)
|
||||||
|
|
||||||
|
models = ["svi", "ssvi", "essvi", "jumpwings"]
|
||||||
|
rows: list[dict] = []
|
||||||
|
|
||||||
|
# reference theta for eSSVI from in-house successful slices
|
||||||
|
theta_ref = float(np.median(ok_params["T_mean"].to_numpy(dtype=np.float64) * 0 + 1.0))
|
||||||
|
# Better theta_ref proxy from observed market ATM if available
|
||||||
|
theta_vals = []
|
||||||
|
for _, row in ok_params.iterrows():
|
||||||
|
Ti = float(row["T_mean"])
|
||||||
|
near = np.isclose(prep["T"].to_numpy(dtype=np.float64), Ti, rtol=0.03, atol=2e-3)
|
||||||
|
sub = prep.loc[near].sort_values("log_moneyness")
|
||||||
|
if len(sub) < 10:
|
||||||
|
continue
|
||||||
|
ks = sub["log_moneyness"].to_numpy(dtype=np.float64)
|
||||||
|
ws = sub["total_var"].to_numpy(dtype=np.float64)
|
||||||
|
theta_vals.append(float(np.interp(0.0, np.sort(ks), ws[np.argsort(ks)])))
|
||||||
|
if theta_vals:
|
||||||
|
theta_ref = float(np.median(theta_vals))
|
||||||
|
|
||||||
|
sorted_rows = list(ok_params.sort_values("T_mean").iterrows())
|
||||||
|
for model_name in models:
|
||||||
|
flags = ArbitrageFreedom.NO_BUTTERFLY | ArbitrageFreedom.NO_CALENDAR
|
||||||
|
model = get_model(model_name, flags)
|
||||||
|
plt.figure(figsize=(11, 7))
|
||||||
|
cmap = plt.colormaps["tab20"]
|
||||||
|
prev_params = None
|
||||||
|
n_used = 0
|
||||||
|
for _, row in sorted_rows:
|
||||||
|
Ti = float(row["T_mean"])
|
||||||
|
near = np.isclose(prep["T"].to_numpy(dtype=np.float64), Ti, rtol=0.03, atol=2e-3)
|
||||||
|
sub = prep.loc[near].sort_values("log_moneyness")
|
||||||
|
if len(sub) < 10:
|
||||||
|
continue
|
||||||
|
k = sub["log_moneyness"].to_numpy(dtype=np.float64)
|
||||||
|
w = sub["total_var"].to_numpy(dtype=np.float64)
|
||||||
|
|
||||||
|
p_ours = SVIParams(
|
||||||
|
float(row["a"]), float(row["b"]), float(row["rho"]), float(row["m"]), float(row["sigma"])
|
||||||
|
)
|
||||||
|
w_ours = p_ours.total_var(k_grid)
|
||||||
|
rmse_ours = float(np.sqrt(np.mean((p_ours.total_var(k) - w) ** 2)))
|
||||||
|
|
||||||
|
try:
|
||||||
|
w_ext, ext_params = _fit_slice_with_svi_py_model(
|
||||||
|
model,
|
||||||
|
model_name,
|
||||||
|
k,
|
||||||
|
w,
|
||||||
|
Ti,
|
||||||
|
theta_ref=theta_ref,
|
||||||
|
prev_params=prev_params,
|
||||||
|
k_eval=k_grid,
|
||||||
|
)
|
||||||
|
rmse_ext = float(np.sqrt(np.mean((np.interp(k, k_grid, w_ext) - w) ** 2)))
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"model": model_name,
|
||||||
|
"T_mean": Ti,
|
||||||
|
"rmse_ours": rmse_ours,
|
||||||
|
"rmse_pysvi": rmse_ext,
|
||||||
|
"delta_rmse": rmse_ext - rmse_ours,
|
||||||
|
"ext_params": str(ext_params),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
color = cmap(n_used % 20)
|
||||||
|
n_used += 1
|
||||||
|
plt.plot(k_grid, np.sqrt(np.maximum(w_ours, 0) / max(Ti, 1e-12)), color=color, lw=2, alpha=0.9)
|
||||||
|
plt.plot(k_grid, np.sqrt(np.maximum(w_ext, 0) / max(Ti, 1e-12)), color=color, lw=1.5, ls="--", alpha=0.9)
|
||||||
|
prev_params = ext_params
|
||||||
|
except Exception as exc:
|
||||||
|
print(f"compare_vs_svi_py[{model_name}]: skipping T={Ti:.4f}, reason: {exc}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if n_used == 0:
|
||||||
|
plt.close()
|
||||||
|
continue
|
||||||
|
|
||||||
|
plt.xlabel("log moneyness")
|
||||||
|
plt.ylabel("implied vol")
|
||||||
|
plt.title(f"In-house SVI (solid) vs pysvi {model_name} (dashed)")
|
||||||
|
plt.grid(alpha=0.3)
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(f"svi_vs_pysvi_{model_name}_comparison.pdf", bbox_inches="tight")
|
||||||
|
plt.clf()
|
||||||
|
|
||||||
|
out = pd.DataFrame(rows)
|
||||||
|
if out.empty:
|
||||||
|
print("compare_vs_svi_py: no slices compared (pysvi unavailable or incompatible).")
|
||||||
|
return
|
||||||
|
out = out.sort_values(["model", "T_mean"])
|
||||||
|
out.to_csv("svi_vs_pysvi_metrics.csv", index=False)
|
||||||
|
print(out.groupby("model")[["rmse_ours", "rmse_pysvi", "delta_rmse"]].mean())
|
||||||
|
|
||||||
|
|
||||||
|
def plot_ivsmile(option_quotes_contracts):
|
||||||
|
option_quotes_contracts = option_quotes_contracts.sort_values("strike")
|
||||||
|
option_quotes_contracts["log_moneyness"] = np.log(
|
||||||
|
option_quotes_contracts["spot"] * np.exp(0.05 * option_quotes_contracts["T"])/option_quotes_contracts["strike"]
|
||||||
|
)
|
||||||
|
option_quotes_contracts = option_quotes_contracts[option_quotes_contracts["log_moneyness"].abs() < 0.2]
|
||||||
|
#option_quotes_contracts = option_quotes_contracts[option_quotes_contracts["mid"] > 0.2]
|
||||||
|
plt.plot(option_quotes_contracts["strike"], option_quotes_contracts["iv"], label="iv smile")
|
||||||
|
plt.plot(option_quotes_contracts["strike"], option_quotes_contracts["implied_vol"], label="i. vol")
|
||||||
|
plt.legend()
|
||||||
|
plt.savefig("iv_smile.pdf")
|
||||||
|
plt.xlabel("iv")
|
||||||
|
plt.ylabel("strike price")
|
||||||
|
plt.clf()
|
||||||
|
return option_quotes_contracts
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
underlyings, underlying_prices, option_quotes, option_contracts = load_data()
|
||||||
|
option_quotes_contracts = merge_quotes_contracts(option_quotes, option_contracts)
|
||||||
|
option_quotes_contracts = clean_data(option_quotes_contracts)
|
||||||
|
option_quotes_contracts = compute_iv(option_quotes_contracts, underlying_prices)
|
||||||
|
mask = option_quotes_contracts["iv"].notna()
|
||||||
|
print(option_quotes_contracts)
|
||||||
|
print(option_quotes_contracts.columns)
|
||||||
|
#plt.plot(option_quotes_contracts["contract_id"][mask], option_quotes_contracts["implied_vol"][mask], label="i. iv")
|
||||||
|
#plt.plot(option_quotes_contracts["contract_id"][mask],option_quotes_contracts["iv"][mask], label="comp. iv")
|
||||||
|
#plt.legend()
|
||||||
|
#plt.show()
|
||||||
|
option_quotes_contracts = plot_ivsmile(option_quotes_contracts)
|
||||||
|
fit_ivsimle(option_quotes_contracts)
|
||||||
|
prep, svi_fit, params = calibrate_svi_surface(
|
||||||
|
clean_before_svi(option_quotes_contracts),
|
||||||
|
r=0.05,
|
||||||
|
plot_backend="matplotlib",
|
||||||
|
finplot_show=True,
|
||||||
|
# optionally: plot_path=None to avoid matplotlib PDF output
|
||||||
|
)
|
||||||
|
print(svi_fit)
|
||||||
|
plot_smoothed_svi_surface(prep, params, r=0.05)
|
||||||
|
compare_vs_svi_py(prep, params)
|
||||||
|
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
This folder is intentionally self-contained.
|
||||||
|
|
||||||
|
- No imports from the parent option_pricing package (no qengine, src/, cpp bindings).
|
||||||
|
- Third-party dependencies: numpy, matplotlib (see requirements.txt).
|
||||||
|
- Run: python run_experiment.py [--out lv_rmse.png]
|
||||||
|
- Safe to copy elsewhere or run in isolation.
|
||||||
Binary file not shown.
|
After Width: | Height: | Size: 148 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 96 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 143 KiB |
@@ -0,0 +1,108 @@
|
|||||||
|
"""
|
||||||
|
Gatheral local variance in total-variance / log-moneyness form (practitioner's guide).
|
||||||
|
|
||||||
|
sigma^2 = (d_T w) / ( 1 - (y/w) d_y w
|
||||||
|
+ (1/4)(-1/4 - 1/w + y^2/w^2) (d_y w)^2
|
||||||
|
+ (1/2) d_yy w )
|
||||||
|
|
||||||
|
where w = omega is total implied variance, y is log-moneyness (convention as in the note).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def local_variance_from_derivatives(
|
||||||
|
y: np.ndarray,
|
||||||
|
w: np.ndarray,
|
||||||
|
dy_w: np.ndarray,
|
||||||
|
dyy_w: np.ndarray,
|
||||||
|
dT_w: np.ndarray,
|
||||||
|
*,
|
||||||
|
eps: float = 1e-14,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Vectorized Gatheral formula. Invalid / near-singular points become nan."""
|
||||||
|
y = np.asarray(y, dtype=float)
|
||||||
|
w = np.asarray(w, dtype=float)
|
||||||
|
dy_w = np.asarray(dy_w, dtype=float)
|
||||||
|
dyy_w = np.asarray(dyy_w, dtype=float)
|
||||||
|
dT_w = np.asarray(dT_w, dtype=float)
|
||||||
|
|
||||||
|
out = np.full_like(y, np.nan, dtype=float)
|
||||||
|
ok = np.isfinite(w) & (np.abs(w) > eps) & np.isfinite(dy_w) & np.isfinite(dyy_w) & np.isfinite(dT_w)
|
||||||
|
|
||||||
|
denom = np.empty_like(w)
|
||||||
|
denom[ok] = (
|
||||||
|
1.0
|
||||||
|
- (y[ok] / w[ok]) * dy_w[ok]
|
||||||
|
+ 0.25 * (-0.25 - 1.0 / w[ok] + (y[ok] ** 2) / (w[ok] ** 2)) * (dy_w[ok] ** 2)
|
||||||
|
+ 0.5 * dyy_w[ok]
|
||||||
|
)
|
||||||
|
|
||||||
|
ok2 = ok & (np.abs(denom) > eps)
|
||||||
|
out[ok2] = dT_w[ok2] / denom[ok2]
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def quadratic_total_variance(
|
||||||
|
y: np.ndarray,
|
||||||
|
alpha: float,
|
||||||
|
beta: float,
|
||||||
|
gamma: float,
|
||||||
|
T: float,
|
||||||
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||||
|
"""
|
||||||
|
w(y,T) = T * (alpha + beta*y + gamma*y^2), with derivatives as in the note:
|
||||||
|
|
||||||
|
d_T w = alpha + beta*y + gamma*y^2
|
||||||
|
d_y w = T * (beta + 2*gamma*y)
|
||||||
|
d_yy w = 2*gamma*T
|
||||||
|
"""
|
||||||
|
y = np.asarray(y, dtype=float)
|
||||||
|
f = alpha + beta * y + gamma * y ** 2
|
||||||
|
w = T * f
|
||||||
|
dT_w = f
|
||||||
|
dy_w = T * (beta + 2.0 * gamma * y)
|
||||||
|
dyy_w = np.full_like(y, 2.0 * gamma * T)
|
||||||
|
return w, dT_w, dy_w, dyy_w
|
||||||
|
|
||||||
|
|
||||||
|
def analytic_local_variance_quadratic(
|
||||||
|
y: np.ndarray,
|
||||||
|
alpha: float,
|
||||||
|
beta: float,
|
||||||
|
gamma: float,
|
||||||
|
T: float,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Closed form from the note (equivalent to plugging derivatives into Gatheral)."""
|
||||||
|
y = np.asarray(y, dtype=float)
|
||||||
|
w, dT_w, dy_w, dyy_w = quadratic_total_variance(y, alpha, beta, gamma, T)
|
||||||
|
return local_variance_from_derivatives(y, w, dy_w, dyy_w, dT_w)
|
||||||
|
|
||||||
|
|
||||||
|
def central_first_derivative_uniform(w: np.ndarray, h: float) -> np.ndarray:
|
||||||
|
"""Interior (w[i+1]-w[i-1])/(2h); endpoints nan."""
|
||||||
|
w = np.asarray(w, dtype=float)
|
||||||
|
out = np.full_like(w, np.nan)
|
||||||
|
out[1:-1] = (w[2:] - w[:-2]) / (2.0 * h)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def second_derivative_uniform(w: np.ndarray, h: float) -> np.ndarray:
|
||||||
|
"""Interior second difference / h^2; endpoints nan."""
|
||||||
|
w = np.asarray(w, dtype=float)
|
||||||
|
out = np.full_like(w, np.nan)
|
||||||
|
out[1:-1] = (w[2:] - 2.0 * w[1:-1] + w[:-2]) / (h ** 2)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def add_multiplicative_noise(
|
||||||
|
w: np.ndarray,
|
||||||
|
sigma_noise: float,
|
||||||
|
rng: np.random.Generator,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""tilde w(y_i) = w(y_i) * (1 + eps), eps ~ N(0, sigma_noise^2)."""
|
||||||
|
w = np.asarray(w, dtype=float)
|
||||||
|
eps = rng.normal(0.0, sigma_noise, size=w.shape)
|
||||||
|
return w * (1.0 + eps)
|
||||||
Binary file not shown.
|
After Width: | Height: | Size: 154 KiB |
@@ -0,0 +1,2 @@
|
|||||||
|
numpy>=1.20
|
||||||
|
matplotlib>=3.5
|
||||||
@@ -0,0 +1,309 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Local-volatility instability experiment (Gatheral total variance in log-moneyness).
|
||||||
|
|
||||||
|
We compare the analytic local variance σ²(y) from a quadratic total variance
|
||||||
|
w(y,T) = T(α + βy + γy²) to σ² reconstructed from a noisy discrete surface
|
||||||
|
w̃(y_i) = w(y_i)(1 + ε_i) using finite differences in y, for several levels of
|
||||||
|
multiplicative noise σ_noise. This script only produces the figure: RMSE of the
|
||||||
|
FD reconstruction vs σ_noise (log–log), with a y = σ reference line of slope 1.
|
||||||
|
|
||||||
|
Dependencies: numpy, matplotlib only (see INDEPENDENT_STANDALONE.txt).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
# Prevent accidental imports from the parent repository
|
||||||
|
_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||||
|
if _REPO_ROOT in sys.path:
|
||||||
|
sys.path.remove(_REPO_ROOT)
|
||||||
|
|
||||||
|
import matplotlib as mpl
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from gatheral_local_vol import (
|
||||||
|
add_multiplicative_noise,
|
||||||
|
analytic_local_variance_quadratic,
|
||||||
|
central_first_derivative_uniform,
|
||||||
|
local_variance_from_derivatives,
|
||||||
|
quadratic_total_variance,
|
||||||
|
second_derivative_uniform,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Defaults (quadratic total variance, positive w on y ∈ [-0.5, 0.5])
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
ALPHA = 0.04
|
||||||
|
BETA = 0.0
|
||||||
|
GAMMA = 0.1
|
||||||
|
T_MATURITY = 1.0
|
||||||
|
Y_MIN = -0.5
|
||||||
|
Y_MAX = 0.5
|
||||||
|
N_GRID = 201
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_parent_dir(path: str) -> None:
|
||||||
|
parent = os.path.dirname(os.path.abspath(path))
|
||||||
|
if parent:
|
||||||
|
os.makedirs(parent, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def log_uniform_sigma_grid(n_points: int, sigma_min: float, sigma_max: float) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Return `n_points` values of σ_noise with log₁₀(σ) equally spaced.
|
||||||
|
|
||||||
|
This is the correct sampling for a log–log RMSE plot; it is not linspace(σ_min, σ_max).
|
||||||
|
"""
|
||||||
|
n_points = max(4, n_points)
|
||||||
|
if sigma_min <= 0 or sigma_max <= 0 or sigma_max < sigma_min:
|
||||||
|
raise ValueError("Require 0 < sigma_min <= sigma_max.")
|
||||||
|
return np.logspace(np.log10(sigma_min), np.log10(sigma_max), n_points)
|
||||||
|
|
||||||
|
|
||||||
|
def relative_pointwise_error(
|
||||||
|
sigma2_analytic: np.ndarray, sigma2_fd: np.ndarray, eps: float = 1e-12
|
||||||
|
) -> np.ndarray:
|
||||||
|
return (sigma2_fd - sigma2_analytic) / np.maximum(np.abs(sigma2_analytic), eps)
|
||||||
|
|
||||||
|
|
||||||
|
def rmse_absolute(
|
||||||
|
sigma2_analytic: np.ndarray,
|
||||||
|
sigma2_fd: np.ndarray,
|
||||||
|
interior: slice,
|
||||||
|
) -> float:
|
||||||
|
"""RMSE of (σ²_FD − σ²_analytic) on interior indices."""
|
||||||
|
sa = np.asarray(sigma2_analytic, dtype=float)[interior]
|
||||||
|
sf = np.asarray(sigma2_fd, dtype=float)[interior]
|
||||||
|
m = np.isfinite(sa) & np.isfinite(sf)
|
||||||
|
if not np.any(m):
|
||||||
|
return float("nan")
|
||||||
|
d = sf[m] - sa[m]
|
||||||
|
return float(np.sqrt(np.mean(d * d)))
|
||||||
|
|
||||||
|
|
||||||
|
def rmse_relative(
|
||||||
|
sigma2_analytic: np.ndarray,
|
||||||
|
sigma2_fd: np.ndarray,
|
||||||
|
interior: slice,
|
||||||
|
eps: float = 1e-12,
|
||||||
|
) -> float:
|
||||||
|
"""RMSE over grid points of relative error (σ²_FD − σ²_analytic) / |σ²_analytic|."""
|
||||||
|
re = relative_pointwise_error(sigma2_analytic, sigma2_fd, eps=eps)[interior]
|
||||||
|
m = np.isfinite(re)
|
||||||
|
if not np.any(m):
|
||||||
|
return float("nan")
|
||||||
|
return float(np.sqrt(np.mean(re[m] ** 2)))
|
||||||
|
|
||||||
|
|
||||||
|
def local_variance_one_draw(
|
||||||
|
y: np.ndarray,
|
||||||
|
h: float,
|
||||||
|
alpha: float,
|
||||||
|
beta: float,
|
||||||
|
gamma: float,
|
||||||
|
T: float,
|
||||||
|
sigma_noise: float,
|
||||||
|
rng: np.random.Generator,
|
||||||
|
dT_mode: Literal["exact", "noisy_ratio"],
|
||||||
|
) -> tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""One noisy surface and FD local variance; returns (σ²_analytic, σ²_FD)."""
|
||||||
|
w_true, dT_w_true, _, _ = quadratic_total_variance(y, alpha, beta, gamma, T)
|
||||||
|
sigma2_a = analytic_local_variance_quadratic(y, alpha, beta, gamma, T)
|
||||||
|
|
||||||
|
w_tilde = add_multiplicative_noise(w_true, sigma_noise, rng)
|
||||||
|
dy = central_first_derivative_uniform(w_tilde, h)
|
||||||
|
dyy = second_derivative_uniform(w_tilde, h)
|
||||||
|
|
||||||
|
if dT_mode == "exact":
|
||||||
|
dT = dT_w_true
|
||||||
|
elif dT_mode == "noisy_ratio":
|
||||||
|
dT = w_tilde / T
|
||||||
|
else:
|
||||||
|
raise ValueError(dT_mode)
|
||||||
|
|
||||||
|
sigma2_fd = local_variance_from_derivatives(y, w_tilde, dy, dyy, dT)
|
||||||
|
return sigma2_a, sigma2_fd
|
||||||
|
|
||||||
|
|
||||||
|
def rmse_curves_averaged(
|
||||||
|
y: np.ndarray,
|
||||||
|
h: float,
|
||||||
|
alpha: float,
|
||||||
|
beta: float,
|
||||||
|
gamma: float,
|
||||||
|
T: float,
|
||||||
|
sigma_grid: np.ndarray,
|
||||||
|
rng: np.random.Generator,
|
||||||
|
dT_mode: Literal["exact", "noisy_ratio"],
|
||||||
|
interior: slice,
|
||||||
|
trials_per_sigma: int,
|
||||||
|
) -> tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""
|
||||||
|
For each σ in `sigma_grid`, average RMSE (relative and absolute) over
|
||||||
|
`trials_per_sigma` independent noise draws.
|
||||||
|
"""
|
||||||
|
rel: list[float] = []
|
||||||
|
abs_: list[float] = []
|
||||||
|
trials_per_sigma = max(1, trials_per_sigma)
|
||||||
|
|
||||||
|
for sig in sigma_grid:
|
||||||
|
tr: list[float] = []
|
||||||
|
ta: list[float] = []
|
||||||
|
for _ in range(trials_per_sigma):
|
||||||
|
sa, sf = local_variance_one_draw(
|
||||||
|
y, h, alpha, beta, gamma, T, float(sig), rng, dT_mode
|
||||||
|
)
|
||||||
|
tr.append(rmse_relative(sa, sf, interior))
|
||||||
|
ta.append(rmse_absolute(sa, sf, interior))
|
||||||
|
rel.append(float(np.nanmean(tr)))
|
||||||
|
abs_.append(float(np.nanmean(ta)))
|
||||||
|
|
||||||
|
return np.asarray(rel, dtype=float), np.asarray(abs_, dtype=float)
|
||||||
|
|
||||||
|
|
||||||
|
def plot_rmse_vs_noise(
|
||||||
|
sigma_grid: np.ndarray,
|
||||||
|
rmse_rel: np.ndarray,
|
||||||
|
rmse_abs: np.ndarray,
|
||||||
|
*,
|
||||||
|
h: float,
|
||||||
|
T: float,
|
||||||
|
dT_mode: str,
|
||||||
|
trials_per_sigma: int,
|
||||||
|
) -> mpl.figure.Figure:
|
||||||
|
"""
|
||||||
|
Log–log plot: RMSE (relative and absolute in σ²) vs σ_noise, reference y = σ.
|
||||||
|
"""
|
||||||
|
fig, ax = plt.subplots(figsize=(5.8, 3.8), constrained_layout=True)
|
||||||
|
|
||||||
|
x = np.asarray(sigma_grid, dtype=float)
|
||||||
|
pos = x > 0
|
||||||
|
n = len(x)
|
||||||
|
ms = 3.5 if n > 50 else 4.5
|
||||||
|
|
||||||
|
ax.loglog(
|
||||||
|
x[pos],
|
||||||
|
rmse_rel[pos],
|
||||||
|
"o-",
|
||||||
|
ms=ms,
|
||||||
|
lw=1.25,
|
||||||
|
label=r"RMSE of relative error $(\sigma^2_{\mathrm{FD}}-\sigma^2_{\mathrm{nat}})/|\sigma^2_{\mathrm{nat}}|$",
|
||||||
|
zorder=3,
|
||||||
|
)
|
||||||
|
ax.loglog(
|
||||||
|
x[pos],
|
||||||
|
rmse_abs[pos],
|
||||||
|
"s--",
|
||||||
|
ms=ms - 1,
|
||||||
|
lw=1.0,
|
||||||
|
alpha=0.9,
|
||||||
|
label=r"RMSE of $\sigma^2$ error $|\sigma^2_{\mathrm{FD}}-\sigma^2_{\mathrm{nat}}|$",
|
||||||
|
zorder=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
s_lo, s_hi = float(x[pos].min()), float(x[pos].max())
|
||||||
|
ax.loglog([s_lo, s_hi], [s_lo, s_hi], ":", color="0.4", lw=2.0, zorder=1, label=r"reference slope 1: $y=\sigma_{\mathrm{noise}}$")
|
||||||
|
|
||||||
|
ax.set_xlabel(r"$\sigma_{\mathrm{noise}}$ (multiplicative noise on $\tilde{w}$)")
|
||||||
|
ax.set_ylabel("RMSE (interior $y$)")
|
||||||
|
subtitle = f"$T={T}$, $h={h:.4f}$, $\\partial_T w$: {dT_mode}"
|
||||||
|
if trials_per_sigma > 1:
|
||||||
|
subtitle += f", mean over {trials_per_sigma} draws per $\\sigma$"
|
||||||
|
ax.set_title("FD local variance: RMSE vs noise\n" + subtitle, fontsize=10)
|
||||||
|
ax.grid(True, which="both", alpha=0.35)
|
||||||
|
ax.legend(loc="best", fontsize=8, framealpha=0.95)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def configure_matplotlib_style() -> None:
|
||||||
|
"""Conservative defaults suitable for print."""
|
||||||
|
mpl.rcParams.update(
|
||||||
|
{
|
||||||
|
"figure.dpi": 120,
|
||||||
|
"savefig.dpi": 300,
|
||||||
|
"font.size": 10,
|
||||||
|
"axes.labelsize": 10,
|
||||||
|
"axes.titlesize": 10,
|
||||||
|
"legend.fontsize": 8,
|
||||||
|
"axes.grid": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
configure_matplotlib_style()
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="RMSE of finite-difference local variance vs multiplicative noise (single figure).",
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=42, help="RNG seed.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--out",
|
||||||
|
type=str,
|
||||||
|
default="lv_rmse.png",
|
||||||
|
help="Output image path.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dT-mode",
|
||||||
|
choices=("exact", "noisy_ratio"),
|
||||||
|
default="exact",
|
||||||
|
help="Treatment of ∂_T w when w is replaced by noisy w̃ on the grid.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--rmse-points", type=int, default=35, help="Number of σ_noise values (log-uniform).")
|
||||||
|
parser.add_argument("--rmse-sigma-min", type=float, default=1e-5, help="Smallest σ_noise.")
|
||||||
|
parser.add_argument("--rmse-sigma-max", type=float, default=5e-4, help="Largest σ_noise.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--rmse-trials",
|
||||||
|
type=int,
|
||||||
|
default=50,
|
||||||
|
help="Independent noisy surfaces per σ_noise; RMSE is averaged.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
rng = np.random.default_rng(args.seed)
|
||||||
|
y = np.linspace(Y_MIN, Y_MAX, N_GRID)
|
||||||
|
h = float(y[1] - y[0])
|
||||||
|
interior = slice(1, -1)
|
||||||
|
|
||||||
|
sigma_grid = log_uniform_sigma_grid(args.rmse_points, args.rmse_sigma_min, args.rmse_sigma_max)
|
||||||
|
rmse_rel, rmse_abs = rmse_curves_averaged(
|
||||||
|
y,
|
||||||
|
h,
|
||||||
|
ALPHA,
|
||||||
|
BETA,
|
||||||
|
GAMMA,
|
||||||
|
T_MATURITY,
|
||||||
|
sigma_grid,
|
||||||
|
rng,
|
||||||
|
args.dT_mode,
|
||||||
|
interior,
|
||||||
|
args.rmse_trials,
|
||||||
|
)
|
||||||
|
|
||||||
|
fig = plot_rmse_vs_noise(
|
||||||
|
sigma_grid,
|
||||||
|
rmse_rel,
|
||||||
|
rmse_abs,
|
||||||
|
h=h,
|
||||||
|
T=T_MATURITY,
|
||||||
|
dT_mode=args.dT_mode,
|
||||||
|
trials_per_sigma=args.rmse_trials,
|
||||||
|
)
|
||||||
|
|
||||||
|
ensure_parent_dir(args.out)
|
||||||
|
fig.savefig(args.out, bbox_inches="tight")
|
||||||
|
print(f"Wrote {args.out}")
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -3,6 +3,7 @@
|
|||||||
//
|
//
|
||||||
|
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
#include "BlackScholesClosedFormEngine.hpp"
|
||||||
#include "BlackScholesProcess.hpp"
|
#include "BlackScholesProcess.hpp"
|
||||||
#include "MonteCarloEngine.hpp"
|
#include "MonteCarloEngine.hpp"
|
||||||
#include "Instrument.hpp"
|
#include "Instrument.hpp"
|
||||||
@@ -51,3 +52,28 @@ TEST(BlackScholesProcess, ExpectedValue) {
|
|||||||
ASSERT_NEAR(callPrice, callGT, tol);
|
ASSERT_NEAR(callPrice, callGT, tol);
|
||||||
ASSERT_NEAR(putPrice, putGT, tol);
|
ASSERT_NEAR(putPrice, putGT, tol);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(BlackScholesClosedForm, MatchesReference) {
|
||||||
|
const double K = 100.0;
|
||||||
|
const double T = 1.0;
|
||||||
|
|
||||||
|
const MarketData marketData(
|
||||||
|
100.0,
|
||||||
|
std::make_shared<FlatYieldCurve>(0.01),
|
||||||
|
std::make_shared<FlatVolatilitySurface>(0.2));
|
||||||
|
|
||||||
|
auto processCall = std::make_unique<BlackScholesProcess>(marketData);
|
||||||
|
auto processPut = std::make_unique<BlackScholesProcess>(marketData);
|
||||||
|
|
||||||
|
auto analyticCall = std::make_unique<BlackScholesClosedFormEngine>(std::move(processCall));
|
||||||
|
auto analyticPut = std::make_unique<BlackScholesClosedFormEngine>(std::move(processPut));
|
||||||
|
|
||||||
|
Instrument callInstr(T, std::make_unique<CallPayoff>(K), std::move(analyticCall));
|
||||||
|
Instrument putInstr(T, std::make_unique<PutPayoff>(K), std::move(analyticPut));
|
||||||
|
|
||||||
|
const double callGT = 8.4333186901;
|
||||||
|
const double putGT = 7.4383020650;
|
||||||
|
|
||||||
|
ASSERT_NEAR(callInstr.price(), callGT, 1e-9);
|
||||||
|
ASSERT_NEAR(putInstr.price(), putGT, 1e-9);
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user