Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions backend/ml/churnModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,25 @@ def predict_churn(self, subscriber_address: str, user_data: Dict) -> Dict:
"risk_factors": top_factors,
"recommended_action": self._get_recommended_action(risk_level, top_factors)
}

def explain_churn(self, user_data: Dict) -> Dict:
"""Return per-feature attributions approximating SHAP values for this linear-style model.

This is a lightweight approximation: contribution = feature_value * weight.
"""
features = self._extract_features(user_data)
contributions = {}
for feat, val in features.items():
w = self.feature_weights.get(feat, 0.0)
contributions[feat] = round(val * w, 6)

# base value is the model bias; since this model is sum of contributions, base=0
base_value = 0.0
return {
"base_value": round(base_value, 6),
"attributions": contributions,
"approx_method": "linear_contribution"
}

def _get_recommended_action(self, risk_level: str, top_factors: List[Dict]) -> str:
if risk_level == "Low":
Expand Down
28 changes: 28 additions & 0 deletions backend/ml/pricingModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,34 @@ def calculate_optimal_price(self, subscription_id: str, context: Dict) -> Dict:
"recommendation": "Increase" if optimal_price > current_price else "Decrease" if optimal_price < current_price else "Maintain"
}

def explain_price(self, context: Dict) -> Dict:
"""Return per-factor attributions for the computed target price.

This provides an approximate explanation by exposing the weighted components used
in the target price calculation.
"""
current_price = context.get("current_price", 10.0)
competitor_avg = context.get("competitor_avg", current_price)
demand = context.get("current_demand", 1.0)
wtp_estimate = self.estimate_willingness_to_pay(context.get("usage_data", {}))

# contributions mirror the weighted average used for target_price
contrib_wtp = round(wtp_estimate * 0.4, 6)
contrib_competitor = round(competitor_avg * 0.4, 6)
contrib_current = round(current_price * demand * 0.2, 6)
total = round(contrib_wtp + contrib_competitor + contrib_current, 6)

return {
"base_value": 0.0,
"attributions": {
"willingness_to_pay": contrib_wtp,
"competitor_benchmark": contrib_competitor,
"current_price_demand_adjusted": contrib_current,
},
"target_price_sum": total,
"approx_method": "weighted_components"
}

def get_price_recommendations(self, plan_id: str, historical_data: List[Dict]) -> List[Dict]:
"""
Returns a range of price recommendations for a specific plan.
Expand Down
66 changes: 66 additions & 0 deletions ml-service/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ def __init__(self):
self._models: Dict[str, Any] = {}
self._meta: Dict[str, ModelMeta] = {}
self._versions = self._load_version_file()
# Explanation storage and aggregates
self._explanations_file = os.path.join(os.path.dirname(__file__), "explanations.json")
# in-memory aggregates: {model_name: {feature: {"sum_abs": float, "count": int}}}
self._explanation_aggregates: Dict[str, Dict[str, Dict[str, float]]] = {}
# segment profiles: {model_name: {segment_key: {feature: avg_value}}}
self._segment_profiles: Dict[str, Dict[str, Dict[str, float]]] = {}

def _load_version_file(self) -> Dict:
if os.path.exists(REGISTRY_FILE):
Expand All @@ -74,6 +80,66 @@ def load_all(self):

logger.info(f"Loaded models: {list(self._models.keys())}")

def _append_explanation_file(self, record: Dict):
try:
data = []
if os.path.exists(self._explanations_file):
with open(self._explanations_file, "r") as f:
try:
data = json.load(f)
except Exception:
data = []
data.append(record)
with open(self._explanations_file, "w") as f:
json.dump(data, f)
except Exception as e:
logger.exception("Failed to write explanation record: %s", e)

def record_explanation(self, model_name: str, subscriber: str, input_features: Dict, attributions: Dict, segment: Optional[str] = None):
"""Store an explanation audit record and update aggregates and segment profiles."""
ts = time.time()
record = {
"timestamp": ts,
"model": model_name,
"subscriber": subscriber,
"segment": segment,
"input_features": input_features,
"attributions": attributions,
}
# append to file (audit trail)
self._append_explanation_file(record)

# update aggregates
agg = self._explanation_aggregates.setdefault(model_name, {})
for feat, val in (attributions or {}).items():
entry = agg.setdefault(feat, {"sum_abs": 0.0, "count": 0})
entry["sum_abs"] += abs(float(val))
entry["count"] += 1

# update segment profiles (simple running avg of attributions)
if segment:
segmap = self._segment_profiles.setdefault(model_name, {})
profile = segmap.setdefault(segment, {})
for feat, val in (attributions or {}).items():
prev = profile.get(feat, {"sum": 0.0, "count": 0})
prev["sum"] += float(val)
prev["count"] += 1
profile[feat] = prev

def get_global_feature_importance(self, model_name: str) -> Dict[str, float]:
"""Return average absolute attribution per feature for a model."""
agg = self._explanation_aggregates.get(model_name, {})
out = {}
for feat, v in agg.items():
if v["count"]:
out[feat] = round(v["sum_abs"] / v["count"], 6)
return out

def get_segment_profile(self, model_name: str, segment: str) -> Dict[str, float]:
segmap = self._segment_profiles.get(model_name, {})
profile = segmap.get(segment, {})
return {feat: round(vals["sum"] / vals["count"], 6) for feat, vals in profile.items() if vals["count"]}

def get(self, name: str) -> Any:
model = self._models.get(name)
if model is None:
Expand Down
29 changes: 26 additions & 3 deletions ml-service/routers/churn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,33 @@ class ForecastRequest(BaseModel):


@router.post("/predict")
def predict_churn(req: ChurnRequest):
def predict_churn(req: ChurnRequest, explain: bool = False):
from main import registry
try:
model = registry.get("churn")
meta = registry.meta("churn")
result = model.predict_churn(req.subscriber, req.user_data.model_dump())
meta.record_prediction()
# optionally compute explanations and record audit
if explain:
try:
expl = model.explain_churn(req.user_data.model_dump())
except Exception:
expl = {"error": "explanation_failed"}
# store audit trail
try:
registry.record_explanation("churn", req.subscriber, req.user_data.model_dump(), expl.get("attributions", {}), segment=None)
except Exception:
pass
return {"model_version": meta.version, **result, "explanation": expl}
return {"model_version": meta.version, **result}
except Exception as e:
registry.meta("churn").record_error()
raise HTTPException(status_code=500, detail=str(e))


@router.post("/predict/batch")
def predict_churn_batch(req: BatchChurnRequest):
def predict_churn_batch(req: BatchChurnRequest, explain: bool = False):
from main import registry
results = []
model = registry.get("churn")
Expand All @@ -56,7 +68,18 @@ def predict_churn_batch(req: BatchChurnRequest):
try:
result = model.predict_churn(item.subscriber, item.user_data.model_dump())
meta.record_prediction()
results.append({"ok": True, **result})
if explain:
try:
expl = model.explain_churn(item.user_data.model_dump())
except Exception:
expl = {"error": "explanation_failed"}
try:
registry.record_explanation("churn", item.subscriber, item.user_data.model_dump(), expl.get("attributions", {}), segment=None)
except Exception:
pass
results.append({"ok": True, **result, "explanation": expl})
else:
results.append({"ok": True, **result})
except Exception as e:
meta.record_error()
results.append({"ok": False, "subscriber": item.subscriber, "error": str(e)})
Expand Down
12 changes: 11 additions & 1 deletion ml-service/routers/pricing.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class ABTestRequest(BaseModel):


@router.post("/optimize")
def optimize_price(req: PricingRequest):
def optimize_price(req: PricingRequest, explain: bool = False):
from main import registry
try:
model = registry.get("pricing")
Expand All @@ -45,6 +45,16 @@ def optimize_price(req: PricingRequest):
ctx["price_ceiling"] = ctx["current_price"] * 1.5
result = model.calculate_optimal_price(req.subscription_id, ctx)
meta.record_prediction()
if explain:
try:
expl = model.explain_price(ctx)
except Exception:
expl = {"error": "explanation_failed"}
try:
registry.record_explanation("pricing", req.subscription_id, ctx, expl.get("attributions", {}), segment=None)
except Exception:
pass
return {"model_version": meta.version, **result, "explanation": expl}
return {"model_version": meta.version, **result}
except Exception as e:
registry.meta("pricing").record_error()
Expand Down
29 changes: 26 additions & 3 deletions ml-service/routers/recommendations.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,34 @@ class FeedbackRequest(BaseModel):


@router.post("/predict")
def get_recommendations(req: RecommendationRequest):
def get_recommendations(req: RecommendationRequest, explain: bool = False):
from main import registry
try:
model = registry.get("recommendations")
meta = registry.meta("recommendations")
ctx = req.context.model_dump() if req.context else {}
result = model.get_recommendations(req.subscriber, ctx)
meta.record_prediction()
if explain:
# optional explain hook on recommendation models
expl = {}
try:
expl = getattr(model, "explain_recommendations", lambda s, c: {})(req.subscriber, ctx)
except Exception:
expl = {"error": "explanation_failed"}
try:
registry.record_explanation("recommendations", req.subscriber, ctx, expl.get("attributions", {}), segment=None)
except Exception:
pass
return {"model_version": meta.version, "recommendations": result, "explanation": expl}
return {"model_version": meta.version, "recommendations": result}
except Exception as e:
registry.meta("recommendations").record_error()
raise HTTPException(status_code=500, detail=str(e))


@router.post("/predict/batch")
def get_recommendations_batch(req: BatchRecommendationRequest):
def get_recommendations_batch(req: BatchRecommendationRequest, explain: bool = False):
from main import registry
model = registry.get("recommendations")
meta = registry.meta("recommendations")
Expand All @@ -55,7 +67,18 @@ def get_recommendations_batch(req: BatchRecommendationRequest):
ctx = item.context.model_dump() if item.context else {}
recs = model.get_recommendations(item.subscriber, ctx)
meta.record_prediction()
results.append({"ok": True, "subscriber": item.subscriber, "recommendations": recs})
if explain:
try:
expl = getattr(model, "explain_recommendations", lambda s, c: {})(item.subscriber, ctx)
except Exception:
expl = {"error": "explanation_failed"}
try:
registry.record_explanation("recommendations", item.subscriber, ctx, expl.get("attributions", {}), segment=None)
except Exception:
pass
results.append({"ok": True, "subscriber": item.subscriber, "recommendations": recs, "explanation": expl})
else:
results.append({"ok": True, "subscriber": item.subscriber, "recommendations": recs})
except Exception as e:
meta.record_error()
results.append({"ok": False, "subscriber": item.subscriber, "error": str(e)})
Expand Down