diff --git a/backend/ml/churnModel.py b/backend/ml/churnModel.py index 853184a2..d5b27b06 100644 --- a/backend/ml/churnModel.py +++ b/backend/ml/churnModel.py @@ -69,6 +69,25 @@ def predict_churn(self, subscriber_address: str, user_data: Dict) -> Dict: "risk_factors": top_factors, "recommended_action": self._get_recommended_action(risk_level, top_factors) } + + def explain_churn(self, user_data: Dict) -> Dict: + """Return per-feature attributions approximating SHAP values for this linear-style model. + + This is a lightweight approximation: contribution = feature_value * weight. + """ + features = self._extract_features(user_data) + contributions = {} + for feat, val in features.items(): + w = self.feature_weights.get(feat, 0.0) + contributions[feat] = round(val * w, 6) + + # base value is the model bias; since this model is sum of contributions, base=0 + base_value = 0.0 + return { + "base_value": round(base_value, 6), + "attributions": contributions, + "approx_method": "linear_contribution" + } def _get_recommended_action(self, risk_level: str, top_factors: List[Dict]) -> str: if risk_level == "Low": diff --git a/backend/ml/pricingModel.py b/backend/ml/pricingModel.py index 5a5058a9..f917b794 100644 --- a/backend/ml/pricingModel.py +++ b/backend/ml/pricingModel.py @@ -58,6 +58,34 @@ def calculate_optimal_price(self, subscription_id: str, context: Dict) -> Dict: "recommendation": "Increase" if optimal_price > current_price else "Decrease" if optimal_price < current_price else "Maintain" } + def explain_price(self, context: Dict) -> Dict: + """Return per-factor attributions for the computed target price. + + This provides an approximate explanation by exposing the weighted components used + in the target price calculation. + """ + current_price = context.get("current_price", 10.0) + competitor_avg = context.get("competitor_avg", current_price) + demand = context.get("current_demand", 1.0) + wtp_estimate = self.estimate_willingness_to_pay(context.get("usage_data", {})) + + # contributions mirror the weighted average used for target_price + contrib_wtp = round(wtp_estimate * 0.4, 6) + contrib_competitor = round(competitor_avg * 0.4, 6) + contrib_current = round(current_price * demand * 0.2, 6) + total = round(contrib_wtp + contrib_competitor + contrib_current, 6) + + return { + "base_value": 0.0, + "attributions": { + "willingness_to_pay": contrib_wtp, + "competitor_benchmark": contrib_competitor, + "current_price_demand_adjusted": contrib_current, + }, + "target_price_sum": total, + "approx_method": "weighted_components" + } + def get_price_recommendations(self, plan_id: str, historical_data: List[Dict]) -> List[Dict]: """ Returns a range of price recommendations for a specific plan. diff --git a/ml-service/model_registry.py b/ml-service/model_registry.py index 77d12a92..dde43896 100644 --- a/ml-service/model_registry.py +++ b/ml-service/model_registry.py @@ -48,6 +48,12 @@ def __init__(self): self._models: Dict[str, Any] = {} self._meta: Dict[str, ModelMeta] = {} self._versions = self._load_version_file() + # Explanation storage and aggregates + self._explanations_file = os.path.join(os.path.dirname(__file__), "explanations.json") + # in-memory aggregates: {model_name: {feature: {"sum_abs": float, "count": int}}} + self._explanation_aggregates: Dict[str, Dict[str, Dict[str, float]]] = {} + # segment profiles: {model_name: {segment_key: {feature: avg_value}}} + self._segment_profiles: Dict[str, Dict[str, Dict[str, float]]] = {} def _load_version_file(self) -> Dict: if os.path.exists(REGISTRY_FILE): @@ -74,6 +80,66 @@ def load_all(self): logger.info(f"Loaded models: {list(self._models.keys())}") + def _append_explanation_file(self, record: Dict): + try: + data = [] + if os.path.exists(self._explanations_file): + with open(self._explanations_file, "r") as f: + try: + data = json.load(f) + except Exception: + data = [] + data.append(record) + with open(self._explanations_file, "w") as f: + json.dump(data, f) + except Exception as e: + logger.exception("Failed to write explanation record: %s", e) + + def record_explanation(self, model_name: str, subscriber: str, input_features: Dict, attributions: Dict, segment: Optional[str] = None): + """Store an explanation audit record and update aggregates and segment profiles.""" + ts = time.time() + record = { + "timestamp": ts, + "model": model_name, + "subscriber": subscriber, + "segment": segment, + "input_features": input_features, + "attributions": attributions, + } + # append to file (audit trail) + self._append_explanation_file(record) + + # update aggregates + agg = self._explanation_aggregates.setdefault(model_name, {}) + for feat, val in (attributions or {}).items(): + entry = agg.setdefault(feat, {"sum_abs": 0.0, "count": 0}) + entry["sum_abs"] += abs(float(val)) + entry["count"] += 1 + + # update segment profiles (simple running avg of attributions) + if segment: + segmap = self._segment_profiles.setdefault(model_name, {}) + profile = segmap.setdefault(segment, {}) + for feat, val in (attributions or {}).items(): + prev = profile.get(feat, {"sum": 0.0, "count": 0}) + prev["sum"] += float(val) + prev["count"] += 1 + profile[feat] = prev + + def get_global_feature_importance(self, model_name: str) -> Dict[str, float]: + """Return average absolute attribution per feature for a model.""" + agg = self._explanation_aggregates.get(model_name, {}) + out = {} + for feat, v in agg.items(): + if v["count"]: + out[feat] = round(v["sum_abs"] / v["count"], 6) + return out + + def get_segment_profile(self, model_name: str, segment: str) -> Dict[str, float]: + segmap = self._segment_profiles.get(model_name, {}) + profile = segmap.get(segment, {}) + return {feat: round(vals["sum"] / vals["count"], 6) for feat, vals in profile.items() if vals["count"]} + def get(self, name: str) -> Any: model = self._models.get(name) if model is None: diff --git a/ml-service/routers/churn.py b/ml-service/routers/churn.py index a7098bff..41229a48 100644 --- a/ml-service/routers/churn.py +++ b/ml-service/routers/churn.py @@ -33,13 +33,25 @@ class ForecastRequest(BaseModel): @router.post("/predict") -def predict_churn(req: ChurnRequest): +def predict_churn(req: ChurnRequest, explain: bool = False): from main import registry try: model = registry.get("churn") meta = registry.meta("churn") result = model.predict_churn(req.subscriber, req.user_data.model_dump()) meta.record_prediction() + # optionally compute explanations and record audit + if explain: + try: + expl = model.explain_churn(req.user_data.model_dump()) + except Exception: + expl = {"error": "explanation_failed"} + # store audit trail + try: + registry.record_explanation("churn", req.subscriber, req.user_data.model_dump(), expl.get("attributions", {}), segment=None) + except Exception: + pass + return {"model_version": meta.version, **result, "explanation": expl} return {"model_version": meta.version, **result} except Exception as e: registry.meta("churn").record_error() @@ -47,7 +59,7 @@ def predict_churn(req: ChurnRequest): @router.post("/predict/batch") -def predict_churn_batch(req: BatchChurnRequest): +def predict_churn_batch(req: BatchChurnRequest, explain: bool = False): from main import registry results = [] model = registry.get("churn") @@ -56,7 +68,18 @@ def predict_churn_batch(req: BatchChurnRequest): try: result = model.predict_churn(item.subscriber, item.user_data.model_dump()) meta.record_prediction() - results.append({"ok": True, **result}) + if explain: + try: + expl = model.explain_churn(item.user_data.model_dump()) + except Exception: + expl = {"error": "explanation_failed"} + try: + registry.record_explanation("churn", item.subscriber, item.user_data.model_dump(), expl.get("attributions", {}), segment=None) + except Exception: + pass + results.append({"ok": True, **result, "explanation": expl}) + else: + results.append({"ok": True, **result}) except Exception as e: meta.record_error() results.append({"ok": False, "subscriber": item.subscriber, "error": str(e)}) diff --git a/ml-service/routers/pricing.py b/ml-service/routers/pricing.py index bde15f9d..5d7743d3 100644 --- a/ml-service/routers/pricing.py +++ b/ml-service/routers/pricing.py @@ -31,7 +31,7 @@ class ABTestRequest(BaseModel): @router.post("/optimize") -def optimize_price(req: PricingRequest): +def optimize_price(req: PricingRequest, explain: bool = False): from main import registry try: model = registry.get("pricing") @@ -45,6 +45,16 @@ def optimize_price(req: PricingRequest): ctx["price_ceiling"] = ctx["current_price"] * 1.5 result = model.calculate_optimal_price(req.subscription_id, ctx) meta.record_prediction() + if explain: + try: + expl = model.explain_price(ctx) + except Exception: + expl = {"error": "explanation_failed"} + try: + registry.record_explanation("pricing", req.subscription_id, ctx, expl.get("attributions", {}), segment=None) + except Exception: + pass + return {"model_version": meta.version, **result, "explanation": expl} return {"model_version": meta.version, **result} except Exception as e: registry.meta("pricing").record_error() diff --git a/ml-service/routers/recommendations.py b/ml-service/routers/recommendations.py index 0e4c480b..931da5eb 100644 --- a/ml-service/routers/recommendations.py +++ b/ml-service/routers/recommendations.py @@ -30,7 +30,7 @@ class FeedbackRequest(BaseModel): @router.post("/predict") -def get_recommendations(req: RecommendationRequest): +def get_recommendations(req: RecommendationRequest, explain: bool = False): from main import registry try: model = registry.get("recommendations") @@ -38,6 +38,18 @@ def get_recommendations(req: RecommendationRequest): ctx = req.context.model_dump() if req.context else {} result = model.get_recommendations(req.subscriber, ctx) meta.record_prediction() + if explain: + # optional explain hook on recommendation models + expl = {} + try: + expl = getattr(model, "explain_recommendations", lambda s, c: {})(req.subscriber, ctx) + except Exception: + expl = {"error": "explanation_failed"} + try: + registry.record_explanation("recommendations", req.subscriber, ctx, expl.get("attributions", {}), segment=None) + except Exception: + pass + return {"model_version": meta.version, "recommendations": result, "explanation": expl} return {"model_version": meta.version, "recommendations": result} except Exception as e: registry.meta("recommendations").record_error() @@ -45,7 +57,7 @@ def get_recommendations(req: RecommendationRequest): @router.post("/predict/batch") -def get_recommendations_batch(req: BatchRecommendationRequest): +def get_recommendations_batch(req: BatchRecommendationRequest, explain: bool = False): from main import registry model = registry.get("recommendations") meta = registry.meta("recommendations") @@ -55,7 +67,18 @@ def get_recommendations_batch(req: BatchRecommendationRequest): ctx = item.context.model_dump() if item.context else {} recs = model.get_recommendations(item.subscriber, ctx) meta.record_prediction() - results.append({"ok": True, "subscriber": item.subscriber, "recommendations": recs}) + if explain: + try: + expl = getattr(model, "explain_recommendations", lambda s, c: {})(item.subscriber, ctx) + except Exception: + expl = {"error": "explanation_failed"} + try: + registry.record_explanation("recommendations", item.subscriber, ctx, expl.get("attributions", {}), segment=None) + except Exception: + pass + results.append({"ok": True, "subscriber": item.subscriber, "recommendations": recs, "explanation": expl}) + else: + results.append({"ok": True, "subscriber": item.subscriber, "recommendations": recs}) except Exception as e: meta.record_error() results.append({"ok": False, "subscriber": item.subscriber, "error": str(e)})