Source code for galahad.server.contrib.sentence_classification.sklearn_sentence_classifier

import logging
from typing import List, Optional

try:
    from sklearn.feature_extraction.text import (CountVectorizer,
                                                 TfidfTransformer)
    from sklearn.naive_bayes import MultinomialNB
    from sklearn.pipeline import Pipeline
except ImportError as error:
    print("Could not import 'sklearn', please install it manually via 'pip install scikit-learn'")

from galahad.formats import build_sentence_classification_document
from galahad.server.annotations import Annotations
from galahad.server.classifier import (AnnotationFeatures, AnnotationTypes,
                                       Classifier)
from galahad.server.dataclasses import Document

logger = logging.getLogger(__name__)


[docs]class SklearnSentenceClassifier(Classifier): def __init__(self): super().__init__() self._sentence_type = AnnotationTypes.SENTENCE.value self._sentence_annotation_type = AnnotationTypes.ANNOTATION.value self._target_feature = AnnotationFeatures.VALUE.value
[docs] def train(self, model_id: str, documents: List[Document]): texts = [] labels = [] for document in documents: annotations = Annotations.from_dict(document.text, document.annotations) for sentence in annotations.select(self._sentence_type): for sentence_label in annotations.select_covered(self._sentence_annotation_type, sentence): text = annotations.get_covered_text(sentence_label) label = sentence_label.features.get(self._target_feature) if label is None: continue texts.append(text) labels.append(label) assert len(texts) == len(labels), "Unequal number of sentences and labels" if not len(texts): logger.debug(f"Empty training set, skipping!") model = Pipeline([("vect", CountVectorizer()), ("tfidf", TfidfTransformer()), ("clf", MultinomialNB())]) model.fit(texts, labels) logger.debug(f"Training finished for model with id [%s]", model_id) self._save_model(model_id, model)
[docs] def predict(self, model_id: str, document: Document) -> Optional[Document]: model: Optional[Pipeline] = self._load_model(model_id) if model is None: logger.debug("No trained model ready yet!") return annotations = Annotations.from_dict(document.text, document.annotations) texts = [annotations.get_covered_text(sentence) for sentence in annotations.select(self._sentence_type)] predicted_labels = model.predict(texts) return build_sentence_classification_document(texts, predicted_labels)
[docs] def consumes(self) -> List[str]: return [self._sentence_type, self._sentence_annotation_type]
[docs] def produces(self) -> List[str]: return [self._target_feature]