Source code for diva.monitoring.api_service_reliability

from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded

from pydantic import BaseModel
import threading
import uvicorn
from diva import parameters
import streamlit as st

import warnings
warnings.filterwarnings(
    "ignore", message="missing ScriptRunContext! This warning can be ignored when running in bare mode.")


app = FastAPI()

limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)

# ---------------- Middleware to block ip ----------------
[docs] @app.middleware("http") async def restrict_ip_middleware(request: Request, call_next): client_host = request.client.host if client_host not in parameters.authorized_IP: raise HTTPException(status_code=403, detail="Forbidden: Unauthorized IP") response = await call_next(request) return response
# ---------------- endpoint /llm_prompt ----------------
[docs] class PromptRequest(BaseModel): prompt: str
[docs] @app.post("/llm_prompt") @limiter.limit(parameters.limit_requests_fastapi) async def llm_prompt(request: Request, payload: PromptRequest): # ajoute request ici ! try: from diva.llm import llms generated_text = llms.generator(payload.prompt) return JSONResponse(content={"response": generated_text}, status_code=200) except Exception as e: raise HTTPException(status_code=500, detail=str(e))
# ---------------- endpoint /get_data ----------------
[docs] class DataRequest(BaseModel): location_name_orig: str location_name: str addresstype: str start_date: str end_date: str
[docs] @app.post("/get_data") @limiter.limit(parameters.limit_requests_fastapi) async def get_data(request: Request, payload: DataRequest): try: import pandas as pd from diva.data.dataset import DataCollection from diva import parameters from diva.logging_.logger import Process if "process_logging" not in st.session_state: st.session_state.process_logging = Process() time_intervals = [[pd.to_datetime(payload.start_date, format="%Y-%m-%d"), pd.to_datetime(payload.end_date, format="%Y-%m-%d")]] locs = [{ "location_name_orig": payload.location_name_orig, "location_name": payload.location_name, "addresstype": payload.addresstype }] dc = DataCollection(parameters.cache_b_collections_normal, 't2m') dc = dc.sample_time(time_intervals) dc = dc.apply_masks(locs) dc = dc.spatial_aggregation() raw_vals = dc.get_values() return JSONResponse(content={"data": str(raw_vals)}, status_code=200) except Exception as e: raise HTTPException(status_code=500, detail=str(e))
# ---------------- endpoint /generate_graph ----------------
[docs] class GraphRequest(BaseModel): starttime: str endtime: str location: str elementofinterest: str graph_type: str aggreg_type: str
[docs] @app.post("/generate_graph") @limiter.limit(parameters.limit_requests_fastapi) async def generate_graph(request: Request, payload: GraphRequest): # try: from diva.logging_.logger import Process from diva.graphs.service_graph_generation import ServiceGeneratePlotlyGraph if "process_logging" not in st.session_state: st.session_state.process_logging = Process() if "plots_history" not in st.session_state: st.session_state.plots_history = {} if "messages" not in st.session_state: st.session_state.messages = {} params = { "starttime": payload.starttime, "endtime": payload.endtime, "location": payload.location, "elementofinterest": payload.elementofinterest, "graph_type": payload.graph_type, "aggreg_type": payload.aggreg_type, "climate_variable": "temperature" } gg = ServiceGeneratePlotlyGraph( params=params, langage="English", user_type="normal") gg.generate(show=False) return JSONResponse(content={"data": str(gg.fig.data)}, status_code=200)
# except Exception as e: # raise HTTPException(status_code=500, detail=str(e)) # ---------------- run api ----------------
[docs] def run_api(): uvicorn.run(app, host=parameters.url_reliability, port=parameters.port_reliability)
api_thread = threading.Thread(target=run_api, daemon=True) api_thread.start() # if __name__ == "__main__": # uvicorn.run("api_service_reliability:app", # host="0.0.0.0", port=8602, reload=True)