Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import json | |
| import sys | |
| import os | |
| from collections import defaultdict | |
| from util.vector_base import EmbeddingFunction, get_or_create_vector_base | |
| from util.Embeddings import TextEmb3LargeEmbedding | |
| from langchain_core.documents import Document | |
| from FlagEmbedding import FlagReranker | |
| import time | |
| from bm25s import BM25, tokenize | |
| import contextlib | |
| import io | |
| from tqdm import tqdm | |
| def rrf(rankings, k = 60): | |
| res = 0 | |
| for r in rankings: | |
| res += 1 / (r + k) | |
| return res | |
| def retriever(requirement, PO, safeguard_vector_store, reranker_model, using_reranker=False, using_BM25=False, using_chroma=True, k=20, if_split_po=True): | |
| final_result = [] | |
| if not if_split_po: | |
| final_result = multiretriever(requirement, PO, safeguard_vector_store, reranker_model, using_reranker=using_reranker, using_BM25=using_BM25, using_chroma=using_chroma, k=k) | |
| else: | |
| for po in PO: | |
| po_result = multiretriever(requirement, [po], safeguard_vector_store, reranker_model, using_reranker=using_reranker, using_BM25=using_BM25, using_chroma=using_chroma, k=k) | |
| for safeguard in po_result: | |
| final_result.append(safeguard) | |
| return final_result | |
| def multiretriever(requirement, PO, safeguard_vector_store, reranker_model, using_reranker=True, using_BM25=False, using_chroma=True, k=20): | |
| """ | |
| requirements_dict: [ | |
| requirement: { | |
| "PO": [], | |
| "safeguard": [] | |
| } | |
| ] | |
| """ | |
| candidate_safeguards = [] | |
| po_list = [po.lower().rstrip() for po in PO if po] | |
| if "young users" in po_list and len(po_list) == 1: | |
| return [] | |
| candidate_safeguards = safeguard_vector_store.get(where={"po": {"$in": po_list}}) | |
| safeguard_dict, safeguard_content = {}, [] | |
| for id, content, metadata in zip(candidate_safeguards['ids'], candidate_safeguards['documents'], candidate_safeguards['metadatas']): | |
| safeguard_dict[content] = { | |
| "metadata": metadata, | |
| "rank": [], | |
| "rrf_score": 0 | |
| } | |
| safeguard_content.append(content) | |
| # Reranker | |
| if using_reranker: | |
| content_pairs, reranking_rank, reranking_results = [], [], [] | |
| for safeguard in safeguard_content: | |
| content_pairs.append([requirement, safeguard]) | |
| safeguard_rerank_scores = reranker_model.compute_score(content_pairs) | |
| for content_pair, score in zip(content_pairs, safeguard_rerank_scores): | |
| reranking_rank.append((content_pair[1], score)) | |
| reranking_results = sorted(reranking_rank, key=lambda x: x[1], reverse=True) | |
| for safeguard, score in reranking_results: | |
| safeguard_dict[safeguard]['rank'].append(reranking_results.index((safeguard, score)) + 1) | |
| # BM25 | |
| if using_BM25: | |
| with contextlib.redirect_stdout(io.StringIO()): | |
| bm25_retriever = BM25(corpus=safeguard_content) | |
| bm25_retriever.index(tokenize(safeguard_content)) | |
| bm25_results, scores = bm25_retriever.retrieve(tokenize(requirement), k = len(safeguard_content)) | |
| bm25_retrieval_rank = 1 | |
| for safeguard in bm25_results[0]: | |
| safeguard_dict[safeguard]['rank'].append(bm25_retrieval_rank) | |
| bm25_retrieval_rank += 1 | |
| # chroma retrieval | |
| if using_chroma: | |
| retrieved_safeguards = safeguard_vector_store.similarity_search_with_score(query=requirement, k=len(candidate_safeguards['ids']), filter={"po": {"$in": po_list}}) | |
| retrieval_rank = 1 | |
| for safeguard in retrieved_safeguards: | |
| safeguard_dict[safeguard[0].page_content]['rank'].append(retrieval_rank) | |
| retrieval_rank += 1 | |
| final_result = [] | |
| for safeguard in safeguard_content: | |
| safeguard_dict[safeguard]['rrf_score'] = rrf(safeguard_dict[safeguard]['rank']) | |
| final_result.append((safeguard_dict[safeguard]['rrf_score'], safeguard_dict[safeguard]['metadata']['safeguard_number'], safeguard, safeguard_dict[safeguard]['metadata']['po'])) | |
| final_result.sort(key=lambda x: x[0], reverse=True) | |
| # top k | |
| topk_final_result = final_result[:k] | |
| return topk_final_result | |
| if __name__=="__main__": | |
| embeddingmodel = TextEmb3LargeEmbedding(max_qpm=58) | |
| embedding = EmbeddingFunction(embeddingmodel) | |
| safeguard_vector_store = get_or_create_vector_base('safeguard_database', embedding) | |
| reranker_model = FlagReranker( | |
| '/root/PTR-LLM/tasks/pcf/model/bge-reranker-v2-m3', | |
| use_fp16=True, | |
| devices=["cpu"], | |
| ) | |
| requirement = """ | |
| Data Minimization Consent for incompatible purposes: Require consent for additional use of personal information not reasonably necessary to or incompatible with original purpose disclosure. | |
| """ | |
| PO = ["Data Minimization & Purpose Limitation", "Transparency"] | |
| final_result = retriever( | |
| requirement, | |
| PO, | |
| safeguard_vector_store, | |
| reranker_model, | |
| using_reranker=True, | |
| using_BM25=False, | |
| using_chroma=True, | |
| k=10 | |
| ) | |
| print(final_result) |