#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# @Author: José Sánchez-Gallego (gallegoj@uw.edu)
# @Date: 2023-03-10
# @Filename: tools.py
# @License: BSD 3-clause (http://www.opensource.org/licenses/BSD-3-Clause)
from __future__ import annotations
import asyncio
import concurrent.futures
import datetime
import functools
import hashlib
import logging
import os
import pathlib
import re
import subprocess
import tempfile
import warnings
from contextlib import asynccontextmanager, contextmanager, suppress
from functools import partial, wraps
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Awaitable,
Callable,
Coroutine,
Generator,
Literal,
Sequence,
TypedDict,
TypeVar,
cast,
overload,
)
import httpx
import numpy
import peewee
import polars
import redis
from astropy import units as uu
from astropy.coordinates import angular_separation as astropy_angular_separation
from redis import asyncio as aioredis
from clu import AMQPClient
from sdsstools import get_sjd
from gort import config
from gort.exceptions import ErrorCode, GortError
if TYPE_CHECKING:
from clu import AMQPClient, AMQPReply
from gort.devices.telescope import FibSel
from gort.gort import Gort
__all__ = [
"get_valid_variable_name",
"get_next_tile_id",
"get_calibrators",
"get_next_tile_id_sync",
"get_calibrators_sync",
"register_observation",
"get_ccd_frame_path",
"move_mask_interval",
"angular_separation",
"get_db_connection",
"redis_client_async",
"redis_client_sync",
"run_in_executor",
"is_interactive",
"is_notebook",
"cancel_task",
"get_ephemeris_summary",
"get_ephemeris_summary_sync",
"get_temporary_file_path",
"insert_to_database",
"get_md5sum_file",
"get_md5sum_from_spectro",
"get_md5sum",
"mark_exposure_bad",
"handle_signals",
"get_lvmapi_route",
"run_lvmapi_task",
"GuiderMonitor",
"overwatcher_is_running",
"get_by_source_id",
"is_actor_running",
"check_overwatcher_not_running",
"kubernetes_restart_deployment",
"kubernetes_list_deployments",
"get_gort_client",
"add_night_log_comment",
"async_noop",
"LogNamespace",
"decap",
"ensure_period",
"try_or_pass",
"record_overheads",
"OverheadDict",
"ping_host",
"get_exposure_list",
]
AnyPath = str | os.PathLike
CAMERAS = [
"sci.west",
"sci.east",
"skye.west",
"skye.east",
"skyw.west",
"skyw.east",
"spec.east",
]
[docs]
def get_valid_variable_name(var_name: str):
"""Converts a string to a valid variable name."""
return re.sub(r"\W|^(?=\d)", "_", var_name)
[docs]
def get_next_tile_id_sync() -> dict:
"""Retrieves the next ``tile_id`` from the scheduler API. Synchronous version."""
sch_config = config["services"]["scheduler"]
host = sch_config["host"]
port = sch_config["port"]
with httpx.Client(base_url=f"http://{host}:{port}/") as client:
resp = client.get("next_tile")
if resp.status_code != 200:
raise httpx.RequestError("Failed request to /next_tile")
tile_id_data = resp.json()
return tile_id_data
[docs]
async def get_next_tile_id() -> dict:
"""Retrieves the next ``tile_id`` from the scheduler API."""
sch_config = config["services"]["scheduler"]
host = sch_config["host"]
port = sch_config["port"]
async with httpx.AsyncClient(base_url=f"http://{host}:{port}/") as client:
resp = await client.get("next_tile")
if resp.status_code != 200:
raise httpx.RequestError("Failed request to /next_tile")
tile_id_data = resp.json()
return tile_id_data
[docs]
def get_calibrators_sync(
tile_id: int | None = None,
ra: float | None = None,
dec: float | None = None,
) -> dict:
"""Get calibrators for a ``tile_id`` or science pointing. Synchronous version."""
sch_config = config["services"]["scheduler"]
host = sch_config["host"]
port = sch_config["port"]
with httpx.Client(base_url=f"http://{host}:{port}/") as client:
if tile_id:
resp = client.get("cals", params={"tile_id": tile_id})
elif ra is not None and dec is not None:
resp = client.get("cals", params={"ra": ra, "dec": dec})
else:
raise ValueError("ra and dec are required.")
if resp.status_code != 200:
raise httpx.RequestError("Failed request to /cals")
return resp.json()
[docs]
async def get_calibrators(
tile_id: int | None = None,
ra: float | None = None,
dec: float | None = None,
):
"""Get calibrators for a ``tile_id`` or science pointing."""
sch_config = config["services"]["scheduler"]
host = sch_config["host"]
port = sch_config["port"]
async with httpx.AsyncClient(base_url=f"http://{host}:{port}/") as client:
if tile_id:
resp = await client.get("cals", params={"tile_id": tile_id})
elif ra is not None and dec is not None:
resp = await client.get("cals", params={"ra": ra, "dec": dec})
else:
raise ValueError("ra and dec are required.")
if resp.status_code != 200:
raise httpx.RequestError("Failed request to /cals")
return resp.json()
[docs]
async def register_observation(payload: dict):
"""Registers an observation with the scheduler."""
sch_config = config["services"]["scheduler"]
host = sch_config["host"]
port = sch_config["port"]
async with httpx.AsyncClient() as client:
resp = await client.put(
f"http://{host}:{port}/register_observation",
json=payload,
follow_redirects=True,
)
if resp.status_code != 200 or not resp.json()["success"]:
raise RuntimeError(f"Failed registering observation: {resp.text}.")
[docs]
def mark_exposure_bad(tile_id: int, dither_position: int = 0):
"""Marks a registered tile/dither as bad."""
db = get_db_connection()
completion_status = peewee.Table("completion_status", schema="lvmopsdb").bind(db)
dither = peewee.Table("dither", schema="lvmopsdb").bind(db)
dither_pk = (
dither.select(dither.c.pk)
.where(
dither.c.tile_id == tile_id,
dither.c.position == dither_position,
)
.namedtuples()
)
if len(dither_pk) == 0:
raise ValueError("No matching tile-position.")
completion_status.update(done=False).where(
completion_status.c.pk == dither_pk[0].pk
).execute()
async def set_tile_status(tile_id: int, note: str | None = None):
"""Enables/disables a tile in the database."""
sch_config = config["services"]["scheduler"]
host = sch_config["host"]
port = sch_config["port"]
async with httpx.AsyncClient() as client:
resp = await client.put(
f"http://{host}:{port}/tile_status",
params={"tile_id": tile_id, "note": note},
json={},
follow_redirects=True,
)
if resp.status_code != 200 or not resp.json()["success"]:
raise RuntimeError(f"Failed setting tile status: {resp.text}.")
[docs]
def is_notebook() -> bool:
"""Returns :obj:`True` if the code is run inside a Jupyter Notebook.
https://stackoverflow.com/questions/15411967/how-can-i-check-if-code-is-executed-in-the-ipython-notebook
"""
try:
shell = get_ipython().__class__.__name__ # type: ignore
if shell == "ZMQInteractiveShell":
return True # Jupyter notebook or qtconsole
elif shell == "TerminalInteractiveShell":
return False # Terminal running IPython
else:
return False # Other type (?)
except NameError:
return False # Probably standard Python interpreter
[docs]
def is_interactive():
"""Returns :obj:`True` is we are in an interactive session."""
import __main__ as main
return not hasattr(main, "__file__")
[docs]
def get_ccd_frame_path(
frame_id: int,
sjd: int | None = None,
cameras: str | list[str] | None = None,
spectro_path="/data/spectro",
) -> list[str]:
"""Returns the paths for the files for a spectrograph frame.
Parameters
----------
frame_id
The spectrograph frame for which the paths are searched.
mjd
The SJD in which the frames where taken. If not provided, all the
directories under ``spectro_path`` are searched.
cameras
The cameras to be returned. If :obj:`None`, all cameras found are returned.
spectro_path
The path to the ``spectro`` directory where spectrograph files are
stored under an SJD structure.
Returns
-------
paths
The list of paths to CCD frames that match ``frame_id``.
"""
if isinstance(cameras, str):
cameras = [cameras]
base_path = pathlib.Path(spectro_path)
recursive = True
if sjd:
base_path /= str(sjd)
recursive = False
# Glob all files that match the frame_id.
globp = f"*{frame_id}.fits.*"
if recursive:
globp = f"**/{globp}"
files = [str(path) for path in base_path.glob(globp)]
if cameras is None:
return files
files_camera = []
for camera in cameras:
for file in files:
if f"-{camera}-" in file:
files_camera.append(file)
return files_camera
[docs]
async def move_mask_interval(
gort: Gort,
positions: str | list[str] = "P1-*",
order_by_steps: bool = False,
total_time: float | None = None,
time_per_position: float | None = None,
notifier: Callable[[str], None] | Callable[[str], Coroutine] | None = None,
):
"""Moves the fibre mask in the spectrophotometric telescope at intervals.
Parameters
----------
gort
The instance of :obj:`.Gort` to communicate with the actor system.
positions
The positions to iterate over. It can be a string in which case it will
be treated as a regular expression and any mask position that matches the
value will be iterated, in alphabetic order. Alternative it can be a list
of positions to move to which will be executed in that order.
order_by_steps
If :obj:`True`, the positions are iterated in order of smaller to larger
number of step motors.
total_time
The total time to spend iterating over positions, in seconds. Each position
will be visited for an equal amount of time. The time required to move the
mask will not be taken into account, which means the total execution
time will be longer than ``total_time``.
time_per_position
The time to spend on each mask position, in seconds. The total execution
time will be ``len(positions)*total_time+overhead`` where ``overhead`` is
the time required to move the mask between positions.
notifier
A function or coroutine to call every time a new position is reached.
If it's a coroutine, it is scheduled as a task. If it is a normal
callback it should run quickly to not perceptibly affect the total
execution time.
"""
try:
fibsel: FibSel = gort.telescopes.spec.fibsel
except Exception as err:
raise RuntimeError(f"Cannot find fibre selector: {err}")
if total_time is not None and time_per_position is not None:
raise ValueError("Only one of total_time or time_per_position can be used.")
if total_time is None and time_per_position is None:
raise ValueError("One of total_time or time_per_position needs to be passed.")
mask_config = gort.config["telescopes"]["mask_positions"]
all_positions = list(mask_config)
if isinstance(positions, str):
regex = positions
all_positions = fibsel.list_positions()
positions = [pos for pos in all_positions if re.match(regex, pos)]
if order_by_steps:
positions = sorted(positions, key=lambda p: mask_config[p])
fibsel.write_to_log(f"Iterating over positions {positions}.")
if total_time:
time_per_position = total_time / len(positions)
assert time_per_position is not None
for position in positions:
await fibsel.move_to_position(position)
# Notify.
if notifier is not None:
if asyncio.iscoroutinefunction(notifier):
asyncio.create_task(notifier(position))
else:
notifier(position)
await asyncio.sleep(time_per_position)
[docs]
def angular_separation(lon1: float, lat1: float, lon2: float, lat2: float):
"""A wrapper around astropy's ``angular_separation``.
Returns the separation between two sets of coordinates. All units must
be degrees and the returned values is also the separation in degrees.
"""
separation = astropy_angular_separation(
lon1 * uu.degree,
lat1 * uu.degree,
lon2 * uu.degree,
lat2 * uu.degree,
)
return separation.to("deg").value
[docs]
def get_db_connection():
"""Returns a DB connection from the configuration file parameters."""
conn = peewee.PostgresqlDatabase(**config["services"]["database"]["connection"])
assert conn.connect(), "Database connection failed."
return conn
[docs]
@asynccontextmanager
async def redis_client_async() -> AsyncGenerator[aioredis.Redis, None]:
"""Returns a Redis connection from the configuration file parameters."""
redis_config = config["services"]["redis"]
client = aioredis.from_url(redis_config["url"], decode_responses=True)
yield client
await client.aclose()
[docs]
@contextmanager
def redis_client_sync() -> Generator[redis.Redis, None]:
"""Returns a Redis connection from the configuration file parameters."""
redis_config = config["services"]["redis"]
client = redis.from_url(redis_config["url"], decode_responses=True)
yield client
client.close()
[docs]
async def run_in_executor(fn, *args, catch_warnings=False, executor="thread", **kwargs):
"""Runs a function in an executor.
In addition to streamlining the use of the executor, this function
catches any warning issued during the execution and reissues them
after the executor is done. This is important when using the
actor log handler since inside the executor there is no loop that
CLU can use to output the warnings.
In general, note that the function must not try to do anything with
the actor since they run on different loops.
"""
fn = partial(fn, *args, **kwargs)
if executor == "thread":
executor = concurrent.futures.ThreadPoolExecutor
elif executor == "process":
executor = concurrent.futures.ProcessPoolExecutor
else:
raise ValueError("Invalid executor name.")
if catch_warnings:
with warnings.catch_warnings(record=True) as records:
with executor() as pool:
result = await asyncio.get_event_loop().run_in_executor(pool, fn)
for ww in records:
warnings.warn(ww.message, ww.category)
else:
with executor() as pool:
result = await asyncio.get_running_loop().run_in_executor(pool, fn)
return result
[docs]
async def cancel_task(task: asyncio.Future | None):
"""Safely cancels a task."""
if task is None or task.done():
return
task.cancel()
with suppress(asyncio.CancelledError):
await task
[docs]
def get_temporary_file_path(*args, create_parents: bool = False, **kwargs):
"""Returns a valid path to a temporary file.
`args` and `kwargs` are directly passed to `tempfile.NamedTemporaryFile`.
If `create_parents`, the parent directories are created if they don't
exist.
"""
tmp_log_file = tempfile.NamedTemporaryFile(*args, **kwargs)
tmp_log_file.close()
tmp_path = pathlib.Path(tmp_log_file.name)
if tmp_path.exists():
tmp_path.unlink()
elif create_parents:
tmp_path.parent.mkdir(parents=True, exist_ok=True)
return tmp_path
[docs]
def insert_to_database(
table_name: str,
payload: Sequence[dict[str, Any]],
columns: Sequence[str] | None = None,
):
"""Inserts data into the database.
Parameters
----------
table_name
The table in the database where to insert the data. Can be in the format
``schema.table_name``.
payload
The data to ingest, as a list of dictionaries in which each dictionary
is a mapping of column name in ``table`` to the value to ingest.
columns
A list of table columns. If not passed, the column names are inferred from
the first element in the payload. In this case you must ensure that all the
elements in the payload contain entries for all the columns (use :obj:`None`
to fill missing data).
"""
if len(payload) == 0:
return
columns = columns or list(payload[0].keys())
conn = get_db_connection()
schema: str | None
if "." in table_name:
schema, table_name = table_name.split(".")
else:
schema = None
table_name = table_name
table = peewee.Table(table_name, schema=schema, columns=columns)
table.bind(conn)
table.insert(payload).execute()
[docs]
def get_md5sum_file(file: AnyPath):
"""Returns the path to the MD5 file for the spectro files."""
file = pathlib.Path(file).absolute()
mjd = file.parts[-2]
md5sum = file.parent / f"{mjd}.md5sum"
return md5sum if md5sum.exists() else None
[docs]
def get_md5sum_from_spectro(file: AnyPath):
"""Returns the MD5 checksum for a file from the spectro checksum file."""
file = pathlib.Path(file).absolute()
basename = file.name
md5sum_file = get_md5sum_file(file)
if not md5sum_file:
return None
data = open(md5sum_file).read()
match = re.search(rf"([0-9a-f]+)\s+{basename}", data)
if not match:
return None
return match.groups(1)[0]
[docs]
def get_md5sum(file: AnyPath):
"""Returns the MD5 checksum for a file."""
data = open(file, "rb").read()
return hashlib.md5(data).hexdigest()
[docs]
def handle_signals(
signals: Sequence[int],
callback: Callable[[], Any] | None = None,
cancel: bool = True,
):
"""Runs a callback when a signal is received during the execution of a task.
This function is meant to decorate coroutines. If a signal matching ``signals``
is received, the callback is run and the coroutine (which is executed as a task)
is cancelled if ``cancel=True``.
"""
def _handle_signal(task: asyncio.Task):
if cancel:
task.cancel()
if callback:
asyncio.get_running_loop().call_soon(callback)
def _outter_wrapper(coro_func):
@functools.wraps(coro_func)
async def _inner_wrapper(*args, **kwargs):
task = asyncio.create_task(coro_func(*args, **kwargs))
handler = partial(_handle_signal, task)
for sgn in signals:
asyncio.get_running_loop().add_signal_handler(sgn, handler)
try:
try:
return await task
except asyncio.CancelledError:
raise KeyboardInterrupt()
finally:
for sgn in signals:
asyncio.get_running_loop().remove_signal_handler(sgn)
return _inner_wrapper
return _outter_wrapper
[docs]
def get_by_source_id(source_id: int) -> dict | None:
"""Returns Gaia DR3 information for a source ID."""
db = get_db_connection()
gaia_dr3 = peewee.Table("gaia_dr3_source", schema="catalogdb").bind(db)
query = gaia_dr3.select(gaia_dr3.star).where(gaia_dr3.c.source_id == source_id)
data = query.dicts()
if len(data) == 0:
return None
return data[0]
[docs]
async def get_ephemeris_summary(sjd: int | None = None) -> dict:
"""Returns the ephemeris summary from ``lvmapi``."""
host = config["services"]["lvmapi"]["host"]
port = config["services"]["lvmapi"]["port"]
url = f"http://{host}:{port}/ephemeris/summary"
sjd = sjd or get_sjd("LCO")
async with httpx.AsyncClient() as client:
resp = await client.get(url, params={"sjd": sjd})
if resp.status_code != 200:
raise httpx.RequestError("Failed request to /ephemeris/summary")
return resp.json()
[docs]
def get_ephemeris_summary_sync(sjd: int | None = None) -> dict:
"""Returns the ephemeris summary from ``lvmapi``."""
host = config["services"]["lvmapi"]["host"]
port = config["services"]["lvmapi"]["port"]
url = f"http://{host}:{port}/ephemeris/summary"
sjd = sjd or get_sjd("LCO")
with httpx.Client() as client:
resp = client.get(url, params={"sjd": sjd})
if resp.status_code != 200:
raise httpx.RequestError("Failed request to /ephemeris/summary")
return resp.json()
[docs]
async def get_lvmapi_route(route: str, params: dict = {}, timeout: float = 5, **kwargs):
"""Gets an ``lvmapi`` route."""
params.update(kwargs)
host, port = config["services"]["lvmapi"].values()
if "/" in host:
host, base_route = host.split("/")
if base_route != "":
base_route += "/"
else:
base_route = ""
if route.startswith("/"):
route = route[1:]
if "?" in route:
raise ValueError("Query parameters should be passed as a dictionary.")
async with httpx.AsyncClient(
base_url=f"http://{host}:{port}/{base_route}",
follow_redirects=True,
timeout=timeout,
) as client:
response = await client.get(route, params=params, timeout=timeout)
if (code := response.status_code) != 200:
raise ValueError(f"Route /{route} failed with error {code}.")
return response.json()
[docs]
async def run_lvmapi_task(
route: str,
params: dict = {},
wait: bool = True,
check_interval: float = 3,
timeout: float | None = None,
**kwargs,
):
"""Runs an LVM API task route.
Task routes immediately start the task in the background and return a task ID.
The API can then be queried until the task is completed.
Parameters
----------
route
The route to the task.
params
The parameters to pass to the route.
wait
If :obj:`True`, waits until the task is completed and returns the result.
check_interval
The interval at which to check the task status. Ignored if ``wait=False``.
timeout
The timeout for the task execution. Ignored if ``wait=False``.
Returns
-------
result
The task ID if ``wait=False`` or the result of the task if ``wait=True``.
"""
task_id = await get_lvmapi_route(route, params=params, **kwargs)
if not wait:
return task_id
elapsed: float = 0
while True:
await asyncio.sleep(check_interval)
elapsed += check_interval
ready = await get_lvmapi_route(f"/tasks/{task_id}/ready", timeout=10)
if ready is True:
return await get_lvmapi_route(f"/tasks/{task_id}/result", timeout=10)
if timeout and elapsed > timeout:
raise asyncio.TimeoutError(f"Task {task_id} did not complete in time.")
[docs]
class GuiderMonitor:
"""A tool to monitor guider outputs and store them in a dataframe."""
def __init__(self, gort: Gort, actor: str | None = None):
self.gort = gort
self.actor = actor
self._data: dict[tuple[int, str], dict[str, Any]] = {}
self._schema = {
"frameno": polars.Int32(),
"telescope": polars.String(),
"time": polars.Datetime(time_zone="UTC"),
"n_sources": polars.Int32(),
"focus_position": polars.Float32(),
"fwhm": polars.Float32(),
"ra": polars.Float64(),
"dec": polars.Float64(),
"ra_offset": polars.Float32(),
"dec_offset": polars.Float32(),
"separation": polars.Float32(),
"pa": polars.Float32(),
"pa_offset": polars.Float32(),
"zero_point": polars.Float32(),
"mode": polars.String(),
"ax0_applied": polars.Float32(),
"ax1_applied": polars.Float32(),
"rot_applied": polars.Float32(),
}
[docs]
def reset(self):
"""Resets the internal state."""
self.__init__(self.gort, self.actor)
[docs]
def start_monitoring(self):
"""Starts monitoring the guider outputs."""
if self._handle_guider_reply not in self.gort._callbacks:
self.gort.add_reply_callback(self._handle_guider_reply)
self.reset()
[docs]
def stop_monitoring(self):
"""Stops monitoring the guider outputs."""
if self._handle_guider_reply in self.gort._callbacks:
self.gort.remove_reply_callback(self._handle_guider_reply)
def __del__(self):
self.stop_monitoring()
[docs]
def get_dataframe(self) -> polars.DataFrame | None:
"""Returns the collected data as a dataframe."""
if len(self._data) == 0:
return
values = self._data.values()
dfs = [polars.DataFrame([dv], schema=self._schema) for dv in values]
return polars.concat(dfs).sort(["frameno", "telescope"])
async def _handle_guider_reply(self, reply: AMQPReply):
"""Processes an actor reply and stores the collected data."""
if self.actor is not None:
if self.actor not in str(reply.sender):
return
else:
if ".guider" not in str(reply.sender):
return
body = reply.body
telescope = str(reply.sender).split(".")[1]
frameno: int | None = None
new_data: dict[str, Any] = {}
try:
if "frame" in body:
frame = body["frame"]
frameno = frame["seqno"]
new_data = {
"frameno": frameno,
"telescope": telescope,
"time": datetime.datetime.now(datetime.timezone.utc),
"n_sources": frame["n_sources"],
"focus_position": frame["focus_position"],
"fwhm": frame["fwhm"],
}
elif "measured_pointing" in body:
measured_pointing = body["measured_pointing"]
frameno = measured_pointing["frameno"]
new_data = {
"frameno": frameno,
"telescope": telescope,
"ra": measured_pointing["ra"],
"dec": measured_pointing["dec"],
"ra_offset": measured_pointing["radec_offset"][0],
"dec_offset": measured_pointing["radec_offset"][1],
"separation": measured_pointing["separation"],
"pa": measured_pointing.get("pa", numpy.nan),
"pa_offset": measured_pointing.get("pa_offset", numpy.nan),
"zero_point": measured_pointing.get("zero_point", numpy.nan),
"mode": measured_pointing["mode"],
}
elif "correction_applied" in body:
correction_applied = body["correction_applied"]
frameno = correction_applied["frameno"]
new_data = {
"frameno": frameno,
"telescope": telescope,
"ax0_applied": correction_applied["motax_applied"][0],
"ax1_applied": correction_applied["motax_applied"][1],
"rot_applied": correction_applied.get("rot_applied", 0.0),
}
else:
return
if not isinstance(frameno, int):
return
key = (frameno, telescope)
if key in self._data:
self._data[key].update(new_data)
else:
self._data[key] = new_data
except Exception as err:
self.gort.log.warning(f"Error processing guider reply: {err}")
[docs]
async def is_actor_running(actor: str, client: AMQPClient | None = None):
"""Returns :obj:`True` if the actor is running.
This assumes that the actor implements a `ping` command.
"""
close_client: bool = False
if not client:
client = await AMQPClient().start()
close_client = True
# TODO: this can fail if the event loop has closed the client connection.
ping = await client.send_command(actor, "ping")
if close_client:
await client.stop()
if ping.status.did_succeed:
return True
return False
[docs]
async def overwatcher_is_running(client: AMQPClient | None = None):
"""Returns :obj:`True` if the overwatcher is running."""
overwatcher_actor_name = config["overwatcher"]["actor"]["name"]
return await is_actor_running(overwatcher_actor_name, client=client)
async def overwatcher_is_enabled(client: AMQPClient | None = None):
"""Returns :obj:`True` if the overwatcher is enabled."""
close_client: bool = False
if not client:
client = await AMQPClient().start()
close_client = True
if not (await overwatcher_is_running(client)):
return False
overwatcher_actor_name = config["overwatcher"]["actor"]["name"]
status = await client.send_command(overwatcher_actor_name, "status")
if close_client:
await client.stop()
if status.status.did_succeed:
return status.replies.get("status")["enabled"]
raise GortError("Cannot determine if the Overwatcher is running.")
[docs]
def check_overwatcher_not_running(coro):
"""Decorator that fails a coroutine if the overwatcher is running."""
@wraps(coro)
async def wrapper(*args, **kwargs):
if await overwatcher_is_running():
raise GortError(
f"Overwatcher is running. Cannot execute {coro.__name__}.",
ErrorCode.OVERATCHER_RUNNING,
)
return await coro(*args, **kwargs)
return wrapper
[docs]
async def kubernetes_list_deployments():
"""Retrieves the Kubernetes deployments from the API."""
return await get_lvmapi_route("/kubernetes/deployments/list")
[docs]
async def kubernetes_restart_deployment(name: str):
"""Restarts a Kubernetes deployment via the API."""
return await get_lvmapi_route(f"/kubernetes/deployments/{name}/restart")
[docs]
@asynccontextmanager
async def get_gort_client(override_overwatcher: bool | None = None):
"""Returns a GORT client."""
from gort import Gort
override_envvar = os.environ.get("GORT_OVERRIDE_OVERWATCHER", "0")
if override_overwatcher is None:
override_overwatcher = False if override_envvar == "0" else True
gort = await Gort(
override_overwatcher=override_overwatcher,
verbosity="debug",
).init()
yield gort
await gort.stop()
NightLogCategories = Literal["weather", "issues", "other", "observers", "overwatcher"]
[docs]
async def async_noop(*args, **kwargs):
"""A no-op coroutine."""
return None
[docs]
class LogNamespace:
"""A namespace for log messages."""
def __init__(self, logger: logging.Logger, header: str = "") -> None:
self.logger = logger
self.header = header
def debug(self, message: str, *args, **kwargs):
self.logger.log(logging.DEBUG, self._get_message(message), *args, **kwargs)
def info(self, message: str, *args, **kwargs):
self.logger.log(logging.INFO, self._get_message(message), *args, **kwargs)
def warning(self, message: str, *args, **kwargs):
self.logger.log(logging.WARNING, self._get_message(message), *args, **kwargs)
def error(self, message: str, *args, **kwargs):
self.logger.log(logging.ERROR, self._get_message(message), *args, **kwargs)
def critical(self, message: str, *args, **kwargs):
self.logger.log(logging.CRITICAL, self._get_message(message), *args, **kwargs)
def exception(self, *args, **kwargs):
self.logger.exception(*args, **kwargs)
def _get_message(self, message: str):
return self.header.format(**locals()) + message
[docs]
def decap(string: Any, add_period: bool = False) -> str:
"""Decapitalises the first letter of a string."""
if not isinstance(string, str):
string = str(string)
if len(string) == 0:
return ""
first_word = string.split(" ")[0]
if len(first_word) > 1 and first_word.isupper():
# Do not decap acronyms.
decapped = string
else:
decapped = str(string[0].lower() + string[1:])
if not add_period:
return decapped
return ensure_period(decapped)
[docs]
def ensure_period(string: str):
"""Ensures that a string ends with a period."""
if not string.strip().endswith("."):
string += "."
return string
T = TypeVar("T", bound=Any)
[docs]
async def try_or_pass(coro: Awaitable[T]) -> T | None:
"""Decorator that wraps a coroutine in a try/except block."""
try:
return await coro
except Exception:
pass
[docs]
class OverheadDict(TypedDict):
"""TypedDict for recording overhead information."""
observer_id: int | None
tile_id: int | None
dither_position: int | None
stage: str | None
start_time: float
end_time: float
duration: float
[docs]
def record_overheads(payload: Sequence[OverheadDict] | OverheadDict):
"""Decorator that records the overhead of a function."""
if not isinstance(payload, Sequence):
payload = [payload]
payload_dict = cast(Sequence[dict[str, Any]], payload)
table_name = config["services.database.tables.overheads"]
insert_to_database(table_name, payload_dict)
[docs]
async def ping_host(host: str):
"""Pings a host."""
cmd = await asyncio.create_subprocess_exec(
*["ping", "-c", "1", "-W", "5", host],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
return_code = await cmd.wait()
return return_code == 0
@overload
async def get_exposure_list(
mjd: int | None = None,
as_dataframe: bool = True,
) -> polars.DataFrame: ...
@overload
async def get_exposure_list(
mjd: int | None = None,
as_dataframe: bool = False,
) -> dict: ...
[docs]
async def get_exposure_list(
mjd: int | None = None,
as_dataframe: bool = True,
) -> dict | polars.DataFrame:
"""Retrieves the list of exposures for a given MJD from the LVM API."""
mjd = mjd or get_sjd("LCO")
data: dict = await get_lvmapi_route(f"/logs/exposures/data/{mjd}", timeout=15)
if not as_dataframe:
return data
data_df: list[dict[str, Any]] = []
for exposure in data.values():
lamps = exposure.pop("lamps")
data_df.append({**exposure, **lamps})
schema = {
"exposure_no": polars.Int32,
"mjd": polars.Int32,
"obstime": polars.String,
"image_type": polars.String,
"exposure_time": polars.Float32,
"ra": polars.Float64,
"dec": polars.Float64,
"airmass": polars.Float32,
"n_standards": polars.Int16,
"n_cameras": polars.Int16,
"object": polars.String,
"dpos": polars.Int16,
"Argon": polars.Boolean,
"Neon": polars.Boolean,
"LDLS": polars.Boolean,
"Quartz": polars.Boolean,
"HgNe": polars.Boolean,
"Xenon": polars.Boolean,
}
if len(data_df) == 0:
return polars.DataFrame(schema=schema)
return polars.DataFrame(data_df, schema=schema)