import logging
from typing import Dict, List
import requests
from requests_toolbelt import sessions
from galahad.server import server
from galahad.server.dataclasses import ClassifierInfo, Document
logger = logging.getLogger("galahad.client")
[docs]class HTTPError(Exception):
def __init__(self, response: requests.Response):
error_msg = ""
if isinstance(response.reason, bytes):
# We attempt to decode utf-8 first because some servers
# choose to localize their reason strings. If the string
# isn't utf-8, we fall back to iso-8859-1 for all other
# encodings. (See PR #3538)
try:
response_reason = response.reason
reason = response_reason.decode("utf-8")
except UnicodeDecodeError:
reason = response.reason.decode("iso-8859-1")
else:
reason = response.reason
status_code = response.status_code
if 400 <= status_code < 500:
error_msg = f"{status_code} Client Error: {reason} for url: {response.url}"
elif 500 <= status_code < 600:
error_msg = f"{status_code} Server Error: {reason} for url: {response.url}"
body = response.content
if len(body):
error_msg += "\n"
error_msg += body.decode("utf-8")
super().__init__(error_msg)
[docs]def create_error_message(**kwargs) -> str:
variables = kwargs
if len(variables.keys()) == 0:
return ""
else:
error_message = ["The problem appeared with the variables "]
for variable_name in variables.keys():
# variable_list.append(variable_name)
# variable_list.append(variables[variable_name])
error_message.append(f'"{variable_name}"')
error_message.append(":")
error_message.append(f'"{variables[variable_name]}"')
error_message.append("and")
del error_message[-1]
return " ".join(error_message)
[docs]def check_response(response: requests.Response):
try:
response.raise_for_status()
except requests.exceptions.HTTPError as e:
raise HTTPError(response) from e
# TODO: would it be better for the performance if we checked the naming before sending the request to the server?
[docs]def check_naming_is_ok(given_status: int, **kwargs):
if given_status == 422:
raise ValueError(
f"Naming for one of the variables is invalid. "
f"Please look at the documentation for correct naming. {create_error_message(**kwargs)}"
)
[docs]class GalahadClient:
def __init__(self, endpoint_url: str):
self.endpoint_url = endpoint_url.rstrip("/")
self._session = self._build_session()
[docs] def start_session(self) -> requests.Session:
self._session = self._build_session()
return self._session
def _build_session(self) -> requests.Session:
session = sessions.BaseUrlSession(self.endpoint_url)
return session
[docs] def is_connected(self) -> bool:
response = self._session.get("/ping")
if response.status_code != 200:
logger.info("StatusCodeError")
return False
if response.json() != {"ping": "pong"}:
logger.info("ResponseError")
return False
return True
# output is sorted by dataset name
[docs] def list_datasets(self) -> List[str]:
response = self._session.get("/dataset")
check_response(response)
return response.json()["names"]
[docs] def contains_dataset(self, dataset_id: str) -> bool:
server.check_naming_is_ok_regex(dataset_id)
return dataset_id in self.list_datasets()
[docs] def create_dataset(self, dataset_id: str):
response = self._session.put(f"/dataset/{dataset_id}", {})
check_naming_is_ok(response.status_code, dataset_id=dataset_id)
if response.status_code == 409:
logger.info(f'Dataset with id "{dataset_id}" already exists')
return None
check_response(response)
[docs] def delete_dataset(self, dataset_id: str):
response = self._session.delete(f"/dataset/{dataset_id}")
check_naming_is_ok(response.status_code, dataset_id=dataset_id)
if response.status_code == 404:
logger.info(f'Dataset with id "{dataset_id}" does not exist')
return None
check_response(response)
[docs] def delete_datasets(self, dataset_ids: List[str]):
for dataset_id in dataset_ids:
self.delete_dataset(dataset_id)
[docs] def delete_all_datasets(self):
self.delete_datasets(self.list_datasets())
# The new document of the same name will override an existing one!
[docs] def create_document_in_dataset(
self, dataset_id: str, document_id: str, document: Document, auto_create_dataset=False
):
response = self._session.put(f"/dataset/{dataset_id}/{document_id}", json=document.dict())
check_naming_is_ok(response.status_code, dataset_id=dataset_id, document_id=document_id)
if response.status_code == 404:
if auto_create_dataset:
self.create_dataset(dataset_id)
response = self._session.put(f"/dataset/{dataset_id}/{document_id}", json=document.dict())
else:
raise ValueError(
f'The dataset for the given id: "{dataset_id}" does not exist. To create it, '
'set the optional parameter "auto_create_dataset" to True'
)
check_response(response)
# result is sorted by doc id
[docs] def list_documents_in_dataset(self, dataset_id) -> Dict[str, int]:
response = self._session.get(f"/dataset/{dataset_id}")
check_naming_is_ok(response.status_code, dataset_id=dataset_id)
check_response(response)
return dict(zip(response.json()["names"], response.json()["versions"]))
[docs] def dataset_contains_document(self, dataset_id: str, document_id: str) -> bool:
server.check_naming_is_ok_regex(document_id)
return document_id in list(self.list_documents_in_dataset(dataset_id).keys())
[docs] def delete_document_in_dataset(self, dataset_id: str, document_id: str):
if self.dataset_contains_document(dataset_id, document_id):
response = self._session.delete(f"/dataset/{dataset_id}/{document_id}")
check_naming_is_ok(response.status_code, dataset_id=dataset_id, document_id=document_id)
check_response(response)
else:
logger.info(f'Document with id "{document_id}" does not exist in dataset with id "{dataset_id}"')
[docs] def delete_all_documents_in_dataset(self, dataset_id: str):
for document_id in list(self.list_documents_in_dataset(dataset_id).keys()):
self.delete_document_in_dataset(dataset_id, document_id)
[docs] def delete_all_documents(self):
for dataset_id in self.list_datasets():
self.delete_all_documents_in_dataset(dataset_id)
[docs] def list_all_classifiers(self) -> List[ClassifierInfo]:
response = self._session.get("/classifier")
check_response(response)
info_list = []
for classifier in response.json():
info_list.append(ClassifierInfo.parse_obj(classifier))
return info_list
[docs] def get_classifier_info(self, classifier_id: str) -> ClassifierInfo:
response = self._session.get(f"/classifier/{classifier_id}")
check_naming_is_ok(response.status_code, classifier_id=classifier_id)
check_response(response)
return ClassifierInfo.parse_obj(response.json())
# True: training has started. False: training has started already and function call had no effect
[docs] def train_on_dataset(self, classifier_id: str, model_id: str, dataset_id: str) -> bool:
response = self._session.post(f"/classifier/{classifier_id}/{model_id}/train/{dataset_id}")
check_naming_is_ok(response.status_code, classifier_id=classifier_id, model_id=model_id, dataset_id=dataset_id)
if response.status_code == 429:
# logger.info("Training has already started! {create_error_message(variables)}")
return False
check_response(response)
return True
[docs] def predict_on_document(self, classifier_id: str, model_id: str, document: Document) -> Document:
response = self._session.post(
f"{self.endpoint_url}/classifier/{classifier_id}/{model_id}/predict", json=document.dict()
)
check_naming_is_ok(response.status_code, classifier_id=classifier_id, model_id=model_id)
check_response(response)
return response.json()