Source code for gpt_graph.utils.utils

# -*- coding: utf-8 -*-
"""
Created on Tue Mar 12 16:53:41 2024

@author: User
"""

import inspect
import collections
import dataclasses


[docs] def serialize_json_recursively(data, ignored_keys=None, included_keys=None): """ Recursively convert non-serializable elements to strings in lists and dictionaries, with options to ignore specific keys or include only certain keys. """ if ignored_keys is None: ignored_keys = [] # Helper function to determine if a key should be included def should_include(key): if included_keys is not None: return key in included_keys and key not in ignored_keys else: return key not in ignored_keys if isinstance(data, collections.abc.Mapping): return { serialize_json_recursively(key): serialize_json_recursively( value, ignored_keys ) # , included_keys) for key, value in data.items() if should_include(key) } elif isinstance(data, list): return [ serialize_json_recursively(item, ignored_keys, included_keys) for item in data ] elif isinstance(data, (str, int, float, bool)) or data is None: return data elif dataclasses.is_dataclass(data): data_dict = dataclasses.asdict(data) return { key: serialize_json_recursively(value, ignored_keys, included_keys) for key, value in data_dict.items() if should_include(key) } else: return str(data)
# %% # def truncate_text(text, max_words): # import re # # Regular expression to identify Chinese characters # chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]") # word_count = 0 # current_length = 0 # # Iterate through each character in the text # for i, char in enumerate(text): # if chinese_char_pattern.match(char): # # If the character is Chinese, increment the word count # if word_count == max_words: # break # word_count += 1 # elif char == " ": # # If the character is a space, increment the word count for the previous English word # if word_count == max_words: # break # word_count += 1 # elif i == len(text) - 1: # # If it's the last character in the text and not a space, increment for the last word # if word_count == max_words: # break # word_count += 1 # # Add the character to the current length to be included in the final string # current_length += 1 # # Return the truncated text # return text[:current_length]
[docs] def truncate_text(text, max_count, if_by_tokens=True, encoding_name="cl100k_base"): import re import tiktoken if if_by_tokens: encoding = tiktoken.get_encoding(encoding_name) encoded = encoding.encode(text) if len(encoded) <= max_count: return text truncated_encoded = encoded[:max_count] return encoding.decode(truncated_encoded) else: chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]") word_count = 0 current_length = 0 for i, char in enumerate(text): if chinese_char_pattern.match(char): if word_count == max_count: break word_count += 1 elif char == " ": if word_count == max_count: break word_count += 1 elif i == len(text) - 1: if word_count == max_count: break word_count += 1 current_length += 1 return text[:current_length]
# Example usage: if __name__ == "__main__": text = "This is a test." max_words = 7 truncated_text = truncate_text(text, max_words) print(truncated_text) # Output should be 'This is a test. 这是一个测' # %% import string
[docs] def sanitize_filename(filename): # Define valid characters valid_chars = "-_.() %s%s" % (string.ascii_letters, string.digits) # Create a new filename by filtering out invalid characters sanitized_filename = "".join(c for c in filename if c in valid_chars) return sanitized_filename
# %%
[docs] def get_func_params( func, input_fields=["nodes"], ignore_fields=["self", "kwargs", "args"] ): input_fields = set(input_fields or []) ignore_fields = set(ignore_fields or []) signature = inspect.signature(func) params_defaults = { param.name: ( "<INPUT>" if param.name in input_fields else param.default if param.default is not inspect.Parameter.empty else "<EMPTY>" ) for param in signature.parameters.values() if param.name not in ignore_fields } return params_defaults
if __name__ == "__main__": g = get_func_params(lambda x, y=3: x) # %% # @staticmethod
[docs] def num_tokens_from_string(string: str, encoding_name: str = "cl100k_base") -> int: import tiktoken encoding = tiktoken.get_encoding(encoding_name) num_tokens = len(encoding.encode(string)) return num_tokens
# %%
[docs] def split_text_by_token_count(text, max_token_count, chunk_size): split_texts = [] current_text = "" current_token_count = 0 words = text.split() for i in range(0, len(words), chunk_size): chunk = words[i : i + chunk_size] chunk_text = " ".join(chunk) chunk_token_count = num_tokens_from_string(chunk_text) if current_token_count + chunk_token_count > max_token_count: split_texts.append(current_text.strip()) current_text = chunk_text current_token_count = chunk_token_count else: current_text += " " + chunk_text current_token_count += chunk_token_count if current_text: split_texts.append(current_text.strip()) return split_texts
if __name__ == "__main__": s = split_text_by_token_count( "eaf wf.feafeaeafafeafafeafffaefaef tewta.--f eraefea wfa ,wr, tetea,wra ,re wtf ,et af ", 10, 2, ) print(s) # %% # def group_strings_by_token_count(nodes_or_str, max_token_count): # # Initialize variables for grouping # groups = [] # current_group = [] # current_token_count = 0 # for node in nodes_or_str: # if isinstance(node, str): # node_text = node # else: # node_text = node['content'] # token_count = num_tokens_from_string(node_text) # #print("c",token_count) # if current_token_count + token_count > max_token_count and current_group: # groups.append(current_group) # current_group = [node] # current_token_count = token_count # else: # current_group.append(node) # current_token_count += token_count # if current_group: # groups.append(current_group) # #print("groups:",groups) # return groups import math
[docs] def group_strings_by_token_count( nodes_or_str, max_token_count, min_compression_ratio=2 ): # min_compression_ratio means that if input is list of 5, at most the output will be list of 5/2 # Initialize variables for grouping groups = [] current_group = [] current_token_count = 0 if min_compression_ratio is None: target_group_count = len(nodes_or_str) else: target_group_count = max( 1, math.ceil(len(nodes_or_str) / min_compression_ratio) ) for node in nodes_or_str: if isinstance(node, str): node_text = node else: node_text = node["content"] token_count = num_tokens_from_string(node_text) # Decide if a new group should start if current_token_count + token_count > max_token_count and current_group: groups.append(current_group) current_group = [node] current_token_count = token_count else: current_group.append(node) current_token_count += token_count # Add the last group if it has any items if current_group: groups.append(current_group) # Adjust the groups to meet the min_compression_ratio if necessary while len(groups) > target_group_count: # Find the best two groups to merge based on token counts best_pair = None best_size = float("inf") for i in range(len(groups) - 1): combined_tokens = sum( num_tokens_from_string( node["content"] if not isinstance(node, str) else node ) for node in groups[i] ) + sum( num_tokens_from_string( node["content"] if not isinstance(node, str) else node ) for node in groups[i + 1] ) if combined_tokens < best_size: best_size = combined_tokens best_pair = i # Merge the best pair found if best_pair is not None: groups[best_pair] += groups.pop(best_pair + 1) else: break return groups
if __name__ == "__main__": s2 = group_strings_by_token_count(s, 2) print(s2)