Source code for ray.tune.web_server

import json
import logging
import threading
from typing import Tuple, List, TYPE_CHECKING

from urllib.parse import urljoin, urlparse
from http.server import SimpleHTTPRequestHandler, HTTPServer

import ray.cloudpickle as cloudpickle
from ray.tune import TuneError
from ray.tune.search import BasicVariantGenerator
from ray._private.utils import binary_to_hex, hex_to_binary
from ray.util.annotations import DeveloperAPI

if TYPE_CHECKING:
    from ray.tune.execution.trial_runner import TrialRunner

logger = logging.getLogger(__name__)

try:
    import requests  # `requests` is not part of stdlib.
except ImportError:
    requests = None
    logger.exception(
        "Couldn't import `requests` library. "
        "Be sure to install it on the client side."
    )


[docs]@DeveloperAPI class TuneClient: """Client to interact with an ongoing Tune experiment. Requires a TuneServer to have started running. Attributes: tune_address: Address of running TuneServer port_forward: Port number of running TuneServer """ def __init__(self, tune_address: str, port_forward: int): self._tune_address = tune_address self._port_forward = port_forward self._path = "http://{}:{}".format(tune_address, port_forward)
[docs] def get_all_trials(self, timeout=None): """Returns a list of all trials' information.""" response = requests.get(urljoin(self._path, "trials"), timeout=timeout) return self._deserialize(response)
[docs] def get_trial(self, trial_id, timeout=None): """Returns trial information by trial_id.""" response = requests.get( urljoin(self._path, "trials/{}".format(trial_id)), timeout=timeout ) return self._deserialize(response)
[docs] def add_trial(self, name, specification): """Adds a trial by name and specification (dict).""" payload = {"name": name, "spec": specification} response = requests.post(urljoin(self._path, "trials"), json=payload) return self._deserialize(response)
[docs] def stop_trial(self, trial_id): """Requests to stop trial by trial_id.""" response = requests.put(urljoin(self._path, "trials/{}".format(trial_id))) return self._deserialize(response)
[docs] def stop_experiment(self): """Requests to stop the entire experiment.""" response = requests.put(urljoin(self._path, "stop_experiment")) return self._deserialize(response)
@property def server_address(self): return self._tune_address @property def server_port(self): return self._port_forward def _load_trial_info(self, trial_info): trial_info["config"] = cloudpickle.loads(hex_to_binary(trial_info["config"])) trial_info["result"] = cloudpickle.loads(hex_to_binary(trial_info["result"])) def _deserialize(self, response): parsed = response.json() if "trial" in parsed: self._load_trial_info(parsed["trial"]) elif "trials" in parsed: for trial_info in parsed["trials"]: self._load_trial_info(trial_info) return parsed
@DeveloperAPI def RunnerHandler(runner): class Handler(SimpleHTTPRequestHandler): """A Handler is a custom handler for TuneServer. Handles all requests and responses coming into and from the TuneServer. """ def _do_header(self, response_code: int = 200, headers: List[Tuple] = None): """Sends the header portion of the HTTP response. Parameters: response_code: Standard HTTP response code headers: Standard HTTP response headers """ if headers is None: headers = [("Content-type", "application/json")] self.send_response(response_code) for key, value in headers: self.send_header(key, value) self.end_headers() def do_HEAD(self): """HTTP HEAD handler method.""" self._do_header() def do_GET(self): """HTTP GET handler method.""" response_code = 200 message = "" try: result = self._get_trial_by_url(self.path) resource = {} if result: if isinstance(result, list): infos = [self._trial_info(t) for t in result] resource["trials"] = infos else: resource["trial"] = self._trial_info(result) message = json.dumps(resource) except TuneError as e: response_code = 404 message = str(e) self._do_header(response_code=response_code) self.wfile.write(message.encode()) def do_PUT(self): """HTTP PUT handler method.""" response_code = 200 message = "" try: resource = {} if self.path.endswith("stop_experiment"): runner.request_stop_experiment() trials = list(runner.get_trials()) else: trials = self._get_trial_by_url(self.path) if trials: if not isinstance(trials, list): trials = [trials] for t in trials: runner.request_stop_trial(t) resource["trials"] = [self._trial_info(t) for t in trials] message = json.dumps(resource) except TuneError as e: response_code = 404 message = str(e) self._do_header(response_code=response_code) self.wfile.write(message.encode()) def do_POST(self): """HTTP POST handler method.""" response_code = 201 content_len = int(self.headers.get("Content-Length"), 0) raw_body = self.rfile.read(content_len) parsed_input = json.loads(raw_body.decode()) resource = self._add_trials(parsed_input["name"], parsed_input["spec"]) headers = [("Content-type", "application/json"), ("Location", "/trials/")] self._do_header(response_code=response_code, headers=headers) self.wfile.write(json.dumps(resource).encode()) def _trial_info(self, trial): """Returns trial information as JSON.""" if trial.last_result: result = trial.last_result.copy() else: result = None info_dict = { "id": trial.trial_id, "trainable_name": trial.trainable_name, "config": binary_to_hex(cloudpickle.dumps(trial.config)), "status": trial.status, "result": binary_to_hex(cloudpickle.dumps(result)), } return info_dict def _get_trial_by_url(self, url): """Parses url to get either all trials or trial by trial_id.""" parts = urlparse(url) path = parts.path if path == "/trials": return list(runner.get_trials()) else: trial_id = path.split("/")[-1] return runner.get_trial(trial_id) def _add_trials(self, name, spec): """Add trial by invoking TrialRunner.""" resource = {} resource["trials"] = [] trial_generator = BasicVariantGenerator() trial_generator.add_configurations({name: spec}) while not trial_generator.is_finished(): trial = trial_generator.next_trial() if not trial: break runner.add_trial(trial) resource["trials"].append(self._trial_info(trial)) return resource return Handler @DeveloperAPI class TuneServer(threading.Thread): """A TuneServer is a thread that initializes and runs a HTTPServer. The server handles requests from a TuneClient. Attributes: runner: Runner that modifies and accesses trials. port_forward: Port number of TuneServer. """ DEFAULT_PORT = 4321 def __init__(self, runner: "TrialRunner", port: int = None): """Initialize HTTPServer and serve forever by invoking self.run()""" threading.Thread.__init__(self) self._port = port if port else self.DEFAULT_PORT address = ("localhost", self._port) logger.info("Starting Tune Server...") self._server = HTTPServer(address, RunnerHandler(runner)) self.daemon = True self.start() def run(self): self._server.serve_forever() def shutdown(self): """Shutdown the underlying server.""" self._server.shutdown()