| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| import networkx as nx |
| import torch |
| import math |
| import re |
| import json |
| from typing import Dict, List, Any |
|
|
| class EndpointHandler: |
| def __init__(self, path: str = ""): |
| |
| self.model_name = "Babelscape/rebel-large" |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name) |
| |
| |
| self.pattern1 = re.compile('<pad>|<s>|</s>') |
| self.pattern2 = re.compile('(<obj>|<subj>|<triplet>)') |
| |
| |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.model = self.model.to(self.device) |
|
|
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| Handler method for processing incoming requests. |
| """ |
| try: |
| |
| inputs = data.pop("inputs", data) |
| if not isinstance(inputs, list): |
| inputs = [inputs] |
| |
| |
| results = [] |
| for text in inputs: |
| graph = self.text_to_graph(text) |
| relations = self.graph_to_relations(graph) |
| results.append({"relations": relations}) |
| |
| return {"results": results} |
| |
| except Exception as e: |
| return {"error": str(e)} |
|
|
| def text_to_graph(self, text: str, span_length: int = 128) -> nx.DiGraph: |
| """ |
| Convert input text to a graph representation using the REBEL model. |
| """ |
| inputs = self.tokenizer([text], return_tensors="pt") |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
| num_tokens = len(inputs["input_ids"][0]) |
| num_spans = math.ceil(num_tokens / span_length) |
| overlap = math.ceil((num_spans * span_length - num_tokens) / |
| max(num_spans - 1, 1)) |
| |
| |
| spans_boundaries = [] |
| start = 0 |
| for i in range(num_spans): |
| spans_boundaries.append([start + span_length * i, |
| start + span_length * (i + 1)]) |
| start -= overlap |
|
|
| |
| tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]] |
| for boundary in spans_boundaries] |
| tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]] |
| for boundary in spans_boundaries] |
| |
| inputs = { |
| "input_ids": torch.stack(tensor_ids).to(self.device), |
| "attention_mask": torch.stack(tensor_masks).to(self.device) |
| } |
|
|
| |
| num_return_sequences = 3 |
| gen_kwargs = { |
| "max_length": 256, |
| "length_penalty": 0, |
| "num_beams": 3, |
| "num_return_sequences": num_return_sequences |
| } |
| |
| with torch.no_grad(): |
| generated_tokens = self.model.generate(**inputs, **gen_kwargs) |
|
|
| decoded_preds = self.tokenizer.batch_decode(generated_tokens, |
| skip_special_tokens=False) |
|
|
| |
| graph = nx.DiGraph() |
| for i, sentence_pred in enumerate(decoded_preds): |
| current_span_index = i // num_return_sequences |
| relations = self.extract_relations_from_model_output(sentence_pred) |
| for relation in relations: |
| relation["meta"] = {"spans": [spans_boundaries[current_span_index]]} |
| self.add_relation_to_graph(graph, relation) |
|
|
| return graph |
|
|
| def extract_relations_from_model_output(self, text: str) -> List[Dict[str, str]]: |
| """ |
| Extract relations from the model's output text. |
| """ |
| relations = [] |
| subject, relation, object_ = '', '', '' |
| text = text.strip() |
| current = None |
| |
| text_replaced = self.pattern1.sub('', text) |
| text_replaced = self.pattern2.sub(' \g<1> ', text_replaced) |
|
|
| for token in text_replaced.split(): |
| if token == "<triplet>": |
| current = 'subj' |
| if subject and relation and object_: |
| relations.append({ |
| 'head': subject.strip(), |
| 'type': relation.strip(), |
| 'tail': object_.strip() |
| }) |
| subject, relation, object_ = '', '', '' |
| elif token == "<subj>": |
| current = 'obj' |
| if subject and relation and object_: |
| relations.append({ |
| 'head': subject.strip(), |
| 'type': relation.strip(), |
| 'tail': object_.strip() |
| }) |
| relation, object_ = '', '' |
| elif token == "<obj>": |
| current = 'rel' |
| else: |
| if current == 'subj': |
| subject += ' ' + token |
| elif current == 'rel': |
| relation += ' ' + token |
| elif current == 'obj': |
| object_ += ' ' + token |
|
|
| if subject and relation and object_: |
| relations.append({ |
| 'head': subject.strip(), |
| 'type': relation.strip(), |
| 'tail': object_.strip() |
| }) |
|
|
| return relations |
|
|
| def add_relation_to_graph(self, graph: nx.DiGraph, relation: Dict[str, Any]) -> None: |
| """ |
| Add a relation to the graph. |
| """ |
| head, tail = relation['head'], relation['tail'] |
| relation_type = relation['type'] |
| span = relation.get('meta', {}).get('spans', []) |
|
|
| if graph.has_edge(head, tail) and relation_type in graph[head][tail]: |
| existing_spans = graph[head][tail][relation_type]['spans'] |
| new_spans = [s for s in span if s not in existing_spans] |
| graph[head][tail][relation_type]['spans'].extend(new_spans) |
| else: |
| graph.add_edge(head, tail, relation=relation_type, spans=span) |
|
|
| def graph_to_relations(self, graph: nx.DiGraph) -> List[Dict[str, str]]: |
| """ |
| Convert a NetworkX graph to a list of relations. |
| """ |
| relations = [] |
| for u, v, data in graph.edges(data=True): |
| relations.append({ |
| "head": u, |
| "type": data["relation"], |
| "tail": v |
| }) |
| return relations |
|
|