| | import torch |
| | from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
| | from textwrap import dedent |
| | from huggingface_hub import login |
| | import os |
| | from dotenv import load_dotenv |
| |
|
| | load_dotenv() |
| | login( |
| | token=os.environ["HF_TOKEN"], |
| | ) |
| |
|
| | MODEL_LIST = [ |
| | "EmergentMethods/Phi-3-mini-4k-instruct-graph", |
| | "EmergentMethods/Phi-3-mini-128k-instruct-graph", |
| | |
| | ] |
| |
|
| | torch.random.manual_seed(0) |
| |
|
| | class Phi3InstructGraph: |
| | def __init__(self, model = "EmergentMethods/Phi-3-mini-4k-instruct-graph"): |
| | if model not in MODEL_LIST: |
| | raise ValueError(f"model must be one of {MODEL_LIST}") |
| | |
| | self.model_path = model |
| | self.model = AutoModelForCausalLM.from_pretrained( |
| | self.model_path, |
| | device_map="cuda", |
| | torch_dtype="auto", |
| | trust_remote_code=True, |
| | ) |
| | self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) |
| | self.pipe = pipeline( |
| | "text-generation", |
| | model=self.model, |
| | tokenizer=self.tokenizer, |
| | ) |
| |
|
| | def _generate(self, messages): |
| | generation_args = { |
| | "max_new_tokens": 2000, |
| | "return_full_text": False, |
| | "temperature": 0.1, |
| | "do_sample": False, |
| | } |
| |
|
| | return self.pipe(messages, **generation_args) |
| |
|
| | def _get_messages(self, text): |
| | context = dedent("""\n |
| | A chat between a curious user and an artificial intelligence Assistant. The Assistant is an expert at identifying entities and relationships in text. The Assistant responds in JSON output only. |
| | |
| | The User provides text in the format: |
| | |
| | -------Text begin------- |
| | <User provided text> |
| | -------Text end------- |
| | |
| | The Assistant follows the following steps before replying to the User: |
| | |
| | 1. **identify the most important entities** The Assistant identifies the most important entities in the text. These entities are listed in the JSON output under the key "nodes", they follow the structure of a list of dictionaries where each dict is: |
| | |
| | "nodes":[{"id": <entity N>, "type": <type>, "detailed_type": <detailed type>}, ...] |
| | |
| | where "type": <type> is a broad categorization of the entity. "detailed type": <detailed_type> is a very descriptive categorization of the entity. |
| | |
| | 2. **determine relationships** The Assistant uses the text between -------Text begin------- and -------Text end------- to determine the relationships between the entities identified in the "nodes" list defined above. These relationships are called "edges" and they follow the structure of: |
| | |
| | "edges":[{"from": <entity 1>, "to": <entity 2>, "label": <relationship>}, ...] |
| | |
| | The <entity N> must correspond to the "id" of an entity in the "nodes" list. |
| | |
| | The Assistant never repeats the same node twice. The Assistant never repeats the same edge twice. |
| | The Assistant responds to the User in JSON only, according to the following JSON schema: |
| | |
| | {"type":"object","properties":{"nodes":{"type":"array","items":{"type":"object","properties":{"id":{"type":"string"},"type":{"type":"string"},"detailed_type":{"type":"string"}},"required":["id","type","detailed_type"],"additionalProperties":false}},"edges":{"type":"array","items":{"type":"object","properties":{"from":{"type":"string"},"to":{"type":"string"},"label":{"type":"string"}},"required":["from","to","label"],"additionalProperties":false}}},"required":["nodes","edges"],"additionalProperties":false} |
| | """) |
| | |
| | user_message = dedent(f"""\n |
| | -------Text begin------- |
| | {text} |
| | -------Text end------- |
| | """) |
| | |
| | if self.model_path == "EmergentMethods/Phi-3-medium-128k-instruct-graph": |
| | |
| | messages = [ |
| | { |
| | "role": "user", |
| | "content": f"{context}\n Input: {user_message}", |
| | } |
| | ] |
| | return messages |
| | else: |
| | messages = [ |
| | { |
| | "role": "system", |
| | "content": context |
| | }, |
| | { |
| | "role": "user", |
| | "content": user_message |
| | } |
| | ] |
| | return messages |
| | |
| | |
| | def extract(self, text): |
| | messages = self._get_messages(text) |
| | pipe_output = self._generate(messages) |
| | |
| | return pipe_output[0]["generated_text"] |
| |
|