import json
import logging
import threading
from urllib.parse import urljoin, urlparse
from http.server import SimpleHTTPRequestHandler, HTTPServer
import ray.cloudpickle as cloudpickle
from ray.tune import TuneError
from ray.tune.suggest import BasicVariantGenerator
from ray.utils import binary_to_hex, hex_to_binary
logger = logging.getLogger(__name__)
try:
import requests # `requests` is not part of stdlib.
except ImportError:
requests = None
logger.exception("Couldn't import `requests` library. "
"Be sure to install it on the client side.")
[docs]class TuneClient:
"""Client to interact with an ongoing Tune experiment.
Requires a TuneServer to have started running.
Attributes:
tune_address (str): Address of running TuneServer
port_forward (int): Port number of running TuneServer
"""
def __init__(self, tune_address, port_forward):
self._tune_address = tune_address
self._port_forward = port_forward
self._path = "http://{}:{}".format(tune_address, port_forward)
[docs] def get_all_trials(self, timeout=None):
"""Returns a list of all trials' information."""
response = requests.get(urljoin(self._path, "trials"), timeout=timeout)
return self._deserialize(response)
[docs] def get_trial(self, trial_id, timeout=None):
"""Returns trial information by trial_id."""
response = requests.get(
urljoin(self._path, "trials/{}".format(trial_id)), timeout=timeout)
return self._deserialize(response)
[docs] def add_trial(self, name, specification):
"""Adds a trial by name and specification (dict)."""
payload = {"name": name, "spec": specification}
response = requests.post(urljoin(self._path, "trials"), json=payload)
return self._deserialize(response)
[docs] def stop_trial(self, trial_id):
"""Requests to stop trial by trial_id."""
response = requests.put(
urljoin(self._path, "trials/{}".format(trial_id)))
return self._deserialize(response)
[docs] def stop_experiment(self):
"""Requests to stop the entire experiment."""
response = requests.put(urljoin(self._path, "stop_experiment"))
return self._deserialize(response)
@property
def server_address(self):
return self._tune_address
@property
def server_port(self):
return self._port_forward
def _load_trial_info(self, trial_info):
trial_info["config"] = cloudpickle.loads(
hex_to_binary(trial_info["config"]))
trial_info["result"] = cloudpickle.loads(
hex_to_binary(trial_info["result"]))
def _deserialize(self, response):
parsed = response.json()
if "trial" in parsed:
self._load_trial_info(parsed["trial"])
elif "trials" in parsed:
for trial_info in parsed["trials"]:
self._load_trial_info(trial_info)
return parsed
def RunnerHandler(runner):
class Handler(SimpleHTTPRequestHandler):
"""A Handler is a custom handler for TuneServer.
Handles all requests and responses coming into and from
the TuneServer.
"""
def _do_header(self, response_code=200, headers=None):
"""Sends the header portion of the HTTP response.
Parameters:
response_code (int): Standard HTTP response code
headers (list[tuples]): Standard HTTP response headers
"""
if headers is None:
headers = [("Content-type", "application/json")]
self.send_response(response_code)
for key, value in headers:
self.send_header(key, value)
self.end_headers()
def do_HEAD(self):
"""HTTP HEAD handler method."""
self._do_header()
def do_GET(self):
"""HTTP GET handler method."""
response_code = 200
message = ""
try:
result = self._get_trial_by_url(self.path)
resource = {}
if result:
if isinstance(result, list):
infos = [self._trial_info(t) for t in result]
resource["trials"] = infos
else:
resource["trial"] = self._trial_info(result)
message = json.dumps(resource)
except TuneError as e:
response_code = 404
message = str(e)
self._do_header(response_code=response_code)
self.wfile.write(message.encode())
def do_PUT(self):
"""HTTP PUT handler method."""
response_code = 200
message = ""
try:
resource = {}
if self.path.endswith("stop_experiment"):
runner.request_stop_experiment()
trials = list(runner.get_trials())
else:
trials = self._get_trial_by_url(self.path)
if trials:
if not isinstance(trials, list):
trials = [trials]
for t in trials:
runner.request_stop_trial(t)
resource["trials"] = [self._trial_info(t) for t in trials]
message = json.dumps(resource)
except TuneError as e:
response_code = 404
message = str(e)
self._do_header(response_code=response_code)
self.wfile.write(message.encode())
def do_POST(self):
"""HTTP POST handler method."""
response_code = 201
content_len = int(self.headers.get("Content-Length"), 0)
raw_body = self.rfile.read(content_len)
parsed_input = json.loads(raw_body.decode())
resource = self._add_trials(parsed_input["name"],
parsed_input["spec"])
headers = [("Content-type", "application/json"), ("Location",
"/trials/")]
self._do_header(response_code=response_code, headers=headers)
self.wfile.write(json.dumps(resource).encode())
def _trial_info(self, trial):
"""Returns trial information as JSON."""
if trial.last_result:
result = trial.last_result.copy()
else:
result = None
info_dict = {
"id": trial.trial_id,
"trainable_name": trial.trainable_name,
"config": binary_to_hex(cloudpickle.dumps(trial.config)),
"status": trial.status,
"result": binary_to_hex(cloudpickle.dumps(result))
}
return info_dict
def _get_trial_by_url(self, url):
"""Parses url to get either all trials or trial by trial_id."""
parts = urlparse(url)
path = parts.path
if path == "/trials":
return list(runner.get_trials())
else:
trial_id = path.split("/")[-1]
return runner.get_trial(trial_id)
def _add_trials(self, name, spec):
"""Add trial by invoking TrialRunner."""
resource = {}
resource["trials"] = []
trial_generator = BasicVariantGenerator()
trial_generator.add_configurations({name: spec})
while not trial_generator.is_finished():
trial = trial_generator.next_trial()
if not trial:
break
runner.add_trial(trial)
resource["trials"].append(self._trial_info(trial))
return resource
return Handler
class TuneServer(threading.Thread):
"""A TuneServer is a thread that initializes and runs a HTTPServer.
The server handles requests from a TuneClient.
Attributes:
runner (TrialRunner): Runner that modifies and accesses trials.
port_forward (int): Port number of TuneServer.
"""
DEFAULT_PORT = 4321
def __init__(self, runner, port=None):
"""Initialize HTTPServer and serve forever by invoking self.run()"""
threading.Thread.__init__(self)
self._port = port if port else self.DEFAULT_PORT
address = ("localhost", self._port)
logger.info("Starting Tune Server...")
self._server = HTTPServer(address, RunnerHandler(runner))
self.daemon = True
self.start()
def run(self):
self._server.serve_forever()
def shutdown(self):
"""Shutdown the underlying server."""
self._server.shutdown()