1pub mod model;
8
9use std::sync::Arc;
10
11use arc_swap::ArcSwap;
12use mas_data_model::Ulid;
13use opa_wasm::{
14    Runtime,
15    wasmtime::{Config, Engine, Module, OptLevel, Store},
16};
17use thiserror::Error;
18use tokio::io::{AsyncRead, AsyncReadExt};
19
20pub use self::model::{
21    AuthorizationGrantInput, ClientRegistrationInput, Code as ViolationCode, EmailInput,
22    EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester, Violation,
23};
24
25#[derive(Debug, Error)]
26pub enum LoadError {
27    #[error("failed to read module")]
28    Read(#[from] tokio::io::Error),
29
30    #[error("failed to create WASM engine")]
31    Engine(#[source] anyhow::Error),
32
33    #[error("module compilation task crashed")]
34    CompilationTask(#[from] tokio::task::JoinError),
35
36    #[error("failed to compile WASM module")]
37    Compilation(#[source] anyhow::Error),
38
39    #[error("invalid policy data")]
40    InvalidData(#[source] anyhow::Error),
41
42    #[error("failed to instantiate a test instance")]
43    Instantiate(#[source] InstantiateError),
44}
45
46impl LoadError {
47    #[doc(hidden)]
50    #[must_use]
51    pub fn invalid_data_example() -> Self {
52        Self::InvalidData(anyhow::Error::msg("Failed to merge policy data objects"))
53    }
54}
55
56#[derive(Debug, Error)]
57pub enum InstantiateError {
58    #[error("failed to create WASM runtime")]
59    Runtime(#[source] anyhow::Error),
60
61    #[error("missing entrypoint {entrypoint}")]
62    MissingEntrypoint { entrypoint: String },
63
64    #[error("failed to load policy data")]
65    LoadData(#[source] anyhow::Error),
66}
67
68#[derive(Debug, Clone)]
70pub struct Entrypoints {
71    pub register: String,
72    pub client_registration: String,
73    pub authorization_grant: String,
74    pub email: String,
75}
76
77impl Entrypoints {
78    fn all(&self) -> [&str; 4] {
79        [
80            self.register.as_str(),
81            self.client_registration.as_str(),
82            self.authorization_grant.as_str(),
83            self.email.as_str(),
84        ]
85    }
86}
87
88#[derive(Debug)]
89pub struct Data {
90    server_name: String,
91
92    rest: Option<serde_json::Value>,
93}
94
95impl Data {
96    #[must_use]
97    pub fn new(server_name: String) -> Self {
98        Self {
99            server_name,
100            rest: None,
101        }
102    }
103
104    #[must_use]
105    pub fn with_rest(mut self, rest: serde_json::Value) -> Self {
106        self.rest = Some(rest);
107        self
108    }
109
110    fn to_value(&self) -> Result<serde_json::Value, anyhow::Error> {
111        let base = serde_json::json!({
112            "server_name": self.server_name,
113        });
114
115        if let Some(rest) = &self.rest {
116            merge_data(base, rest.clone())
117        } else {
118            Ok(base)
119        }
120    }
121}
122
123fn value_kind(value: &serde_json::Value) -> &'static str {
124    match value {
125        serde_json::Value::Object(_) => "object",
126        serde_json::Value::Array(_) => "array",
127        serde_json::Value::String(_) => "string",
128        serde_json::Value::Number(_) => "number",
129        serde_json::Value::Bool(_) => "boolean",
130        serde_json::Value::Null => "null",
131    }
132}
133
134fn merge_data(
135    mut left: serde_json::Value,
136    right: serde_json::Value,
137) -> Result<serde_json::Value, anyhow::Error> {
138    merge_data_rec(&mut left, right)?;
139    Ok(left)
140}
141
142fn merge_data_rec(
143    left: &mut serde_json::Value,
144    right: serde_json::Value,
145) -> Result<(), anyhow::Error> {
146    match (left, right) {
147        (serde_json::Value::Object(left), serde_json::Value::Object(right)) => {
148            for (key, value) in right {
149                if let Some(left_value) = left.get_mut(&key) {
150                    merge_data_rec(left_value, value)?;
151                } else {
152                    left.insert(key, value);
153                }
154            }
155        }
156        (serde_json::Value::Array(left), serde_json::Value::Array(right)) => {
157            left.extend(right);
158        }
159        (serde_json::Value::Number(left), serde_json::Value::Number(right)) => {
161            *left = right;
162        }
163        (serde_json::Value::Bool(left), serde_json::Value::Bool(right)) => {
164            *left = right;
165        }
166        (serde_json::Value::String(left), serde_json::Value::String(right)) => {
167            *left = right;
168        }
169
170        (left, right) if left.is_null() => *left = right,
172
173        (left, right) if right.is_null() => *left = right,
175
176        (left, right) => anyhow::bail!(
177            "Cannot merge a {} into a {}",
178            value_kind(&right),
179            value_kind(left),
180        ),
181    }
182
183    Ok(())
184}
185
186struct DynamicData {
187    version: Option<Ulid>,
188    merged: serde_json::Value,
189}
190
191pub struct PolicyFactory {
192    engine: Engine,
193    module: Module,
194    data: Data,
195    dynamic_data: ArcSwap<DynamicData>,
196    entrypoints: Entrypoints,
197}
198
199impl PolicyFactory {
200    #[tracing::instrument(name = "policy.load", skip(source))]
201    pub async fn load(
202        mut source: impl AsyncRead + std::marker::Unpin,
203        data: Data,
204        entrypoints: Entrypoints,
205    ) -> Result<Self, LoadError> {
206        let mut config = Config::default();
207        config.async_support(true);
208        config.cranelift_opt_level(OptLevel::SpeedAndSize);
209
210        let engine = Engine::new(&config).map_err(LoadError::Engine)?;
211
212        let mut buf = Vec::new();
214        source.read_to_end(&mut buf).await?;
215        let (engine, module) = tokio::task::spawn_blocking(move || {
217            let module = Module::new(&engine, buf)?;
218            anyhow::Ok((engine, module))
219        })
220        .await?
221        .map_err(LoadError::Compilation)?;
222
223        let merged = data.to_value().map_err(LoadError::InvalidData)?;
224        let dynamic_data = ArcSwap::new(Arc::new(DynamicData {
225            version: None,
226            merged,
227        }));
228
229        let factory = Self {
230            engine,
231            module,
232            data,
233            dynamic_data,
234            entrypoints,
235        };
236
237        factory
239            .instantiate()
240            .await
241            .map_err(LoadError::Instantiate)?;
242
243        Ok(factory)
244    }
245
246    pub async fn set_dynamic_data(
259        &self,
260        dynamic_data: mas_data_model::PolicyData,
261    ) -> Result<bool, LoadError> {
262        if self.dynamic_data.load().version == Some(dynamic_data.id) {
265            return Ok(false);
267        }
268
269        let static_data = self.data.to_value().map_err(LoadError::InvalidData)?;
270        let merged = merge_data(static_data, dynamic_data.data).map_err(LoadError::InvalidData)?;
271
272        self.instantiate_with_data(&merged)
274            .await
275            .map_err(LoadError::Instantiate)?;
276
277        self.dynamic_data.store(Arc::new(DynamicData {
279            version: Some(dynamic_data.id),
280            merged,
281        }));
282
283        Ok(true)
284    }
285
286    #[tracing::instrument(name = "policy.instantiate", skip_all)]
287    pub async fn instantiate(&self) -> Result<Policy, InstantiateError> {
288        let data = self.dynamic_data.load();
289        self.instantiate_with_data(&data.merged).await
290    }
291
292    async fn instantiate_with_data(
293        &self,
294        data: &serde_json::Value,
295    ) -> Result<Policy, InstantiateError> {
296        let mut store = Store::new(&self.engine, ());
297        let runtime = Runtime::new(&mut store, &self.module)
298            .await
299            .map_err(InstantiateError::Runtime)?;
300
301        let policy_entrypoints = runtime.entrypoints();
303
304        for e in self.entrypoints.all() {
305            if !policy_entrypoints.contains(e) {
306                return Err(InstantiateError::MissingEntrypoint {
307                    entrypoint: e.to_owned(),
308                });
309            }
310        }
311
312        let instance = runtime
313            .with_data(&mut store, data)
314            .await
315            .map_err(InstantiateError::LoadData)?;
316
317        Ok(Policy {
318            store,
319            instance,
320            entrypoints: self.entrypoints.clone(),
321        })
322    }
323}
324
325pub struct Policy {
326    store: Store<()>,
327    instance: opa_wasm::Policy<opa_wasm::DefaultContext>,
328    entrypoints: Entrypoints,
329}
330
331#[derive(Debug, Error)]
332#[error("failed to evaluate policy")]
333pub enum EvaluationError {
334    Serialization(#[from] serde_json::Error),
335    Evaluation(#[from] anyhow::Error),
336}
337
338impl Policy {
339    #[tracing::instrument(
340        name = "policy.evaluate_email",
341        skip_all,
342        fields(
343            %input.email,
344        ),
345    )]
346    pub async fn evaluate_email(
347        &mut self,
348        input: EmailInput<'_>,
349    ) -> Result<EvaluationResult, EvaluationError> {
350        let [res]: [EvaluationResult; 1] = self
351            .instance
352            .evaluate(&mut self.store, &self.entrypoints.email, &input)
353            .await?;
354
355        Ok(res)
356    }
357
358    #[tracing::instrument(
359        name = "policy.evaluate.register",
360        skip_all,
361        fields(
362            ?input.registration_method,
363            input.username = input.username,
364            input.email = input.email,
365        ),
366    )]
367    pub async fn evaluate_register(
368        &mut self,
369        input: RegisterInput<'_>,
370    ) -> Result<EvaluationResult, EvaluationError> {
371        let [res]: [EvaluationResult; 1] = self
372            .instance
373            .evaluate(&mut self.store, &self.entrypoints.register, &input)
374            .await?;
375
376        Ok(res)
377    }
378
379    #[tracing::instrument(skip(self))]
380    pub async fn evaluate_client_registration(
381        &mut self,
382        input: ClientRegistrationInput<'_>,
383    ) -> Result<EvaluationResult, EvaluationError> {
384        let [res]: [EvaluationResult; 1] = self
385            .instance
386            .evaluate(
387                &mut self.store,
388                &self.entrypoints.client_registration,
389                &input,
390            )
391            .await?;
392
393        Ok(res)
394    }
395
396    #[tracing::instrument(
397        name = "policy.evaluate.authorization_grant",
398        skip_all,
399        fields(
400            %input.scope,
401            %input.client.id,
402        ),
403    )]
404    pub async fn evaluate_authorization_grant(
405        &mut self,
406        input: AuthorizationGrantInput<'_>,
407    ) -> Result<EvaluationResult, EvaluationError> {
408        let [res]: [EvaluationResult; 1] = self
409            .instance
410            .evaluate(
411                &mut self.store,
412                &self.entrypoints.authorization_grant,
413                &input,
414            )
415            .await?;
416
417        Ok(res)
418    }
419}
420
421#[cfg(test)]
422mod tests {
423
424    use std::time::SystemTime;
425
426    use super::*;
427
428    #[tokio::test]
429    async fn test_register() {
430        let data = Data::new("example.com".to_owned()).with_rest(serde_json::json!({
431            "allowed_domains": ["element.io", "*.element.io"],
432            "banned_domains": ["staging.element.io"],
433        }));
434
435        #[allow(clippy::disallowed_types)]
436        let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
437            .join("..")
438            .join("..")
439            .join("policies")
440            .join("policy.wasm");
441
442        let file = tokio::fs::File::open(path).await.unwrap();
443
444        let entrypoints = Entrypoints {
445            register: "register/violation".to_owned(),
446            client_registration: "client_registration/violation".to_owned(),
447            authorization_grant: "authorization_grant/violation".to_owned(),
448            email: "email/violation".to_owned(),
449        };
450
451        let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
452
453        let mut policy = factory.instantiate().await.unwrap();
454
455        let res = policy
456            .evaluate_register(RegisterInput {
457                registration_method: RegistrationMethod::Password,
458                username: "hello",
459                email: Some("hello@example.com"),
460                requester: Requester {
461                    ip_address: None,
462                    user_agent: None,
463                },
464            })
465            .await
466            .unwrap();
467        assert!(!res.valid());
468
469        let res = policy
470            .evaluate_register(RegisterInput {
471                registration_method: RegistrationMethod::Password,
472                username: "hello",
473                email: Some("hello@foo.element.io"),
474                requester: Requester {
475                    ip_address: None,
476                    user_agent: None,
477                },
478            })
479            .await
480            .unwrap();
481        assert!(res.valid());
482
483        let res = policy
484            .evaluate_register(RegisterInput {
485                registration_method: RegistrationMethod::Password,
486                username: "hello",
487                email: Some("hello@staging.element.io"),
488                requester: Requester {
489                    ip_address: None,
490                    user_agent: None,
491                },
492            })
493            .await
494            .unwrap();
495        assert!(!res.valid());
496    }
497
498    #[tokio::test]
499    async fn test_dynamic_data() {
500        let data = Data::new("example.com".to_owned());
501
502        #[allow(clippy::disallowed_types)]
503        let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
504            .join("..")
505            .join("..")
506            .join("policies")
507            .join("policy.wasm");
508
509        let file = tokio::fs::File::open(path).await.unwrap();
510
511        let entrypoints = Entrypoints {
512            register: "register/violation".to_owned(),
513            client_registration: "client_registration/violation".to_owned(),
514            authorization_grant: "authorization_grant/violation".to_owned(),
515            email: "email/violation".to_owned(),
516        };
517
518        let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
519
520        let mut policy = factory.instantiate().await.unwrap();
521
522        let res = policy
523            .evaluate_register(RegisterInput {
524                registration_method: RegistrationMethod::Password,
525                username: "hello",
526                email: Some("hello@example.com"),
527                requester: Requester {
528                    ip_address: None,
529                    user_agent: None,
530                },
531            })
532            .await
533            .unwrap();
534        assert!(res.valid());
535
536        factory
538            .set_dynamic_data(mas_data_model::PolicyData {
539                id: Ulid::nil(),
540                created_at: SystemTime::now().into(),
541                data: serde_json::json!({
542                    "emails": {
543                        "banned_addresses": {
544                            "substrings": ["hello"]
545                        }
546                    }
547                }),
548            })
549            .await
550            .unwrap();
551        let mut policy = factory.instantiate().await.unwrap();
552        let res = policy
553            .evaluate_register(RegisterInput {
554                registration_method: RegistrationMethod::Password,
555                username: "hello",
556                email: Some("hello@example.com"),
557                requester: Requester {
558                    ip_address: None,
559                    user_agent: None,
560                },
561            })
562            .await
563            .unwrap();
564        assert!(!res.valid());
565    }
566
567    #[tokio::test]
568    async fn test_big_dynamic_data() {
569        let data = Data::new("example.com".to_owned());
570
571        #[allow(clippy::disallowed_types)]
572        let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
573            .join("..")
574            .join("..")
575            .join("policies")
576            .join("policy.wasm");
577
578        let file = tokio::fs::File::open(path).await.unwrap();
579
580        let entrypoints = Entrypoints {
581            register: "register/violation".to_owned(),
582            client_registration: "client_registration/violation".to_owned(),
583            authorization_grant: "authorization_grant/violation".to_owned(),
584            email: "email/violation".to_owned(),
585        };
586
587        let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
588
589        let data: Vec<String> = (0..(1024 * 1024 / 8))
592            .map(|i| format!("{:05}", i % 100_000))
593            .collect();
594        let json = serde_json::json!({ "emails": { "banned_addresses": { "substrings": data } } });
595        factory
596            .set_dynamic_data(mas_data_model::PolicyData {
597                id: Ulid::nil(),
598                created_at: SystemTime::now().into(),
599                data: json,
600            })
601            .await
602            .unwrap();
603
604        let mut policy = factory.instantiate().await.unwrap();
607        let res = policy
608            .evaluate_register(RegisterInput {
609                registration_method: RegistrationMethod::Password,
610                username: "hello",
611                email: Some("12345@example.com"),
612                requester: Requester {
613                    ip_address: None,
614                    user_agent: None,
615                },
616            })
617            .await
618            .unwrap();
619        assert!(!res.valid());
620    }
621
622    #[test]
623    fn test_merge() {
624        use serde_json::json as j;
625
626        let res = merge_data(j!({"hello": "world"}), j!({"foo": "bar"})).unwrap();
628        assert_eq!(res, j!({"hello": "world", "foo": "bar"}));
629
630        let res = merge_data(j!({"hello": "world"}), j!({"hello": "john"})).unwrap();
632        assert_eq!(res, j!({"hello": "john"}));
633
634        let res = merge_data(j!({"hello": true}), j!({"hello": false})).unwrap();
635        assert_eq!(res, j!({"hello": false}));
636
637        let res = merge_data(j!({"hello": 0}), j!({"hello": 42})).unwrap();
638        assert_eq!(res, j!({"hello": 42}));
639
640        merge_data(j!({"hello": "world"}), j!({"hello": 123}))
642            .expect_err("Can't merge different types");
643
644        let res = merge_data(j!({"hello": ["world"]}), j!({"hello": ["john"]})).unwrap();
646        assert_eq!(res, j!({"hello": ["world", "john"]}));
647
648        let res = merge_data(j!({"hello": "world"}), j!({"hello": null})).unwrap();
650        assert_eq!(res, j!({"hello": null}));
651
652        let res = merge_data(j!({"hello": null}), j!({"hello": "world"})).unwrap();
654        assert_eq!(res, j!({"hello": "world"}));
655
656        let res = merge_data(j!({"a": {"b": {"c": "d"}}}), j!({"a": {"b": {"e": "f"}}})).unwrap();
658        assert_eq!(res, j!({"a": {"b": {"c": "d", "e": "f"}}}));
659    }
660}