-
Notifications
You must be signed in to change notification settings - Fork 55
Expand file tree
/
Copy pathfasttext_wrapper.py
More file actions
97 lines (80 loc) · 2.91 KB
/
fasttext_wrapper.py
File metadata and controls
97 lines (80 loc) · 2.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""
FastText wrapper for MLflow.
"""
from typing import Tuple, Optional, Dict, Any, List
import fasttext
import mlflow
import pandas as pd
from preprocessor import Preprocessor
from constants import TEXT_FEATURE, LABEL_PREFIX
class FastTextWrapper(mlflow.pyfunc.PythonModel):
"""
Class to wrap and use FastText models with MLflow.
"""
def __init__(self):
"""
Construct a FastTextWrapper object.
"""
self.model = None
self.preprocessor = Preprocessor()
def load_context(self, context: mlflow.pyfunc.PythonModelContext) -> None:
"""
Load the FastText model and its configuration file from an MLflow model
artifact. This method is called when loading an MLflow model with
pyfunc.load_model(), as soon as the PythonModel is constructed.
Args:
context (mlflow.pyfunc.PythonModelContext): MLflow context where
the model artifact is stored. It should contain the following
artifacts:
- "model_path": path to the FastText model file.
- "config_path": path to the configuration file.
"""
self.model = fasttext.load_model(context.artifacts["model_path"])
def predict(
self,
context: mlflow.pyfunc.PythonModelContext,
model_input: List[str],
params: Optional[Dict[str, Any]] = None
) -> Tuple:
"""
Predicts the most likely codes for a list of texts.
Args:
context (mlflow.pyfunc.PythonModelContext): The MLflow model
context.
model_input (List): A list of text observations.
params (Optional[Dict[str, Any]]): Additional parameters to
pass to the model for inference.
Returns:
A tuple containing the k most likely codes to the query.
"""
df = self.preprocessor.clean_text(
pd.DataFrame(model_input, columns=[TEXT_FEATURE]),
text_feature=TEXT_FEATURE,
)
texts = df.apply(self._format_item, axis=1).to_list()
predictions = self.model.predict(
texts,
**params
)
predictions_formatted = {
i: {
rank_pred
+ 1: {
"nace": predictions[0][i][rank_pred].replace(LABEL_PREFIX, ""),
"probability": float(predictions[1][i][rank_pred]),
}
for rank_pred in range(params["k"])
}
for i in range(len(predictions[0]))
}
return predictions_formatted
def _format_item(self, row: pd.Series) -> str:
"""
Formats a row of data into a string.
Args:
row (pandas.Series): A pandas series containing the row data.
Returns:
A formatted item string.
"""
formatted_item = row[TEXT_FEATURE]
return formatted_item