LogRel.PERTactics: Ltac2 tactics to manipulate partial equivalence relations.


From Coq Require Import CRelationClasses.
From LogRel Require Import Utils.
From Ltac2 Require Import Ltac2 Printf.
From Ltac2 Require FMap FSet Constr.
Import FSet.Tags.

(* Prerequisites *)

Ltac2 array_extend (v : 'a array) (default : 'a) :=
  let old_size := Array.length v in
  let size := Int.add 1 (Int.mul old_size 2) in
  Control.assert_bounds "UFBase.extend" (Int.le old_size size) ;
  let table := Array.make size default in
  Array.lowlevel_blit v 0 table 0 old_size ;
  table.

Ltac2 unif (c : constr) (id : ident) : bool :=
  let c' := Control.hyp id in
  Control.once_plus
    (fun () => Unification.unify_with_full_ts c c' ; true)
    (fun _ => false).

Ltac2 mem_unif (c : constr) (ids : ident list) : ident option :=
  List.find_opt (unif c) ids.

Ltac2 ident_of_constr (c : constr) : ident :=
  match Constr.Unsafe.kind c with
  | Constr.Unsafe.Var h => h
  | _ =>
    let h := Fresh.in_goal @pt in
    Std.pose (Some h) c ;
    h
  end.

Ltac2 Type exn ::= [Panic(string)].

Ltac2 default_panic (msg : string) (x : 'a option) : 'a :=
  match x with
  | Some a => a
  | None => Control.throw (Panic msg)
  end.

Ltac2 match_rel (r : constr -> constr -> constr) c :=
  Control.once_plus (fun () =>
    let a := open_constr:(_) in
    let b := open_constr:(_) in
    Unification.unify_with_full_ts c (r a b) ;
    Some (a,b))
    (fun _ => None).

Module BiMap.
  Ltac2 Type ('a,'b) t := ('a * 'b) list.

  Ltac2 empty : ('a,'b) t := [].

  Ltac2 assoc (eqa : 'a -> 'a -> bool) (a : 'a) (m : ('a,'b) t) : 'b option :=
    Option.map snd (List.find_opt (fun (a', _) => eqa a a') m).

  Ltac2 assoc_inv (eqb : 'b -> 'b -> bool) (b : 'b) (m : ('a,'b) t) : 'a option :=
    Option.map fst (List.find_opt (fun (_, b') => eqb b b') m).

  Ltac2 assoc_dflt (eqa : 'a -> 'a -> bool) dflt (a : 'a) (m : ('a,'b) t) : 'b :=
    Option.default dflt (assoc eqa a m).

  Ltac2 assoc_inv_dflt (eqb : 'b -> 'b -> bool) dflt (b : 'b) (m : ('a,'b) t) : 'a :=
    Option.default dflt (assoc_inv eqb b m).

  Ltac2 mem (eqa : 'a -> 'a -> bool) (a : 'a) (m : ('a,'b) t) : bool :=
    match (assoc eqa a m) with | Some _ => true | _ => false end.

  Ltac2 mem_inv (eqb : 'b -> 'b -> bool) (b : 'b) (m : ('a,'b) t) : bool :=
    match (assoc_inv eqb b m) with | Some _ => true | _ => false end.

  Ltac2 push (a : 'a) (b : 'b) (m : ('a,'b) t) : ('a,'b) t :=
    (a,b) :: m.

  Ltac2 domain (m : ('a,'b) t) : 'a list := List.map fst m.
  Ltac2 image (m : ('a,'b) t) : 'b list := List.map snd m.

End BiMap.

Module UF.

(* Representation of the elements of the disjoint-union structure
  should always be >= 0, -1 representing an absence of data/root node *)

Ltac2 Type elt := int.

(* Representation of witnesses of relatedness between elements x -> y
  cannot be 0, negative values correspond to the reverse witness y -> x *)

Ltac2 Type witness := int.

Ltac2 Type st := {
    mutable last : int ;
    mutable table : (elt * witness list) array ;
    mutable numwits : int ;
  }.

Ltac2 make (size : int) : st :=
  Control.assert_valid_argument "union find size should be positive" (Int.ge size 0) ;
  { last := 0; numwits := 0; table := Array.make size (-1, []) }.

Ltac2 extend (x : st) : unit :=
  x.(table) := array_extend (x.(table)) (-1, []).

Ltac2 add (x : st) : elt :=
  (if Int.equal (x.(last)) (Array.length (x.(table)))
  then extend x else ()) ;
  let r := x.(last) in
  x.(last) := Int.add r 1 ;
  r.

Ltac2 rec root (x : st) (k : int) : elt * witness list :=
  let (i, wski) := Array.get (x.(table)) k in
  if Int.equal i (-1)
  then
    Control.assert_bounds "BaseUF.root non-empty witness list" (Int.equal (List.length wski) 0) ;
    k, []
  else
    let (r, wsir) := root x i in
    let b := Int.equal r i in
    let wskr := if b then wski else List.append wsir wski in
    (if b then () else Array.set (x.(table)) k (r, wskr)) ;
    r, wskr.

Ltac2 inverse_path (l : witness list) : witness list :=
  List.rev_map Int.neg l.

Ltac2 unify (x : st) k l : witness option :=
  let (r1, wskr1) := root x k in
  let (r2, wslr2) := root x l in
  if Int.equal r1 r2
  then None
  else
    let wkl := Int.add (x.(numwits)) 1 in
    x.(numwits) := wkl ;
    let wsr1r2 := List.append wslr2 (wkl :: inverse_path wskr1) in
    Array.set (x.(table)) r1 (r2, wsr1r2) ;
    Some wkl.

Ltac2 rec drop_prefix p q :=
  match p, q with
  | x :: xs, y :: ys =>
    if Int.equal x y
    then drop_prefix xs ys
    else (p, q)
  | _, _ => (p, q)
  end.

(*normalizes a path from k -> l given paths k -> r and l -> r *)
Ltac2 normalize_path (wskr : witness list) (wslr : witness list) : witness list :=
  let (wski, wsli) := drop_prefix wskr wslr in
  List.append (inverse_path wsli) wski.

Ltac2 conv (x : st) (k : elt) (l : elt) : witness list option :=
  if Int.equal k l
  then Some []
  else
    let (r1, wskr1) := root x k in
    let (r2, wslr2) := root x l in
    if Int.equal r1 r2
    then Some (normalize_path wskr1 wslr2)
    else None.

End UF.

Module PER.

Ltac2 lrfl c := preterm:(lrefl $preterm:c).
Ltac2 urfl c := preterm:(urefl $preterm:c).
Ltac2 trans c1 c2 := preterm:(transitivity $preterm:c1 $preterm:c2).
Ltac2 symm c := preterm:(symmetry $preterm:c).

Ltac2 Type st := {
  mutable pts_id_bimap : (ident, UF.elt) BiMap.t ;
  mutable id_qrefl : (ident, preterm) FMap.t ;
  mutable id_to_wits : (UF.witness, preterm) FMap.t ;
  st : UF.st ;
}.

Ltac2 id_of_pt (x : st) (h : ident) : UF.elt :=
  BiMap.assoc_dflt Ident.equal (-1) h (x.(pts_id_bimap)).

Ltac2 pt_of_id (x : st) (pt : UF.elt) : ident :=
  BiMap.assoc_inv_dflt Int.equal (@foo) pt (x.(pts_id_bimap)).

(* Ltac2 qrefl_of_id (x : st) (i : ident) : preterm :=
  Option.get (FMap.find_opt i (x.(id_qrefl))). *)


Ltac2 witness_of_id (x : st) (i : UF.witness) : preterm :=
  Option.get (FMap.find_opt i (x.(id_to_wits))).

Ltac2 make n := {
  pts_id_bimap := BiMap.empty ;
  id_qrefl := FMap.empty FSet.Tags.ident_tag ;
  id_to_wits := FMap.empty FSet.Tags.int_tag ;
  st := UF.make n ;
}.

Ltac2 add_pt (x : st) (h : ident) : unit :=
  if BiMap.mem Ident.equal h (x.(pts_id_bimap)) then ()
  else x.(pts_id_bimap) := BiMap.push h (UF.add (x.(st))) (x.(pts_id_bimap)).

Ltac2 get_pt_cstr (x : st) (c : constr) : ident option :=
  (* shortcut the case where c = Var h ? *)
  mem_unif c (BiMap.domain (x.(pts_id_bimap))).

Ltac2 add_pt_cstr (x : st) (c : constr) : ident :=
  match get_pt_cstr x c with
  | Some h => h
  | None =>
    let h := ident_of_constr c in
    add_pt x h;
    h
  end.

Ltac2 add_qrefl (x : st) (h : ident) (pf : preterm) :=
  if FMap.mem h (x.(id_qrefl)) then () else x.(id_qrefl) := FMap.add h pf (x.(id_qrefl)).

Ltac2 add_witness (x : st) c1 c2 pf :=
  let h1 := add_pt_cstr x c1 in
  let h2 := add_pt_cstr x c2 in
  let k1 := id_of_pt x h1 in
  let k2 := id_of_pt x h2 in
  add_qrefl x h1 (lrfl pf) ;
  add_qrefl x h2 (urfl pf) ;
  match UF.unify (x.(st)) k1 k2 with
  | None => ()
  | Some n =>
    x.(id_to_wits) := FMap.add n pf (x.(id_to_wits))
  end.

Ltac2 build_witness (x : st) h l :=
  let map i :=
    if Int.le i 0
    then symm (witness_of_id x (Int.neg i))
    else witness_of_id x i
  in
  (* printf "building witness from *)
  match List.map map l with
  | [] => default_panic "build_witness/qrefl" (FMap.find_opt h (x.(id_qrefl)))
  | [c] => c
  | c :: cs => List.fold_left (fun c1 c2 => trans c2 c1) c cs
  end.

Ltac2 equiv (x : st) cx cy :=
  Option.default false
    (Option.bind (get_pt_cstr x cx) (fun hx =>
      Option.bind (get_pt_cstr x cy) (fun hy =>
        Option.map (fun _ => true)
           (UF.conv (x.(st)) (id_of_pt x hx) (id_of_pt x hy))))).

Ltac2 get_witness (x : st) h1 h2 :=
  match BiMap.assoc Ident.equal h1 (x.(pts_id_bimap)),
    BiMap.assoc Ident.equal h2 (x.(pts_id_bimap)) with
  | Some k1, Some k2 =>
    let postprocess w := (* Constr.pretype *) (build_witness x h1 w) in
    Option.map postprocess (UF.conv (x.(st)) k1 k2)
    (* Option.map postprocess (Control.time (Some "Conv:") (fun () => UF.conv (x.(st)) k1 k2)) *)
  | _, _ => None
  end.

Ltac2 get_witness_cstr st (cx, cy) :=
  Option.bind (get_pt_cstr st cx) (fun hx =>
    Option.bind (get_pt_cstr st cy) (fun hy =>
      get_witness st hx hy)).

Ltac2 repr (x : st) (c : constr) : (constr * preterm) option :=
  Option.bind (get_pt_cstr x c) (fun h =>
    let (hr, w) := UF.root (x.(st)) (id_of_pt x h) in
    Some (Control.hyp (pt_of_id x hr), build_witness x h w)).

Ltac2 qrefl st (c : constr) :=
  (* Option.map Constr.pretype *) (Option.bind (get_pt_cstr st c) (fun i =>
    FMap.find_opt i (st.(id_qrefl)))).

Ltac2 add_rel (st : st) (f : ident -> constr -> (constr * constr * preterm) list) (hyp : ident * constr option * constr) : unit :=
  let (hpf, _, c) := hyp in
  List.iter (fun (a,b, pf) => add_witness st a b pf) (f hpf c).

Ltac2 init_with extractor (n : int) : constr * constr -> constr option :=
  let st := make n in
  List.iter (add_rel st extractor) (Control.hyps ()) ;
  fun cs => Option.map Constr.pretype (get_witness_cstr st cs).

Ltac2 solve_with extractor matcher n :=
  let solver := init_with extractor n in
  match matcher (Control.goal ()) with
  | None => Control.throw (Tactic_failure (Some (fprintf "Goal does not match.")))
  | Some (cx, cy) =>
    match solver (cx, cy) with
    | None => Control.throw (Tactic_failure (Some (fprintf "No equivalence found between %t and %t." cx cy)))
    | Some w => Control.refine (fun _ => w)
    end
  end.

Ltac2 lift_matcher matcher h c :=
  match matcher c with | None => [] | Some (cx, cy) => [cx, cy, let c := Control.hyp h in preterm:($c)] end.

Ltac2 solve0 n :=
  match! goal with
  | [|- ?r _ _ ] =>
    let matchr := match_rel (fun cx cy => open_constr:($r $cx $cy)) in
    solve_with (lift_matcher matchr) matchr n
  end.

Ltac2 solve1 extractor n :=
  match! goal with
  | [|- ?r _ _ ] =>
    let matchr := match_rel (fun cx cy => open_constr:($r $cx $cy)) in
    solve_with extractor matchr n
  end.

Module Notations.
  Ltac2 Notation "per" "with" extractor(tactic(6)) :=
    solve1 extractor 5.
  Ltac2 Notation "per" := solve0 5.
End Notations.

End PER.