{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Multi-fidelity optimization" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Some experiments can be very expensive. These may be supplemented by simpler alternatives, or perhaps high-throughput calculations. This would give measurements of lower *fidelity*, and the planner can take advantage of these measurements to guide high fidelity optimization.\n", "\n", "This can also be used in a virtual screening setting. Expensive quantum chemistry calculations can be supplemented by faster semi-empirical methods. Another example could also be the virtual screening of compounds for drug activity, with high fidelity free-energy perturbation calcualtions being approximated by faster and lower fidelity docking calculations." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/garyk/mambaforge/envs/atlas/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import json\n", "import pickle\n", "import numpy as np\n", "import pandas as pd\n", "from copy import deepcopy\n", "\n", "from olympus.datasets import Dataset\n", "from olympus.objects import (\n", "\tParameterContinuous,\n", "\tParameterDiscrete, \n", "\tParameterCategorical, \n", "\tParameterVector\n", ")\n", "from olympus.campaigns import ParameterSpace, Campaign\n", "\n", "from atlas.planners.multi_fidelity.planner import MultiFidelityPlanner # specially designed planner for multi-fidelity optimization\n", "\n", "import pickle" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For this example, we will perform a screening of the bandgap of perovskites. There's two fidelities of measurements, one using GGA (*low*), and one use HSE06 (*high*). You can set the associated cost to each one, but we will consider queries to GGA calculations as 10 times cheaper than with HSE06." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "COST_BUDGET = 50 # this time the budget is a cost\n", "NUM_INIT_DESIGN = 10\n", "NUM_CHEAP = 8 # this is the ratio of low:high measurements (ie. 8:1 low/high fidelity)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here we will create an additional fidelity parameter `s`, which can only be the permitted fidelities. The `MultiFidelityPlanner` will be allowed to vary this parameter, and perform optimization with an additional constrained *fidelity* parameter." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "dataset = Dataset(kind='perovskites')\n", "\n", "# build parameter space\n", "param_space = ParameterSpace()\n", "\n", "# fidelity param\n", "param_space.add(ParameterDiscrete(name='s', options=[0.1, 1.0], low=0.1, high=1.0))\n", "for param in dataset.param_space: # add perovskite component parameters ('organic', 'cation', and 'anion')\n", "\tparam_space.add(param)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# lower fidelity data calucated using GGA is available in the examples folder\n", "# so we will load it here to create a new function for measurements\n", "# fill in the ATLAS_PATH\n", "ATLAS_PATH = '.'\n", "LOOKUP = pickle.load(open(f'{ATLAS_PATH}/examples/multi_fidelity/perovskites/lookup/lookup_table.pkl', 'rb'))\n", "\n", "def measure(params, s):\n", "\t# high-fidelity is hse06, low-fidelity is gga\n", "\tif s == 1.0:\n", "\t\tmeasurement = np.amin(\n", "\t\t\tLOOKUP[params.organic.capitalize()][params.cation][params.anion]['bandgap_hse06']\n", "\t\t)\n", "\telif s == 0.1:\n", "\t\tmeasurement = np.amin(\n", "\t\t\tLOOKUP[params.organic.capitalize()][params.cation][params.anion]['bandgap_gga']\n", "\t\t)\n", "\treturn measurement\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
───────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
       "
\n" ], "text/plain": [ "\u001b[38;2;5;25;35m───────────────────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
                                                                                                                   \n",
       "                                                 Welcome to ATLAS!                                                 \n",
       "
\n" ], "text/plain": [ "\u001b[1;38;2;5;166;251m \u001b[0m\u001b[1;38;2;5;166;251m \u001b[0m\u001b[1;38;2;5;166;251m \u001b[0m\n", "\u001b[1;38;2;5;166;251m \u001b[0m\u001b[1;38;2;5;166;251mWelcome to ATLAS!\u001b[0m\u001b[1;38;2;5;166;251m \u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
                                                Made with 💕 in 🇨🇦                                                 \n",
       "                                                                                                                   \n",
       "
\n" ], "text/plain": [ "\u001b[38;2;0;100;148m \u001b[0m\u001b[38;2;0;100;148mMade with 💕 in 🇨🇦\u001b[0m\u001b[38;2;0;100;148m \u001b[0m\n", "\u001b[38;2;0;100;148m \u001b[0m\u001b[38;2;0;100;148m \u001b[0m\u001b[38;2;0;100;148m \u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
───────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
       "
\n" ], "text/plain": [ "\u001b[38;2;5;25;35m───────────────────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
───────────────────────────── Initial design phase ─────────────────────────────\n",
       "
\n" ], "text/plain": [ "\u001b[38;2;5;166;251m───────────────────────────── Initial design phase ─────────────────────────────\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "campaign = Campaign()\n", "campaign.set_param_space(param_space)\n", "\n", "planner = MultiFidelityPlanner(\n", " goal='minimize',\n", " init_design_strategy='random',\n", " num_init_design=NUM_INIT_DESIGN,\n", " use_descriptors=True,\n", " batch_size=1,\n", " acquisition_optimizer_kind='pymoo', # this is required\n", " fidelity_params=0, # this dimension is the fidelity parameter (we use the first one)\n", " fidelities=[0.1, 1.], # these are the possible fidelities (GGA = 0.1, and HSE = 1.0)\n", ")\n", "\n", "planner.set_param_space(param_space)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# accumulated cost, the budget is also cost\n", "COST = 0.\n", "\n", "target_rec_measurements = []\n", "iter_ = 0\n", "while COST < COST_BUDGET:\n", " print(f'\\nITER : {iter_+1}\\tCOST : {COST}\\n')\n", "\n", " # this is how much the corresponding measurement will cost\n", " if iter_ % NUM_CHEAP == 0:\n", " planner.set_ask_fidelity(1.0)\n", " else:\n", " planner.set_ask_fidelity(0.1)\n", "\n", " samples = planner.recommend(campaign.observations)\n", " for sample in samples:\n", " measurement = measure(sample, sample.s)\n", " campaign.add_observation(sample, measurement)\n", "\n", " print('SAMPLE : ', sample)\n", " print('MEASUREMENT : ', measurement)\n", "\n", " iter_+=1\n", "\n", " # do a check to see if model will find the optimal\n", " if campaign.num_obs > NUM_INIT_DESIGN:\n", " # make greedy recommendation on the target fidelity\n", " # use this to make a high-fidelity measurement\n", " rec_sample = planner.recommend_target_fidelity(batch_size=1)[0]\n", " rec_measurement = measure(rec_sample, rec_sample.s)\n", " print('REC SAMPLE : ', rec_sample)\n", " print('REC MEASUREMENT : ', rec_measurement)\n", "\n", " target_rec_measurements.append(rec_measurement)\n", " # kill the run if we have found the lowest hse06 bandgap\n", " # on the most recent high-fidelity measurement\n", " if rec_measurement == min_hse06_bandgap:\n", " print('found the min hse06 bandgap!')\n", " break\n", " else:\n", " target_rec_measurements.append(measurement)\n", " if measurement == min_hse06_bandgap and samples[0].s == 1.:\n", " print('found the min hse06 bandgap!')\n", " break" ] } ], "metadata": { "kernelspec": { "display_name": "atlas", "language": "python", "name": "atlas" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.20" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }